This commit is contained in:
mofeng-git
2025-12-28 18:19:16 +08:00
commit d143d158e4
771 changed files with 220548 additions and 0 deletions

356
src/atx/controller.rs Normal file
View File

@@ -0,0 +1,356 @@
//! ATX Controller
//!
//! High-level controller for ATX power management with flexible hardware binding.
//! Each action (power short, power long, reset) can be configured independently.
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use super::executor::{timing, AtxKeyExecutor};
use super::led::LedSensor;
use super::types::{AtxKeyConfig, AtxLedConfig, AtxState, PowerStatus};
use crate::error::{AppError, Result};
/// ATX power control configuration
#[derive(Debug, Clone)]
pub struct AtxControllerConfig {
/// Whether ATX is enabled
pub enabled: bool,
/// Power button configuration (used for both short and long press)
pub power: AtxKeyConfig,
/// Reset button configuration
pub reset: AtxKeyConfig,
/// LED sensing configuration
pub led: AtxLedConfig,
}
impl Default for AtxControllerConfig {
fn default() -> Self {
Self {
enabled: false,
power: AtxKeyConfig::default(),
reset: AtxKeyConfig::default(),
led: AtxLedConfig::default(),
}
}
}
/// Internal state holding all ATX components
/// Grouped together to reduce lock acquisitions
struct AtxInner {
config: AtxControllerConfig,
power_executor: Option<AtxKeyExecutor>,
reset_executor: Option<AtxKeyExecutor>,
led_sensor: Option<LedSensor>,
}
/// ATX Controller
///
/// Manages ATX power control through independent executors for each action.
/// Supports hot-reload of configuration.
pub struct AtxController {
/// Single lock for all internal state to reduce lock contention
inner: RwLock<AtxInner>,
}
impl AtxController {
/// Create a new ATX controller with the specified configuration
pub fn new(config: AtxControllerConfig) -> Self {
Self {
inner: RwLock::new(AtxInner {
config,
power_executor: None,
reset_executor: None,
led_sensor: None,
}),
}
}
/// Create a disabled ATX controller
pub fn disabled() -> Self {
Self::new(AtxControllerConfig::default())
}
/// Initialize the ATX controller and its executors
pub async fn init(&self) -> Result<()> {
let mut inner = self.inner.write().await;
if !inner.config.enabled {
info!("ATX disabled in configuration");
return Ok(());
}
info!("Initializing ATX controller");
// Initialize power executor
if inner.config.power.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.power.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver, inner.config.power.device, inner.config.power.pin
);
inner.power_executor = Some(executor);
}
}
// Initialize reset executor
if inner.config.reset.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.reset.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver, inner.config.reset.device, inner.config.reset.pin
);
inner.reset_executor = Some(executor);
}
}
// Initialize LED sensor
if inner.config.led.is_configured() {
let mut sensor = LedSensor::new(inner.config.led.clone());
if let Err(e) = sensor.init().await {
warn!("Failed to initialize LED sensor: {}", e);
} else {
info!(
"LED sensor initialized on {} pin {}",
inner.config.led.gpio_chip, inner.config.led.gpio_pin
);
inner.led_sensor = Some(sensor);
}
}
info!("ATX controller initialized successfully");
Ok(())
}
/// Reload the ATX controller with new configuration
///
/// This is called when configuration changes and supports hot-reload.
pub async fn reload(&self, new_config: AtxControllerConfig) -> Result<()> {
info!("Reloading ATX controller with new configuration");
// Shutdown existing executors
self.shutdown_internal().await?;
// Update configuration and re-initialize
{
let mut inner = self.inner.write().await;
inner.config = new_config;
}
// Re-initialize
self.init().await?;
info!("ATX controller reloaded successfully");
Ok(())
}
/// Get current ATX state (single lock acquisition)
pub async fn state(&self) -> AtxState {
let inner = self.inner.read().await;
let power_status = if let Some(sensor) = inner.led_sensor.as_ref() {
sensor.read().await.unwrap_or(PowerStatus::Unknown)
} else {
PowerStatus::Unknown
};
AtxState {
available: inner.config.enabled,
power_configured: inner
.power_executor
.as_ref()
.map(|e| e.is_initialized())
.unwrap_or(false),
reset_configured: inner
.reset_executor
.as_ref()
.map(|e| e.is_initialized())
.unwrap_or(false),
power_status,
led_supported: inner
.led_sensor
.as_ref()
.map(|s| s.is_initialized())
.unwrap_or(false),
}
}
/// Get current state as SystemEvent
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
let state = self.state().await;
crate::events::SystemEvent::AtxStateChanged {
power_status: state.power_status,
}
}
/// Check if ATX is available
pub async fn is_available(&self) -> bool {
let inner = self.inner.read().await;
inner.config.enabled
}
/// Check if power button is configured and initialized
pub async fn is_power_ready(&self) -> bool {
let inner = self.inner.read().await;
inner
.power_executor
.as_ref()
.map(|e| e.is_initialized())
.unwrap_or(false)
}
/// Check if reset button is configured and initialized
pub async fn is_reset_ready(&self) -> bool {
let inner = self.inner.read().await;
inner
.reset_executor
.as_ref()
.map(|e| e.is_initialized())
.unwrap_or(false)
}
/// Short press power button (turn on or graceful shutdown)
pub async fn power_short(&self) -> Result<()> {
let inner = self.inner.read().await;
let executor = inner
.power_executor
.as_ref()
.ok_or_else(|| AppError::Internal("Power button not configured".to_string()))?;
info!(
"ATX: Short press power button ({}ms)",
timing::SHORT_PRESS.as_millis()
);
executor.pulse(timing::SHORT_PRESS).await
}
/// Long press power button (force power off)
pub async fn power_long(&self) -> Result<()> {
let inner = self.inner.read().await;
let executor = inner
.power_executor
.as_ref()
.ok_or_else(|| AppError::Internal("Power button not configured".to_string()))?;
info!(
"ATX: Long press power button ({}ms)",
timing::LONG_PRESS.as_millis()
);
executor.pulse(timing::LONG_PRESS).await
}
/// Press reset button
pub async fn reset(&self) -> Result<()> {
let inner = self.inner.read().await;
let executor = inner
.reset_executor
.as_ref()
.ok_or_else(|| AppError::Internal("Reset button not configured".to_string()))?;
info!(
"ATX: Press reset button ({}ms)",
timing::RESET_PRESS.as_millis()
);
executor.pulse(timing::RESET_PRESS).await
}
/// Get current power status from LED sensor
pub async fn power_status(&self) -> Result<PowerStatus> {
let inner = self.inner.read().await;
match inner.led_sensor.as_ref() {
Some(sensor) => sensor.read().await,
None => Ok(PowerStatus::Unknown),
}
}
/// Shutdown the ATX controller
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down ATX controller");
self.shutdown_internal().await?;
info!("ATX controller shutdown complete");
Ok(())
}
/// Internal shutdown helper
async fn shutdown_internal(&self) -> Result<()> {
let mut inner = self.inner.write().await;
// Shutdown power executor
if let Some(mut executor) = inner.power_executor.take() {
executor.shutdown().await.ok();
}
// Shutdown reset executor
if let Some(mut executor) = inner.reset_executor.take() {
executor.shutdown().await.ok();
}
// Shutdown LED sensor
if let Some(mut sensor) = inner.led_sensor.take() {
sensor.shutdown().await.ok();
}
Ok(())
}
}
impl Drop for AtxController {
fn drop(&mut self) {
debug!("ATX controller dropped");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_controller_config_default() {
let config = AtxControllerConfig::default();
assert!(!config.enabled);
assert!(!config.power.is_configured());
assert!(!config.reset.is_configured());
assert!(!config.led.is_configured());
}
#[test]
fn test_controller_creation() {
let controller = AtxController::disabled();
assert!(controller.inner.try_read().is_ok());
}
#[tokio::test]
async fn test_controller_disabled_state() {
let controller = AtxController::disabled();
let state = controller.state().await;
assert!(!state.available);
assert!(!state.power_configured);
assert!(!state.reset_configured);
}
#[tokio::test]
async fn test_controller_init_disabled() {
let controller = AtxController::disabled();
let result = controller.init().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_controller_is_available() {
let controller = AtxController::disabled();
assert!(!controller.is_available().await);
let config = AtxControllerConfig {
enabled: true,
..Default::default()
};
let controller = AtxController::new(config);
assert!(controller.is_available().await);
}
}

305
src/atx/executor.rs Normal file
View File

@@ -0,0 +1,305 @@
//! ATX Key Executor
//!
//! Lightweight executor for a single ATX key operation.
//! Each executor handles one button (power or reset) with its own hardware binding.
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::types::{ActiveLevel, AtxDriverType, AtxKeyConfig};
use crate::error::{AppError, Result};
/// Timing constants for ATX operations
pub mod timing {
use std::time::Duration;
/// Short press duration (power on/graceful shutdown)
pub const SHORT_PRESS: Duration = Duration::from_millis(500);
/// Long press duration (force power off)
pub const LONG_PRESS: Duration = Duration::from_millis(5000);
/// Reset press duration
pub const RESET_PRESS: Duration = Duration::from_millis(500);
}
/// Executor for a single ATX key operation
///
/// Each executor manages one hardware button (power or reset).
/// It handles both GPIO and USB relay backends.
pub struct AtxKeyExecutor {
config: AtxKeyConfig,
gpio_handle: Mutex<Option<LineHandle>>,
/// Cached USB relay file handle to avoid repeated open/close syscalls
usb_relay_handle: Mutex<Option<File>>,
initialized: AtomicBool,
}
impl AtxKeyExecutor {
/// Create a new executor with the given configuration
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
gpio_handle: Mutex::new(None),
usb_relay_handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
/// Check if this executor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
/// Check if this executor is initialized
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
/// Initialize the executor
pub async fn init(&mut self) -> Result<()> {
if !self.config.is_configured() {
debug!("ATX key executor not configured, skipping init");
return Ok(());
}
match self.config.driver {
AtxDriverType::Gpio => self.init_gpio().await?,
AtxDriverType::UsbRelay => self.init_usb_relay().await?,
AtxDriverType::None => {}
}
self.initialized.store(true, Ordering::Relaxed);
Ok(())
}
/// Initialize GPIO backend
async fn init_gpio(&mut self) -> Result<()> {
info!(
"Initializing GPIO ATX executor on {} pin {}",
self.config.device, self.config.pin
);
let mut chip = Chip::new(&self.config.device)
.map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?;
let line = chip.get_line(self.config.pin).map_err(|e| {
AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e))
})?;
// Initial value depends on active level (start in inactive state)
let initial_value = match self.config.active_level {
ActiveLevel::High => 0, // Inactive = low
ActiveLevel::Low => 1, // Inactive = high
};
let handle = line
.request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx")
.map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?;
*self.gpio_handle.lock().unwrap() = Some(handle);
debug!("GPIO pin {} configured successfully", self.config.pin);
Ok(())
}
/// Initialize USB relay backend
async fn init_usb_relay(&self) -> Result<()> {
info!(
"Initializing USB relay ATX executor on {} channel {}",
self.config.device, self.config.pin
);
// Open and cache the device handle
let device = OpenOptions::new()
.read(true)
.write(true)
.open(&self.config.device)
.map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?;
*self.usb_relay_handle.lock().unwrap() = Some(device);
// Ensure relay is off initially
self.send_usb_relay_command(false)?;
debug!(
"USB relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
/// Pulse the button for the specified duration
pub async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_configured() {
return Err(AppError::Internal("ATX key not configured".to_string()));
}
if !self.is_initialized() {
return Err(AppError::Internal("ATX key not initialized".to_string()));
}
match self.config.driver {
AtxDriverType::Gpio => self.pulse_gpio(duration).await,
AtxDriverType::UsbRelay => self.pulse_usb_relay(duration).await,
AtxDriverType::None => Ok(()),
}
}
/// Pulse GPIO pin
async fn pulse_gpio(&self, duration: Duration) -> Result<()> {
let (active, inactive) = match self.config.active_level {
ActiveLevel::High => (1u8, 0u8),
ActiveLevel::Low => (0u8, 1u8),
};
// Set to active state
{
let guard = self.gpio_handle.lock().unwrap();
let handle = guard
.as_ref()
.ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?;
handle
.set_value(active)
.map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?;
}
// Wait for duration (no lock held)
sleep(duration).await;
// Set to inactive state
{
let guard = self.gpio_handle.lock().unwrap();
if let Some(handle) = guard.as_ref() {
handle.set_value(inactive).ok();
}
}
Ok(())
}
/// Pulse USB relay
async fn pulse_usb_relay(&self, duration: Duration) -> Result<()> {
// Turn relay on
self.send_usb_relay_command(true)?;
// Wait for duration
sleep(duration).await;
// Turn relay off
self.send_usb_relay_command(false)?;
Ok(())
}
/// Send USB relay command using cached handle
fn send_usb_relay_command(&self, on: bool) -> Result<()> {
let channel = self.config.pin as u8;
// Standard HID relay command format
let cmd = if on {
[0x00, channel + 1, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00]
} else {
[0x00, channel + 1, 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00]
};
let mut guard = self.usb_relay_handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
device
.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("USB relay write failed: {}", e)))?;
Ok(())
}
/// Shutdown the executor
pub async fn shutdown(&mut self) -> Result<()> {
if !self.is_initialized() {
return Ok(());
}
match self.config.driver {
AtxDriverType::Gpio => {
// Release GPIO handle
*self.gpio_handle.lock().unwrap() = None;
}
AtxDriverType::UsbRelay => {
// Ensure relay is off before closing handle
let _ = self.send_usb_relay_command(false);
// Release USB relay handle
*self.usb_relay_handle.lock().unwrap() = None;
}
AtxDriverType::None => {}
}
self.initialized.store(false, Ordering::Relaxed);
debug!("ATX key executor shutdown complete");
Ok(())
}
}
impl Drop for AtxKeyExecutor {
fn drop(&mut self) {
// Ensure GPIO lines are released
*self.gpio_handle.lock().unwrap() = None;
// Ensure USB relay is off and handle released
if self.config.driver == AtxDriverType::UsbRelay && self.is_initialized() {
let _ = self.send_usb_relay_command(false);
}
*self.usb_relay_handle.lock().unwrap() = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_executor_creation() {
let config = AtxKeyConfig::default();
let executor = AtxKeyExecutor::new(config);
assert!(!executor.is_configured());
assert!(!executor.is_initialized());
}
#[test]
fn test_executor_with_gpio_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::Gpio,
device: "/dev/gpiochip0".to_string(),
pin: 5,
active_level: ActiveLevel::High,
};
let executor = AtxKeyExecutor::new(config);
assert!(executor.is_configured());
assert!(!executor.is_initialized());
}
#[test]
fn test_executor_with_usb_relay_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
active_level: ActiveLevel::High, // Ignored for USB relay
};
let executor = AtxKeyExecutor::new(config);
assert!(executor.is_configured());
}
#[test]
fn test_timing_constants() {
assert_eq!(timing::SHORT_PRESS.as_millis(), 500);
assert_eq!(timing::LONG_PRESS.as_millis(), 5000);
assert_eq!(timing::RESET_PRESS.as_millis(), 500);
}
}

154
src/atx/led.rs Normal file
View File

@@ -0,0 +1,154 @@
//! ATX LED Sensor
//!
//! Reads power LED status from GPIO to determine if the target system is powered on.
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use tracing::{debug, info};
use super::types::{AtxLedConfig, PowerStatus};
use crate::error::{AppError, Result};
/// LED sensor for reading power status
///
/// Uses GPIO to read the power LED state and determine if the system is on or off.
pub struct LedSensor {
config: AtxLedConfig,
handle: Mutex<Option<LineHandle>>,
initialized: AtomicBool,
}
impl LedSensor {
/// Create a new LED sensor with the given configuration
pub fn new(config: AtxLedConfig) -> Self {
Self {
config,
handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
/// Check if the sensor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
/// Check if the sensor is initialized
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
/// Initialize the LED sensor
pub async fn init(&mut self) -> Result<()> {
if !self.config.is_configured() {
debug!("LED sensor not configured, skipping init");
return Ok(());
}
info!(
"Initializing LED sensor on {} pin {}",
self.config.gpio_chip, self.config.gpio_pin
);
let mut chip = Chip::new(&self.config.gpio_chip)
.map_err(|e| AppError::Internal(format!("LED GPIO chip failed: {}", e)))?;
let line = chip.get_line(self.config.gpio_pin).map_err(|e| {
AppError::Internal(format!("LED GPIO line {} failed: {}", self.config.gpio_pin, e))
})?;
let handle = line
.request(LineRequestFlags::INPUT, 0, "one-kvm-led")
.map_err(|e| AppError::Internal(format!("LED GPIO request failed: {}", e)))?;
*self.handle.lock().unwrap() = Some(handle);
self.initialized.store(true, Ordering::Relaxed);
debug!("LED sensor initialized successfully");
Ok(())
}
/// Read the current power status
pub async fn read(&self) -> Result<PowerStatus> {
if !self.is_configured() || !self.is_initialized() {
return Ok(PowerStatus::Unknown);
}
let guard = self.handle.lock().unwrap();
match guard.as_ref() {
Some(handle) => {
let value = handle
.get_value()
.map_err(|e| AppError::Internal(format!("LED read failed: {}", e)))?;
// Apply inversion if configured
let is_on = if self.config.inverted {
value == 0 // Active low: 0 means on
} else {
value == 1 // Active high: 1 means on
};
Ok(if is_on {
PowerStatus::On
} else {
PowerStatus::Off
})
}
None => Ok(PowerStatus::Unknown),
}
}
/// Shutdown the LED sensor
pub async fn shutdown(&mut self) -> Result<()> {
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
debug!("LED sensor shutdown complete");
Ok(())
}
}
impl Drop for LedSensor {
fn drop(&mut self) {
*self.handle.lock().unwrap() = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_led_sensor_creation() {
let config = AtxLedConfig::default();
let sensor = LedSensor::new(config);
assert!(!sensor.is_configured());
assert!(!sensor.is_initialized());
}
#[test]
fn test_led_sensor_with_config() {
let config = AtxLedConfig {
enabled: true,
gpio_chip: "/dev/gpiochip0".to_string(),
gpio_pin: 7,
inverted: false,
};
let sensor = LedSensor::new(config);
assert!(sensor.is_configured());
assert!(!sensor.is_initialized());
}
#[test]
fn test_led_sensor_inverted_config() {
let config = AtxLedConfig {
enabled: true,
gpio_chip: "/dev/gpiochip0".to_string(),
gpio_pin: 7,
inverted: true,
};
let sensor = LedSensor::new(config);
assert!(sensor.is_configured());
assert!(sensor.config.inverted);
}
}

107
src/atx/mod.rs Normal file
View File

@@ -0,0 +1,107 @@
//! ATX Power Control Module
//!
//! Provides ATX power management functionality for IP-KVM.
//! Supports flexible hardware binding with independent configuration for each action.
//!
//! # Features
//!
//! - Power button control (short press for on/graceful shutdown, long press for force off)
//! - Reset button control
//! - Power status monitoring via LED sensing (GPIO only)
//! - Independent hardware binding for each action (GPIO or USB relay)
//! - Hot-reload configuration support
//!
//! # Hardware Support
//!
//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control
//! - **USB Relay**: Uses HID USB relay modules for isolated switching
//!
//! # Example
//!
//! ```ignore
//! use one_kvm::atx::{AtxController, AtxControllerConfig, AtxKeyConfig, AtxDriverType, ActiveLevel};
//!
//! let config = AtxControllerConfig {
//! enabled: true,
//! power: AtxKeyConfig {
//! driver: AtxDriverType::Gpio,
//! device: "/dev/gpiochip0".to_string(),
//! pin: 5,
//! active_level: ActiveLevel::High,
//! },
//! reset: AtxKeyConfig {
//! driver: AtxDriverType::UsbRelay,
//! device: "/dev/hidraw0".to_string(),
//! pin: 0,
//! active_level: ActiveLevel::High,
//! },
//! led: Default::default(),
//! };
//!
//! let controller = AtxController::new(config);
//! controller.init().await?;
//! controller.power_short().await?; // Turn on or graceful shutdown
//! ```
mod controller;
mod executor;
mod led;
mod types;
mod wol;
pub use controller::{AtxController, AtxControllerConfig};
pub use executor::timing;
pub use types::{
ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig,
AtxPowerRequest, AtxState, PowerStatus,
};
pub use wol::send_wol;
/// Discover available ATX devices on the system
///
/// Scans for GPIO chips and USB HID relay devices in a single pass.
pub fn discover_devices() -> AtxDevices {
let mut devices = AtxDevices::default();
// Single pass through /dev directory
if let Ok(entries) = std::fs::read_dir("/dev") {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("gpiochip") {
devices.gpio_chips.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("hidraw") {
devices.usb_relays.push(format!("/dev/{}", name_str));
}
}
}
devices.gpio_chips.sort();
devices.usb_relays.sort();
devices
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discover_devices() {
let devices = discover_devices();
// Just verify the function runs without error
assert!(devices.gpio_chips.len() >= 0);
assert!(devices.usb_relays.len() >= 0);
}
#[test]
fn test_module_exports() {
// Verify all public exports are accessible
let _: AtxDriverType = AtxDriverType::None;
let _: ActiveLevel = ActiveLevel::High;
let _: AtxKeyConfig = AtxKeyConfig::default();
let _: AtxLedConfig = AtxLedConfig::default();
let _: AtxState = AtxState::default();
let _: AtxDevices = AtxDevices::default();
}
}

270
src/atx/types.rs Normal file
View File

@@ -0,0 +1,270 @@
//! ATX data types and structures
//!
//! Defines the configuration and state types for the flexible ATX power control system.
//! Each ATX action (power, reset) can be independently configured with different hardware.
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Power status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PowerStatus {
/// Power is on
On,
/// Power is off
Off,
/// Power status unknown (no LED connected)
Unknown,
}
impl Default for PowerStatus {
fn default() -> Self {
Self::Unknown
}
}
/// Driver type for ATX key operations
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AtxDriverType {
/// GPIO control via Linux character device
Gpio,
/// USB HID relay module
UsbRelay,
/// Disabled / Not configured
None,
}
impl Default for AtxDriverType {
fn default() -> Self {
Self::None
}
}
/// Active level for GPIO pins
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ActiveLevel {
/// Active high (default for most cases)
High,
/// Active low (inverted)
Low,
}
impl Default for ActiveLevel {
fn default() -> Self {
Self::High
}
}
/// Configuration for a single ATX key (power or reset)
/// This is the "four-tuple" configuration: (driver, device, pin/channel, level)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct AtxKeyConfig {
/// Driver type (GPIO or USB Relay)
pub driver: AtxDriverType,
/// Device path:
/// - For GPIO: /dev/gpiochipX
/// - For USB Relay: /dev/hidrawX
pub device: String,
/// Pin or channel number:
/// - For GPIO: GPIO pin number
/// - For USB Relay: relay channel (0-based)
pub pin: u32,
/// Active level (only applicable to GPIO, ignored for USB Relay)
pub active_level: ActiveLevel,
}
impl Default for AtxKeyConfig {
fn default() -> Self {
Self {
driver: AtxDriverType::None,
device: String::new(),
pin: 0,
active_level: ActiveLevel::High,
}
}
}
impl AtxKeyConfig {
/// Check if this key is configured
pub fn is_configured(&self) -> bool {
self.driver != AtxDriverType::None && !self.device.is_empty()
}
}
/// LED sensing configuration (optional)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct AtxLedConfig {
/// Whether LED sensing is enabled
pub enabled: bool,
/// GPIO chip for LED sensing
pub gpio_chip: String,
/// GPIO pin for LED input
pub gpio_pin: u32,
/// Whether LED is active low (inverted logic)
pub inverted: bool,
}
impl Default for AtxLedConfig {
fn default() -> Self {
Self {
enabled: false,
gpio_chip: String::new(),
gpio_pin: 0,
inverted: false,
}
}
}
impl AtxLedConfig {
/// Check if LED sensing is configured
pub fn is_configured(&self) -> bool {
self.enabled && !self.gpio_chip.is_empty()
}
}
/// ATX state information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxState {
/// Whether ATX feature is available/enabled
pub available: bool,
/// Whether power button is configured
pub power_configured: bool,
/// Whether reset button is configured
pub reset_configured: bool,
/// Current power status
pub power_status: PowerStatus,
/// Whether power LED sensing is supported
pub led_supported: bool,
}
impl Default for AtxState {
fn default() -> Self {
Self {
available: false,
power_configured: false,
reset_configured: false,
power_status: PowerStatus::Unknown,
led_supported: false,
}
}
}
/// ATX power action request
#[derive(Debug, Clone, Deserialize)]
pub struct AtxPowerRequest {
/// Action to perform: "short", "long", "reset"
pub action: AtxAction,
}
/// ATX power action
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AtxAction {
/// Short press power button (turn on or graceful shutdown)
Short,
/// Long press power button (force power off)
Long,
/// Press reset button
Reset,
}
/// Available ATX devices for discovery
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDevices {
/// Available GPIO chips (/dev/gpiochip*)
pub gpio_chips: Vec<String>,
/// Available USB HID relay devices (/dev/hidraw*)
pub usb_relays: Vec<String>,
}
impl Default for AtxDevices {
fn default() -> Self {
Self {
gpio_chips: Vec::new(),
usb_relays: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_power_status_default() {
assert_eq!(PowerStatus::default(), PowerStatus::Unknown);
}
#[test]
fn test_atx_driver_type_default() {
assert_eq!(AtxDriverType::default(), AtxDriverType::None);
}
#[test]
fn test_active_level_default() {
assert_eq!(ActiveLevel::default(), ActiveLevel::High);
}
#[test]
fn test_atx_key_config_default() {
let config = AtxKeyConfig::default();
assert_eq!(config.driver, AtxDriverType::None);
assert!(config.device.is_empty());
assert_eq!(config.pin, 0);
assert!(!config.is_configured());
}
#[test]
fn test_atx_key_config_is_configured() {
let mut config = AtxKeyConfig::default();
assert!(!config.is_configured());
config.driver = AtxDriverType::Gpio;
assert!(!config.is_configured()); // device still empty
config.device = "/dev/gpiochip0".to_string();
assert!(config.is_configured());
config.driver = AtxDriverType::None;
assert!(!config.is_configured()); // driver is None
}
#[test]
fn test_atx_led_config_default() {
let config = AtxLedConfig::default();
assert!(!config.enabled);
assert!(config.gpio_chip.is_empty());
assert!(!config.is_configured());
}
#[test]
fn test_atx_led_config_is_configured() {
let mut config = AtxLedConfig::default();
assert!(!config.is_configured());
config.enabled = true;
assert!(!config.is_configured()); // gpio_chip still empty
config.gpio_chip = "/dev/gpiochip0".to_string();
assert!(config.is_configured());
}
#[test]
fn test_atx_state_default() {
let state = AtxState::default();
assert!(!state.available);
assert!(!state.power_configured);
assert!(!state.reset_configured);
assert_eq!(state.power_status, PowerStatus::Unknown);
}
}

171
src/atx/wol.rs Normal file
View File

@@ -0,0 +1,171 @@
//! Wake-on-LAN (WOL) implementation
//!
//! Sends magic packets to wake up remote machines.
use std::net::{SocketAddr, UdpSocket};
use tracing::{debug, info};
use crate::error::{AppError, Result};
/// WOL magic packet structure:
/// - 6 bytes of 0xFF
/// - 16 repetitions of the target MAC address (6 bytes each)
/// Total: 6 + 16 * 6 = 102 bytes
const MAGIC_PACKET_SIZE: usize = 102;
/// Parse MAC address string into bytes
/// Supports formats: "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF"
fn parse_mac_address(mac: &str) -> Result<[u8; 6]> {
let mac = mac.trim().to_uppercase();
let parts: Vec<&str> = if mac.contains(':') {
mac.split(':').collect()
} else if mac.contains('-') {
mac.split('-').collect()
} else {
return Err(AppError::Config(format!("Invalid MAC address format: {}", mac)));
};
if parts.len() != 6 {
return Err(AppError::Config(format!(
"Invalid MAC address: expected 6 parts, got {}",
parts.len()
)));
}
let mut bytes = [0u8; 6];
for (i, part) in parts.iter().enumerate() {
bytes[i] = u8::from_str_radix(part, 16).map_err(|_| {
AppError::Config(format!("Invalid MAC address byte: {}", part))
})?;
}
Ok(bytes)
}
/// Build WOL magic packet
fn build_magic_packet(mac: &[u8; 6]) -> [u8; MAGIC_PACKET_SIZE] {
let mut packet = [0u8; MAGIC_PACKET_SIZE];
// First 6 bytes are 0xFF
for byte in packet.iter_mut().take(6) {
*byte = 0xFF;
}
// Next 96 bytes are 16 repetitions of the MAC address
for i in 0..16 {
let offset = 6 + i * 6;
packet[offset..offset + 6].copy_from_slice(mac);
}
packet
}
/// Send WOL magic packet
///
/// # Arguments
/// * `mac_address` - Target MAC address (e.g., "AA:BB:CC:DD:EE:FF")
/// * `interface` - Optional network interface name (e.g., "eth0"). If None, uses default routing.
pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
let mac = parse_mac_address(mac_address)?;
let packet = build_magic_packet(&mac);
info!("Sending WOL packet to {} via {:?}", mac_address, interface);
// Create UDP socket
let socket = UdpSocket::bind("0.0.0.0:0")
.map_err(|e| AppError::Internal(format!("Failed to create UDP socket: {}", e)))?;
// Enable broadcast
socket
.set_broadcast(true)
.map_err(|e| AppError::Internal(format!("Failed to enable broadcast: {}", e)))?;
// Bind to specific interface if specified
#[cfg(target_os = "linux")]
if let Some(iface) = interface {
if !iface.is_empty() {
use std::os::unix::io::AsRawFd;
let fd = socket.as_raw_fd();
let iface_bytes = iface.as_bytes();
// SO_BINDTODEVICE requires interface name as null-terminated string
let mut iface_buf = [0u8; 16]; // IFNAMSIZ is typically 16
let len = iface_bytes.len().min(15);
iface_buf[..len].copy_from_slice(&iface_bytes[..len]);
let ret = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_BINDTODEVICE,
iface_buf.as_ptr() as *const libc::c_void,
(len + 1) as libc::socklen_t,
)
};
if ret < 0 {
let err = std::io::Error::last_os_error();
return Err(AppError::Internal(format!(
"Failed to bind to interface {}: {}",
iface, err
)));
}
debug!("Bound to interface: {}", iface);
}
}
// Send to broadcast address on port 9 (discard protocol, commonly used for WOL)
let broadcast_addr: SocketAddr = "255.255.255.255:9".parse().unwrap();
socket
.send_to(&packet, broadcast_addr)
.map_err(|e| AppError::Internal(format!("Failed to send WOL packet: {}", e)))?;
// Also try sending to port 7 (echo protocol, alternative WOL port)
let broadcast_addr_7: SocketAddr = "255.255.255.255:7".parse().unwrap();
let _ = socket.send_to(&packet, broadcast_addr_7);
info!("WOL packet sent successfully to {}", mac_address);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_mac_address_colon() {
let mac = parse_mac_address("AA:BB:CC:DD:EE:FF").unwrap();
assert_eq!(mac, [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
}
#[test]
fn test_parse_mac_address_dash() {
let mac = parse_mac_address("aa-bb-cc-dd-ee-ff").unwrap();
assert_eq!(mac, [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
}
#[test]
fn test_parse_mac_address_invalid() {
assert!(parse_mac_address("invalid").is_err());
assert!(parse_mac_address("AA:BB:CC:DD:EE").is_err());
assert!(parse_mac_address("AA:BB:CC:DD:EE:GG").is_err());
}
#[test]
fn test_build_magic_packet() {
let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF];
let packet = build_magic_packet(&mac);
// Check header (6 bytes of 0xFF)
for i in 0..6 {
assert_eq!(packet[i], 0xFF);
}
// Check MAC repetitions
for i in 0..16 {
let offset = 6 + i * 6;
assert_eq!(&packet[offset..offset + 6], &mac);
}
}
}

390
src/audio/capture.rs Normal file
View File

@@ -0,0 +1,390 @@
//! ALSA audio capture implementation
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, error, info, warn};
use super::device::AudioDeviceInfo;
use crate::error::{AppError, Result};
/// Audio capture configuration
#[derive(Debug, Clone)]
pub struct AudioConfig {
/// ALSA device name (e.g., "hw:0,0" or "default")
pub device_name: String,
/// Sample rate in Hz
pub sample_rate: u32,
/// Number of channels (1 = mono, 2 = stereo)
pub channels: u32,
/// Samples per frame (for Opus, typically 480 for 10ms at 48kHz)
pub frame_size: u32,
/// Buffer size in frames
pub buffer_frames: u32,
/// Period size in frames
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: "default".to_string(),
sample_rate: 48000,
channels: 2,
frame_size: 960, // 20ms at 48kHz (good for Opus)
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
/// Create config for a specific device
pub fn for_device(device: &AudioDeviceInfo) -> Self {
let sample_rate = if device.sample_rates.contains(&48000) {
48000
} else {
*device.sample_rates.first().unwrap_or(&48000)
};
let channels = if device.channels.contains(&2) {
2
} else {
*device.channels.first().unwrap_or(&2)
};
Self {
device_name: device.name.clone(),
sample_rate,
channels,
frame_size: sample_rate / 50, // 20ms
..Default::default()
}
}
/// Bytes per sample (16-bit signed)
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
/// Bytes per frame
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
/// Audio frame data
#[derive(Debug, Clone)]
pub struct AudioFrame {
/// Raw PCM data (S16LE interleaved)
pub data: Bytes,
/// Sample rate
pub sample_rate: u32,
/// Number of channels
pub channels: u32,
/// Number of samples per channel
pub samples: u32,
/// Frame sequence number
pub sequence: u64,
/// Capture timestamp
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new(data: Bytes, config: &AudioConfig, sequence: u64) -> Self {
Self {
samples: data.len() as u32 / config.bytes_per_sample(),
data,
sample_rate: config.sample_rate,
channels: config.channels,
sequence,
timestamp: Instant::now(),
}
}
}
/// Audio capture state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
/// Audio capture statistics
#[derive(Debug, Clone, Default)]
pub struct AudioStats {
pub frames_captured: u64,
pub frames_dropped: u64,
pub buffer_overruns: u64,
pub current_latency_ms: f32,
}
/// ALSA audio capturer
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
stats: Arc<Mutex<AudioStats>>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl AudioCapturer {
/// Create a new audio capturer
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(32);
Self {
config,
state: Arc::new(state_tx),
state_rx,
stats: Arc::new(Mutex::new(AudioStats::default())),
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
}
}
/// Get current state
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
/// Subscribe to audio frames
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
/// Get statistics
pub async fn stats(&self) -> AudioStats {
self.stats.lock().await.clone()
}
/// Start capturing
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
info!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let stats = self.stats.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let handle = tokio::task::spawn_blocking(move || {
capture_loop(config, state, stats, frame_tx, stop_flag, sequence);
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
/// Stop capturing
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
/// Check if running
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
/// Main capture loop
fn capture_loop(
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
stats: Arc<Mutex<AudioStats>>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
) {
let result = run_capture(&config, &state, &stats, &frame_tx, &stop_flag, &sequence);
if let Err(e) = result {
error!("Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
stats: &Arc<Mutex<AudioStats>>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
) -> Result<()> {
// Open ALSA device
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
config.device_name, e
))
})?;
// Configure hardware parameters
{
let hwp = HwParams::any(&pcm).map_err(|e| {
AppError::AudioError(format!("Failed to get HwParams: {}", e))
})?;
hwp.set_channels(config.channels).map_err(|e| {
AppError::AudioError(format!("Failed to set channels: {}", e))
})?;
hwp.set_rate(config.sample_rate, ValueOr::Nearest).map_err(|e| {
AppError::AudioError(format!("Failed to set sample rate: {}", e))
})?;
hwp.set_format(Format::s16()).map_err(|e| {
AppError::AudioError(format!("Failed to set format: {}", e))
})?;
hwp.set_access(Access::RWInterleaved).map_err(|e| {
AppError::AudioError(format!("Failed to set access: {}", e))
})?;
hwp.set_buffer_size_near(config.buffer_frames as Frames).map_err(|e| {
AppError::AudioError(format!("Failed to set buffer size: {}", e))
})?;
hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?;
pcm.hw_params(&hwp).map_err(|e| {
AppError::AudioError(format!("Failed to apply hw params: {}", e))
})?;
}
// Get actual configuration
let actual_rate = pcm.hw_params_current()
.map(|h| h.get_rate().unwrap_or(config.sample_rate))
.unwrap_or(config.sample_rate);
info!(
"Audio capture configured: {}Hz {}ch (requested {}Hz)",
actual_rate, config.channels, config.sample_rate
);
// Prepare for capture
pcm.prepare().map_err(|e| {
AppError::AudioError(format!("Failed to prepare PCM: {}", e))
})?;
let _ = state.send(CaptureState::Running);
// Allocate buffer - use u8 directly for zero-copy
let frame_bytes = config.bytes_per_frame();
let mut buffer = vec![0u8; frame_bytes];
// Capture loop
while !stop_flag.load(Ordering::Relaxed) {
// Check PCM state
match pcm.state() {
State::XRun => {
warn!("Audio buffer overrun, recovering");
if let Ok(mut s) = stats.try_lock() {
s.buffer_overruns += 1;
}
let _ = pcm.prepare();
continue;
}
State::Suspended => {
warn!("Audio device suspended, recovering");
let _ = pcm.resume();
continue;
}
_ => {}
}
// Get IO handle and read audio data directly as bytes
// Note: Use io() instead of io_checked() because USB audio devices
// typically don't support mmap, which io_checked() requires
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
Ok(frames_read) => {
if frames_read == 0 {
continue;
}
// Calculate actual byte count
let byte_count = frames_read * config.channels as usize * 2;
// Directly use the buffer slice (already in correct byte format)
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new(
Bytes::copy_from_slice(&buffer[..byte_count]),
config,
seq,
);
// Send to subscribers
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
// Update stats
if let Ok(mut s) = stats.try_lock() {
s.frames_captured += 1;
}
}
Err(e) => {
// Check for buffer overrun (EPIPE = 32 on Linux)
let desc = e.to_string();
if desc.contains("EPIPE") || desc.contains("Broken pipe") {
// Buffer overrun
warn!("Audio buffer overrun");
if let Ok(mut s) = stats.try_lock() {
s.buffer_overruns += 1;
}
let _ = pcm.prepare();
} else {
error!("Audio read error: {}", e);
if let Ok(mut s) = stats.try_lock() {
s.frames_dropped += 1;
}
}
}
}
}
info!("Audio capture stopped");
Ok(())
}

495
src/audio/controller.rs Normal file
View File

@@ -0,0 +1,495 @@
//! Audio controller for high-level audio management
//!
//! Provides device enumeration, selection, quality control, and streaming management.
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::info;
use super::capture::AudioConfig;
use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo};
use super::encoder::{OpusConfig, OpusFrame};
use super::monitor::{AudioHealthMonitor, AudioHealthStatus};
use super::streamer::{AudioStreamer, AudioStreamerConfig};
use crate::error::{AppError, Result};
use crate::events::{EventBus, SystemEvent};
/// Audio quality presets
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
/// Low bandwidth voice (32kbps)
Voice,
/// Balanced quality (64kbps) - default
#[default]
Balanced,
/// High quality audio (128kbps)
High,
}
impl AudioQuality {
/// Get the bitrate for this quality level
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
AudioQuality::Balanced => 64000,
AudioQuality::High => 128000,
}
}
/// Parse from string
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"voice" | "low" => AudioQuality::Voice,
"high" | "music" => AudioQuality::High,
_ => AudioQuality::Balanced,
}
}
/// Convert to OpusConfig
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
AudioQuality::Balanced => OpusConfig::default(),
AudioQuality::High => OpusConfig::music(),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AudioQuality::Voice => write!(f, "voice"),
AudioQuality::Balanced => write!(f, "balanced"),
AudioQuality::High => write!(f, "high"),
}
}
}
/// Audio controller configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
/// Whether audio is enabled
pub enabled: bool,
/// Selected device name
pub device: String,
/// Audio quality preset
pub quality: AudioQuality,
}
impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
quality: AudioQuality::Balanced,
}
}
}
/// Current audio status
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
/// Whether audio feature is enabled
pub enabled: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Currently selected device
pub device: Option<String>,
/// Current quality preset
pub quality: AudioQuality,
/// Number of connected subscribers
pub subscriber_count: usize,
/// Frames encoded
pub frames_encoded: u64,
/// Bytes output
pub bytes_output: u64,
/// Error message if any
pub error: Option<String>,
}
/// Audio controller
///
/// High-level interface for audio management, providing:
/// - Device enumeration and selection
/// - Quality control
/// - Stream start/stop
/// - Status reporting
pub struct AudioController {
config: RwLock<AudioControllerConfig>,
streamer: RwLock<Option<Arc<AudioStreamer>>>,
devices: RwLock<Vec<AudioDeviceInfo>>,
event_bus: RwLock<Option<Arc<EventBus>>>,
last_error: RwLock<Option<String>>,
/// Health monitor for error tracking and recovery
monitor: Arc<AudioHealthMonitor>,
}
impl AudioController {
/// Create a new audio controller with configuration
pub fn new(config: AudioControllerConfig) -> Self {
Self {
config: RwLock::new(config),
streamer: RwLock::new(None),
devices: RwLock::new(Vec::new()),
event_bus: RwLock::new(None),
last_error: RwLock::new(None),
monitor: Arc::new(AudioHealthMonitor::with_defaults()),
}
}
/// Set event bus for publishing audio events
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus.clone());
// Also set event bus on the monitor for health notifications
self.monitor.set_event_bus(event_bus).await;
}
/// Publish an event to the event bus
async fn publish_event(&self, event: SystemEvent) {
if let Some(ref bus) = *self.event_bus.read().await {
bus.publish(event);
}
}
/// List available audio capture devices
pub async fn list_devices(&self) -> Result<Vec<AudioDeviceInfo>> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
None
};
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
*self.devices.write().await = devices.clone();
Ok(devices)
}
/// Refresh device list and cache it
pub async fn refresh_devices(&self) -> Result<()> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
None
};
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
*self.devices.write().await = devices;
Ok(())
}
/// Get cached device list
pub async fn get_cached_devices(&self) -> Vec<AudioDeviceInfo> {
self.devices.read().await.clone()
}
/// Select audio device
pub async fn select_device(&self, device: &str) -> Result<()> {
// Validate device exists
let devices = self.list_devices().await?;
let found = devices.iter().any(|d| d.name == device || d.description.contains(device));
if !found && device != "default" {
return Err(AppError::AudioError(format!(
"Audio device not found: {}",
device
)));
}
// Update config
{
let mut config = self.config.write().await;
config.device = device.to_string();
}
// Publish event
self.publish_event(SystemEvent::AudioDeviceSelected {
device: device.to_string(),
})
.await;
info!("Audio device selected: {}", device);
// If streaming, restart with new device
if self.is_streaming().await {
self.stop_streaming().await?;
self.start_streaming().await?;
}
Ok(())
}
/// Set audio quality
pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> {
// Update config
{
let mut config = self.config.write().await;
config.quality = quality;
}
// Update streamer if running
if let Some(ref streamer) = *self.streamer.read().await {
streamer.set_bitrate(quality.bitrate()).await?;
}
// Publish event
self.publish_event(SystemEvent::AudioQualityChanged {
quality: quality.to_string(),
})
.await;
info!("Audio quality set to: {:?} ({}bps)", quality, quality.bitrate());
Ok(())
}
/// Start audio streaming
pub async fn start_streaming(&self) -> Result<()> {
let config = self.config.read().await.clone();
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
}
// Check if already streaming
if self.is_streaming().await {
return Ok(());
}
info!("Starting audio streaming with device: {}", config.device);
// Clear any previous error
*self.last_error.write().await = None;
// Create streamer config (fixed 48kHz stereo)
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: config.device.clone(),
..Default::default()
},
opus: config.quality.to_opus_config(),
};
// Create and start streamer
let streamer = Arc::new(AudioStreamer::with_config(streamer_config));
if let Err(e) = streamer.start().await {
let error_msg = format!("Failed to start audio: {}", e);
*self.last_error.write().await = Some(error_msg.clone());
// Report error to health monitor
self.monitor
.report_error(Some(&config.device), &error_msg, "start_failed")
.await;
self.publish_event(SystemEvent::AudioStateChanged {
streaming: false,
device: None,
})
.await;
return Err(AppError::AudioError(error_msg));
}
*self.streamer.write().await = Some(streamer);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered(Some(&config.device)).await;
}
// Publish event
self.publish_event(SystemEvent::AudioStateChanged {
streaming: true,
device: Some(config.device),
})
.await;
info!("Audio streaming started");
Ok(())
}
/// Stop audio streaming
pub async fn stop_streaming(&self) -> Result<()> {
if let Some(streamer) = self.streamer.write().await.take() {
streamer.stop().await?;
}
// Publish event
self.publish_event(SystemEvent::AudioStateChanged {
streaming: false,
device: None,
})
.await;
info!("Audio streaming stopped");
Ok(())
}
/// Check if currently streaming
pub async fn is_streaming(&self) -> bool {
if let Some(ref streamer) = *self.streamer.read().await {
streamer.is_running()
} else {
false
}
}
/// Get current status
pub async fn status(&self) -> AudioStatus {
let config = self.config.read().await;
let streaming = self.is_streaming().await;
let error = self.last_error.read().await.clone();
let (subscriber_count, frames_encoded, bytes_output) = if let Some(ref streamer) =
*self.streamer.read().await
{
let stats = streamer.stats().await;
(stats.subscriber_count, stats.frames_encoded, stats.bytes_output)
} else {
(0, 0, 0)
};
AudioStatus {
enabled: config.enabled,
streaming,
device: if streaming || config.enabled {
Some(config.device.clone())
} else {
None
},
quality: config.quality,
subscriber_count,
frames_encoded,
bytes_output,
error,
}
}
/// Subscribe to Opus frames (for WebSocket clients)
pub fn subscribe_opus(&self) -> Option<broadcast::Receiver<OpusFrame>> {
// Use try_read to avoid blocking - this is called from sync context sometimes
if let Ok(guard) = self.streamer.try_read() {
guard.as_ref().map(|s| s.subscribe_opus())
} else {
None
}
}
/// Subscribe to Opus frames (async version)
pub async fn subscribe_opus_async(&self) -> Option<broadcast::Receiver<OpusFrame>> {
self.streamer.read().await.as_ref().map(|s| s.subscribe_opus())
}
/// Enable or disable audio
pub async fn set_enabled(&self, enabled: bool) -> Result<()> {
{
let mut config = self.config.write().await;
config.enabled = enabled;
}
if !enabled && self.is_streaming().await {
self.stop_streaming().await?;
}
info!("Audio enabled: {}", enabled);
Ok(())
}
/// Update full configuration
pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> {
let was_streaming = self.is_streaming().await;
let old_config = self.config.read().await.clone();
// Stop streaming if running
if was_streaming {
self.stop_streaming().await?;
}
// Update config
*self.config.write().await = new_config.clone();
// Restart streaming if it was running and still enabled
if was_streaming && new_config.enabled {
self.start_streaming().await?;
}
// Publish events for changes
if old_config.device != new_config.device {
self.publish_event(SystemEvent::AudioDeviceSelected {
device: new_config.device.clone(),
})
.await;
}
if old_config.quality != new_config.quality {
self.publish_event(SystemEvent::AudioQualityChanged {
quality: new_config.quality.to_string(),
})
.await;
}
Ok(())
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
self.stop_streaming().await
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<AudioHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> AudioHealthStatus {
self.monitor.status().await
}
/// Check if the audio is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Default for AudioController {
fn default() -> Self {
Self::new(AudioControllerConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_quality_bitrate() {
assert_eq!(AudioQuality::Voice.bitrate(), 32000);
assert_eq!(AudioQuality::Balanced.bitrate(), 64000);
assert_eq!(AudioQuality::High.bitrate(), 128000);
}
#[test]
fn test_audio_quality_from_str() {
assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced);
assert_eq!(AudioQuality::from_str("high"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("music"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced);
}
#[tokio::test]
async fn test_controller_default() {
let controller = AudioController::default();
let status = controller.status().await;
assert!(!status.enabled);
assert!(!status.streaming);
}
}

234
src/audio/device.rs Normal file
View File

@@ -0,0 +1,234 @@
//! Audio device enumeration using ALSA
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
/// Audio device information
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
/// Device name (e.g., "hw:0,0" or "default")
pub name: String,
/// Human-readable description
pub description: String,
/// Card index
pub card_index: i32,
/// Device index
pub device_index: i32,
/// Supported sample rates
pub sample_rates: Vec<u32>,
/// Supported channel counts
pub channels: Vec<u32>,
/// Is this a capture device
pub is_capture: bool,
/// Is this an HDMI audio device (likely from capture card)
pub is_hdmi: bool,
}
impl AudioDeviceInfo {
/// Get ALSA device name
pub fn alsa_name(&self) -> String {
format!("hw:{},{}", self.card_index, self.device_index)
}
}
/// Enumerate available audio capture devices
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
/// Enumerate available audio capture devices, with option to include a currently-in-use device
///
/// # Arguments
/// * `current_device` - Optional device name that is currently in use. This device will be
/// included in the list even if it cannot be opened (because it's already open by us).
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
// Try to enumerate cards
let cards = match alsa::card::Iter::new() {
i => i,
};
for card_result in cards {
let card = match card_result {
Ok(c) => c,
Err(e) => {
debug!("Error iterating card: {}", e);
continue;
}
};
let card_index = card.get_index();
let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string());
let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone());
debug!("Found audio card {}: {}", card_index, card_longname);
// Check if this looks like an HDMI capture device
let is_hdmi = card_longname.to_lowercase().contains("hdmi")
|| card_longname.to_lowercase().contains("capture")
|| card_longname.to_lowercase().contains("usb");
// Try to open each device on this card for capture
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
// Check if this is the currently-in-use device
let is_current_device = current_device == Some(device_name.as_str());
// Try to open for capture
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
// Query capabilities
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
devices.push(AudioDeviceInfo {
name: device_name,
description: format!("{} - Device {}", card_longname, device_index),
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
});
}
}
Err(_) => {
// Device doesn't exist or can't be opened for capture
// But if it's the current device, include it anyway (it's busy because we're using it)
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
devices.push(AudioDeviceInfo {
name: device_name,
description: format!(
"{} - Device {} (in use)",
card_longname, device_index
),
card_index,
device_index,
// Use common default capabilities for HDMI capture devices
sample_rates: vec![44100, 48000],
channels: vec![2],
is_capture: true,
is_hdmi,
});
}
continue;
}
}
}
}
// Also check for "default" device
if let Ok(pcm) = PCM::new("default", Direction::Capture, false) {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() {
devices.insert(
0,
AudioDeviceInfo {
name: "default".to_string(),
description: "Default Audio Device".to_string(),
card_index: -1,
device_index: -1,
sample_rates,
channels,
is_capture: true,
is_hdmi: false,
},
);
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
/// Query device capabilities
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
// Common sample rates to check
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
for rate in &common_rates {
if hwp.test_rate(*rate).is_ok() {
supported_rates.push(*rate);
}
}
// Check channel counts
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
supported_channels.push(ch);
}
}
(supported_rates, supported_channels)
}
/// Find the best audio device for capture
/// Prefers HDMI/capture devices over built-in microphones
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError("No audio capture devices found".to_string()));
}
// First, look for HDMI/capture card devices that support 48kHz stereo
for device in &devices {
if device.is_hdmi
&& device.sample_rates.contains(&48000)
&& device.channels.contains(&2)
{
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
}
// Then look for any device supporting 48kHz stereo
for device in &devices {
if device.sample_rates.contains(&48000) && device.channels.contains(&2) {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
}
// Fall back to first device
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
device.description
);
Ok(device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enumerate_devices() {
// This test may not find devices in CI environment
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
// Just verify it doesn't panic
assert!(result.is_ok());
}
}

280
src/audio/encoder.rs Normal file
View File

@@ -0,0 +1,280 @@
//! Opus audio encoder for WebRTC
use audiopus::coder::GenericCtl;
use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate};
use bytes::Bytes;
use std::time::Instant;
use tracing::{info, trace};
use super::capture::AudioFrame;
use crate::error::{AppError, Result};
/// Opus encoder configuration
#[derive(Debug, Clone)]
pub struct OpusConfig {
/// Sample rate (must be 8000, 12000, 16000, 24000, or 48000)
pub sample_rate: u32,
/// Channels (1 or 2)
pub channels: u32,
/// Target bitrate in bps
pub bitrate: u32,
/// Application mode
pub application: OpusApplication,
/// Enable forward error correction
pub fec: bool,
}
impl Default for OpusConfig {
fn default() -> Self {
Self {
sample_rate: 48000,
channels: 2,
bitrate: 64000, // 64 kbps
application: OpusApplication::Audio,
fec: true,
}
}
}
impl OpusConfig {
/// Create config for voice (lower latency)
pub fn voice() -> Self {
Self {
application: OpusApplication::Voip,
bitrate: 32000,
..Default::default()
}
}
/// Create config for music (higher quality)
pub fn music() -> Self {
Self {
application: OpusApplication::Audio,
bitrate: 128000,
..Default::default()
}
}
fn to_audiopus_sample_rate(&self) -> SampleRate {
match self.sample_rate {
8000 => SampleRate::Hz8000,
12000 => SampleRate::Hz12000,
16000 => SampleRate::Hz16000,
24000 => SampleRate::Hz24000,
_ => SampleRate::Hz48000,
}
}
fn to_audiopus_channels(&self) -> Channels {
if self.channels == 1 {
Channels::Mono
} else {
Channels::Stereo
}
}
fn to_audiopus_application(&self) -> Application {
match self.application {
OpusApplication::Voip => Application::Voip,
OpusApplication::Audio => Application::Audio,
OpusApplication::LowDelay => Application::LowDelay,
}
}
}
/// Opus application mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpusApplication {
/// Voice over IP
Voip,
/// General audio
Audio,
/// Low delay mode
LowDelay,
}
/// Encoded Opus frame
#[derive(Debug, Clone)]
pub struct OpusFrame {
/// Encoded Opus data
pub data: Bytes,
/// Duration in milliseconds
pub duration_ms: u32,
/// Sequence number
pub sequence: u64,
/// Timestamp
pub timestamp: Instant,
/// RTP timestamp (samples)
pub rtp_timestamp: u32,
}
impl OpusFrame {
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
/// Opus encoder
pub struct OpusEncoder {
config: OpusConfig,
encoder: Encoder,
/// Output buffer
output_buffer: Vec<u8>,
/// Frame counter for RTP timestamp
frame_count: u64,
/// Samples per frame
samples_per_frame: u32,
}
impl OpusEncoder {
/// Create a new Opus encoder
pub fn new(config: OpusConfig) -> Result<Self> {
let sample_rate = config.to_audiopus_sample_rate();
let channels = config.to_audiopus_channels();
let application = config.to_audiopus_application();
let mut encoder = Encoder::new(sample_rate, channels, application).map_err(|e| {
AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e))
})?;
// Configure encoder
encoder
.set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32))
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
if config.fec {
encoder
.set_inband_fec(true)
.map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?;
}
// Calculate samples per frame (20ms at sample_rate)
let samples_per_frame = config.sample_rate / 50;
info!(
"Opus encoder created: {}Hz {}ch {}bps",
config.sample_rate, config.channels, config.bitrate
);
Ok(Self {
config,
encoder,
output_buffer: vec![0u8; 4000], // Max Opus frame size
frame_count: 0,
samples_per_frame,
})
}
/// Create with default configuration
pub fn default_config() -> Result<Self> {
Self::new(OpusConfig::default())
}
/// Encode PCM audio data (S16LE interleaved)
pub fn encode(&mut self, pcm_data: &[i16]) -> Result<OpusFrame> {
let encoded_len = self
.encoder
.encode(pcm_data, &mut self.output_buffer)
.map_err(|e| AppError::AudioError(format!("Opus encode failed: {:?}", e)))?;
let samples = pcm_data.len() as u32 / self.config.channels;
let duration_ms = (samples * 1000) / self.config.sample_rate;
let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32;
self.frame_count += 1;
trace!(
"Encoded {} samples to {} bytes Opus",
pcm_data.len(),
encoded_len
);
Ok(OpusFrame {
data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]),
duration_ms,
sequence: self.frame_count - 1,
timestamp: Instant::now(),
rtp_timestamp,
})
}
/// Encode from AudioFrame
///
/// Uses zero-copy conversion from bytes to i16 samples via bytemuck.
pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result<OpusFrame> {
// Zero-copy: directly cast bytes to i16 slice
// AudioFrame.data is S16LE format, which matches native little-endian i16
let samples: &[i16] = bytemuck::cast_slice(&frame.data);
self.encode(samples)
}
/// Get encoder configuration
pub fn config(&self) -> &OpusConfig {
&self.config
}
/// Reset encoder state
pub fn reset(&mut self) -> Result<()> {
self.encoder
.reset_state()
.map_err(|e| AppError::AudioError(format!("Failed to reset encoder: {:?}", e)))?;
self.frame_count = 0;
Ok(())
}
/// Set bitrate dynamically
pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
self.encoder
.set_bitrate(Bitrate::BitsPerSecond(bitrate as i32))
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
Ok(())
}
}
/// Audio encoder statistics
#[derive(Debug, Clone, Default)]
pub struct EncoderStats {
pub frames_encoded: u64,
pub bytes_output: u64,
pub avg_frame_size: usize,
pub current_bitrate: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opus_config_default() {
let config = OpusConfig::default();
assert_eq!(config.sample_rate, 48000);
assert_eq!(config.channels, 2);
assert_eq!(config.bitrate, 64000);
}
#[test]
fn test_create_encoder() {
let config = OpusConfig::default();
let encoder = OpusEncoder::new(config);
assert!(encoder.is_ok());
}
#[test]
fn test_encode_silence() {
let config = OpusConfig::default();
let mut encoder = OpusEncoder::new(config).unwrap();
// 20ms of stereo silence at 48kHz
let silence = vec![0i16; 960 * 2];
let result = encoder.encode(&silence);
assert!(result.is_ok());
let frame = result.unwrap();
assert!(!frame.is_empty());
assert!(frame.len() < silence.len() * 2); // Should be compressed
}
}

26
src/audio/mod.rs Normal file
View File

@@ -0,0 +1,26 @@
//! Audio capture and encoding module
//!
//! This module provides:
//! - ALSA audio capture
//! - Opus encoding for WebRTC
//! - Audio device enumeration
//! - Audio streaming pipeline
//! - High-level audio controller
//! - Shared audio pipeline for WebRTC multi-session support
//! - Device health monitoring
pub mod capture;
pub mod controller;
pub mod device;
pub mod encoder;
pub mod monitor;
pub mod shared_pipeline;
pub mod streamer;
pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig};
pub use shared_pipeline::{SharedAudioPipeline, SharedAudioPipelineConfig, SharedAudioPipelineStats};
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};

352
src/audio/monitor.rs Normal file
View File

@@ -0,0 +1,352 @@
//! Audio device health monitoring
//!
//! This module provides health monitoring for audio capture devices, including:
//! - Device connectivity checks
//! - Automatic reconnection on failure
//! - Error tracking and notification
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::events::{EventBus, SystemEvent};
use crate::utils::LogThrottler;
/// Audio health status
#[derive(Debug, Clone, PartialEq)]
pub enum AudioHealthStatus {
/// Device is healthy and operational
Healthy,
/// Device has an error, attempting recovery
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
/// Number of recovery attempts made
retry_count: u32,
},
/// Device is disconnected or not available
Disconnected,
}
impl Default for AudioHealthStatus {
fn default() -> Self {
Self::Healthy
}
}
/// Audio health monitor configuration
#[derive(Debug, Clone)]
pub struct AudioMonitorConfig {
/// Retry interval when device is lost (milliseconds)
pub retry_interval_ms: u64,
/// Maximum retry attempts before giving up (0 = infinite)
pub max_retries: u32,
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for AudioMonitorConfig {
fn default() -> Self {
Self {
retry_interval_ms: 1000,
max_retries: 0, // infinite retry
log_throttle_secs: 5,
}
}
}
/// Audio health monitor
///
/// Monitors audio device health and manages error recovery.
/// Publishes WebSocket events when device status changes.
pub struct AudioHealthMonitor {
/// Current health status
status: RwLock<AudioHealthStatus>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
config: AudioMonitorConfig,
/// Whether monitoring is active (reserved for future use)
#[allow(dead_code)]
running: AtomicBool,
/// Current retry count
retry_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
}
impl AudioHealthMonitor {
/// Create a new audio health monitor with the specified configuration
pub fn new(config: AudioMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
Self {
status: RwLock::new(AudioHealthStatus::Healthy),
events: RwLock::new(None),
throttler: LogThrottler::with_secs(throttle_secs),
config,
running: AtomicBool::new(false),
retry_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
}
}
/// Create a new audio health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(AudioMonitorConfig::default())
}
/// Set the event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Report an error from audio operations
///
/// This method is called when an audio operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Publishes a WebSocket event if the error is new or changed
///
/// # Arguments
///
/// * `device` - The audio device name (if known)
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, device: Option<&str>, reason: &str, error_code: &str) {
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("audio_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
"Audio error: {} (code: {}, attempt: {})",
reason, error_code, count
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = AudioHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
retry_count: count,
};
// Publish event (only if error changed or first occurrence)
if error_changed || count == 1 {
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::AudioDeviceLost {
device: device.map(|s| s.to_string()),
reason: reason.to_string(),
error_code: error_code.to_string(),
});
}
}
}
/// Report that a reconnection attempt is starting
///
/// Publishes a reconnecting event to notify clients.
pub async fn report_reconnecting(&self) {
let attempt = self.retry_count.load(Ordering::Relaxed);
// Only publish every 5 attempts to avoid event spam
if attempt == 1 || attempt % 5 == 0 {
debug!("Audio reconnecting, attempt {}", attempt);
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::AudioReconnecting { attempt });
}
}
}
/// Report that the device has recovered
///
/// This method is called when the audio device successfully reconnects.
/// It resets the error state and publishes a recovery event.
///
/// # Arguments
///
/// * `device` - The audio device name
pub async fn report_recovered(&self, device: Option<&str>) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != AudioHealthStatus::Healthy {
let retry_count = self.retry_count.load(Ordering::Relaxed);
info!("Audio recovered after {} retries", retry_count);
// Reset state
self.retry_count.store(0, Ordering::Relaxed);
self.throttler.clear("audio_");
*self.last_error_code.write().await = None;
*self.status.write().await = AudioHealthStatus::Healthy;
// Publish recovery event
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::AudioRecovered {
device: device.map(|s| s.to_string()),
});
}
}
}
/// Get the current health status
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Get the current retry count
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = AudioHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the configuration
pub fn config(&self) -> &AudioMonitorConfig {
&self.config
}
/// Check if we should continue retrying
///
/// Returns `false` if max_retries is set and we've exceeded it.
pub fn should_retry(&self) -> bool {
if self.config.max_retries == 0 {
return true; // Infinite retry
}
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
}
/// Get the retry interval
pub fn retry_interval(&self) -> Duration {
Duration::from_millis(self.config.retry_interval_ms)
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
match &*self.status.read().await {
AudioHealthStatus::Error { reason, .. } => Some(reason.clone()),
_ => None,
}
}
}
impl Default for AudioHealthMonitor {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_status() {
let monitor = AudioHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = AudioHealthMonitor::with_defaults();
monitor
.report_error(Some("hw:0,0"), "Device not found", "device_disconnected")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.retry_count(), 1);
if let AudioHealthStatus::Error {
reason,
error_code,
retry_count,
} = monitor.status().await
{
assert_eq!(reason, "Device not found");
assert_eq!(error_code, "device_disconnected");
assert_eq!(retry_count, 1);
} else {
panic!("Expected Error status");
}
}
#[tokio::test]
async fn test_report_recovered() {
let monitor = AudioHealthMonitor::with_defaults();
// First report an error
monitor
.report_error(Some("default"), "Capture failed", "capture_error")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered(Some("default")).await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_retry_count_increments() {
let monitor = AudioHealthMonitor::with_defaults();
for i in 1..=5 {
monitor
.report_error(None, "Error", "io_error")
.await;
assert_eq!(monitor.retry_count(), i);
}
}
#[tokio::test]
async fn test_reset() {
let monitor = AudioHealthMonitor::with_defaults();
monitor
.report_error(None, "Error", "io_error")
.await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.retry_count(), 0);
}
}

View File

@@ -0,0 +1,453 @@
//! Shared Audio Pipeline for WebRTC
//!
//! This module provides a shared audio encoding pipeline that can serve
//! multiple WebRTC sessions with a single encoder instance.
//!
//! # Architecture
//!
//! ```text
//! AudioCapturer (ALSA)
//! |
//! v (broadcast::Receiver<AudioFrame>)
//! SharedAudioPipeline (single Opus encoder)
//! |
//! v (broadcast::Sender<OpusFrame>)
//! ┌────┴────┬────────┬────────┐
//! v v v v
//! Session1 Session2 Session3 ...
//! (RTP) (RTP) (RTP) (RTP)
//! ```
//!
//! # Key Features
//!
//! - **Single encoder**: All sessions share one Opus encoder
//! - **Broadcast distribution**: Encoded frames are broadcast to all subscribers
//! - **Dynamic bitrate**: Bitrate can be changed at runtime
//! - **Statistics**: Tracks encoding performance metrics
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, Mutex, RwLock};
use tracing::{debug, error, info, trace, warn};
use super::capture::AudioFrame;
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
use crate::error::{AppError, Result};
/// Shared audio pipeline configuration
#[derive(Debug, Clone)]
pub struct SharedAudioPipelineConfig {
/// Sample rate (must match audio capture)
pub sample_rate: u32,
/// Number of channels (1 or 2)
pub channels: u32,
/// Target bitrate in bps
pub bitrate: u32,
/// Opus application mode
pub application: OpusApplicationMode,
/// Enable forward error correction
pub fec: bool,
/// Broadcast channel capacity
pub channel_capacity: usize,
}
impl Default for SharedAudioPipelineConfig {
fn default() -> Self {
Self {
sample_rate: 48000,
channels: 2,
bitrate: 64000,
application: OpusApplicationMode::Audio,
fec: true,
channel_capacity: 64,
}
}
}
impl SharedAudioPipelineConfig {
/// Create config optimized for voice
pub fn voice() -> Self {
Self {
bitrate: 32000,
application: OpusApplicationMode::Voip,
..Default::default()
}
}
/// Create config optimized for music/high quality
pub fn high_quality() -> Self {
Self {
bitrate: 128000,
application: OpusApplicationMode::Audio,
..Default::default()
}
}
/// Convert to OpusConfig
pub fn to_opus_config(&self) -> OpusConfig {
OpusConfig {
sample_rate: self.sample_rate,
channels: self.channels,
bitrate: self.bitrate,
application: match self.application {
OpusApplicationMode::Voip => super::encoder::OpusApplication::Voip,
OpusApplicationMode::Audio => super::encoder::OpusApplication::Audio,
OpusApplicationMode::LowDelay => super::encoder::OpusApplication::LowDelay,
},
fec: self.fec,
}
}
}
/// Opus application mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpusApplicationMode {
/// Voice over IP - optimized for speech
Voip,
/// General audio - balanced quality
Audio,
/// Low delay mode - minimal latency
LowDelay,
}
/// Shared audio pipeline statistics
#[derive(Debug, Clone, Default)]
pub struct SharedAudioPipelineStats {
/// Frames received from audio capture
pub frames_received: u64,
/// Frames successfully encoded
pub frames_encoded: u64,
/// Frames dropped (encode errors)
pub frames_dropped: u64,
/// Total bytes encoded
pub bytes_encoded: u64,
/// Number of active subscribers
pub subscribers: u64,
/// Average encode time in milliseconds
pub avg_encode_time_ms: f32,
/// Current bitrate in bps
pub current_bitrate: u32,
/// Pipeline running time in seconds
pub running_time_secs: f64,
}
/// Shared Audio Pipeline
///
/// Provides a single Opus encoder that serves multiple WebRTC sessions.
/// All sessions receive the same encoded audio stream via broadcast channel.
pub struct SharedAudioPipeline {
/// Configuration
config: RwLock<SharedAudioPipelineConfig>,
/// Opus encoder (protected by mutex for encoding)
encoder: Mutex<Option<OpusEncoder>>,
/// Broadcast sender for encoded Opus frames
opus_tx: broadcast::Sender<OpusFrame>,
/// Running state
running: AtomicBool,
/// Statistics
stats: Mutex<SharedAudioPipelineStats>,
/// Start time for running time calculation
start_time: RwLock<Option<Instant>>,
/// Encode time accumulator for averaging
encode_time_sum_us: AtomicU64,
/// Encode count for averaging
encode_count: AtomicU64,
/// Stop signal (atomic for lock-free checking)
stop_flag: AtomicBool,
/// Encoding task handle
task_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl SharedAudioPipeline {
/// Create a new shared audio pipeline
pub fn new(config: SharedAudioPipelineConfig) -> Result<Arc<Self>> {
let (opus_tx, _) = broadcast::channel(config.channel_capacity);
Ok(Arc::new(Self {
config: RwLock::new(config),
encoder: Mutex::new(None),
opus_tx,
running: AtomicBool::new(false),
stats: Mutex::new(SharedAudioPipelineStats::default()),
start_time: RwLock::new(None),
encode_time_sum_us: AtomicU64::new(0),
encode_count: AtomicU64::new(0),
stop_flag: AtomicBool::new(false),
task_handle: Mutex::new(None),
}))
}
/// Create with default configuration
pub fn default_config() -> Result<Arc<Self>> {
Self::new(SharedAudioPipelineConfig::default())
}
/// Start the audio encoding pipeline
///
/// # Arguments
/// * `audio_rx` - Receiver for raw audio frames from AudioCapturer
pub async fn start(self: &Arc<Self>, audio_rx: broadcast::Receiver<AudioFrame>) -> Result<()> {
if self.running.load(Ordering::SeqCst) {
return Ok(());
}
let config = self.config.read().await.clone();
info!(
"Starting shared audio pipeline: {}Hz {}ch {}bps",
config.sample_rate, config.channels, config.bitrate
);
// Create encoder
let opus_config = config.to_opus_config();
let encoder = OpusEncoder::new(opus_config)?;
*self.encoder.lock().await = Some(encoder);
// Reset stats
{
let mut stats = self.stats.lock().await;
*stats = SharedAudioPipelineStats::default();
stats.current_bitrate = config.bitrate;
}
// Reset counters
self.encode_time_sum_us.store(0, Ordering::SeqCst);
self.encode_count.store(0, Ordering::SeqCst);
*self.start_time.write().await = Some(Instant::now());
self.stop_flag.store(false, Ordering::SeqCst);
self.running.store(true, Ordering::SeqCst);
// Start encoding task
let pipeline = self.clone();
let handle = tokio::spawn(async move {
pipeline.encoding_task(audio_rx).await;
});
*self.task_handle.lock().await = Some(handle);
info!("Shared audio pipeline started");
Ok(())
}
/// Stop the audio encoding pipeline
pub fn stop(&self) {
if !self.running.load(Ordering::SeqCst) {
return;
}
info!("Stopping shared audio pipeline");
// Signal stop (atomic, no lock needed)
self.stop_flag.store(true, Ordering::SeqCst);
self.running.store(false, Ordering::SeqCst);
}
/// Check if pipeline is running
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
/// Subscribe to encoded Opus frames
pub fn subscribe(&self) -> broadcast::Receiver<OpusFrame> {
self.opus_tx.subscribe()
}
/// Get number of active subscribers
pub fn subscriber_count(&self) -> usize {
self.opus_tx.receiver_count()
}
/// Get current statistics
pub async fn stats(&self) -> SharedAudioPipelineStats {
let mut stats = self.stats.lock().await.clone();
stats.subscribers = self.subscriber_count() as u64;
// Calculate average encode time
let count = self.encode_count.load(Ordering::SeqCst);
if count > 0 {
let sum_us = self.encode_time_sum_us.load(Ordering::SeqCst);
stats.avg_encode_time_ms = (sum_us as f64 / count as f64 / 1000.0) as f32;
}
// Calculate running time
if let Some(start) = *self.start_time.read().await {
stats.running_time_secs = start.elapsed().as_secs_f64();
}
stats
}
/// Set bitrate dynamically
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
// Update config
self.config.write().await.bitrate = bitrate;
// Update encoder if running
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate)?;
}
// Update stats
self.stats.lock().await.current_bitrate = bitrate;
info!("Shared audio pipeline bitrate changed to {}bps", bitrate);
Ok(())
}
/// Update configuration (requires restart)
pub async fn update_config(&self, config: SharedAudioPipelineConfig) -> Result<()> {
if self.is_running() {
return Err(AppError::AudioError(
"Cannot update config while pipeline is running".to_string(),
));
}
*self.config.write().await = config;
Ok(())
}
/// Internal encoding task
async fn encoding_task(self: Arc<Self>, mut audio_rx: broadcast::Receiver<AudioFrame>) {
info!("Audio encoding task started");
loop {
// Check stop flag (atomic, no async lock needed)
if self.stop_flag.load(Ordering::Relaxed) {
break;
}
// Receive audio frame with timeout
let recv_result = tokio::time::timeout(
std::time::Duration::from_secs(2),
audio_rx.recv(),
)
.await;
match recv_result {
Ok(Ok(audio_frame)) => {
// Update received count
{
let mut stats = self.stats.lock().await;
stats.frames_received += 1;
}
// Encode frame
let encode_start = Instant::now();
let encode_result = {
let mut encoder_guard = self.encoder.lock().await;
if let Some(ref mut encoder) = *encoder_guard {
Some(encoder.encode_frame(&audio_frame))
} else {
None
}
};
let encode_time = encode_start.elapsed();
// Update encode time stats
self.encode_time_sum_us
.fetch_add(encode_time.as_micros() as u64, Ordering::SeqCst);
self.encode_count.fetch_add(1, Ordering::SeqCst);
match encode_result {
Some(Ok(opus_frame)) => {
// Update stats
{
let mut stats = self.stats.lock().await;
stats.frames_encoded += 1;
stats.bytes_encoded += opus_frame.data.len() as u64;
}
// Broadcast to subscribers
if self.opus_tx.receiver_count() > 0 {
if let Err(e) = self.opus_tx.send(opus_frame) {
trace!("No audio subscribers: {}", e);
}
}
}
Some(Err(e)) => {
error!("Opus encode error: {}", e);
let mut stats = self.stats.lock().await;
stats.frames_dropped += 1;
}
None => {
warn!("Encoder not available");
break;
}
}
}
Ok(Err(broadcast::error::RecvError::Closed)) => {
info!("Audio source channel closed");
break;
}
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
warn!("Audio pipeline lagged by {} frames", n);
let mut stats = self.stats.lock().await;
stats.frames_dropped += n;
}
Err(_) => {
// Timeout - check if still running
if !self.running.load(Ordering::SeqCst) {
break;
}
debug!("Audio receive timeout, continuing...");
}
}
}
// Cleanup
self.running.store(false, Ordering::SeqCst);
*self.encoder.lock().await = None;
let stats = self.stats().await;
info!(
"Audio encoding task ended: {} frames encoded, {} dropped, {:.1}s runtime",
stats.frames_encoded, stats.frames_dropped, stats.running_time_secs
);
}
}
impl Drop for SharedAudioPipeline {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = SharedAudioPipelineConfig::default();
assert_eq!(config.sample_rate, 48000);
assert_eq!(config.channels, 2);
assert_eq!(config.bitrate, 64000);
}
#[test]
fn test_config_voice() {
let config = SharedAudioPipelineConfig::voice();
assert_eq!(config.bitrate, 32000);
assert_eq!(config.application, OpusApplicationMode::Voip);
}
#[test]
fn test_config_high_quality() {
let config = SharedAudioPipelineConfig::high_quality();
assert_eq!(config.bitrate, 128000);
}
#[tokio::test]
async fn test_pipeline_creation() {
let config = SharedAudioPipelineConfig::default();
let pipeline = SharedAudioPipeline::new(config);
assert!(pipeline.is_ok());
let pipeline = pipeline.unwrap();
assert!(!pipeline.is_running());
assert_eq!(pipeline.subscriber_count(), 0);
}
}

401
src/audio/streamer.rs Normal file
View File

@@ -0,0 +1,401 @@
//! Audio streaming pipeline
//!
//! Coordinates audio capture and Opus encoding, distributing encoded
//! frames to multiple subscribers via broadcast channel.
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex, RwLock};
use tracing::{error, info, trace, warn};
use super::capture::{AudioCapturer, AudioConfig, CaptureState};
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
use crate::error::{AppError, Result};
/// Audio stream state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AudioStreamState {
/// Stream is stopped
Stopped,
/// Stream is starting up
Starting,
/// Stream is running
Running,
/// Stream encountered an error
Error,
}
impl Default for AudioStreamState {
fn default() -> Self {
Self::Stopped
}
}
/// Audio streamer configuration
#[derive(Debug, Clone)]
pub struct AudioStreamerConfig {
/// Audio capture configuration
pub capture: AudioConfig,
/// Opus encoder configuration
pub opus: OpusConfig,
}
impl Default for AudioStreamerConfig {
fn default() -> Self {
Self {
capture: AudioConfig::default(),
opus: OpusConfig::default(),
}
}
}
impl AudioStreamerConfig {
/// Create config for a specific device with default quality
pub fn for_device(device_name: &str) -> Self {
Self {
capture: AudioConfig {
device_name: device_name.to_string(),
..Default::default()
},
opus: OpusConfig::default(),
}
}
/// Create config with specified bitrate
pub fn with_bitrate(mut self, bitrate: u32) -> Self {
self.opus.bitrate = bitrate;
self
}
}
/// Audio stream statistics
#[derive(Debug, Clone, Default)]
pub struct AudioStreamStats {
/// Frames captured from ALSA
pub frames_captured: u64,
/// Frames encoded to Opus
pub frames_encoded: u64,
/// Total bytes output (Opus)
pub bytes_output: u64,
/// Current encoding bitrate
pub current_bitrate: u32,
/// Number of active subscribers
pub subscriber_count: usize,
/// Buffer overruns
pub buffer_overruns: u64,
}
/// Audio streamer
///
/// Manages the audio capture -> encode -> broadcast pipeline.
pub struct AudioStreamer {
config: RwLock<AudioStreamerConfig>,
state: watch::Sender<AudioStreamState>,
state_rx: watch::Receiver<AudioStreamState>,
capturer: RwLock<Option<Arc<AudioCapturer>>>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: broadcast::Sender<OpusFrame>,
stats: Arc<Mutex<AudioStreamStats>>,
sequence: AtomicU64,
stream_start_time: RwLock<Option<Instant>>,
stop_flag: Arc<AtomicBool>,
}
impl AudioStreamer {
/// Create a new audio streamer with default configuration
pub fn new() -> Self {
Self::with_config(AudioStreamerConfig::default())
}
/// Create a new audio streamer with specified configuration
pub fn with_config(config: AudioStreamerConfig) -> Self {
let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped);
let (opus_tx, _) = broadcast::channel(64);
Self {
config: RwLock::new(config),
state: state_tx,
state_rx,
capturer: RwLock::new(None),
encoder: Arc::new(Mutex::new(None)),
opus_tx,
stats: Arc::new(Mutex::new(AudioStreamStats::default())),
sequence: AtomicU64::new(0),
stream_start_time: RwLock::new(None),
stop_flag: Arc::new(AtomicBool::new(false)),
}
}
/// Get current state
pub fn state(&self) -> AudioStreamState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<AudioStreamState> {
self.state_rx.clone()
}
/// Subscribe to Opus frames
pub fn subscribe_opus(&self) -> broadcast::Receiver<OpusFrame> {
self.opus_tx.subscribe()
}
/// Get number of active subscribers
pub fn subscriber_count(&self) -> usize {
self.opus_tx.receiver_count()
}
/// Get current statistics
pub async fn stats(&self) -> AudioStreamStats {
let mut stats = self.stats.lock().await.clone();
stats.subscriber_count = self.subscriber_count();
stats
}
/// Update configuration (only when stopped)
pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> {
if self.state() != AudioStreamState::Stopped {
return Err(AppError::AudioError(
"Cannot change config while streaming".to_string(),
));
}
*self.config.write().await = config;
Ok(())
}
/// Update bitrate dynamically (can be done while streaming)
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
// Update config
self.config.write().await.opus.bitrate = bitrate;
// Update encoder if running
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate)?;
}
// Update stats
self.stats.lock().await.current_bitrate = bitrate;
info!("Audio bitrate changed to {}bps", bitrate);
Ok(())
}
/// Start the audio stream
pub async fn start(&self) -> Result<()> {
if self.state() == AudioStreamState::Running {
return Ok(());
}
let _ = self.state.send(AudioStreamState::Starting);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.read().await.clone();
info!(
"Starting audio stream: {} @ {}Hz {}ch, {}bps Opus",
config.capture.device_name,
config.capture.sample_rate,
config.capture.channels,
config.opus.bitrate
);
// Create capturer
let capturer = Arc::new(AudioCapturer::new(config.capture.clone()));
*self.capturer.write().await = Some(capturer.clone());
// Create encoder
let encoder = OpusEncoder::new(config.opus.clone())?;
*self.encoder.lock().await = Some(encoder);
// Start capture
capturer.start().await?;
// Reset stats
{
let mut stats = self.stats.lock().await;
*stats = AudioStreamStats::default();
stats.current_bitrate = config.opus.bitrate;
}
// Record start time
*self.stream_start_time.write().await = Some(Instant::now());
self.sequence.store(0, Ordering::SeqCst);
// Start encoding task
let capturer_for_task = capturer.clone();
let encoder = self.encoder.clone();
let opus_tx = self.opus_tx.clone();
let stats = self.stats.clone();
let state = self.state.clone();
let stop_flag = self.stop_flag.clone();
tokio::spawn(async move {
Self::stream_task(capturer_for_task, encoder, opus_tx, stats, state, stop_flag).await;
});
Ok(())
}
/// Stop the audio stream
pub async fn stop(&self) -> Result<()> {
if self.state() == AudioStreamState::Stopped {
return Ok(());
}
info!("Stopping audio stream");
// Signal stop
self.stop_flag.store(true, Ordering::SeqCst);
// Stop capturer
if let Some(ref capturer) = *self.capturer.read().await {
capturer.stop().await?;
}
// Clear resources
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
*self.stream_start_time.write().await = None;
let _ = self.state.send(AudioStreamState::Stopped);
info!("Audio stream stopped");
Ok(())
}
/// Check if streaming
pub fn is_running(&self) -> bool {
self.state() == AudioStreamState::Running
}
/// Internal streaming task
async fn stream_task(
capturer: Arc<AudioCapturer>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: broadcast::Sender<OpusFrame>,
stats: Arc<Mutex<AudioStreamStats>>,
state: watch::Sender<AudioStreamState>,
stop_flag: Arc<AtomicBool>,
) {
let mut pcm_rx = capturer.subscribe();
let _ = state.send(AudioStreamState::Running);
info!("Audio stream task started");
loop {
// Check stop flag (atomic, no async lock needed)
if stop_flag.load(Ordering::Relaxed) {
break;
}
// Check capturer state
if capturer.state() == CaptureState::Error {
error!("Audio capture error, stopping stream");
let _ = state.send(AudioStreamState::Error);
break;
}
// Receive PCM frame with timeout
let recv_result = tokio::time::timeout(
std::time::Duration::from_secs(2),
pcm_rx.recv(),
)
.await;
match recv_result {
Ok(Ok(audio_frame)) => {
// Update capture stats
{
let mut s = stats.lock().await;
s.frames_captured += 1;
}
// Encode to Opus
let opus_result = {
let mut enc_guard = encoder.lock().await;
if let Some(ref mut enc) = *enc_guard {
Some(enc.encode_frame(&audio_frame))
} else {
None
}
};
match opus_result {
Some(Ok(opus_frame)) => {
// Update stats
{
let mut s = stats.lock().await;
s.frames_encoded += 1;
s.bytes_output += opus_frame.data.len() as u64;
}
// Broadcast to subscribers
if opus_tx.receiver_count() > 0 {
if let Err(e) = opus_tx.send(opus_frame) {
trace!("No audio subscribers: {}", e);
}
}
}
Some(Err(e)) => {
error!("Opus encode error: {}", e);
}
None => {
warn!("Encoder not available");
break;
}
}
}
Ok(Err(broadcast::error::RecvError::Closed)) => {
info!("Audio capture channel closed");
break;
}
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
warn!("Audio receiver lagged by {} frames", n);
let mut s = stats.lock().await;
s.buffer_overruns += n;
}
Err(_) => {
// Timeout - check if still capturing
if capturer.state() != CaptureState::Running {
info!("Audio capture stopped, ending stream task");
break;
}
}
}
}
let _ = state.send(AudioStreamState::Stopped);
info!("Audio stream task ended");
}
}
impl Default for AudioStreamer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_streamer_config_default() {
let config = AudioStreamerConfig::default();
assert_eq!(config.capture.sample_rate, 48000);
assert_eq!(config.opus.bitrate, 64000);
}
#[test]
fn test_streamer_config_for_device() {
let config = AudioStreamerConfig::for_device("hw:0,0");
assert_eq!(config.capture.device_name, "hw:0,0");
}
#[tokio::test]
async fn test_streamer_state() {
let streamer = AudioStreamer::new();
assert_eq!(streamer.state(), AudioStreamState::Stopped);
}
}

142
src/auth/middleware.rs Normal file
View File

@@ -0,0 +1,142 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use axum_extra::extract::CookieJar;
use std::sync::Arc;
use crate::state::AppState;
/// Session cookie name
pub const SESSION_COOKIE: &str = "one_kvm_session";
/// Auth layer for extracting session from request
#[derive(Clone)]
pub struct AuthLayer;
/// Extract session ID from request
pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option<String> {
// First try cookie
if let Some(cookie) = cookies.get(SESSION_COOKIE) {
return Some(cookie.value().to_string());
}
// Then try Authorization header (Bearer token)
if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
return Some(token.to_string());
}
}
}
None
}
/// Authentication middleware
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
cookies: CookieJar,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Check if system is initialized
if !state.config.is_initialized() {
// Allow access to setup endpoints when not initialized
let path = request.uri().path();
if path.starts_with("/api/setup") || path == "/api/info" || path.starts_with("/") && !path.starts_with("/api/") {
return Ok(next.run(request).await);
}
}
// Public endpoints that don't require auth
let path = request.uri().path();
if is_public_endpoint(path) {
return Ok(next.run(request).await);
}
// Extract session ID
let session_id = extract_session_id(&cookies, request.headers());
if let Some(session_id) = session_id {
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
// Add session to request extensions
request.extensions_mut().insert(session);
return Ok(next.run(request).await);
}
}
Err(StatusCode::UNAUTHORIZED)
}
/// Check if endpoint is public (no auth required)
fn is_public_endpoint(path: &str) -> bool {
// Note: paths here are relative to /api since middleware is applied before nest
matches!(
path,
"/"
| "/auth/login"
| "/info"
| "/health"
| "/setup"
| "/setup/init"
// Also check with /api prefix for direct access
| "/api/auth/login"
| "/api/info"
| "/api/health"
| "/api/setup"
| "/api/setup/init"
) || path.starts_with("/assets/")
|| path.starts_with("/static/")
|| path.ends_with(".js")
|| path.ends_with(".css")
|| path.ends_with(".ico")
|| path.ends_with(".png")
|| path.ends_with(".svg")
}
/// Require authentication - returns 401 if not authenticated
pub async fn require_auth(
State(state): State<Arc<AppState>>,
cookies: CookieJar,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let session_id = extract_session_id(&cookies, request.headers());
if let Some(session_id) = session_id {
if let Ok(Some(_session)) = state.sessions.get(&session_id).await {
return Ok(next.run(request).await);
}
}
Err(StatusCode::UNAUTHORIZED)
}
/// Require admin privileges - returns 403 if not admin
pub async fn require_admin(
State(state): State<Arc<AppState>>,
cookies: CookieJar,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let session_id = extract_session_id(&cookies, request.headers());
if let Some(session_id) = session_id {
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
// Get user and check admin status
if let Ok(Some(user)) = state.users.get(&session.user_id).await {
if user.is_admin {
return Ok(next.run(request).await);
}
// User is authenticated but not admin
return Err(StatusCode::FORBIDDEN);
}
}
}
// Not authenticated at all
Err(StatusCode::UNAUTHORIZED)
}

9
src/auth/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
mod password;
mod session;
mod user;
pub mod middleware;
pub use password::{hash_password, verify_password};
pub use session::{Session, SessionStore};
pub use user::{User, UserStore};
pub use middleware::{AuthLayer, SESSION_COOKIE, auth_middleware, require_admin};

41
src/auth/password.rs Normal file
View File

@@ -0,0 +1,41 @@
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2,
};
use crate::error::{AppError, Result};
/// Hash a password using Argon2
pub fn hash_password(password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
argon2
.hash_password(password.as_bytes(), &salt)
.map(|hash| hash.to_string())
.map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e)))
}
/// Verify a password against a hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?;
Ok(Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_password_hash_verify() {
let password = "test_password_123";
let hash = hash_password(password).unwrap();
assert!(verify_password(password, &hash).unwrap());
assert!(!verify_password("wrong_password", &hash).unwrap());
}
}

129
src/auth/session.rs Normal file
View File

@@ -0,0 +1,129 @@
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use uuid::Uuid;
use crate::error::Result;
/// Session data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user_id: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub data: Option<serde_json::Value>,
}
impl Session {
/// Check if session is expired
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
}
/// Session store backed by SQLite
#[derive(Clone)]
pub struct SessionStore {
pool: Pool<Sqlite>,
default_ttl: Duration,
}
impl SessionStore {
/// Create a new session store
pub fn new(pool: Pool<Sqlite>, ttl_secs: i64) -> Self {
Self {
pool,
default_ttl: Duration::seconds(ttl_secs),
}
}
/// Create a new session
pub async fn create(&self, user_id: &str) -> Result<Session> {
let session = Session {
id: Uuid::new_v4().to_string(),
user_id: user_id.to_string(),
created_at: Utc::now(),
expires_at: Utc::now() + self.default_ttl,
data: None,
};
sqlx::query(
r#"
INSERT INTO sessions (id, user_id, created_at, expires_at, data)
VALUES (?1, ?2, ?3, ?4, ?5)
"#,
)
.bind(&session.id)
.bind(&session.user_id)
.bind(session.created_at.to_rfc3339())
.bind(session.expires_at.to_rfc3339())
.bind(session.data.as_ref().map(|d| d.to_string()))
.execute(&self.pool)
.await?;
Ok(session)
}
/// Get a session by ID
pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
let row: Option<(String, String, String, String, Option<String>)> = sqlx::query_as(
"SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
match row {
Some((id, user_id, created_at, expires_at, data)) => {
let session = Session {
id,
user_id,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
expires_at: DateTime::parse_from_rfc3339(&expires_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
data: data.and_then(|d| serde_json::from_str(&d).ok()),
};
if session.is_expired() {
self.delete(&session.id).await?;
Ok(None)
} else {
Ok(Some(session))
}
}
None => Ok(None),
}
}
/// Delete a session
pub async fn delete(&self, session_id: &str) -> Result<()> {
sqlx::query("DELETE FROM sessions WHERE id = ?1")
.bind(session_id)
.execute(&self.pool)
.await?;
Ok(())
}
/// Delete all expired sessions
pub async fn cleanup_expired(&self) -> Result<u64> {
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < datetime('now')")
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
/// Extend session expiration
pub async fn extend(&self, session_id: &str) -> Result<()> {
let new_expires = Utc::now() + self.default_ttl;
sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2")
.bind(new_expires.to_rfc3339())
.bind(session_id)
.execute(&self.pool)
.await?;
Ok(())
}
}

185
src/auth/user.rs Normal file
View File

@@ -0,0 +1,185 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use uuid::Uuid;
use crate::error::{AppError, Result};
use super::password::{hash_password, verify_password};
/// User row type from database
type UserRow = (String, String, String, i32, String, String);
/// User data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
#[serde(skip_serializing)]
pub password_hash: String,
pub is_admin: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl User {
/// Convert from database row to User
fn from_row(row: UserRow) -> Self {
let (id, username, password_hash, is_admin, created_at, updated_at) = row;
Self {
id,
username,
password_hash,
is_admin: is_admin != 0,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
updated_at: DateTime::parse_from_rfc3339(&updated_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
}
}
}
/// User store backed by SQLite
#[derive(Clone)]
pub struct UserStore {
pool: Pool<Sqlite>,
}
impl UserStore {
/// Create a new user store
pub fn new(pool: Pool<Sqlite>) -> Self {
Self { pool }
}
/// Create a new user
pub async fn create(&self, username: &str, password: &str, is_admin: bool) -> Result<User> {
// Check if username already exists
if self.get_by_username(username).await?.is_some() {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
username
)));
}
let password_hash = hash_password(password)?;
let now = Utc::now();
let user = User {
id: Uuid::new_v4().to_string(),
username: username.to_string(),
password_hash,
is_admin,
created_at: now,
updated_at: now,
};
sqlx::query(
r#"
INSERT INTO users (id, username, password_hash, is_admin, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
"#,
)
.bind(&user.id)
.bind(&user.username)
.bind(&user.password_hash)
.bind(user.is_admin as i32)
.bind(user.created_at.to_rfc3339())
.bind(user.updated_at.to_rfc3339())
.execute(&self.pool)
.await?;
Ok(user)
}
/// Get user by ID
pub async fn get(&self, user_id: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users WHERE id = ?1",
)
.bind(user_id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Get user by username
pub async fn get_by_username(&self, username: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users WHERE username = ?1",
)
.bind(username)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Verify user credentials
pub async fn verify(&self, username: &str, password: &str) -> Result<Option<User>> {
let user = match self.get_by_username(username).await? {
Some(user) => user,
None => return Ok(None),
};
if verify_password(password, &user.password_hash)? {
Ok(Some(user))
} else {
Ok(None)
}
}
/// Update user password
pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> {
let password_hash = hash_password(new_password)?;
let now = Utc::now();
let result = sqlx::query(
"UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3",
)
.bind(&password_hash)
.bind(now.to_rfc3339())
.bind(user_id)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("User not found".to_string()));
}
Ok(())
}
/// List all users
pub async fn list(&self) -> Result<Vec<User>> {
let rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users ORDER BY created_at",
)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(User::from_row).collect())
}
/// Delete user by ID
pub async fn delete(&self, user_id: &str) -> Result<()> {
let result = sqlx::query("DELETE FROM users WHERE id = ?1")
.bind(user_id)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("User not found".to_string()));
}
Ok(())
}
/// Check if any users exist
pub async fn has_users(&self) -> Result<bool> {
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
.fetch_one(&self.pool)
.await?;
Ok(count.0 > 0)
}
}

5
src/config/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
mod schema;
mod store;
pub use schema::*;
pub use store::ConfigStore;

416
src/config/schema.rs Normal file
View File

@@ -0,0 +1,416 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
// Re-export ExtensionsConfig from extensions module
pub use crate::extensions::ExtensionsConfig;
/// Main application configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AppConfig {
/// Whether initial setup has been completed
pub initialized: bool,
/// Authentication settings
pub auth: AuthConfig,
/// Video capture settings
pub video: VideoConfig,
/// HID (keyboard/mouse) settings
pub hid: HidConfig,
/// Mass Storage Device settings
pub msd: MsdConfig,
/// ATX power control settings
pub atx: AtxConfig,
/// Audio settings
pub audio: AudioConfig,
/// Streaming settings
pub stream: StreamConfig,
/// Web server settings
pub web: WebConfig,
/// Extensions settings (ttyd, gostc, easytier)
pub extensions: ExtensionsConfig,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
initialized: false,
auth: AuthConfig::default(),
video: VideoConfig::default(),
hid: HidConfig::default(),
msd: MsdConfig::default(),
atx: AtxConfig::default(),
audio: AudioConfig::default(),
stream: StreamConfig::default(),
web: WebConfig::default(),
extensions: ExtensionsConfig::default(),
}
}
}
/// Authentication configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AuthConfig {
/// Session timeout in seconds
pub session_timeout_secs: u32,
/// Enable 2FA
pub totp_enabled: bool,
/// TOTP secret (encrypted)
pub totp_secret: Option<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
session_timeout_secs: 3600 * 24, // 24 hours
totp_enabled: false,
totp_secret: None,
}
}
}
/// Video capture configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct VideoConfig {
/// Video device path (e.g., /dev/video0)
pub device: Option<String>,
/// Video pixel format (e.g., "MJPEG", "YUYV", "NV12")
pub format: Option<String>,
/// Resolution width
pub width: u32,
/// Resolution height
pub height: u32,
/// Frame rate
pub fps: u32,
/// JPEG quality (1-100)
pub quality: u32,
}
impl Default for VideoConfig {
fn default() -> Self {
Self {
device: None,
format: None, // Auto-detect or use MJPEG as default
width: 1920,
height: 1080,
fps: 30,
quality: 80,
}
}
}
/// HID backend type
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum HidBackend {
/// USB OTG HID gadget
Otg,
/// CH9329 serial HID controller
Ch9329,
/// Disabled
None,
}
impl Default for HidBackend {
fn default() -> Self {
Self::None
}
}
/// HID configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct HidConfig {
/// HID backend type
pub backend: HidBackend,
/// OTG keyboard device path
pub otg_keyboard: String,
/// OTG mouse device path
pub otg_mouse: String,
/// OTG UDC (USB Device Controller) name
pub otg_udc: Option<String>,
/// CH9329 serial port
pub ch9329_port: String,
/// CH9329 baud rate
pub ch9329_baudrate: u32,
/// Mouse mode: absolute or relative
pub mouse_absolute: bool,
}
impl Default for HidConfig {
fn default() -> Self {
Self {
backend: HidBackend::None,
otg_keyboard: "/dev/hidg0".to_string(),
otg_mouse: "/dev/hidg1".to_string(),
otg_udc: None,
ch9329_port: "/dev/ttyUSB0".to_string(),
ch9329_baudrate: 9600,
mouse_absolute: true,
}
}
}
/// MSD configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct MsdConfig {
/// Enable MSD functionality
pub enabled: bool,
/// Storage path for ISO/IMG images
pub images_path: String,
/// Path for Ventoy bootable drive file
pub drive_path: String,
/// Ventoy drive size in MB (minimum 1024 MB / 1 GB)
pub virtual_drive_size_mb: u32,
}
impl Default for MsdConfig {
fn default() -> Self {
Self {
enabled: true,
images_path: "./data/msd/images".to_string(),
drive_path: "./data/msd/ventoy.img".to_string(),
virtual_drive_size_mb: 16 * 1024, // 16GB default
}
}
}
// Re-export ATX types from atx module for configuration
pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig};
/// ATX power control configuration
///
/// Each ATX action (power, reset) can be independently configured with its own
/// hardware binding using the four-tuple: (driver, device, pin, active_level).
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AtxConfig {
/// Enable ATX functionality
pub enabled: bool,
/// Power button configuration (used for both short and long press)
pub power: AtxKeyConfig,
/// Reset button configuration
pub reset: AtxKeyConfig,
/// LED sensing configuration (optional)
pub led: AtxLedConfig,
/// Network interface for WOL packets (empty = auto)
pub wol_interface: String,
}
impl Default for AtxConfig {
fn default() -> Self {
Self {
enabled: false,
power: AtxKeyConfig::default(),
reset: AtxKeyConfig::default(),
led: AtxLedConfig::default(),
wol_interface: String::new(),
}
}
}
impl AtxConfig {
/// Convert to AtxControllerConfig for the controller
pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig {
crate::atx::AtxControllerConfig {
enabled: self.enabled,
power: self.power.clone(),
reset: self.reset.clone(),
led: self.led.clone(),
}
}
}
/// Audio configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AudioConfig {
/// Enable audio capture
pub enabled: bool,
/// ALSA device name
pub device: String,
/// Audio quality preset: "voice", "balanced", "high"
pub quality: String,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
quality: "balanced".to_string(),
}
}
}
/// Stream mode
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum StreamMode {
/// WebRTC with H264/H265
WebRTC,
/// MJPEG over HTTP
Mjpeg,
}
impl Default for StreamMode {
fn default() -> Self {
Self::Mjpeg
}
}
/// Encoder type
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum EncoderType {
/// Auto-detect best encoder
Auto,
/// Software encoder (libx264)
Software,
/// VAAPI hardware encoder
Vaapi,
/// NVIDIA NVENC hardware encoder
Nvenc,
/// Intel Quick Sync hardware encoder
Qsv,
/// AMD AMF hardware encoder
Amf,
/// Rockchip MPP hardware encoder
Rkmpp,
/// V4L2 M2M hardware encoder
V4l2m2m,
}
impl Default for EncoderType {
fn default() -> Self {
Self::Auto
}
}
impl EncoderType {
/// Convert to EncoderBackend for registry queries
pub fn to_backend(&self) -> Option<crate::video::encoder::registry::EncoderBackend> {
use crate::video::encoder::registry::EncoderBackend;
match self {
EncoderType::Auto => None,
EncoderType::Software => Some(EncoderBackend::Software),
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
EncoderType::Qsv => Some(EncoderBackend::Qsv),
EncoderType::Amf => Some(EncoderBackend::Amf),
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
}
}
/// Get display name for UI
pub fn display_name(&self) -> &'static str {
match self {
EncoderType::Auto => "Auto (Recommended)",
EncoderType::Software => "Software (CPU)",
EncoderType::Vaapi => "VAAPI",
EncoderType::Nvenc => "NVIDIA NVENC",
EncoderType::Qsv => "Intel Quick Sync",
EncoderType::Amf => "AMD AMF",
EncoderType::Rkmpp => "Rockchip MPP",
EncoderType::V4l2m2m => "V4L2 M2M",
}
}
}
/// Streaming configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
/// Stream mode
pub mode: StreamMode,
/// Encoder type for H264/H265
pub encoder: EncoderType,
/// Target bitrate in kbps (for H264/H265)
pub bitrate_kbps: u32,
/// GOP size
pub gop_size: u32,
/// Custom STUN server (e.g., "stun:stun.l.google.com:19302")
pub stun_server: Option<String>,
/// Custom TURN server (e.g., "turn:turn.example.com:3478")
pub turn_server: Option<String>,
/// TURN username
pub turn_username: Option<String>,
/// TURN password (stored encrypted in DB, not exposed via API)
pub turn_password: Option<String>,
/// Auto-pause when no clients connected
#[typeshare(skip)]
pub auto_pause_enabled: bool,
/// Auto-pause delay (seconds)
#[typeshare(skip)]
pub auto_pause_delay_secs: u64,
/// Client timeout for cleanup (seconds)
#[typeshare(skip)]
pub client_timeout_secs: u64,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
mode: StreamMode::Mjpeg,
encoder: EncoderType::Auto,
bitrate_kbps: 8000,
gop_size: 30,
stun_server: Some("stun:stun.l.google.com:19302".to_string()),
turn_server: None,
turn_username: None,
turn_password: None,
auto_pause_enabled: false,
auto_pause_delay_secs: 10,
client_timeout_secs: 30,
}
}
}
/// Web server configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WebConfig {
/// HTTP port
pub http_port: u16,
/// HTTPS port
pub https_port: u16,
/// Bind address
pub bind_address: String,
/// Enable HTTPS
pub https_enabled: bool,
/// Custom SSL certificate path
pub ssl_cert_path: Option<String>,
/// Custom SSL key path
pub ssl_key_path: Option<String>,
}
impl Default for WebConfig {
fn default() -> Self {
Self {
http_port: 8080,
https_port: 8443,
bind_address: "0.0.0.0".to_string(),
https_enabled: false,
ssl_cert_path: None,
ssl_key_path: None,
}
}
}

264
src/config/store.rs Normal file
View File

@@ -0,0 +1,264 @@
use arc_swap::ArcSwap;
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use super::AppConfig;
use crate::error::{AppError, Result};
/// Configuration store backed by SQLite
///
/// Uses `ArcSwap` for lock-free reads, providing high performance
/// for frequent configuration access in hot paths.
#[derive(Clone)]
pub struct ConfigStore {
pool: Pool<Sqlite>,
/// Lock-free cache using ArcSwap for zero-cost reads
cache: Arc<ArcSwap<AppConfig>>,
change_tx: broadcast::Sender<ConfigChange>,
}
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}
impl ConfigStore {
/// Create a new configuration store
pub async fn new(db_path: &Path) -> Result<Self> {
// Ensure parent directory exists
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
// SQLite uses single-writer mode, 2 connections is sufficient for embedded devices
// One for reads, one for writes to avoid blocking
.max_connections(2)
// Set reasonable timeouts for embedded environments
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
// Initialize database schema
Self::init_schema(&pool).await?;
// Load or create default config
let config = Self::load_config(&pool).await?;
let cache = Arc::new(ArcSwap::from_pointee(config));
let (change_tx, _) = broadcast::channel(16);
Ok(Self {
pool,
cache,
change_tx,
})
}
/// Initialize database schema
async fn init_schema(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
is_admin INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
expires_at TEXT NOT NULL,
data TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(pool)
.await?;
Ok(())
}
/// Load configuration from database
async fn load_config(pool: &Pool<Sqlite>) -> Result<AppConfig> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT value FROM config WHERE key = 'app_config'"
)
.fetch_optional(pool)
.await?;
match row {
Some((json,)) => {
serde_json::from_str(&json).map_err(|e| AppError::Config(e.to_string()))
}
None => {
// Create default config
let config = AppConfig::default();
Self::save_config_to_db(pool, &config).await?;
Ok(config)
}
}
}
/// Save configuration to database
async fn save_config_to_db(pool: &Pool<Sqlite>, config: &AppConfig) -> Result<()> {
let json = serde_json::to_string(config)?;
sqlx::query(
r#"
INSERT INTO config (key, value, updated_at)
VALUES ('app_config', ?1, datetime('now'))
ON CONFLICT(key) DO UPDATE SET value = ?1, updated_at = datetime('now')
"#,
)
.bind(&json)
.execute(pool)
.await?;
Ok(())
}
/// Get current configuration (lock-free, zero-copy)
///
/// Returns an `Arc<AppConfig>` for efficient sharing without cloning.
/// This is a lock-free operation with minimal overhead.
pub fn get(&self) -> Arc<AppConfig> {
self.cache.load_full()
}
/// Set entire configuration
pub async fn set(&self, config: AppConfig) -> Result<()> {
Self::save_config_to_db(&self.pool, &config).await?;
self.cache.store(Arc::new(config));
// Notify subscribers
let _ = self.change_tx.send(ConfigChange {
key: "app_config".to_string(),
});
Ok(())
}
/// Update configuration with a closure
///
/// Note: This uses a read-modify-write pattern. For concurrent updates,
/// the last write wins. This is acceptable for configuration changes
/// which are infrequent and typically user-initiated.
pub async fn update<F>(&self, f: F) -> Result<()>
where
F: FnOnce(&mut AppConfig),
{
// Load current config, clone it for modification
let current = self.cache.load();
let mut config = (**current).clone();
f(&mut config);
// Persist to database first
Self::save_config_to_db(&self.pool, &config).await?;
// Then update cache atomically
self.cache.store(Arc::new(config));
// Notify subscribers
let _ = self.change_tx.send(ConfigChange {
key: "app_config".to_string(),
});
Ok(())
}
/// Subscribe to configuration changes
pub fn subscribe(&self) -> broadcast::Receiver<ConfigChange> {
self.change_tx.subscribe()
}
/// Check if system is initialized (lock-free)
pub fn is_initialized(&self) -> bool {
self.cache.load().initialized
}
/// Get database pool for session management
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn test_config_store() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let store = ConfigStore::new(&db_path).await.unwrap();
// Check default config (now lock-free, no await needed)
let config = store.get();
assert!(!config.initialized);
// Update config
store.update(|c| {
c.initialized = true;
c.web.http_port = 9000;
}).await.unwrap();
// Verify update
let config = store.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);
// Create new store instance and verify persistence
let store2 = ConfigStore::new(&db_path).await.unwrap();
let config = store2.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);
}
}

98
src/error.rs Normal file
View File

@@ -0,0 +1,98 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use thiserror::Error;
/// Application-wide error type
#[derive(Error, Debug)]
pub enum AppError {
#[error("Authentication failed: {0}")]
AuthError(String),
#[error("Not authenticated")]
Unauthorized,
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Internal error: {0}")]
Internal(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Video error: {0}")]
VideoError(String),
#[error("Video device lost [{device}]: {reason}")]
VideoDeviceLost { device: String, reason: String },
#[error("Audio error: {0}")]
AudioError(String),
#[error("HID error [{backend}]: {reason} (code: {error_code})")]
HidError {
backend: String,
reason: String,
error_code: String,
},
#[error("WebRTC error: {0}")]
WebRtcError(String),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
}
/// Error response body (unified success format)
#[derive(Serialize)]
pub struct ErrorResponse {
pub success: bool,
pub message: String,
}
impl AppError {
fn status_code(&self) -> StatusCode {
// Always return 200 OK - success/failure is indicated by the success field
StatusCode::OK
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let status = self.status_code();
let body = ErrorResponse {
success: false,
message: self.to_string(),
};
tracing::error!(
error_type = std::any::type_name_of_val(&self),
error_message = %body.message,
"Request failed"
);
(status, Json(body)).into_response()
}
}
/// Result type alias for handlers
pub type Result<T> = std::result::Result<T, AppError>;

137
src/events/mod.rs Normal file
View File

@@ -0,0 +1,137 @@
//! Event system for real-time state notifications
//!
//! This module provides a global event bus for broadcasting system events
//! to WebSocket clients and other subscribers.
pub mod types;
pub use types::{
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent, VideoDeviceInfo,
};
use tokio::sync::broadcast;
/// Event channel capacity (ring buffer size)
const EVENT_CHANNEL_CAPACITY: usize = 256;
/// Global event bus for broadcasting system events
///
/// The event bus uses tokio's broadcast channel to distribute events
/// to multiple subscribers. Events are delivered to all active subscribers.
///
/// # Example
///
/// ```no_run
/// use one_kvm::events::{EventBus, SystemEvent};
///
/// let bus = EventBus::new();
///
/// // Publish an event
/// bus.publish(SystemEvent::StreamStateChanged {
/// state: "streaming".to_string(),
/// device: Some("/dev/video0".to_string()),
/// });
///
/// // Subscribe to events
/// let mut rx = bus.subscribe();
/// tokio::spawn(async move {
/// while let Ok(event) = rx.recv().await {
/// println!("Received event: {:?}", event);
/// }
/// });
/// ```
pub struct EventBus {
tx: broadcast::Sender<SystemEvent>,
}
impl EventBus {
/// Create a new event bus
pub fn new() -> Self {
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
Self { tx }
}
/// Publish an event to all subscribers
///
/// If there are no active subscribers, the event is silently dropped.
/// This is by design - events are fire-and-forget notifications.
pub fn publish(&self, event: SystemEvent) {
// If no subscribers, send returns Err which is normal
let _ = self.tx.send(event);
}
/// Subscribe to events
///
/// Returns a receiver that will receive all future events.
/// The receiver uses a ring buffer, so if a subscriber falls too far
/// behind, it will receive a `Lagged` error and miss some events.
pub fn subscribe(&self) -> broadcast::Receiver<SystemEvent> {
self.tx.subscribe()
}
/// Get the current number of active subscribers
///
/// Useful for monitoring and debugging.
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_publish_subscribe() {
let bus = EventBus::new();
let mut rx = bus.subscribe();
bus.publish(SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
});
let event = rx.recv().await.unwrap();
assert!(matches!(event, SystemEvent::StreamStateChanged { .. }));
}
#[tokio::test]
async fn test_multiple_subscribers() {
let bus = EventBus::new();
let mut rx1 = bus.subscribe();
let mut rx2 = bus.subscribe();
assert_eq!(bus.subscriber_count(), 2);
bus.publish(SystemEvent::SystemError {
module: "test".to_string(),
severity: "info".to_string(),
message: "test message".to_string(),
});
let event1 = rx1.recv().await.unwrap();
let event2 = rx2.recv().await.unwrap();
assert!(matches!(event1, SystemEvent::SystemError { .. }));
assert!(matches!(event2, SystemEvent::SystemError { .. }));
}
#[test]
fn test_no_subscribers() {
let bus = EventBus::new();
assert_eq!(bus.subscriber_count(), 0);
// Should not panic when publishing with no subscribers
bus.publish(SystemEvent::SystemError {
module: "test".to_string(),
severity: "info".to_string(),
message: "test".to_string(),
});
}
}

592
src/events/types.rs Normal file
View File

@@ -0,0 +1,592 @@
//! System event types
//!
//! Defines all event types that can be broadcast through the event bus.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::atx::PowerStatus;
use crate::msd::MsdMode;
// ============================================================================
// Device Info Structures (for system.device_info event)
// ============================================================================
/// Video device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoDeviceInfo {
/// Whether video device is available
pub available: bool,
/// Device path (e.g., /dev/video0)
pub device: Option<String>,
/// Pixel format (e.g., "MJPEG", "YUYV")
pub format: Option<String>,
/// Resolution (width, height)
pub resolution: Option<(u32, u32)>,
/// Frames per second
pub fps: u32,
/// Whether stream is currently active
pub online: bool,
/// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
pub stream_mode: String,
/// Whether video config is currently being changed (frontend should skip mode sync)
pub config_changing: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// HID device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HidDeviceInfo {
/// Whether HID backend is available
pub available: bool,
/// Backend type: "otg", "ch9329", "none"
pub backend: String,
/// Whether backend is initialized and ready
pub initialized: bool,
/// Whether absolute mouse positioning is supported
pub supports_absolute_mouse: bool,
/// Device path (e.g., serial port for CH9329)
pub device: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// MSD device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdDeviceInfo {
/// Whether MSD is available
pub available: bool,
/// Operating mode: "none", "image", "drive"
pub mode: String,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image ID
pub image_id: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// ATX device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDeviceInfo {
/// Whether ATX controller is available
pub available: bool,
/// Backend type: "gpio", "usb_relay", "none"
pub backend: String,
/// Whether backend is initialized
pub initialized: bool,
/// Whether power is currently on
pub power_on: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// Audio device information
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioDeviceInfo {
/// Whether audio is enabled/available
pub available: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Current audio device name
pub device: Option<String>,
/// Quality preset: "voice", "balanced", "high"
pub quality: String,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// Per-client statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientStats {
/// Client ID
pub id: String,
/// Current FPS for this client (frames sent in last second)
pub fps: u32,
/// Connected duration (seconds)
pub connected_secs: u64,
}
/// System event enumeration
///
/// All events are tagged with their event name for serialization.
/// The `serde(tag = "event", content = "data")` attribute creates a
/// JSON structure like:
/// ```json
/// {
/// "event": "stream.state_changed",
/// "data": { "state": "streaming", "device": "/dev/video0" }
/// }
/// ```
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event", content = "data")]
pub enum SystemEvent {
// ============================================================================
// Video Stream Events
// ============================================================================
/// Stream state changed (e.g., started, stopped, error)
#[serde(rename = "stream.state_changed")]
StreamStateChanged {
/// Current state: "uninitialized", "ready", "streaming", "no_signal", "error"
state: String,
/// Device path if available
device: Option<String>,
},
/// Stream configuration is being changed
///
/// Sent before applying new configuration to notify clients that
/// the stream will be interrupted temporarily.
#[serde(rename = "stream.config_changing")]
StreamConfigChanging {
/// Reason for change: "device_switch", "resolution_change", "format_change"
reason: String,
},
/// Stream configuration has been applied successfully
///
/// Sent after new configuration is active. Clients can reconnect now.
#[serde(rename = "stream.config_applied")]
StreamConfigApplied {
/// Device path
device: String,
/// Resolution (width, height)
resolution: (u32, u32),
/// Pixel format: "mjpeg", "yuyv", etc.
format: String,
/// Frames per second
fps: u32,
},
/// Stream device was lost (disconnected or error)
#[serde(rename = "stream.device_lost")]
StreamDeviceLost {
/// Device path that was lost
device: String,
/// Reason for loss
reason: String,
},
/// Stream device is reconnecting
#[serde(rename = "stream.reconnecting")]
StreamReconnecting {
/// Device path being reconnected
device: String,
/// Retry attempt number
attempt: u32,
},
/// Stream device has recovered
#[serde(rename = "stream.recovered")]
StreamRecovered {
/// Device path that was recovered
device: String,
},
/// Stream statistics update (sent periodically for client stats)
#[serde(rename = "stream.stats_update")]
StreamStatsUpdate {
/// Number of connected clients
clients: u64,
/// Per-client statistics (client_id -> client stats)
/// Each client's FPS reflects the actual frames sent in the last second
clients_stat: HashMap<String, ClientStats>,
},
/// Stream mode changed (MJPEG <-> WebRTC)
///
/// Sent when the streaming mode is switched. Clients should disconnect
/// from the current stream and reconnect using the new mode.
#[serde(rename = "stream.mode_changed")]
StreamModeChanged {
/// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
previous_mode: String,
},
// ============================================================================
// HID Events
// ============================================================================
/// HID backend state changed
#[serde(rename = "hid.state_changed")]
HidStateChanged {
/// Backend type: "otg", "ch9329", "none"
backend: String,
/// Whether backend is initialized and ready
initialized: bool,
/// Error message if any, None if OK
error: Option<String>,
/// Error code for programmatic handling: "epipe", "eagain", "port_not_found", etc.
error_code: Option<String>,
},
/// HID backend is being switched
#[serde(rename = "hid.backend_switching")]
HidBackendSwitching {
/// Current backend
from: String,
/// New backend
to: String,
},
/// HID device lost (device file missing or I/O error)
#[serde(rename = "hid.device_lost")]
HidDeviceLost {
/// Backend type: "otg", "ch9329"
backend: String,
/// Device path that was lost (e.g., /dev/hidg0 or /dev/ttyUSB0)
device: Option<String>,
/// Human-readable reason for loss
reason: String,
/// Error code: "epipe", "eshutdown", "eagain", "enxio", "port_not_found", "io_error"
error_code: String,
},
/// HID device is reconnecting
#[serde(rename = "hid.reconnecting")]
HidReconnecting {
/// Backend type: "otg", "ch9329"
backend: String,
/// Current retry attempt number
attempt: u32,
},
/// HID device has recovered after error
#[serde(rename = "hid.recovered")]
HidRecovered {
/// Backend type: "otg", "ch9329"
backend: String,
},
// ============================================================================
// MSD (Mass Storage Device) Events
// ============================================================================
/// MSD state changed
#[serde(rename = "msd.state_changed")]
MsdStateChanged {
/// Operating mode
mode: MsdMode,
/// Whether storage is connected to target
connected: bool,
},
/// Image has been mounted
#[serde(rename = "msd.image_mounted")]
MsdImageMounted {
/// Image ID
image_id: String,
/// Image filename
image_name: String,
/// Image size in bytes
size: u64,
/// Mount as CD-ROM (read-only)
cdrom: bool,
},
/// Image has been unmounted
#[serde(rename = "msd.image_unmounted")]
MsdImageUnmounted,
/// File upload progress (for large file uploads)
#[serde(rename = "msd.upload_progress")]
MsdUploadProgress {
/// Upload operation ID
upload_id: String,
/// Filename being uploaded
filename: String,
/// Bytes uploaded so far
bytes_uploaded: u64,
/// Total file size
total_bytes: u64,
/// Progress percentage (0.0 - 100.0)
progress_pct: f32,
},
/// Image download progress (for URL downloads)
#[serde(rename = "msd.download_progress")]
MsdDownloadProgress {
/// Download operation ID
download_id: String,
/// Source URL
url: String,
/// Target filename
filename: String,
/// Bytes downloaded so far
bytes_downloaded: u64,
/// Total file size (None if unknown)
total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
progress_pct: Option<f32>,
/// Download status: "started", "in_progress", "completed", "failed"
status: String,
},
/// USB gadget connection status changed (host connected/disconnected)
#[serde(rename = "msd.usb_status_changed")]
MsdUsbStatusChanged {
/// Whether host is connected to USB device
connected: bool,
/// USB device state from kernel (e.g., "configured", "not attached")
device_state: String,
},
/// MSD operation error (configfs, image mount, etc.)
#[serde(rename = "msd.error")]
MsdError {
/// Human-readable reason for error
reason: String,
/// Error code: "configfs_error", "image_not_found", "mount_failed", "io_error"
error_code: String,
},
/// MSD has recovered after error
#[serde(rename = "msd.recovered")]
MsdRecovered,
// ============================================================================
// ATX (Power Control) Events
// ============================================================================
/// ATX power state changed
#[serde(rename = "atx.state_changed")]
AtxStateChanged {
/// Power status
power_status: PowerStatus,
},
/// ATX action was executed
#[serde(rename = "atx.action_executed")]
AtxActionExecuted {
/// Action: "short", "long", "reset"
action: String,
/// When the action was executed
timestamp: DateTime<Utc>,
},
// ============================================================================
// Audio Events
// ============================================================================
/// Audio state changed (streaming started/stopped)
#[serde(rename = "audio.state_changed")]
AudioStateChanged {
/// Whether audio is currently streaming
streaming: bool,
/// Current device (None if stopped)
device: Option<String>,
},
/// Audio device was selected
#[serde(rename = "audio.device_selected")]
AudioDeviceSelected {
/// Selected device name
device: String,
},
/// Audio quality was changed
#[serde(rename = "audio.quality_changed")]
AudioQualityChanged {
/// New quality setting: "voice", "balanced", "high"
quality: String,
},
/// Audio device lost (capture error or device disconnected)
#[serde(rename = "audio.device_lost")]
AudioDeviceLost {
/// Audio device name (e.g., "hw:0,0")
device: Option<String>,
/// Human-readable reason for loss
reason: String,
/// Error code: "device_busy", "device_disconnected", "capture_error", "io_error"
error_code: String,
},
/// Audio device is reconnecting
#[serde(rename = "audio.reconnecting")]
AudioReconnecting {
/// Current retry attempt number
attempt: u32,
},
/// Audio device has recovered after error
#[serde(rename = "audio.recovered")]
AudioRecovered {
/// Audio device name
device: Option<String>,
},
// ============================================================================
// System Events
// ============================================================================
/// A device was added (hot-plug)
#[serde(rename = "system.device_added")]
SystemDeviceAdded {
/// Device type: "video", "audio", "hid", etc.
device_type: String,
/// Device path
device_path: String,
/// Device name/description
device_name: String,
},
/// A device was removed (hot-unplug)
#[serde(rename = "system.device_removed")]
SystemDeviceRemoved {
/// Device type
device_type: String,
/// Device path that was removed
device_path: String,
},
/// System error or warning
#[serde(rename = "system.error")]
SystemError {
/// Module that generated the error: "stream", "hid", "msd", "atx"
module: String,
/// Severity: "warning", "error", "critical"
severity: String,
/// Error message
message: String,
},
/// Complete device information (sent on WebSocket connect and state changes)
#[serde(rename = "system.device_info")]
DeviceInfo {
/// Video device information
video: VideoDeviceInfo,
/// HID device information
hid: HidDeviceInfo,
/// MSD device information (None if MSD not enabled)
msd: Option<MsdDeviceInfo>,
/// ATX device information (None if ATX not enabled)
atx: Option<AtxDeviceInfo>,
/// Audio device information (None if audio not enabled)
audio: Option<AudioDeviceInfo>,
},
/// WebSocket error notification (for connection-level errors like lag)
#[serde(rename = "error")]
Error {
/// Error message
message: String,
},
}
impl SystemEvent {
/// Get the event name (for filtering/routing)
pub fn event_name(&self) -> &'static str {
match self {
Self::StreamStateChanged { .. } => "stream.state_changed",
Self::StreamConfigChanging { .. } => "stream.config_changing",
Self::StreamConfigApplied { .. } => "stream.config_applied",
Self::StreamDeviceLost { .. } => "stream.device_lost",
Self::StreamReconnecting { .. } => "stream.reconnecting",
Self::StreamRecovered { .. } => "stream.recovered",
Self::StreamStatsUpdate { .. } => "stream.stats_update",
Self::StreamModeChanged { .. } => "stream.mode_changed",
Self::HidStateChanged { .. } => "hid.state_changed",
Self::HidBackendSwitching { .. } => "hid.backend_switching",
Self::HidDeviceLost { .. } => "hid.device_lost",
Self::HidReconnecting { .. } => "hid.reconnecting",
Self::HidRecovered { .. } => "hid.recovered",
Self::MsdStateChanged { .. } => "msd.state_changed",
Self::MsdImageMounted { .. } => "msd.image_mounted",
Self::MsdImageUnmounted => "msd.image_unmounted",
Self::MsdUploadProgress { .. } => "msd.upload_progress",
Self::MsdDownloadProgress { .. } => "msd.download_progress",
Self::MsdUsbStatusChanged { .. } => "msd.usb_status_changed",
Self::MsdError { .. } => "msd.error",
Self::MsdRecovered => "msd.recovered",
Self::AtxStateChanged { .. } => "atx.state_changed",
Self::AtxActionExecuted { .. } => "atx.action_executed",
Self::AudioStateChanged { .. } => "audio.state_changed",
Self::AudioDeviceSelected { .. } => "audio.device_selected",
Self::AudioQualityChanged { .. } => "audio.quality_changed",
Self::AudioDeviceLost { .. } => "audio.device_lost",
Self::AudioReconnecting { .. } => "audio.reconnecting",
Self::AudioRecovered { .. } => "audio.recovered",
Self::SystemDeviceAdded { .. } => "system.device_added",
Self::SystemDeviceRemoved { .. } => "system.device_removed",
Self::SystemError { .. } => "system.error",
Self::DeviceInfo { .. } => "system.device_info",
Self::Error { .. } => "error",
}
}
/// Check if event name matches a topic pattern
///
/// Supports wildcards:
/// - `*` matches all events
/// - `stream.*` matches all stream events
/// - `stream.state_changed` matches exact event
pub fn matches_topic(&self, topic: &str) -> bool {
if topic == "*" {
return true;
}
let event_name = self.event_name();
if topic.ends_with(".*") {
let prefix = topic.trim_end_matches(".*");
event_name.starts_with(prefix)
} else {
event_name == topic
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_name() {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
};
assert_eq!(event.event_name(), "stream.state_changed");
let event = SystemEvent::MsdImageMounted {
image_id: "123".to_string(),
image_name: "ubuntu.iso".to_string(),
size: 1024,
cdrom: true,
};
assert_eq!(event.event_name(), "msd.image_mounted");
}
#[test]
fn test_matches_topic() {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: None,
};
assert!(event.matches_topic("*"));
assert!(event.matches_topic("stream.*"));
assert!(event.matches_topic("stream.state_changed"));
assert!(!event.matches_topic("msd.*"));
assert!(!event.matches_topic("stream.config_changed"));
}
#[test]
fn test_serialization() {
let event = SystemEvent::StreamConfigApplied {
device: "/dev/video0".to_string(),
resolution: (1920, 1080),
format: "mjpeg".to_string(),
fps: 30,
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("stream.config_applied"));
assert!(json.contains("/dev/video0"));
let deserialized: SystemEvent = serde_json::from_str(&json).unwrap();
assert!(matches!(deserialized, SystemEvent::StreamConfigApplied { .. }));
}
}

425
src/extensions/manager.rs Normal file
View File

@@ -0,0 +1,425 @@
//! Extension process manager
use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::RwLock;
use super::types::*;
/// Maximum number of log lines to keep per extension
const LOG_BUFFER_SIZE: usize = 200;
/// Number of log lines to buffer before flushing to shared storage
const LOG_BATCH_SIZE: usize = 16;
/// Unix socket path for ttyd
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
/// Extension process with log buffer
struct ExtensionProcess {
child: Child,
logs: Arc<RwLock<VecDeque<String>>>,
}
/// Extension manager handles lifecycle of external processes
pub struct ExtensionManager {
processes: RwLock<HashMap<ExtensionId, ExtensionProcess>>,
/// Cached availability status (checked once at startup)
availability: HashMap<ExtensionId, bool>,
}
impl Default for ExtensionManager {
fn default() -> Self {
Self::new()
}
}
impl ExtensionManager {
/// Create a new extension manager with cached availability
pub fn new() -> Self {
// Check availability once at startup
let availability = ExtensionId::all()
.iter()
.map(|id| (*id, Path::new(id.binary_path()).exists()))
.collect();
Self {
processes: RwLock::new(HashMap::new()),
availability,
}
}
/// Check if the binary for an extension is available (cached)
pub fn check_available(&self, id: ExtensionId) -> bool {
*self.availability.get(&id).unwrap_or(&false)
}
/// Get the current status of an extension
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
if !self.check_available(id) {
return ExtensionStatus::Unavailable;
}
let processes = self.processes.read().await;
match processes.get(&id) {
Some(proc) => {
if let Some(pid) = proc.child.id() {
ExtensionStatus::Running { pid }
} else {
ExtensionStatus::Stopped
}
}
None => ExtensionStatus::Stopped,
}
}
/// Start an extension with the given configuration
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
if !self.check_available(id) {
return Err(format!(
"{} not found at {}",
id.display_name(),
id.binary_path()
));
}
// Stop existing process first
self.stop(id).await.ok();
// Build command arguments
let args = self.build_args(id, config).await?;
tracing::info!(
"Starting extension {}: {} {}",
id,
id.binary_path(),
args.join(" ")
);
let mut child = Command::new(id.binary_path())
.args(&args)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?;
let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE)));
// Spawn log collector for stdout
if let Some(stdout) = child.stdout.take() {
let logs_clone = logs.clone();
let id_clone = id;
tokio::spawn(async move {
Self::collect_logs(id_clone, stdout, logs_clone).await;
});
}
// Spawn log collector for stderr
if let Some(stderr) = child.stderr.take() {
let logs_clone = logs.clone();
let id_clone = id;
tokio::spawn(async move {
Self::collect_logs(id_clone, stderr, logs_clone).await;
});
}
let pid = child.id();
tracing::info!("Extension {} started with PID {:?}", id, pid);
let mut processes = self.processes.write().await;
processes.insert(id, ExtensionProcess { child, logs });
Ok(())
}
/// Stop an extension
pub async fn stop(&self, id: ExtensionId) -> Result<(), String> {
let mut processes = self.processes.write().await;
if let Some(mut proc) = processes.remove(&id) {
tracing::info!("Stopping extension {}", id);
if let Err(e) = proc.child.kill().await {
tracing::warn!("Failed to kill {}: {}", id, e);
}
}
Ok(())
}
/// Get recent logs for an extension
pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec<String> {
let processes = self.processes.read().await;
if let Some(proc) = processes.get(&id) {
let logs = proc.logs.read().await;
let start = logs.len().saturating_sub(lines);
logs.range(start..).cloned().collect()
} else {
Vec::new()
}
}
/// Collect logs from a stream with batched writes to reduce lock contention
async fn collect_logs<R: tokio::io::AsyncRead + Unpin>(
id: ExtensionId,
reader: R,
logs: Arc<RwLock<VecDeque<String>>>,
) {
let reader = BufReader::new(reader);
let mut lines = reader.lines();
let mut local_buffer = Vec::with_capacity(LOG_BATCH_SIZE);
loop {
match lines.next_line().await {
Ok(Some(line)) => {
tracing::debug!("[{}] {}", id, line);
local_buffer.push(line);
// Flush when batch is full
if local_buffer.len() >= LOG_BATCH_SIZE {
Self::flush_logs(&logs, &mut local_buffer).await;
}
}
Ok(None) => {
// Stream ended, flush remaining logs
if !local_buffer.is_empty() {
Self::flush_logs(&logs, &mut local_buffer).await;
}
break;
}
Err(e) => {
tracing::warn!("[{}] Error reading log: {}", id, e);
break;
}
}
}
}
/// Flush buffered logs to shared storage
async fn flush_logs(logs: &RwLock<VecDeque<String>>, buffer: &mut Vec<String>) {
let mut logs = logs.write().await;
for line in buffer.drain(..) {
if logs.len() >= LOG_BUFFER_SIZE {
logs.pop_front();
}
logs.push_back(line);
}
}
/// Build command arguments for an extension
async fn build_args(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<Vec<String>, String> {
match id {
ExtensionId::Ttyd => {
let c = &config.ttyd;
// Prepare socket directory and clean up old socket (async)
Self::prepare_ttyd_socket().await?;
let mut args = vec![
"-i".to_string(), TTYD_SOCKET_PATH.to_string(), // Unix socket
"-b".to_string(), "/api/terminal".to_string(), // Base path for reverse proxy
"-W".to_string(), // Writable (allow input)
];
// Add credential if set (still useful for additional security layer)
if let Some(ref cred) = c.credential {
if !cred.is_empty() {
args.extend(["-c".to_string(), cred.clone()]);
}
}
// Add shell as last argument
args.push(c.shell.clone());
Ok(args)
}
ExtensionId::Gostc => {
let c = &config.gostc;
if c.key.is_empty() {
return Err("GOSTC client key is required".into());
}
let mut args = Vec::new();
// Add TLS flag
if c.tls {
args.push("--tls=true".to_string());
}
// Add server address
if !c.addr.is_empty() {
args.extend(["-addr".to_string(), c.addr.clone()]);
}
// Add client key
args.extend(["-key".to_string(), c.key.clone()]);
Ok(args)
}
ExtensionId::Easytier => {
let c = &config.easytier;
if c.network_name.is_empty() {
return Err("EasyTier network name is required".into());
}
let mut args = vec![
"--network-name".to_string(),
c.network_name.clone(),
"--network-secret".to_string(),
c.network_secret.clone(),
];
// Add peer URLs
for peer in &c.peer_urls {
if !peer.is_empty() {
args.extend(["--peers".to_string(), peer.clone()]);
}
}
// Add virtual IP: use -d for DHCP if empty, or -i for specific IP
if let Some(ref ip) = c.virtual_ip {
if !ip.is_empty() {
// Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24)
args.extend(["-i".to_string(), ip.clone()]);
} else {
// Empty string means use DHCP
args.push("-d".to_string());
}
} else {
// None means use DHCP
args.push("-d".to_string());
}
Ok(args)
}
}
}
/// Prepare ttyd socket directory and clean up old socket file
async fn prepare_ttyd_socket() -> Result<(), String> {
let socket_path = Path::new(TTYD_SOCKET_PATH);
// Ensure socket directory exists
if let Some(socket_dir) = socket_path.parent() {
if !socket_dir.exists() {
tokio::fs::create_dir_all(socket_dir)
.await
.map_err(|e| format!("Failed to create socket directory: {}", e))?;
}
}
// Remove old socket file if exists
if tokio::fs::try_exists(TTYD_SOCKET_PATH).await.unwrap_or(false) {
tokio::fs::remove_file(TTYD_SOCKET_PATH)
.await
.map_err(|e| format!("Failed to remove old socket: {}", e))?;
}
Ok(())
}
/// Health check - restart crashed processes that should be running
pub async fn health_check(&self, config: &ExtensionsConfig) {
// Collect extensions that need restart check
let checks: Vec<_> = ExtensionId::all()
.iter()
.filter_map(|id| {
let should_run = match id {
ExtensionId::Ttyd => config.ttyd.enabled,
ExtensionId::Gostc => config.gostc.enabled && !config.gostc.key.is_empty(),
ExtensionId::Easytier => {
config.easytier.enabled && !config.easytier.network_name.is_empty()
}
};
if should_run && self.check_available(*id) {
Some(*id)
} else {
None
}
})
.collect();
// Check which ones need restart (single read lock)
let needs_restart: Vec<_> = {
let processes = self.processes.read().await;
checks
.into_iter()
.filter(|id| {
if let Some(proc) = processes.get(id) {
proc.child.id().is_none()
} else {
true
}
})
.collect()
};
// Restart all crashed extensions in parallel
let restart_futures: Vec<_> = needs_restart
.into_iter()
.map(|id| async move {
tracing::info!("Health check: restarting {}", id);
if let Err(e) = self.start(id, config).await {
tracing::error!("Failed to restart {}: {}", id, e);
}
})
.collect();
futures::future::join_all(restart_futures).await;
}
/// Start all enabled extensions in parallel
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
use std::pin::Pin;
use futures::Future;
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
// Collect enabled extensions
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
tracing::error!("Failed to start ttyd: {}", e);
}
}));
}
if config.gostc.enabled
&& !config.gostc.key.is_empty()
&& self.check_available(ExtensionId::Gostc)
{
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Gostc, config).await {
tracing::error!("Failed to start gostc: {}", e);
}
}));
}
if config.easytier.enabled
&& !config.easytier.network_name.is_empty()
&& self.check_available(ExtensionId::Easytier)
{
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Easytier, config).await {
tracing::error!("Failed to start easytier: {}", e);
}
}));
}
// Start all in parallel
futures::future::join_all(start_futures).await;
}
/// Stop all running extensions in parallel
pub async fn stop_all(&self) {
let stop_futures: Vec<_> = ExtensionId::all()
.iter()
.map(|id| self.stop(*id))
.collect();
futures::future::join_all(stop_futures).await;
}
}

7
src/extensions/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
//! Extensions module - manage external processes like ttyd, gostc, easytier
mod manager;
mod types;
pub use manager::{ExtensionManager, TTYD_SOCKET_PATH};
pub use types::*;

251
src/extensions/types.rs Normal file
View File

@@ -0,0 +1,251 @@
//! Extension types and configurations
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Extension identifier (fixed set of supported extensions)
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ExtensionId {
/// Web terminal (ttyd)
Ttyd,
/// NAT traversal client (gostc)
Gostc,
/// P2P VPN (easytier)
Easytier,
}
impl ExtensionId {
/// Get the binary path for this extension
pub fn binary_path(&self) -> &'static str {
match self {
Self::Ttyd => "/usr/bin/ttyd",
Self::Gostc => "/usr/bin/gostc",
Self::Easytier => "/usr/bin/easytier-core",
}
}
/// Get the display name for this extension
pub fn display_name(&self) -> &'static str {
match self {
Self::Ttyd => "Web Terminal",
Self::Gostc => "GOSTC Tunnel",
Self::Easytier => "EasyTier VPN",
}
}
/// Get all extension IDs
pub fn all() -> &'static [ExtensionId] {
&[Self::Ttyd, Self::Gostc, Self::Easytier]
}
}
impl std::fmt::Display for ExtensionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ttyd => write!(f, "ttyd"),
Self::Gostc => write!(f, "gostc"),
Self::Easytier => write!(f, "easytier"),
}
}
}
impl std::str::FromStr for ExtensionId {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ttyd" => Ok(Self::Ttyd),
"gostc" => Ok(Self::Gostc),
"easytier" => Ok(Self::Easytier),
_ => Err(format!("Unknown extension: {}", s)),
}
}
}
/// Extension running status
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "state", content = "data", rename_all = "lowercase")]
pub enum ExtensionStatus {
/// Binary not found at expected path
Unavailable,
/// Extension is stopped
Stopped,
/// Extension is running
Running {
/// Process ID
pid: u32,
},
/// Extension failed to start
Failed {
/// Error message
error: String,
},
}
impl ExtensionStatus {
pub fn is_running(&self) -> bool {
matches!(self, Self::Running { .. })
}
}
/// ttyd configuration (Web Terminal)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TtydConfig {
/// Enable auto-start
pub enabled: bool,
/// Port to listen on
pub port: u16,
/// Shell to execute
pub shell: String,
/// Credential in format "user:password" (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub credential: Option<String>,
}
impl Default for TtydConfig {
fn default() -> Self {
Self {
enabled: false,
port: 7681,
shell: "/bin/bash".to_string(),
credential: None,
}
}
}
/// gostc configuration (NAT traversal based on FRP)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct GostcConfig {
/// Enable auto-start
pub enabled: bool,
/// Server address (e.g., gostc.mofeng.run)
pub addr: String,
/// Client key from GOSTC management panel
#[serde(skip_serializing_if = "String::is_empty")]
pub key: String,
/// Enable TLS
pub tls: bool,
}
impl Default for GostcConfig {
fn default() -> Self {
Self {
enabled: false,
addr: "gostc.mofeng.run".to_string(),
key: String::new(),
tls: true,
}
}
}
/// EasyTier configuration (P2P VPN)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct EasytierConfig {
/// Enable auto-start
pub enabled: bool,
/// Network name
pub network_name: String,
/// Network secret/password
#[serde(skip_serializing_if = "String::is_empty")]
pub network_secret: String,
/// Peer node URLs
#[serde(skip_serializing_if = "Vec::is_empty")]
pub peer_urls: Vec<String>,
/// Virtual IP address (optional, auto-assigned if not set)
#[serde(skip_serializing_if = "Option::is_none")]
pub virtual_ip: Option<String>,
}
impl Default for EasytierConfig {
fn default() -> Self {
Self {
enabled: false,
network_name: String::new(),
network_secret: String::new(),
peer_urls: Vec::new(),
virtual_ip: None,
}
}
}
/// Combined extensions configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct ExtensionsConfig {
pub ttyd: TtydConfig,
pub gostc: GostcConfig,
pub easytier: EasytierConfig,
}
/// Extension info with status and config
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
}
/// ttyd extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: TtydConfig,
}
/// gostc extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GostcInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: GostcConfig,
}
/// easytier extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EasytierInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: EasytierConfig,
}
/// All extensions status response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionsStatus {
pub ttyd: TtydInfo,
pub gostc: GostcInfo,
pub easytier: EasytierInfo,
}
/// Extension logs response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionLogs {
pub id: ExtensionId,
pub logs: Vec<String>,
}

130
src/hid/backend.rs Normal file
View File

@@ -0,0 +1,130 @@
//! HID backend trait definition
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::types::{KeyboardEvent, MouseEvent};
use crate::error::Result;
/// Default CH9329 baud rate
fn default_ch9329_baud_rate() -> u32 {
9600
}
/// HID backend type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum HidBackendType {
/// USB OTG gadget mode
Otg,
/// CH9329 serial HID controller
Ch9329 {
/// Serial port path
port: String,
/// Baud rate (default: 9600)
#[serde(default = "default_ch9329_baud_rate")]
baud_rate: u32,
},
/// No HID backend (disabled)
None,
}
impl Default for HidBackendType {
fn default() -> Self {
Self::None
}
}
impl HidBackendType {
/// Check if OTG backend is available on this system
pub fn otg_available() -> bool {
// Check for USB gadget support
std::path::Path::new("/sys/class/udc").exists()
}
/// Detect the best available backend
pub fn detect() -> Self {
// Check for OTG gadget support
if Self::otg_available() {
return Self::Otg;
}
// Check for common CH9329 serial ports
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
];
for port in &common_ports {
if std::path::Path::new(port).exists() {
return Self::Ch9329 {
port: port.to_string(),
baud_rate: 9600, // Use default baud rate for auto-detection
};
}
}
Self::None
}
/// Get backend name as string
pub fn name_str(&self) -> &str {
match self {
Self::Otg => "otg",
Self::Ch9329 { .. } => "ch9329",
Self::None => "none",
}
}
}
/// HID backend trait
#[async_trait]
pub trait HidBackend: Send + Sync {
/// Get backend name
fn name(&self) -> &'static str;
/// Initialize the backend
async fn init(&self) -> Result<()>;
/// Send a keyboard event
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>;
/// Send a mouse event
async fn send_mouse(&self, event: MouseEvent) -> Result<()>;
/// Reset all inputs (release all keys/buttons)
async fn reset(&self) -> Result<()>;
/// Shutdown the backend
async fn shutdown(&self) -> Result<()>;
/// Check if backend supports absolute mouse positioning
fn supports_absolute_mouse(&self) -> bool {
false
}
/// Get screen resolution (for absolute mouse)
fn screen_resolution(&self) -> Option<(u32, u32)> {
None
}
/// Set screen resolution (for absolute mouse)
fn set_screen_resolution(&mut self, _width: u32, _height: u32) {}
}
/// HID backend information
#[derive(Debug, Clone, Serialize)]
pub struct HidBackendInfo {
/// Backend name
pub name: String,
/// Backend type
pub backend_type: String,
/// Is initialized
pub initialized: bool,
/// Supports absolute mouse
pub absolute_mouse: bool,
/// Screen resolution (if absolute mouse)
pub resolution: Option<(u32, u32)>,
}

1324
src/hid/ch9329.rs Normal file

File diff suppressed because it is too large Load Diff

281
src/hid/datachannel.rs Normal file
View File

@@ -0,0 +1,281 @@
//! DataChannel HID message parsing and handling
//!
//! Binary message format:
//! - Byte 0: Message type
//! - 0x01: Keyboard event
//! - 0x02: Mouse event
//! - Remaining bytes: Event data
//!
//! Keyboard event (type 0x01):
//! - Byte 1: Event type (0x00 = down, 0x01 = up)
//! - Byte 2: Key code (USB HID usage code or JS keyCode)
//! - Byte 3: Modifiers bitmask
//! - Bit 0: Left Ctrl
//! - Bit 1: Left Shift
//! - Bit 2: Left Alt
//! - Bit 3: Left Meta
//! - Bit 4: Right Ctrl
//! - Bit 5: Right Shift
//! - Bit 6: Right Alt
//! - Bit 7: Right Meta
//!
//! Mouse event (type 0x02):
//! - Byte 1: Event type
//! - 0x00: Move (relative)
//! - 0x01: MoveAbs (absolute)
//! - 0x02: Down
//! - 0x03: Up
//! - 0x04: Scroll
//! - Bytes 2-3: X coordinate (i16 LE for relative, u16 LE for absolute)
//! - Bytes 4-5: Y coordinate (i16 LE for relative, u16 LE for absolute)
//! - Byte 6: Button (0=left, 1=middle, 2=right) or Scroll delta (i8)
use tracing::{debug, warn};
use super::{
KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent, MouseEventType,
};
/// Message types
pub const MSG_KEYBOARD: u8 = 0x01;
pub const MSG_MOUSE: u8 = 0x02;
/// Keyboard event types
pub const KB_EVENT_DOWN: u8 = 0x00;
pub const KB_EVENT_UP: u8 = 0x01;
/// Mouse event types
pub const MS_EVENT_MOVE: u8 = 0x00;
pub const MS_EVENT_MOVE_ABS: u8 = 0x01;
pub const MS_EVENT_DOWN: u8 = 0x02;
pub const MS_EVENT_UP: u8 = 0x03;
pub const MS_EVENT_SCROLL: u8 = 0x04;
/// Parsed HID event from DataChannel
#[derive(Debug, Clone)]
pub enum HidChannelEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
}
/// Parse a binary HID message from DataChannel
pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.is_empty() {
warn!("Empty HID message");
return None;
}
let msg_type = data[0];
match msg_type {
MSG_KEYBOARD => parse_keyboard_message(&data[1..]),
MSG_MOUSE => parse_mouse_message(&data[1..]),
_ => {
warn!("Unknown HID message type: 0x{:02X}", msg_type);
None
}
}
}
/// Parse keyboard message payload
fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 3 {
warn!("Keyboard message too short: {} bytes", data.len());
return None;
}
let event_type = match data[0] {
KB_EVENT_DOWN => KeyEventType::Down,
KB_EVENT_UP => KeyEventType::Up,
_ => {
warn!("Unknown keyboard event type: 0x{:02X}", data[0]);
return None;
}
};
let key = data[1];
let modifiers_byte = data[2];
let modifiers = KeyboardModifiers {
left_ctrl: modifiers_byte & 0x01 != 0,
left_shift: modifiers_byte & 0x02 != 0,
left_alt: modifiers_byte & 0x04 != 0,
left_meta: modifiers_byte & 0x08 != 0,
right_ctrl: modifiers_byte & 0x10 != 0,
right_shift: modifiers_byte & 0x20 != 0,
right_alt: modifiers_byte & 0x40 != 0,
right_meta: modifiers_byte & 0x80 != 0,
};
debug!(
"Parsed keyboard: {:?} key=0x{:02X} modifiers=0x{:02X}",
event_type, key, modifiers_byte
);
Some(HidChannelEvent::Keyboard(KeyboardEvent {
event_type,
key,
modifiers,
}))
}
/// Parse mouse message payload
fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 6 {
warn!("Mouse message too short: {} bytes", data.len());
return None;
}
let event_type = match data[0] {
MS_EVENT_MOVE => MouseEventType::Move,
MS_EVENT_MOVE_ABS => MouseEventType::MoveAbs,
MS_EVENT_DOWN => MouseEventType::Down,
MS_EVENT_UP => MouseEventType::Up,
MS_EVENT_SCROLL => MouseEventType::Scroll,
_ => {
warn!("Unknown mouse event type: 0x{:02X}", data[0]);
return None;
}
};
// Parse coordinates as i16 LE (works for both relative and absolute)
let x = i16::from_le_bytes([data[1], data[2]]) as i32;
let y = i16::from_le_bytes([data[3], data[4]]) as i32;
// Button or scroll delta
let (button, scroll) = match event_type {
MouseEventType::Down | MouseEventType::Up => {
let btn = match data[5] {
0 => Some(MouseButton::Left),
1 => Some(MouseButton::Middle),
2 => Some(MouseButton::Right),
3 => Some(MouseButton::Back),
4 => Some(MouseButton::Forward),
_ => Some(MouseButton::Left),
};
(btn, 0i8)
}
MouseEventType::Scroll => (None, data[5] as i8),
_ => (None, 0i8),
};
debug!(
"Parsed mouse: {:?} x={} y={} button={:?} scroll={}",
event_type, x, y, button, scroll
);
Some(HidChannelEvent::Mouse(MouseEvent {
event_type,
x,
y,
button,
scroll,
}))
}
/// Encode a keyboard event to binary format (for sending to client if needed)
pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
let event_type = match event.event_type {
KeyEventType::Down => KB_EVENT_DOWN,
KeyEventType::Up => KB_EVENT_UP,
};
let modifiers = event.modifiers.to_hid_byte();
vec![MSG_KEYBOARD, event_type, event.key, modifiers]
}
/// Encode a mouse event to binary format (for sending to client if needed)
pub fn encode_mouse_event(event: &MouseEvent) -> Vec<u8> {
let event_type = match event.event_type {
MouseEventType::Move => MS_EVENT_MOVE,
MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS,
MouseEventType::Down => MS_EVENT_DOWN,
MouseEventType::Up => MS_EVENT_UP,
MouseEventType::Scroll => MS_EVENT_SCROLL,
};
let x_bytes = (event.x as i16).to_le_bytes();
let y_bytes = (event.y as i16).to_le_bytes();
let extra = match event.event_type {
MouseEventType::Down | MouseEventType::Up => {
event.button.as_ref().map(|b| match b {
MouseButton::Left => 0u8,
MouseButton::Middle => 1u8,
MouseButton::Right => 2u8,
MouseButton::Back => 3u8,
MouseButton::Forward => 4u8,
}).unwrap_or(0)
}
MouseEventType::Scroll => event.scroll as u8,
_ => 0,
};
vec![
MSG_MOUSE,
event_type,
x_bytes[0],
x_bytes[1],
y_bytes[0],
y_bytes[1],
extra,
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_keyboard_down() {
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // A key with left ctrl
let event = parse_hid_message(&data).unwrap();
match event {
HidChannelEvent::Keyboard(kb) => {
assert!(matches!(kb.event_type, KeyEventType::Down));
assert_eq!(kb.key, 0x04);
assert!(kb.modifiers.left_ctrl);
assert!(!kb.modifiers.left_shift);
}
_ => panic!("Expected keyboard event"),
}
}
#[test]
fn test_parse_mouse_move() {
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
let event = parse_hid_message(&data).unwrap();
match event {
HidChannelEvent::Mouse(ms) => {
assert!(matches!(ms.event_type, MouseEventType::Move));
assert_eq!(ms.x, 10);
assert_eq!(ms.y, -10);
}
_ => panic!("Expected mouse event"),
}
}
#[test]
fn test_encode_keyboard() {
let event = KeyboardEvent {
event_type: KeyEventType::Down,
key: 0x04,
modifiers: KeyboardModifiers {
left_ctrl: true,
left_shift: false,
left_alt: false,
left_meta: false,
right_ctrl: false,
right_shift: false,
right_alt: false,
right_meta: false,
},
};
let encoded = encode_keyboard_event(&event);
assert_eq!(encoded, vec![MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]);
}
}

430
src/hid/keymap.rs Normal file
View File

@@ -0,0 +1,430 @@
//! USB HID keyboard key codes mapping
//!
//! This module provides mapping between JavaScript key codes and USB HID usage codes.
//! Reference: USB HID Usage Tables 1.12, Section 10 (Keyboard/Keypad Page)
/// USB HID key codes (Usage Page 0x07)
#[allow(dead_code)]
pub mod usb {
// Letters A-Z (0x04 - 0x1D)
pub const KEY_A: u8 = 0x04;
pub const KEY_B: u8 = 0x05;
pub const KEY_C: u8 = 0x06;
pub const KEY_D: u8 = 0x07;
pub const KEY_E: u8 = 0x08;
pub const KEY_F: u8 = 0x09;
pub const KEY_G: u8 = 0x0A;
pub const KEY_H: u8 = 0x0B;
pub const KEY_I: u8 = 0x0C;
pub const KEY_J: u8 = 0x0D;
pub const KEY_K: u8 = 0x0E;
pub const KEY_L: u8 = 0x0F;
pub const KEY_M: u8 = 0x10;
pub const KEY_N: u8 = 0x11;
pub const KEY_O: u8 = 0x12;
pub const KEY_P: u8 = 0x13;
pub const KEY_Q: u8 = 0x14;
pub const KEY_R: u8 = 0x15;
pub const KEY_S: u8 = 0x16;
pub const KEY_T: u8 = 0x17;
pub const KEY_U: u8 = 0x18;
pub const KEY_V: u8 = 0x19;
pub const KEY_W: u8 = 0x1A;
pub const KEY_X: u8 = 0x1B;
pub const KEY_Y: u8 = 0x1C;
pub const KEY_Z: u8 = 0x1D;
// Numbers 1-9, 0 (0x1E - 0x27)
pub const KEY_1: u8 = 0x1E;
pub const KEY_2: u8 = 0x1F;
pub const KEY_3: u8 = 0x20;
pub const KEY_4: u8 = 0x21;
pub const KEY_5: u8 = 0x22;
pub const KEY_6: u8 = 0x23;
pub const KEY_7: u8 = 0x24;
pub const KEY_8: u8 = 0x25;
pub const KEY_9: u8 = 0x26;
pub const KEY_0: u8 = 0x27;
// Control keys
pub const KEY_ENTER: u8 = 0x28;
pub const KEY_ESCAPE: u8 = 0x29;
pub const KEY_BACKSPACE: u8 = 0x2A;
pub const KEY_TAB: u8 = 0x2B;
pub const KEY_SPACE: u8 = 0x2C;
pub const KEY_MINUS: u8 = 0x2D;
pub const KEY_EQUAL: u8 = 0x2E;
pub const KEY_LEFT_BRACKET: u8 = 0x2F;
pub const KEY_RIGHT_BRACKET: u8 = 0x30;
pub const KEY_BACKSLASH: u8 = 0x31;
pub const KEY_HASH: u8 = 0x32; // Non-US # and ~
pub const KEY_SEMICOLON: u8 = 0x33;
pub const KEY_APOSTROPHE: u8 = 0x34;
pub const KEY_GRAVE: u8 = 0x35;
pub const KEY_COMMA: u8 = 0x36;
pub const KEY_PERIOD: u8 = 0x37;
pub const KEY_SLASH: u8 = 0x38;
pub const KEY_CAPS_LOCK: u8 = 0x39;
// Function keys F1-F12
pub const KEY_F1: u8 = 0x3A;
pub const KEY_F2: u8 = 0x3B;
pub const KEY_F3: u8 = 0x3C;
pub const KEY_F4: u8 = 0x3D;
pub const KEY_F5: u8 = 0x3E;
pub const KEY_F6: u8 = 0x3F;
pub const KEY_F7: u8 = 0x40;
pub const KEY_F8: u8 = 0x41;
pub const KEY_F9: u8 = 0x42;
pub const KEY_F10: u8 = 0x43;
pub const KEY_F11: u8 = 0x44;
pub const KEY_F12: u8 = 0x45;
// Special keys
pub const KEY_PRINT_SCREEN: u8 = 0x46;
pub const KEY_SCROLL_LOCK: u8 = 0x47;
pub const KEY_PAUSE: u8 = 0x48;
pub const KEY_INSERT: u8 = 0x49;
pub const KEY_HOME: u8 = 0x4A;
pub const KEY_PAGE_UP: u8 = 0x4B;
pub const KEY_DELETE: u8 = 0x4C;
pub const KEY_END: u8 = 0x4D;
pub const KEY_PAGE_DOWN: u8 = 0x4E;
pub const KEY_RIGHT_ARROW: u8 = 0x4F;
pub const KEY_LEFT_ARROW: u8 = 0x50;
pub const KEY_DOWN_ARROW: u8 = 0x51;
pub const KEY_UP_ARROW: u8 = 0x52;
// Numpad
pub const KEY_NUM_LOCK: u8 = 0x53;
pub const KEY_NUMPAD_DIVIDE: u8 = 0x54;
pub const KEY_NUMPAD_MULTIPLY: u8 = 0x55;
pub const KEY_NUMPAD_MINUS: u8 = 0x56;
pub const KEY_NUMPAD_PLUS: u8 = 0x57;
pub const KEY_NUMPAD_ENTER: u8 = 0x58;
pub const KEY_NUMPAD_1: u8 = 0x59;
pub const KEY_NUMPAD_2: u8 = 0x5A;
pub const KEY_NUMPAD_3: u8 = 0x5B;
pub const KEY_NUMPAD_4: u8 = 0x5C;
pub const KEY_NUMPAD_5: u8 = 0x5D;
pub const KEY_NUMPAD_6: u8 = 0x5E;
pub const KEY_NUMPAD_7: u8 = 0x5F;
pub const KEY_NUMPAD_8: u8 = 0x60;
pub const KEY_NUMPAD_9: u8 = 0x61;
pub const KEY_NUMPAD_0: u8 = 0x62;
pub const KEY_NUMPAD_DECIMAL: u8 = 0x63;
// Additional keys
pub const KEY_NON_US_BACKSLASH: u8 = 0x64;
pub const KEY_APPLICATION: u8 = 0x65; // Context menu
pub const KEY_POWER: u8 = 0x66;
pub const KEY_NUMPAD_EQUAL: u8 = 0x67;
// F13-F24
pub const KEY_F13: u8 = 0x68;
pub const KEY_F14: u8 = 0x69;
pub const KEY_F15: u8 = 0x6A;
pub const KEY_F16: u8 = 0x6B;
pub const KEY_F17: u8 = 0x6C;
pub const KEY_F18: u8 = 0x6D;
pub const KEY_F19: u8 = 0x6E;
pub const KEY_F20: u8 = 0x6F;
pub const KEY_F21: u8 = 0x70;
pub const KEY_F22: u8 = 0x71;
pub const KEY_F23: u8 = 0x72;
pub const KEY_F24: u8 = 0x73;
// Modifier keys (these are handled separately in the modifier byte)
pub const KEY_LEFT_CTRL: u8 = 0xE0;
pub const KEY_LEFT_SHIFT: u8 = 0xE1;
pub const KEY_LEFT_ALT: u8 = 0xE2;
pub const KEY_LEFT_META: u8 = 0xE3;
pub const KEY_RIGHT_CTRL: u8 = 0xE4;
pub const KEY_RIGHT_SHIFT: u8 = 0xE5;
pub const KEY_RIGHT_ALT: u8 = 0xE6;
pub const KEY_RIGHT_META: u8 = 0xE7;
}
/// JavaScript key codes (event.keyCode / event.code)
#[allow(dead_code)]
pub mod js {
// Letters
pub const KEY_A: u8 = 65;
pub const KEY_B: u8 = 66;
pub const KEY_C: u8 = 67;
pub const KEY_D: u8 = 68;
pub const KEY_E: u8 = 69;
pub const KEY_F: u8 = 70;
pub const KEY_G: u8 = 71;
pub const KEY_H: u8 = 72;
pub const KEY_I: u8 = 73;
pub const KEY_J: u8 = 74;
pub const KEY_K: u8 = 75;
pub const KEY_L: u8 = 76;
pub const KEY_M: u8 = 77;
pub const KEY_N: u8 = 78;
pub const KEY_O: u8 = 79;
pub const KEY_P: u8 = 80;
pub const KEY_Q: u8 = 81;
pub const KEY_R: u8 = 82;
pub const KEY_S: u8 = 83;
pub const KEY_T: u8 = 84;
pub const KEY_U: u8 = 85;
pub const KEY_V: u8 = 86;
pub const KEY_W: u8 = 87;
pub const KEY_X: u8 = 88;
pub const KEY_Y: u8 = 89;
pub const KEY_Z: u8 = 90;
// Numbers (top row)
pub const KEY_0: u8 = 48;
pub const KEY_1: u8 = 49;
pub const KEY_2: u8 = 50;
pub const KEY_3: u8 = 51;
pub const KEY_4: u8 = 52;
pub const KEY_5: u8 = 53;
pub const KEY_6: u8 = 54;
pub const KEY_7: u8 = 55;
pub const KEY_8: u8 = 56;
pub const KEY_9: u8 = 57;
// Function keys
pub const KEY_F1: u8 = 112;
pub const KEY_F2: u8 = 113;
pub const KEY_F3: u8 = 114;
pub const KEY_F4: u8 = 115;
pub const KEY_F5: u8 = 116;
pub const KEY_F6: u8 = 117;
pub const KEY_F7: u8 = 118;
pub const KEY_F8: u8 = 119;
pub const KEY_F9: u8 = 120;
pub const KEY_F10: u8 = 121;
pub const KEY_F11: u8 = 122;
pub const KEY_F12: u8 = 123;
// Control keys
pub const KEY_BACKSPACE: u8 = 8;
pub const KEY_TAB: u8 = 9;
pub const KEY_ENTER: u8 = 13;
pub const KEY_SHIFT: u8 = 16;
pub const KEY_CTRL: u8 = 17;
pub const KEY_ALT: u8 = 18;
pub const KEY_PAUSE: u8 = 19;
pub const KEY_CAPS_LOCK: u8 = 20;
pub const KEY_ESCAPE: u8 = 27;
pub const KEY_SPACE: u8 = 32;
pub const KEY_PAGE_UP: u8 = 33;
pub const KEY_PAGE_DOWN: u8 = 34;
pub const KEY_END: u8 = 35;
pub const KEY_HOME: u8 = 36;
pub const KEY_LEFT: u8 = 37;
pub const KEY_UP: u8 = 38;
pub const KEY_RIGHT: u8 = 39;
pub const KEY_DOWN: u8 = 40;
pub const KEY_INSERT: u8 = 45;
pub const KEY_DELETE: u8 = 46;
// Punctuation
pub const KEY_SEMICOLON: u8 = 186;
pub const KEY_EQUAL: u8 = 187;
pub const KEY_COMMA: u8 = 188;
pub const KEY_MINUS: u8 = 189;
pub const KEY_PERIOD: u8 = 190;
pub const KEY_SLASH: u8 = 191;
pub const KEY_GRAVE: u8 = 192;
pub const KEY_LEFT_BRACKET: u8 = 219;
pub const KEY_BACKSLASH: u8 = 220;
pub const KEY_RIGHT_BRACKET: u8 = 221;
pub const KEY_APOSTROPHE: u8 = 222;
// Numpad
pub const KEY_NUMPAD_0: u8 = 96;
pub const KEY_NUMPAD_1: u8 = 97;
pub const KEY_NUMPAD_2: u8 = 98;
pub const KEY_NUMPAD_3: u8 = 99;
pub const KEY_NUMPAD_4: u8 = 100;
pub const KEY_NUMPAD_5: u8 = 101;
pub const KEY_NUMPAD_6: u8 = 102;
pub const KEY_NUMPAD_7: u8 = 103;
pub const KEY_NUMPAD_8: u8 = 104;
pub const KEY_NUMPAD_9: u8 = 105;
pub const KEY_NUMPAD_MULTIPLY: u8 = 106;
pub const KEY_NUMPAD_ADD: u8 = 107;
pub const KEY_NUMPAD_SUBTRACT: u8 = 109;
pub const KEY_NUMPAD_DECIMAL: u8 = 110;
pub const KEY_NUMPAD_DIVIDE: u8 = 111;
// Lock keys
pub const KEY_NUM_LOCK: u8 = 144;
pub const KEY_SCROLL_LOCK: u8 = 145;
// Windows keys
pub const KEY_META_LEFT: u8 = 91;
pub const KEY_META_RIGHT: u8 = 92;
pub const KEY_CONTEXT_MENU: u8 = 93;
}
/// JavaScript keyCode to USB HID keyCode mapping table
/// Using a fixed-size array for O(1) lookup instead of HashMap
/// Index = JavaScript keyCode, Value = USB HID keyCode (0 means unmapped)
static JS_TO_USB_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
// Letters A-Z (JS 65-90 -> USB 0x04-0x1D)
let mut i = 0u8;
while i < 26 {
table[(65 + i) as usize] = usb::KEY_A + i;
i += 1;
}
// Numbers 1-9, 0 (JS 49-57, 48 -> USB 0x1E-0x27)
table[49] = usb::KEY_1; // 1
table[50] = usb::KEY_2; // 2
table[51] = usb::KEY_3; // 3
table[52] = usb::KEY_4; // 4
table[53] = usb::KEY_5; // 5
table[54] = usb::KEY_6; // 6
table[55] = usb::KEY_7; // 7
table[56] = usb::KEY_8; // 8
table[57] = usb::KEY_9; // 9
table[48] = usb::KEY_0; // 0
// Function keys F1-F12 (JS 112-123 -> USB 0x3A-0x45)
table[112] = usb::KEY_F1;
table[113] = usb::KEY_F2;
table[114] = usb::KEY_F3;
table[115] = usb::KEY_F4;
table[116] = usb::KEY_F5;
table[117] = usb::KEY_F6;
table[118] = usb::KEY_F7;
table[119] = usb::KEY_F8;
table[120] = usb::KEY_F9;
table[121] = usb::KEY_F10;
table[122] = usb::KEY_F11;
table[123] = usb::KEY_F12;
// Control keys
table[13] = usb::KEY_ENTER; // Enter
table[27] = usb::KEY_ESCAPE; // Escape
table[8] = usb::KEY_BACKSPACE; // Backspace
table[9] = usb::KEY_TAB; // Tab
table[32] = usb::KEY_SPACE; // Space
table[20] = usb::KEY_CAPS_LOCK; // Caps Lock
// Punctuation (JS codes vary by browser/layout)
table[189] = usb::KEY_MINUS; // -
table[187] = usb::KEY_EQUAL; // =
table[219] = usb::KEY_LEFT_BRACKET; // [
table[221] = usb::KEY_RIGHT_BRACKET; // ]
table[220] = usb::KEY_BACKSLASH; // \
table[186] = usb::KEY_SEMICOLON; // ;
table[222] = usb::KEY_APOSTROPHE; // '
table[192] = usb::KEY_GRAVE; // `
table[188] = usb::KEY_COMMA; // ,
table[190] = usb::KEY_PERIOD; // .
table[191] = usb::KEY_SLASH; // /
// Navigation keys
table[45] = usb::KEY_INSERT;
table[46] = usb::KEY_DELETE;
table[36] = usb::KEY_HOME;
table[35] = usb::KEY_END;
table[33] = usb::KEY_PAGE_UP;
table[34] = usb::KEY_PAGE_DOWN;
// Arrow keys
table[39] = usb::KEY_RIGHT_ARROW;
table[37] = usb::KEY_LEFT_ARROW;
table[40] = usb::KEY_DOWN_ARROW;
table[38] = usb::KEY_UP_ARROW;
// Numpad
table[144] = usb::KEY_NUM_LOCK;
table[111] = usb::KEY_NUMPAD_DIVIDE;
table[106] = usb::KEY_NUMPAD_MULTIPLY;
table[109] = usb::KEY_NUMPAD_MINUS;
table[107] = usb::KEY_NUMPAD_PLUS;
table[96] = usb::KEY_NUMPAD_0;
table[97] = usb::KEY_NUMPAD_1;
table[98] = usb::KEY_NUMPAD_2;
table[99] = usb::KEY_NUMPAD_3;
table[100] = usb::KEY_NUMPAD_4;
table[101] = usb::KEY_NUMPAD_5;
table[102] = usb::KEY_NUMPAD_6;
table[103] = usb::KEY_NUMPAD_7;
table[104] = usb::KEY_NUMPAD_8;
table[105] = usb::KEY_NUMPAD_9;
table[110] = usb::KEY_NUMPAD_DECIMAL;
// Special keys
table[19] = usb::KEY_PAUSE;
table[145] = usb::KEY_SCROLL_LOCK;
table[93] = usb::KEY_APPLICATION; // Context menu
// Modifier keys
table[17] = usb::KEY_LEFT_CTRL;
table[16] = usb::KEY_LEFT_SHIFT;
table[18] = usb::KEY_LEFT_ALT;
table[91] = usb::KEY_LEFT_META; // Left Windows/Command
table[92] = usb::KEY_RIGHT_META; // Right Windows/Command
table
};
/// Convert JavaScript keyCode to USB HID keyCode
///
/// Uses a fixed-size lookup table for O(1) performance.
/// Returns None if the key code is not mapped.
#[inline]
pub fn js_to_usb(js_code: u8) -> Option<u8> {
let usb_code = JS_TO_USB_TABLE[js_code as usize];
if usb_code != 0 {
Some(usb_code)
} else {
None
}
}
/// Check if a key code is a modifier key
pub fn is_modifier_key(usb_code: u8) -> bool {
(0xE0..=0xE7).contains(&usb_code)
}
/// Get modifier bit for a modifier key
pub fn modifier_bit(usb_code: u8) -> Option<u8> {
match usb_code {
usb::KEY_LEFT_CTRL => Some(0x01),
usb::KEY_LEFT_SHIFT => Some(0x02),
usb::KEY_LEFT_ALT => Some(0x04),
usb::KEY_LEFT_META => Some(0x08),
usb::KEY_RIGHT_CTRL => Some(0x10),
usb::KEY_RIGHT_SHIFT => Some(0x20),
usb::KEY_RIGHT_ALT => Some(0x40),
usb::KEY_RIGHT_META => Some(0x80),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_letter_mapping() {
assert_eq!(js_to_usb(65), Some(usb::KEY_A)); // A
assert_eq!(js_to_usb(90), Some(usb::KEY_Z)); // Z
}
#[test]
fn test_number_mapping() {
assert_eq!(js_to_usb(48), Some(usb::KEY_0));
assert_eq!(js_to_usb(49), Some(usb::KEY_1));
}
#[test]
fn test_modifier_key() {
assert!(is_modifier_key(usb::KEY_LEFT_CTRL));
assert!(is_modifier_key(usb::KEY_RIGHT_SHIFT));
assert!(!is_modifier_key(usb::KEY_A));
}
}

417
src/hid/mod.rs Normal file
View File

@@ -0,0 +1,417 @@
//! HID (Human Interface Device) control module
//!
//! This module provides keyboard and mouse control for remote KVM:
//! - USB OTG gadget mode (native Linux USB gadget)
//! - CH9329 serial HID controller
//!
//! Architecture:
//! ```text
//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC
//! |
//! [OTG | CH9329]
//! ```
pub mod backend;
pub mod ch9329;
pub mod datachannel;
pub mod keymap;
pub mod monitor;
pub mod otg;
pub mod types;
pub mod websocket;
pub use backend::{HidBackend, HidBackendType};
pub use monitor::{HidHealthMonitor, HidHealthStatus, HidMonitorConfig};
pub use otg::LedState;
pub use types::{
KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent, MouseEventType,
};
/// HID backend information
#[derive(Debug, Clone)]
pub struct HidInfo {
/// Backend name
pub name: &'static str,
/// Whether backend is initialized
pub initialized: bool,
/// Whether absolute mouse positioning is supported
pub supports_absolute_mouse: bool,
/// Screen resolution for absolute mouse
pub screen_resolution: Option<(u32, u32)>,
}
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::error::{AppError, Result};
use crate::otg::OtgService;
/// HID controller managing keyboard and mouse input
pub struct HidController {
/// OTG Service reference (only used when backend is OTG)
otg_service: Option<Arc<OtgService>>,
/// Active backend
backend: Arc<RwLock<Option<Box<dyn HidBackend>>>>,
/// Backend type (mutable for reload)
backend_type: RwLock<HidBackendType>,
/// Event bus for broadcasting state changes (optional)
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
/// Health monitor for error tracking and recovery
monitor: Arc<HidHealthMonitor>,
}
impl HidController {
/// Create a new HID controller with specified backend
///
/// For OTG backend, otg_service should be provided to support hot-reload
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
Self {
otg_service,
backend: Arc::new(RwLock::new(None)),
backend_type: RwLock::new(backend_type),
events: tokio::sync::RwLock::new(None),
monitor: Arc::new(HidHealthMonitor::with_defaults()),
}
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<crate::events::EventBus>) {
*self.events.write().await = Some(events.clone());
// Also set event bus on the monitor for health notifications
self.monitor.set_event_bus(events).await;
}
/// Initialize the HID backend
pub async fn init(&self) -> Result<()> {
let backend_type = self.backend_type.read().await.clone();
let backend: Box<dyn HidBackend> = match backend_type {
HidBackendType::Otg => {
// Request HID functions from OtgService
let otg_service = self
.otg_service
.as_ref()
.ok_or_else(|| AppError::Internal("OtgService not available".into()))?;
info!("Requesting HID functions from OtgService");
let handles = otg_service.enable_hid().await?;
// Create OtgBackend from handles (no longer manages gadget itself)
info!("Creating OTG HID backend from device paths");
Box::new(otg::OtgBackend::from_handles(handles)?)
}
HidBackendType::Ch9329 { ref port, baud_rate } => {
info!("Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate);
Box::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?)
}
HidBackendType::None => {
warn!("HID backend disabled");
return Ok(());
}
};
backend.init().await?;
*self.backend.write().await = Some(backend);
info!("HID backend initialized: {:?}", backend_type);
Ok(())
}
/// Shutdown the HID backend and release resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down HID controller");
// Close the backend
*self.backend.write().await = None;
// If OTG backend, notify OtgService to disable HID
let backend_type = self.backend_type.read().await.clone();
if matches!(backend_type, HidBackendType::Otg) {
if let Some(ref otg_service) = self.otg_service {
info!("Disabling HID functions in OtgService");
otg_service.disable_hid().await?;
}
}
info!("HID controller shutdown complete");
Ok(())
}
/// Send keyboard event
pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
let backend = self.backend.read().await;
match backend.as_ref() {
Some(b) => {
match b.send_keyboard(event).await {
Ok(_) => {
// Check if we were in an error state and now recovered
if self.monitor.is_error().await {
let backend_type = self.backend_type.read().await;
self.monitor.report_recovered(backend_type.name_str()).await;
}
Ok(())
}
Err(e) => {
// Report error to monitor, but skip temporary EAGAIN retries
// - "eagain_retry": within threshold, just temporary busy
// - "eagain": exceeded threshold, report as error
if let AppError::HidError { ref backend, ref reason, ref error_code } = e {
if error_code != "eagain_retry" {
self.monitor.report_error(backend, None, reason, error_code).await;
}
}
Err(e)
}
}
}
None => Err(AppError::BadRequest("HID backend not available".to_string())),
}
}
/// Send mouse event
pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
let backend = self.backend.read().await;
match backend.as_ref() {
Some(b) => {
match b.send_mouse(event).await {
Ok(_) => {
// Check if we were in an error state and now recovered
if self.monitor.is_error().await {
let backend_type = self.backend_type.read().await;
self.monitor.report_recovered(backend_type.name_str()).await;
}
Ok(())
}
Err(e) => {
// Report error to monitor, but skip temporary EAGAIN retries
// - "eagain_retry": within threshold, just temporary busy
// - "eagain": exceeded threshold, report as error
if let AppError::HidError { ref backend, ref reason, ref error_code } = e {
if error_code != "eagain_retry" {
self.monitor.report_error(backend, None, reason, error_code).await;
}
}
Err(e)
}
}
}
None => Err(AppError::BadRequest("HID backend not available".to_string())),
}
}
/// Reset all keys (release all pressed keys)
pub async fn reset(&self) -> Result<()> {
let backend = self.backend.read().await;
match backend.as_ref() {
Some(b) => b.reset().await,
None => Ok(()),
}
}
/// Check if backend is available
pub async fn is_available(&self) -> bool {
self.backend.read().await.is_some()
}
/// Get backend type
pub async fn backend_type(&self) -> HidBackendType {
self.backend_type.read().await.clone()
}
/// Get backend info
pub async fn info(&self) -> Option<HidInfo> {
let backend = self.backend.read().await;
backend.as_ref().map(|b| HidInfo {
name: b.name(),
initialized: true,
supports_absolute_mouse: b.supports_absolute_mouse(),
screen_resolution: b.screen_resolution(),
})
}
/// Get current state as SystemEvent
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
let backend = self.backend.read().await;
let backend_type = self.backend_type().await;
let (backend_name, initialized) = match backend.as_ref() {
Some(b) => (b.name(), true),
None => (backend_type.name_str(), false),
};
// Include error information from monitor
let (error, error_code) = match self.monitor.status().await {
HidHealthStatus::Error { reason, error_code, .. } => {
(Some(reason), Some(error_code))
}
_ => (None, None),
};
crate::events::SystemEvent::HidStateChanged {
backend: backend_name.to_string(),
initialized,
error,
error_code,
}
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<HidHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> HidHealthStatus {
self.monitor.status().await
}
/// Check if the HID backend is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
/// Reload the HID backend with new type
pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> {
info!("Reloading HID backend: {:?}", new_backend_type);
// Shutdown existing backend first
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down old HID backend: {}", e);
}
}
// Create and initialize new backend
let new_backend: Option<Box<dyn HidBackend>> = match new_backend_type {
HidBackendType::Otg => {
info!("Initializing OTG HID backend");
// Get OtgService reference
let otg_service = match self.otg_service.as_ref() {
Some(svc) => svc,
None => {
warn!("OTG backend requires OtgService, but it's not available");
return Err(AppError::Config(
"OTG backend not available (OtgService missing)".to_string()
));
}
};
// Request HID functions from OtgService
match otg_service.enable_hid().await {
Ok(handles) => {
// Create OtgBackend from handles
match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let boxed: Box<dyn HidBackend> = Box::new(backend);
match boxed.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(boxed)
}
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
// Cleanup: disable HID in OtgService
if let Err(e2) = otg_service.disable_hid().await {
warn!("Failed to cleanup HID after init failure: {}", e2);
}
None
}
}
}
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
// Cleanup: disable HID in OtgService
if let Err(e2) = otg_service.disable_hid().await {
warn!("Failed to cleanup HID after creation failure: {}", e2);
}
None
}
}
}
Err(e) => {
warn!("Failed to enable HID in OtgService: {}", e);
None
}
}
}
HidBackendType::Ch9329 { ref port, baud_rate } => {
info!("Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate);
match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) {
Ok(b) => {
let boxed = Box::new(b);
match boxed.init().await {
Ok(_) => Some(boxed),
Err(e) => {
warn!("Failed to initialize CH9329 backend: {}", e);
None
}
}
}
Err(e) => {
warn!("Failed to create CH9329 backend: {}", e);
None
}
}
}
HidBackendType::None => {
warn!("HID backend disabled");
None
}
};
*self.backend.write().await = new_backend;
if self.backend.read().await.is_some() {
info!("HID backend reloaded successfully: {:?}", new_backend_type);
// Update backend_type on success
*self.backend_type.write().await = new_backend_type.clone();
// Reset monitor state on successful reload
self.monitor.reset().await;
// Publish HID state changed event
let backend_name = new_backend_type.name_str().to_string();
self.publish_event(crate::events::SystemEvent::HidStateChanged {
backend: backend_name,
initialized: true,
error: None,
error_code: None,
})
.await;
Ok(())
} else {
warn!("HID backend reload resulted in no active backend");
// Update backend_type even on failure (to reflect the attempted change)
*self.backend_type.write().await = new_backend_type.clone();
// Publish event with initialized=false
self.publish_event(crate::events::SystemEvent::HidStateChanged {
backend: new_backend_type.name_str().to_string(),
initialized: false,
error: Some("Failed to initialize HID backend".to_string()),
error_code: Some("init_failed".to_string()),
})
.await;
Err(AppError::Internal(
"Failed to reload HID backend".to_string(),
))
}
}
/// Publish event to event bus if available
async fn publish_event(&self, event: crate::events::SystemEvent) {
if let Some(events) = self.events.read().await.as_ref() {
events.publish(event);
}
}
}
impl Default for HidController {
fn default() -> Self {
Self::new(HidBackendType::None, None)
}
}

429
src/hid/monitor.rs Normal file
View File

@@ -0,0 +1,429 @@
//! HID device health monitoring
//!
//! This module provides health monitoring for HID devices, including:
//! - Device connectivity checks
//! - Automatic reconnection on failure
//! - Error tracking and notification
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, warn};
use crate::events::{EventBus, SystemEvent};
use crate::utils::LogThrottler;
/// HID health status
#[derive(Debug, Clone, PartialEq)]
pub enum HidHealthStatus {
/// Device is healthy and operational
Healthy,
/// Device has an error, attempting recovery
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
/// Number of recovery attempts made
retry_count: u32,
},
/// Device is disconnected
Disconnected,
}
impl Default for HidHealthStatus {
fn default() -> Self {
Self::Healthy
}
}
/// HID health monitor configuration
#[derive(Debug, Clone)]
pub struct HidMonitorConfig {
/// Health check interval in milliseconds
pub check_interval_ms: u64,
/// Retry interval when device is lost (milliseconds)
pub retry_interval_ms: u64,
/// Maximum retry attempts before giving up (0 = infinite)
pub max_retries: u32,
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
/// Recovery cooldown in milliseconds (suppress logs after recovery)
pub recovery_cooldown_ms: u64,
}
impl Default for HidMonitorConfig {
fn default() -> Self {
Self {
check_interval_ms: 1000,
retry_interval_ms: 1000,
max_retries: 0, // infinite retry
log_throttle_secs: 5,
recovery_cooldown_ms: 1000, // 1 second cooldown after recovery
}
}
}
/// HID health monitor
///
/// Monitors HID device health and manages error recovery.
/// Publishes WebSocket events when device status changes.
pub struct HidHealthMonitor {
/// Current health status
status: RwLock<HidHealthStatus>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
config: HidMonitorConfig,
/// Whether monitoring is active (reserved for future use)
#[allow(dead_code)]
running: AtomicBool,
/// Current retry count
retry_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
/// Last recovery timestamp (milliseconds since start, for cooldown)
last_recovery_ms: AtomicU64,
/// Start instant for timing
start_instant: Instant,
}
impl HidHealthMonitor {
/// Create a new HID health monitor with the specified configuration
pub fn new(config: HidMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
Self {
status: RwLock::new(HidHealthStatus::Healthy),
events: RwLock::new(None),
throttler: LogThrottler::with_secs(throttle_secs),
config,
running: AtomicBool::new(false),
retry_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
last_recovery_ms: AtomicU64::new(0),
start_instant: Instant::now(),
}
}
/// Create a new HID health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(HidMonitorConfig::default())
}
/// Set the event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Report an error from HID operations
///
/// This method is called when an HID operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling and cooldown respect)
/// 3. Publishes a WebSocket event if the error is new or changed
///
/// # Arguments
///
/// * `backend` - The HID backend type ("otg" or "ch9329")
/// * `device` - The device path (if known)
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(
&self,
backend: &str,
device: Option<&str>,
reason: &str,
error_code: &str,
) {
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if we're in cooldown period after recent recovery
let current_ms = self.start_instant.elapsed().as_millis() as u64;
let last_recovery = self.last_recovery_ms.load(Ordering::Relaxed);
let in_cooldown = last_recovery > 0 && current_ms < last_recovery + self.config.recovery_cooldown_ms;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (skip if in cooldown period unless error type changed)
let throttle_key = format!("hid_{}_{}", backend, error_code);
if !in_cooldown && (error_changed || self.throttler.should_log(&throttle_key)) {
warn!(
"HID {} error: {} (code: {}, attempt: {})",
backend, reason, error_code, count
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = HidHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
retry_count: count,
};
// Publish event (only if error changed or first occurrence, and not in cooldown)
if !in_cooldown && (error_changed || count == 1) {
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::HidDeviceLost {
backend: backend.to_string(),
device: device.map(|s| s.to_string()),
reason: reason.to_string(),
error_code: error_code.to_string(),
});
}
}
}
/// Report that a reconnection attempt is starting
///
/// Publishes a reconnecting event to notify clients.
///
/// # Arguments
///
/// * `backend` - The HID backend type
pub async fn report_reconnecting(&self, backend: &str) {
let attempt = self.retry_count.load(Ordering::Relaxed);
// Only publish every 5 attempts to avoid event spam
if attempt == 1 || attempt % 5 == 0 {
debug!("HID {} reconnecting, attempt {}", backend, attempt);
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::HidReconnecting {
backend: backend.to_string(),
attempt,
});
}
}
}
/// Report that the device has recovered
///
/// This method is called when the HID device successfully reconnects.
/// It resets the error state and publishes a recovery event.
///
/// # Arguments
///
/// * `backend` - The HID backend type
pub async fn report_recovered(&self, backend: &str) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != HidHealthStatus::Healthy {
let retry_count = self.retry_count.load(Ordering::Relaxed);
// Set cooldown timestamp
let current_ms = self.start_instant.elapsed().as_millis() as u64;
self.last_recovery_ms.store(current_ms, Ordering::Relaxed);
// Only log and publish events if there were multiple retries
// (avoid log spam for transient single-retry recoveries)
if retry_count > 1 {
debug!(
"HID {} recovered after {} retries",
backend, retry_count
);
// Publish recovery event
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::HidRecovered {
backend: backend.to_string(),
});
// Also publish state changed to indicate healthy state
events.publish(SystemEvent::HidStateChanged {
backend: backend.to_string(),
initialized: true,
error: None,
error_code: None,
});
}
}
// Reset state (always reset, even for single-retry recoveries)
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = HidHealthStatus::Healthy;
}
}
/// Get the current health status
pub async fn status(&self) -> HidHealthStatus {
self.status.read().await.clone()
}
/// Get the current retry count
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, HidHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, HidHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = HidHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the configuration
pub fn config(&self) -> &HidMonitorConfig {
&self.config
}
/// Check if we should continue retrying
///
/// Returns `false` if max_retries is set and we've exceeded it.
pub fn should_retry(&self) -> bool {
if self.config.max_retries == 0 {
return true; // Infinite retry
}
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
}
/// Get the retry interval
pub fn retry_interval(&self) -> Duration {
Duration::from_millis(self.config.retry_interval_ms)
}
}
impl Default for HidHealthMonitor {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_status() {
let monitor = HidHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = HidHealthMonitor::with_defaults();
monitor
.report_error("otg", Some("/dev/hidg0"), "Device not found", "enoent")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.retry_count(), 1);
if let HidHealthStatus::Error {
reason,
error_code,
retry_count,
} = monitor.status().await
{
assert_eq!(reason, "Device not found");
assert_eq!(error_code, "enoent");
assert_eq!(retry_count, 1);
} else {
panic!("Expected Error status");
}
}
#[tokio::test]
async fn test_report_recovered() {
let monitor = HidHealthMonitor::with_defaults();
// First report an error
monitor
.report_error("ch9329", None, "Port not found", "port_not_found")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered("ch9329").await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_retry_count_increments() {
let monitor = HidHealthMonitor::with_defaults();
for i in 1..=5 {
monitor
.report_error("otg", None, "Error", "io_error")
.await;
assert_eq!(monitor.retry_count(), i);
}
}
#[tokio::test]
async fn test_should_retry_infinite() {
let monitor = HidHealthMonitor::new(HidMonitorConfig {
max_retries: 0, // infinite
..Default::default()
});
for _ in 0..100 {
monitor
.report_error("otg", None, "Error", "io_error")
.await;
assert!(monitor.should_retry());
}
}
#[tokio::test]
async fn test_should_retry_limited() {
let monitor = HidHealthMonitor::new(HidMonitorConfig {
max_retries: 3,
..Default::default()
});
assert!(monitor.should_retry());
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(monitor.should_retry()); // 1 < 3
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(monitor.should_retry()); // 2 < 3
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(!monitor.should_retry()); // 3 >= 3
}
#[tokio::test]
async fn test_reset() {
let monitor = HidHealthMonitor::with_defaults();
monitor
.report_error("otg", None, "Error", "io_error")
.await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.retry_count(), 0);
}
}

848
src/hid/otg.rs Normal file
View File

@@ -0,0 +1,848 @@
//! OTG USB Gadget HID backend
//!
//! This backend uses Linux USB Gadget API to emulate USB HID devices.
//! It creates and manages three HID devices:
//! - hidg0: Keyboard (8-byte reports, with LED feedback)
//! - hidg1: Relative Mouse (4-byte reports)
//! - hidg2: Absolute Mouse (6-byte reports)
//!
//! Requirements:
//! - USB OTG/Device controller (UDC)
//! - ConfigFS with USB gadget support
//! - Root privileges for gadget setup
//!
//! Error Recovery:
//! This module implements automatic device reconnection based on PiKVM's approach.
//! When ESHUTDOWN or EAGAIN errors occur (common during MSD operations), the device
//! file handles are closed and reopened on the next operation.
//! See: https://github.com/raspberrypi/linux/issues/4373
use async_trait::async_trait;
use parking_lot::Mutex;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use tracing::{debug, info, trace, warn};
use super::backend::HidBackend;
use super::keymap;
use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType};
use crate::error::{AppError, Result};
use crate::otg::{HidDevicePaths, wait_for_hid_devices};
/// Device type for ensure_device operations
#[derive(Debug, Clone, Copy)]
enum DeviceType {
Keyboard,
MouseRelative,
MouseAbsolute,
}
/// Keyboard LED state
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct LedState {
/// Num Lock LED
pub num_lock: bool,
/// Caps Lock LED
pub caps_lock: bool,
/// Scroll Lock LED
pub scroll_lock: bool,
/// Compose LED
pub compose: bool,
/// Kana LED
pub kana: bool,
}
impl LedState {
/// Create from raw byte
pub fn from_byte(b: u8) -> Self {
Self {
num_lock: b & 0x01 != 0,
caps_lock: b & 0x02 != 0,
scroll_lock: b & 0x04 != 0,
compose: b & 0x08 != 0,
kana: b & 0x10 != 0,
}
}
/// Convert to raw byte
pub fn to_byte(&self) -> u8 {
let mut b = 0u8;
if self.num_lock { b |= 0x01; }
if self.caps_lock { b |= 0x02; }
if self.scroll_lock { b |= 0x04; }
if self.compose { b |= 0x08; }
if self.kana { b |= 0x10; }
b
}
}
/// OTG HID backend with 3 devices
///
/// This backend opens HID device files created by OtgService.
/// It does NOT manage the USB gadget itself - that's handled by OtgService.
///
/// ## Error Recovery
///
/// Based on PiKVM's implementation, this backend automatically handles:
/// - EAGAIN (errno 11): Resource temporarily unavailable - just retry later, don't close device
/// - ESHUTDOWN (errno 108): Transport endpoint shutdown - close and reopen device
///
/// When ESHUTDOWN occurs, the device file handle is closed and will be
/// reopened on the next operation attempt.
pub struct OtgBackend {
/// Keyboard device path (/dev/hidg0)
keyboard_path: PathBuf,
/// Relative mouse device path (/dev/hidg1)
mouse_rel_path: PathBuf,
/// Absolute mouse device path (/dev/hidg2)
mouse_abs_path: PathBuf,
/// Keyboard device file
keyboard_dev: Mutex<Option<File>>,
/// Relative mouse device file
mouse_rel_dev: Mutex<Option<File>>,
/// Absolute mouse device file
mouse_abs_dev: Mutex<Option<File>>,
/// Current keyboard state
keyboard_state: Mutex<KeyboardReport>,
/// Current mouse button state
mouse_buttons: AtomicU8,
/// Last known LED state (using parking_lot::RwLock for sync access)
led_state: parking_lot::RwLock<LedState>,
/// Screen resolution for absolute mouse (using parking_lot::RwLock for sync access)
screen_resolution: parking_lot::RwLock<Option<(u32, u32)>>,
/// UDC name for state checking (e.g., "fcc00000.usb")
udc_name: parking_lot::RwLock<Option<String>>,
/// Whether the device is currently online (UDC configured and devices accessible)
online: AtomicBool,
/// Last error log time for throttling (using parking_lot for sync)
last_error_log: parking_lot::Mutex<std::time::Instant>,
/// Error count since last successful operation (for log throttling)
error_count: AtomicU8,
/// Consecutive EAGAIN count (for offline threshold detection)
eagain_count: AtomicU8,
}
/// Threshold for consecutive EAGAIN errors before reporting offline
const EAGAIN_OFFLINE_THRESHOLD: u8 = 3;
impl OtgBackend {
/// Create OTG backend from device paths provided by OtgService
///
/// This is the ONLY way to create an OtgBackend - it no longer manages
/// the USB gadget itself. The gadget must already be set up by OtgService.
pub fn from_handles(paths: HidDevicePaths) -> Result<Self> {
Ok(Self {
keyboard_path: paths.keyboard,
mouse_rel_path: paths.mouse_relative,
mouse_abs_path: paths.mouse_absolute,
keyboard_dev: Mutex::new(None),
mouse_rel_dev: Mutex::new(None),
mouse_abs_dev: Mutex::new(None),
keyboard_state: Mutex::new(KeyboardReport::default()),
mouse_buttons: AtomicU8::new(0),
led_state: parking_lot::RwLock::new(LedState::default()),
screen_resolution: parking_lot::RwLock::new(Some((1920, 1080))),
udc_name: parking_lot::RwLock::new(None),
online: AtomicBool::new(false),
last_error_log: parking_lot::Mutex::new(std::time::Instant::now()),
error_count: AtomicU8::new(0),
eagain_count: AtomicU8::new(0),
})
}
/// Log throttled error message (max once per second)
fn log_throttled_error(&self, msg: &str) {
let mut last_log = self.last_error_log.lock();
let now = std::time::Instant::now();
if now.duration_since(*last_log).as_secs() >= 1 {
let count = self.error_count.swap(0, Ordering::Relaxed);
if count > 1 {
warn!("{} (repeated {} times)", msg, count);
} else {
warn!("{}", msg);
}
*last_log = now;
} else {
self.error_count.fetch_add(1, Ordering::Relaxed);
}
}
/// Reset error count on successful operation
fn reset_error_count(&self) {
self.error_count.store(0, Ordering::Relaxed);
// Also reset EAGAIN count - successful operation means device is working
self.eagain_count.store(0, Ordering::Relaxed);
}
/// Set the UDC name for state checking
pub fn set_udc_name(&self, udc: &str) {
*self.udc_name.write() = Some(udc.to_string());
}
/// Check if the UDC is in "configured" state
///
/// This is based on PiKVM's `__is_udc_configured()` method.
/// The UDC state file indicates whether the USB host has enumerated and configured the gadget.
pub fn is_udc_configured(&self) -> bool {
let udc_name = self.udc_name.read();
if let Some(ref udc) = *udc_name {
let state_path = format!("/sys/class/udc/{}/state", udc);
match fs::read_to_string(&state_path) {
Ok(content) => {
let state = content.trim().to_lowercase();
trace!("UDC {} state: {}", udc, state);
state == "configured"
}
Err(e) => {
debug!("Failed to read UDC state from {}: {}", state_path, e);
// If we can't read the state, assume it might be configured
// to avoid blocking operations unnecessarily
true
}
}
} else {
// No UDC name set, try to auto-detect
if let Some(udc) = Self::find_udc() {
drop(udc_name);
*self.udc_name.write() = Some(udc.clone());
let state_path = format!("/sys/class/udc/{}/state", udc);
fs::read_to_string(&state_path)
.map(|s| s.trim().to_lowercase() == "configured")
.unwrap_or(true)
} else {
true
}
}
}
/// Find the first available UDC
fn find_udc() -> Option<String> {
let udc_path = PathBuf::from("/sys/class/udc");
if let Ok(entries) = fs::read_dir(&udc_path) {
for entry in entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
return Some(name.to_string());
}
}
}
None
}
/// Check if device is online
pub fn is_online(&self) -> bool {
self.online.load(Ordering::Relaxed)
}
/// Ensure a device is open and ready for I/O
///
/// This method is based on PiKVM's `__ensure_device()` pattern:
/// 1. Check if device path exists, close handle if not
/// 2. If handle is None but path exists, reopen the device
/// 3. Return whether the device is ready for I/O
fn ensure_device(&self, device_type: DeviceType) -> Result<()> {
let (path, dev_mutex) = match device_type {
DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev),
DeviceType::MouseRelative => (&self.mouse_rel_path, &self.mouse_rel_dev),
DeviceType::MouseAbsolute => (&self.mouse_abs_path, &self.mouse_abs_dev),
};
// Check if device path exists
if !path.exists() {
// Close the device if open (device was removed)
let mut dev = dev_mutex.lock();
if dev.is_some() {
debug!("Device path {} no longer exists, closing handle", path.display());
*dev = None;
}
self.online.store(false, Ordering::Relaxed);
return Err(AppError::HidError {
backend: "otg".to_string(),
reason: format!("Device not found: {}", path.display()),
error_code: "enoent".to_string(),
});
}
// If device is not open, try to open it
let mut dev = dev_mutex.lock();
if dev.is_none() {
match Self::open_device(path) {
Ok(file) => {
info!("Reopened HID device: {}", path.display());
*dev = Some(file);
}
Err(e) => {
warn!("Failed to reopen HID device {}: {}", path.display(), e);
return Err(e);
}
}
}
self.online.store(true, Ordering::Relaxed);
Ok(())
}
/// Close a device (used when ESHUTDOWN is received)
#[allow(dead_code)]
fn close_device(&self, device_type: DeviceType) {
let dev_mutex = match device_type {
DeviceType::Keyboard => &self.keyboard_dev,
DeviceType::MouseRelative => &self.mouse_rel_dev,
DeviceType::MouseAbsolute => &self.mouse_abs_dev,
};
let mut dev = dev_mutex.lock();
if dev.is_some() {
debug!("Closing {:?} device handle for recovery", device_type);
*dev = None;
}
}
/// Close all device handles (for recovery)
#[allow(dead_code)]
fn close_all_devices(&self) {
self.close_device(DeviceType::Keyboard);
self.close_device(DeviceType::MouseRelative);
self.close_device(DeviceType::MouseAbsolute);
self.online.store(false, Ordering::Relaxed);
}
/// Open a HID device file with read/write access
fn open_device(path: &PathBuf) -> Result<File> {
OpenOptions::new()
.read(true)
.write(true)
.custom_flags(libc::O_NONBLOCK)
.open(path)
.map_err(|e| {
AppError::Internal(format!("Failed to open HID device {}: {}", path.display(), e))
})
}
/// Convert I/O error to HidError with appropriate error code
fn io_error_to_hid_error(e: std::io::Error, operation: &str) -> AppError {
let error_code = match e.raw_os_error() {
Some(32) => "epipe", // EPIPE - broken pipe
Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown
Some(11) => "eagain", // EAGAIN - resource temporarily unavailable
Some(6) => "enxio", // ENXIO - no such device or address
Some(19) => "enodev", // ENODEV - no such device
Some(5) => "eio", // EIO - I/O error
Some(2) => "enoent", // ENOENT - no such file or directory
_ => "io_error",
};
AppError::HidError {
backend: "otg".to_string(),
reason: format!("{}: {}", operation, e),
error_code: error_code.to_string(),
}
}
/// Check if all HID device files exist
pub fn check_devices_exist(&self) -> bool {
self.keyboard_path.exists()
&& self.mouse_rel_path.exists()
&& self.mouse_abs_path.exists()
}
/// Get list of missing device paths
pub fn get_missing_devices(&self) -> Vec<String> {
let mut missing = Vec::new();
if !self.keyboard_path.exists() {
missing.push(self.keyboard_path.display().to_string());
}
if !self.mouse_rel_path.exists() {
missing.push(self.mouse_rel_path.display().to_string());
}
if !self.mouse_abs_path.exists() {
missing.push(self.mouse_abs_path.display().to_string());
}
missing
}
/// Send keyboard report (8 bytes)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// EAGAIN errors are treated as temporary - device stays open.
fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> {
// Ensure device is ready
self.ensure_device(DeviceType::Keyboard)?;
let mut dev = self.keyboard_dev.lock();
if let Some(ref mut file) = *dev {
let data = report.to_bytes();
match file.write_all(&data) {
Ok(_) => {
self.online.store(true, Ordering::Relaxed);
self.reset_error_count();
trace!("Sent keyboard report: {:02X?}", data);
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
// ESHUTDOWN - endpoint closed, need to reopen device
self.online.store(false, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Keyboard ESHUTDOWN, closing for recovery");
*dev = None;
Err(Self::io_error_to_hid_error(e, "Failed to write keyboard report"))
}
Some(11) => {
// EAGAIN - temporary busy, track consecutive count
self.log_throttled_error("HID keyboard busy (EAGAIN)");
let count = self.eagain_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= EAGAIN_OFFLINE_THRESHOLD {
// Exceeded threshold, report as offline
self.online.store(false, Ordering::Relaxed);
Err(AppError::HidError {
backend: "otg".to_string(),
reason: format!("Device busy ({} consecutive EAGAIN)", count),
error_code: "eagain".to_string(),
})
} else {
// Within threshold, return retry error (won't trigger offline event)
Err(AppError::HidError {
backend: "otg".to_string(),
reason: "Device temporarily busy".to_string(),
error_code: "eagain_retry".to_string(),
})
}
}
_ => {
self.online.store(false, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Keyboard write error: {}", e);
Err(Self::io_error_to_hid_error(e, "Failed to write keyboard report"))
}
}
}
}
} else {
Err(AppError::HidError {
backend: "otg".to_string(),
reason: "Keyboard device not opened".to_string(),
error_code: "not_opened".to_string(),
})
}
}
/// Send relative mouse report (4 bytes: buttons, dx, dy, wheel)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// EAGAIN errors are treated as temporary - device stays open.
fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> {
// Ensure device is ready
self.ensure_device(DeviceType::MouseRelative)?;
let mut dev = self.mouse_rel_dev.lock();
if let Some(ref mut file) = *dev {
let data = [buttons, dx as u8, dy as u8, wheel as u8];
match file.write_all(&data) {
Ok(_) => {
self.online.store(true, Ordering::Relaxed);
self.reset_error_count();
trace!("Sent relative mouse report: {:02X?}", data);
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.online.store(false, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Relative mouse ESHUTDOWN, closing for recovery");
*dev = None;
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
}
Some(11) => {
// EAGAIN - temporary busy, track consecutive count
self.log_throttled_error("HID relative mouse busy (EAGAIN)");
let count = self.eagain_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= EAGAIN_OFFLINE_THRESHOLD {
// Exceeded threshold, report as offline
self.online.store(false, Ordering::Relaxed);
Err(AppError::HidError {
backend: "otg".to_string(),
reason: format!("Device busy ({} consecutive EAGAIN)", count),
error_code: "eagain".to_string(),
})
} else {
// Within threshold, return retry error (won't trigger offline event)
Err(AppError::HidError {
backend: "otg".to_string(),
reason: "Device temporarily busy".to_string(),
error_code: "eagain_retry".to_string(),
})
}
}
_ => {
self.online.store(false, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Relative mouse write error: {}", e);
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
}
}
}
}
} else {
Err(AppError::HidError {
backend: "otg".to_string(),
reason: "Relative mouse device not opened".to_string(),
error_code: "not_opened".to_string(),
})
}
}
/// Send absolute mouse report (6 bytes: buttons, x_lo, x_hi, y_lo, y_hi, wheel)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// EAGAIN errors are treated as temporary - device stays open.
fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> {
// Ensure device is ready
self.ensure_device(DeviceType::MouseAbsolute)?;
let mut dev = self.mouse_abs_dev.lock();
if let Some(ref mut file) = *dev {
let data = [
buttons,
(x & 0xFF) as u8,
(x >> 8) as u8,
(y & 0xFF) as u8,
(y >> 8) as u8,
wheel as u8,
];
match file.write_all(&data) {
Ok(_) => {
self.online.store(true, Ordering::Relaxed);
self.reset_error_count();
trace!("Sent absolute mouse report: {:02X?}", data);
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.online.store(false, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Absolute mouse ESHUTDOWN, closing for recovery");
*dev = None;
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
}
Some(11) => {
// EAGAIN - temporary busy, track consecutive count
self.log_throttled_error("HID absolute mouse busy (EAGAIN)");
let count = self.eagain_count.fetch_add(1, Ordering::Relaxed) + 1;
if count >= EAGAIN_OFFLINE_THRESHOLD {
// Exceeded threshold, report as offline
self.online.store(false, Ordering::Relaxed);
Err(AppError::HidError {
backend: "otg".to_string(),
reason: format!("Device busy ({} consecutive EAGAIN)", count),
error_code: "eagain".to_string(),
})
} else {
// Within threshold, return retry error (won't trigger offline event)
Err(AppError::HidError {
backend: "otg".to_string(),
reason: "Device temporarily busy".to_string(),
error_code: "eagain_retry".to_string(),
})
}
}
_ => {
self.online.store(false, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Absolute mouse write error: {}", e);
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
}
}
}
}
} else {
Err(AppError::HidError {
backend: "otg".to_string(),
reason: "Absolute mouse device not opened".to_string(),
error_code: "not_opened".to_string(),
})
}
}
/// Read keyboard LED state (non-blocking)
pub fn read_led_state(&self) -> Result<Option<LedState>> {
let mut dev = self.keyboard_dev.lock();
if let Some(ref mut file) = *dev {
let mut buf = [0u8; 1];
match file.read(&mut buf) {
Ok(1) => {
let state = LedState::from_byte(buf[0]);
// Update LED state (using parking_lot RwLock)
*self.led_state.write() = state;
Ok(Some(state))
}
Ok(_) => Ok(None), // No data available
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(AppError::Internal(format!("Failed to read LED state: {}", e))),
}
} else {
Ok(None)
}
}
/// Get last known LED state
pub fn led_state(&self) -> LedState {
*self.led_state.read()
}
}
#[async_trait]
impl HidBackend for OtgBackend {
fn name(&self) -> &'static str {
"OTG USB Gadget"
}
async fn init(&self) -> Result<()> {
info!("Initializing OTG HID backend");
// Auto-detect UDC name for state checking
if let Some(udc) = Self::find_udc() {
info!("Auto-detected UDC: {}", udc);
self.set_udc_name(&udc);
}
// Wait for devices to appear (they should already exist from OtgService)
let device_paths = vec![
self.keyboard_path.clone(),
self.mouse_rel_path.clone(),
self.mouse_abs_path.clone(),
];
if !wait_for_hid_devices(&device_paths, 2000).await {
return Err(AppError::Internal("HID devices did not appear".into()));
}
// Open keyboard device
if self.keyboard_path.exists() {
let file = Self::open_device(&self.keyboard_path)?;
*self.keyboard_dev.lock() = Some(file);
info!("Keyboard device opened: {}", self.keyboard_path.display());
} else {
warn!("Keyboard device not found: {}", self.keyboard_path.display());
}
// Open relative mouse device
if self.mouse_rel_path.exists() {
let file = Self::open_device(&self.mouse_rel_path)?;
*self.mouse_rel_dev.lock() = Some(file);
info!("Relative mouse device opened: {}", self.mouse_rel_path.display());
} else {
warn!("Relative mouse device not found: {}", self.mouse_rel_path.display());
}
// Open absolute mouse device
if self.mouse_abs_path.exists() {
let file = Self::open_device(&self.mouse_abs_path)?;
*self.mouse_abs_dev.lock() = Some(file);
info!("Absolute mouse device opened: {}", self.mouse_abs_path.display());
} else {
warn!("Absolute mouse device not found: {}", self.mouse_abs_path.display());
}
// Mark as online if all devices opened successfully
self.online.store(true, Ordering::Relaxed);
Ok(())
}
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
// Convert JS keycode to USB HID if needed
let usb_key = keymap::js_to_usb(event.key).unwrap_or(event.key);
// Handle modifier keys separately
if keymap::is_modifier_key(usb_key) {
let mut state = self.keyboard_state.lock();
if let Some(bit) = keymap::modifier_bit(usb_key) {
match event.event_type {
KeyEventType::Down => state.modifiers |= bit,
KeyEventType::Up => state.modifiers &= !bit,
}
}
let report = state.clone();
drop(state);
self.send_keyboard_report(&report)?;
} else {
let mut state = self.keyboard_state.lock();
// Update modifiers from event
state.modifiers = event.modifiers.to_hid_byte();
match event.event_type {
KeyEventType::Down => {
state.add_key(usb_key);
}
KeyEventType::Up => {
state.remove_key(usb_key);
}
}
let report = state.clone();
drop(state);
self.send_keyboard_report(&report)?;
}
Ok(())
}
async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
let buttons = self.mouse_buttons.load(Ordering::Relaxed);
match event.event_type {
MouseEventType::Move => {
// Relative movement - use hidg1
let dx = event.x.clamp(-127, 127) as i8;
let dy = event.y.clamp(-127, 127) as i8;
self.send_mouse_report_relative(buttons, dx, dy, 0)?;
}
MouseEventType::MoveAbs => {
// Absolute movement - use hidg2
// Frontend sends 0-32767 range directly (standard HID absolute mouse range)
let x = event.x.clamp(0, 32767) as u16;
let y = event.y.clamp(0, 32767) as u16;
self.send_mouse_report_absolute(buttons, x, y, 0)?;
}
MouseEventType::Down => {
if let Some(button) = event.button {
let bit = button.to_hid_bit();
let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit;
// Send on relative device for button clicks
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
}
}
MouseEventType::Up => {
if let Some(button) = event.button {
let bit = button.to_hid_bit();
let new_buttons = self.mouse_buttons.fetch_and(!bit, Ordering::Relaxed) & !bit;
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
}
}
MouseEventType::Scroll => {
self.send_mouse_report_relative(buttons, 0, 0, event.scroll)?;
}
}
Ok(())
}
async fn reset(&self) -> Result<()> {
// Reset keyboard
{
let mut state = self.keyboard_state.lock();
state.clear();
let report = state.clone();
drop(state);
self.send_keyboard_report(&report)?;
}
// Reset mouse
self.mouse_buttons.store(0, Ordering::Relaxed);
self.send_mouse_report_relative(0, 0, 0, 0)?;
self.send_mouse_report_absolute(0, 0, 0, 0)?;
info!("HID state reset");
Ok(())
}
async fn shutdown(&self) -> Result<()> {
// Reset before closing
self.reset().await?;
// Close devices
*self.keyboard_dev.lock() = None;
*self.mouse_rel_dev.lock() = None;
*self.mouse_abs_dev.lock() = None;
// Gadget cleanup is handled by OtgService, not here
info!("OTG backend shutdown");
Ok(())
}
fn supports_absolute_mouse(&self) -> bool {
self.mouse_abs_path.exists()
}
fn screen_resolution(&self) -> Option<(u32, u32)> {
*self.screen_resolution.read()
}
fn set_screen_resolution(&mut self, width: u32, height: u32) {
*self.screen_resolution.write() = Some((width, height));
}
}
/// Check if OTG HID gadget is available
pub fn is_otg_available() -> bool {
// Check for existing HID devices (they should be created by OtgService)
let kb = PathBuf::from("/dev/hidg0");
let mouse_rel = PathBuf::from("/dev/hidg1");
let mouse_abs = PathBuf::from("/dev/hidg2");
kb.exists() && mouse_rel.exists() && mouse_abs.exists()
}
/// Implement Drop for OtgBackend to close device files
impl Drop for OtgBackend {
fn drop(&mut self) {
// Close device files
// Note: Gadget cleanup is handled by OtgService, not here
*self.keyboard_dev.lock() = None;
*self.mouse_rel_dev.lock() = None;
*self.mouse_abs_dev.lock() = None;
debug!("OtgBackend dropped, device files closed");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_otg_availability_check() {
// This just tests the function runs without panicking
let _available = is_otg_available();
}
#[test]
fn test_led_state() {
let state = LedState::from_byte(0b00000011);
assert!(state.num_lock);
assert!(state.caps_lock);
assert!(!state.scroll_lock);
assert_eq!(state.to_byte(), 0b00000011);
}
#[test]
fn test_report_sizes() {
// Keyboard report is 8 bytes
let kb_report = KeyboardReport::default();
assert_eq!(kb_report.to_bytes().len(), 8);
}
}

382
src/hid/types.rs Normal file
View File

@@ -0,0 +1,382 @@
//! HID event types for keyboard and mouse
use serde::{Deserialize, Serialize};
/// Keyboard event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KeyEventType {
/// Key pressed down
Down,
/// Key released
Up,
}
/// Keyboard modifier flags
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeyboardModifiers {
/// Left Control
#[serde(default)]
pub left_ctrl: bool,
/// Left Shift
#[serde(default)]
pub left_shift: bool,
/// Left Alt
#[serde(default)]
pub left_alt: bool,
/// Left Meta (Windows/Super key)
#[serde(default)]
pub left_meta: bool,
/// Right Control
#[serde(default)]
pub right_ctrl: bool,
/// Right Shift
#[serde(default)]
pub right_shift: bool,
/// Right Alt (AltGr)
#[serde(default)]
pub right_alt: bool,
/// Right Meta
#[serde(default)]
pub right_meta: bool,
}
impl KeyboardModifiers {
/// Convert to USB HID modifier byte
pub fn to_hid_byte(&self) -> u8 {
let mut byte = 0u8;
if self.left_ctrl {
byte |= 0x01;
}
if self.left_shift {
byte |= 0x02;
}
if self.left_alt {
byte |= 0x04;
}
if self.left_meta {
byte |= 0x08;
}
if self.right_ctrl {
byte |= 0x10;
}
if self.right_shift {
byte |= 0x20;
}
if self.right_alt {
byte |= 0x40;
}
if self.right_meta {
byte |= 0x80;
}
byte
}
/// Create from USB HID modifier byte
pub fn from_hid_byte(byte: u8) -> Self {
Self {
left_ctrl: byte & 0x01 != 0,
left_shift: byte & 0x02 != 0,
left_alt: byte & 0x04 != 0,
left_meta: byte & 0x08 != 0,
right_ctrl: byte & 0x10 != 0,
right_shift: byte & 0x20 != 0,
right_alt: byte & 0x40 != 0,
right_meta: byte & 0x80 != 0,
}
}
/// Check if any modifier is active
pub fn any(&self) -> bool {
self.left_ctrl
|| self.left_shift
|| self.left_alt
|| self.left_meta
|| self.right_ctrl
|| self.right_shift
|| self.right_alt
|| self.right_meta
}
}
/// Keyboard event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyboardEvent {
/// Event type (down/up)
#[serde(rename = "type")]
pub event_type: KeyEventType,
/// Key code (USB HID usage code or JavaScript key code)
pub key: u8,
/// Modifier keys state
#[serde(default)]
pub modifiers: KeyboardModifiers,
}
impl KeyboardEvent {
/// Create a key down event
pub fn key_down(key: u8, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Down,
key,
modifiers,
}
}
/// Create a key up event
pub fn key_up(key: u8, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Up,
key,
modifiers,
}
}
}
/// Mouse button
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseButton {
Left,
Right,
Middle,
Back,
Forward,
}
impl MouseButton {
/// Convert to USB HID button bit
pub fn to_hid_bit(&self) -> u8 {
match self {
MouseButton::Left => 0x01,
MouseButton::Right => 0x02,
MouseButton::Middle => 0x04,
MouseButton::Back => 0x08,
MouseButton::Forward => 0x10,
}
}
}
/// Mouse event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseEventType {
/// Mouse moved (relative movement)
Move,
/// Mouse moved (absolute position)
MoveAbs,
/// Button pressed
Down,
/// Button released
Up,
/// Mouse wheel scroll
Scroll,
}
/// Mouse event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MouseEvent {
/// Event type
#[serde(rename = "type")]
pub event_type: MouseEventType,
/// X coordinate or delta
#[serde(default)]
pub x: i32,
/// Y coordinate or delta
#[serde(default)]
pub y: i32,
/// Button (for down/up events)
#[serde(default)]
pub button: Option<MouseButton>,
/// Scroll delta (for scroll events)
#[serde(default)]
pub scroll: i8,
}
impl MouseEvent {
/// Create a relative move event
pub fn move_rel(dx: i32, dy: i32) -> Self {
Self {
event_type: MouseEventType::Move,
x: dx,
y: dy,
button: None,
scroll: 0,
}
}
/// Create an absolute move event
pub fn move_abs(x: i32, y: i32) -> Self {
Self {
event_type: MouseEventType::MoveAbs,
x,
y,
button: None,
scroll: 0,
}
}
/// Create a button down event
pub fn button_down(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Down,
x: 0,
y: 0,
button: Some(button),
scroll: 0,
}
}
/// Create a button up event
pub fn button_up(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Up,
x: 0,
y: 0,
button: Some(button),
scroll: 0,
}
}
/// Create a scroll event
pub fn scroll(delta: i8) -> Self {
Self {
event_type: MouseEventType::Scroll,
x: 0,
y: 0,
button: None,
scroll: delta,
}
}
}
/// Combined HID event (keyboard or mouse)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "device", rename_all = "lowercase")]
pub enum HidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
}
/// USB HID keyboard report (8 bytes)
#[derive(Debug, Clone, Default)]
pub struct KeyboardReport {
/// Modifier byte
pub modifiers: u8,
/// Reserved byte
pub reserved: u8,
/// Key codes (up to 6 simultaneous keys)
pub keys: [u8; 6],
}
impl KeyboardReport {
/// Convert to bytes for USB HID
pub fn to_bytes(&self) -> [u8; 8] {
[
self.modifiers,
self.reserved,
self.keys[0],
self.keys[1],
self.keys[2],
self.keys[3],
self.keys[4],
self.keys[5],
]
}
/// Add a key to the report
pub fn add_key(&mut self, key: u8) -> bool {
for slot in &mut self.keys {
if *slot == 0 {
*slot = key;
return true;
}
}
false // All slots full
}
/// Remove a key from the report
pub fn remove_key(&mut self, key: u8) {
for slot in &mut self.keys {
if *slot == key {
*slot = 0;
}
}
// Compact the array
self.keys.sort_by(|a, b| b.cmp(a));
}
/// Clear all keys
pub fn clear(&mut self) {
self.modifiers = 0;
self.keys = [0; 6];
}
}
/// USB HID mouse report
#[derive(Debug, Clone, Default)]
pub struct MouseReport {
/// Button state
pub buttons: u8,
/// X movement (-127 to 127)
pub x: i8,
/// Y movement (-127 to 127)
pub y: i8,
/// Wheel movement (-127 to 127)
pub wheel: i8,
}
impl MouseReport {
/// Convert to bytes for USB HID (relative mouse)
pub fn to_bytes_relative(&self) -> [u8; 4] {
[
self.buttons,
self.x as u8,
self.y as u8,
self.wheel as u8,
]
}
/// Convert to bytes for USB HID (absolute mouse)
pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] {
[
self.buttons,
(x & 0xFF) as u8,
(x >> 8) as u8,
(y & 0xFF) as u8,
(y >> 8) as u8,
self.wheel as u8,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_modifier_conversion() {
let mods = KeyboardModifiers {
left_ctrl: true,
left_shift: true,
..Default::default()
};
assert_eq!(mods.to_hid_byte(), 0x03);
let mods2 = KeyboardModifiers::from_hid_byte(0x03);
assert!(mods2.left_ctrl);
assert!(mods2.left_shift);
assert!(!mods2.left_alt);
}
#[test]
fn test_keyboard_report() {
let mut report = KeyboardReport::default();
assert!(report.add_key(0x04)); // 'A'
assert!(report.add_key(0x05)); // 'B'
assert_eq!(report.keys[0], 0x04);
assert_eq!(report.keys[1], 0x05);
report.remove_key(0x04);
assert_eq!(report.keys[0], 0x05);
}
}

160
src/hid/websocket.rs Normal file
View File

@@ -0,0 +1,160 @@
//! WebSocket HID channel for HTTP/MJPEG mode
//!
//! This provides an alternative to WebRTC DataChannel for HID input
//! when using MJPEG streaming mode.
//!
//! Uses binary protocol only (same format as DataChannel):
//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes)
//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes)
//!
//! See datachannel.rs for detailed protocol specification.
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::Response,
};
use futures::{SinkExt, StreamExt};
use std::sync::Arc;
use tracing::{debug, error, info, warn};
use super::datachannel::{parse_hid_message, HidChannelEvent};
use crate::state::AppState;
use crate::utils::LogThrottler;
/// Binary response codes
const RESP_OK: u8 = 0x00;
const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01;
const RESP_ERR_INVALID_MESSAGE: u8 = 0x02;
#[allow(dead_code)]
const RESP_ERR_SEND_FAILED: u8 = 0x03;
/// WebSocket HID upgrade handler
pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
ws.on_upgrade(move |socket| handle_hid_socket(socket, state))
}
/// Handle HID WebSocket connection
async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// Log throttler for error messages (5 second interval)
let log_throttler = LogThrottler::with_secs(5);
info!("WebSocket HID connection established (binary protocol)");
// Check if HID controller is available and send initial status
let hid_available = state.hid.is_available().await;
let initial_response = if hid_available {
vec![RESP_OK]
} else {
vec![RESP_ERR_HID_UNAVAILABLE]
};
if sender.send(Message::Binary(initial_response)).await.is_err() {
error!("Failed to send initial HID status");
return;
}
// Process incoming messages (binary only)
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Check HID availability before processing each message
let hid_available = state.hid.is_available().await;
if !hid_available {
if log_throttler.should_log("hid_unavailable") {
warn!("HID controller not available, ignoring message");
}
// Send error response (optional, for client awareness)
let _ = sender.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE])).await;
continue;
}
if let Err(e) = handle_binary_message(&data, &state).await {
// Log with throttling to avoid spam
if log_throttler.should_log("binary_hid_error") {
warn!("Binary HID message error: {}", e);
}
// Don't send error response for every failed message to reduce overhead
}
}
Ok(Message::Text(text)) => {
// Text messages are no longer supported
if log_throttler.should_log("text_message_rejected") {
debug!("Received text message (not supported): {} bytes", text.len());
}
let _ = sender.send(Message::Binary(vec![RESP_ERR_INVALID_MESSAGE])).await;
}
Ok(Message::Ping(data)) => {
let _ = sender.send(Message::Pong(data)).await;
}
Ok(Message::Close(_)) => {
info!("WebSocket HID connection closed by client");
break;
}
Err(e) => {
error!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
info!("WebSocket HID connection ended");
}
/// Handle binary HID message (same format as DataChannel)
async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> {
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
match event {
HidChannelEvent::Keyboard(kb_event) => {
state
.hid
.send_keyboard(kb_event)
.await
.map_err(|e| e.to_string())?;
}
HidChannelEvent::Mouse(ms_event) => {
state
.hid
.send_mouse(ms_event)
.await
.map_err(|e| e.to_string())?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hid::datachannel::{MSG_KEYBOARD, MSG_MOUSE, KB_EVENT_DOWN, MS_EVENT_MOVE};
#[test]
fn test_response_codes() {
assert_eq!(RESP_OK, 0x00);
assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01);
assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02);
assert_eq!(RESP_ERR_SEND_FAILED, 0x03);
}
#[test]
fn test_keyboard_message_format() {
// Keyboard message: [0x01, event_type, key, modifiers]
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl
let event = parse_hid_message(&data);
assert!(event.is_some());
}
#[test]
fn test_mouse_message_format() {
// Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra]
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
let event = parse_hid_message(&data);
assert!(event.is_some());
}
}

24
src/lib.rs Normal file
View File

@@ -0,0 +1,24 @@
//! One-KVM - Lightweight IP-KVM solution
//!
//! This crate provides the core functionality for One-KVM,
//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust.
pub mod atx;
pub mod audio;
pub mod auth;
pub mod config;
pub mod error;
pub mod events;
pub mod extensions;
pub mod hid;
pub mod modules;
pub mod msd;
pub mod otg;
pub mod state;
pub mod stream;
pub mod utils;
pub mod video;
pub mod web;
pub mod webrtc;
pub use error::{AppError, Result};

667
src/main.rs Normal file
View File

@@ -0,0 +1,667 @@
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use clap::{Parser, ValueEnum};
use tokio::sync::broadcast;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use rustls::crypto::{ring, CryptoProvider};
use one_kvm::atx::AtxController;
use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality};
use one_kvm::auth::{SessionStore, UserStore};
use one_kvm::config::{self, AppConfig, ConfigStore};
use one_kvm::events::EventBus;
use one_kvm::extensions::ExtensionManager;
use one_kvm::hid::{HidBackendType, HidController};
use one_kvm::msd::MsdController;
use one_kvm::otg::OtgService;
use one_kvm::state::AppState;
use one_kvm::video::format::{PixelFormat, Resolution};
use one_kvm::video::{Streamer, VideoStreamManager};
use one_kvm::web;
use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig};
/// Log level for the application
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
enum LogLevel {Error, Warn, #[default] Info, Verbose, Debug, Trace,}
/// One-KVM command line arguments
#[derive(Parser, Debug)]
#[command(name = "one-kvm")]
#[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)]
struct CliArgs {
/// Listen address
#[arg(short = 'a', long, value_name = "ADDRESS", default_value = "0.0.0.0")]
address: String,
/// HTTP port (used when HTTPS is disabled)
#[arg(short = 'p', long, value_name = "PORT", default_value = "8080")]
http_port: u16,
/// HTTPS port (used when HTTPS is enabled)
#[arg(long, value_name = "PORT", default_value = "8443")]
https_port: u16,
/// Enable HTTPS
#[arg(long, default_value = "false")]
enable_https: bool,
/// Path to SSL certificate file (generates self-signed if not provided)
#[arg(long, value_name = "FILE", requires = "ssl_key")]
ssl_cert: Option<PathBuf>,
/// Path to SSL private key file
#[arg(long, value_name = "FILE", requires = "ssl_cert")]
ssl_key: Option<PathBuf>,
/// Data directory path (default: ./data)
#[arg(short = 'd', long, value_name = "DIR")]
data_dir: Option<PathBuf>,
/// Log level (error, warn, info, verbose, debug, trace)
#[arg(short = 'l', long, value_name = "LEVEL", default_value = "info")]
log_level: LogLevel,
/// Increase verbosity (-v for verbose, -vv for debug, -vvv for trace)
#[arg(short = 'v', long, action = clap::ArgAction::Count)]
verbose: u8,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Parse command line arguments
let args = CliArgs::parse();
// Initialize logging with CLI arguments
init_logging(args.log_level, args.verbose);
// Install default crypto provider (required by rustls 0.23+)
CryptoProvider::install_default(ring::default_provider())
.expect("Failed to install rustls crypto provider");
tracing::info!(
"Starting One-KVM v{}",
env!("CARGO_PKG_VERSION")
);
// Determine data directory (CLI arg takes precedence)
let data_dir = args.data_dir.unwrap_or_else(get_data_dir);
tracing::info!("Data directory: {}", data_dir.display());
// Ensure data directory exists
tokio::fs::create_dir_all(&data_dir).await?;
// Initialize configuration store
let db_path = data_dir.join("one-kvm.db");
let config_store = ConfigStore::new(&db_path).await?;
let mut config = (*config_store.get()).clone();
// Apply CLI argument overrides to config
config.web.bind_address = args.address;
config.web.http_port = args.http_port;
config.web.https_port = args.https_port;
config.web.https_enabled = args.enable_https;
if let Some(cert_path) = args.ssl_cert {
config.web.ssl_cert_path = Some(cert_path.to_string_lossy().to_string());
}
if let Some(key_path) = args.ssl_key {
config.web.ssl_key_path = Some(key_path.to_string_lossy().to_string());
}
// Log final configuration
if config.web.https_enabled {
tracing::info!(
"Server will listen on: https://{}:{}",
config.web.bind_address,
config.web.https_port
);
} else {
tracing::info!(
"Server will listen on: http://{}:{}",
config.web.bind_address,
config.web.http_port
);
}
// Initialize session store
let session_store = SessionStore::new(
config_store.pool().clone(),
config.auth.session_timeout_secs as i64,
);
// Initialize user store
let user_store = UserStore::new(config_store.pool().clone());
// Create shutdown channel
let (shutdown_tx, _) = broadcast::channel::<()>(1);
// Create event bus for real-time notifications
let events = Arc::new(EventBus::new());
tracing::info!("Event bus initialized");
// Parse video configuration once (avoid duplication)
let (video_format, video_resolution) = parse_video_config(&config);
tracing::debug!("Parsed video config: {} @ {}x{}", video_format, video_resolution.width, video_resolution.height);
// Create video streamer and initialize with config if device is set
let streamer = Streamer::new();
streamer.set_event_bus(events.clone()).await;
if let Some(ref device_path) = config.video.device {
if let Err(e) = streamer
.apply_video_config(device_path, video_format, video_resolution, config.video.fps)
.await
{
tracing::warn!("Failed to initialize video with config: {}, will auto-detect", e);
} else {
tracing::info!(
"Video configured: {} @ {}x{} {}",
device_path, video_resolution.width, video_resolution.height, video_format
);
}
}
// Create WebRTC streamer
let webrtc_streamer = {
let webrtc_config = WebRtcStreamerConfig {
resolution: video_resolution,
input_format: video_format,
fps: config.video.fps,
bitrate_kbps: config.stream.bitrate_kbps,
gop_size: config.stream.gop_size,
encoder_backend: config.stream.encoder.to_backend(),
webrtc: {
let mut stun_servers = vec![];
let mut turn_servers = vec![];
// Add STUN server from config
if let Some(ref stun) = config.stream.stun_server {
if !stun.is_empty() {
stun_servers.push(stun.clone());
tracing::info!("WebRTC STUN server configured: {}", stun);
}
}
// Add TURN server from config
if let Some(ref turn) = config.stream.turn_server {
if !turn.is_empty() {
let username = config.stream.turn_username.clone().unwrap_or_default();
let credential = config.stream.turn_password.clone().unwrap_or_default();
turn_servers.push(one_kvm::webrtc::config::TurnServer {
url: turn.clone(),
username: username.clone(),
credential,
});
tracing::info!("WebRTC TURN server configured: {} (user: {})", turn, username);
}
}
one_kvm::webrtc::config::WebRtcConfig {
stun_servers,
turn_servers,
..Default::default()
}
},
..Default::default()
};
WebRtcStreamer::with_config(webrtc_config)
};
tracing::info!("WebRTC streamer created (supports H264, extensible to VP8/VP9/H265)");
// Create OTG Service (single instance for centralized USB gadget management)
let otg_service = Arc::new(OtgService::new());
tracing::info!("OTG Service created");
// Pre-enable OTG functions to avoid gadget recreation (prevents kernel crashes)
let will_use_otg_hid = matches!(config.hid.backend, config::HidBackend::Otg);
let will_use_msd = config.msd.enabled || will_use_otg_hid;
if will_use_otg_hid {
if !config.msd.enabled {
tracing::info!("OTG HID enabled, automatically enabling MSD functionality");
}
if let Err(e) = otg_service.enable_hid().await {
tracing::warn!("Failed to pre-enable HID: {}", e);
}
}
if will_use_msd {
if let Err(e) = otg_service.enable_msd().await {
tracing::warn!("Failed to pre-enable MSD: {}", e);
}
}
// Create HID controller based on config
let hid_backend = match config.hid.backend {
config::HidBackend::Otg => HidBackendType::Otg,
config::HidBackend::Ch9329 => HidBackendType::Ch9329 {
port: config.hid.ch9329_port.clone(),
baud_rate: config.hid.ch9329_baudrate,
},
config::HidBackend::None => HidBackendType::None,
};
let hid = Arc::new(HidController::new(
hid_backend,
Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG
));
hid.set_event_bus(events.clone()).await;
if let Err(e) = hid.init().await {
tracing::warn!("Failed to initialize HID backend: {}", e);
}
// Create MSD controller (optional, based on config)
let msd = if config.msd.enabled {
// Initialize Ventoy resources from data directory
let ventoy_resource_dir = ventoy_img::get_resource_dir(&data_dir);
if ventoy_resource_dir.exists() {
if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) {
tracing::warn!("Failed to initialize Ventoy resources: {}", e);
tracing::info!("Ventoy resource files should be placed in: {}", ventoy_resource_dir.display());
tracing::info!("Required files: {:?}", ventoy_img::required_files());
} else {
tracing::info!("Ventoy resources initialized from {}", ventoy_resource_dir.display());
}
} else {
tracing::warn!("Ventoy resource directory not found: {}", ventoy_resource_dir.display());
tracing::info!("Create the directory and place the following files: {:?}", ventoy_img::required_files());
}
let controller = MsdController::new(
otg_service.clone(),
&config.msd.images_path,
&config.msd.drive_path,
);
if let Err(e) = controller.init().await {
tracing::warn!("Failed to initialize MSD controller: {}", e);
None
} else {
controller.set_event_bus(events.clone()).await;
Some(controller)
}
} else {
tracing::info!("MSD disabled in configuration");
None
};
// Create ATX controller (optional, based on config)
let atx = if config.atx.enabled {
let controller_config = config.atx.to_controller_config();
let controller = AtxController::new(controller_config);
if let Err(e) = controller.init().await {
tracing::warn!("Failed to initialize ATX controller: {}", e);
None
} else {
Some(controller)
}
} else {
tracing::info!("ATX disabled in configuration");
None
};
// Create Audio controller
let audio = {
let audio_config = AudioControllerConfig {
enabled: config.audio.enabled,
device: config.audio.device.clone(),
quality: AudioQuality::from_str(&config.audio.quality),
};
let controller = AudioController::new(audio_config);
controller.set_event_bus(events.clone()).await;
if config.audio.enabled {
tracing::info!(
"Audio enabled: {}, quality={}",
config.audio.device,
config.audio.quality
);
} else {
tracing::info!("Audio disabled in configuration");
}
Arc::new(controller)
};
// Create Extension manager (ttyd, gostc, easytier)
let extensions = Arc::new(ExtensionManager::new());
tracing::info!("Extension manager initialized");
// Wire up WebRTC streamer with HID controller
// This enables WebRTC DataChannel to process HID events
webrtc_streamer.set_hid_controller(hid.clone()).await;
// Wire up WebRTC streamer with Audio controller
// This enables WebRTC audio track to receive Opus frames
webrtc_streamer.set_audio_controller(audio.clone()).await;
if config.audio.enabled {
if let Err(e) = webrtc_streamer.set_audio_enabled(true).await {
tracing::warn!("Failed to enable WebRTC audio: {}", e);
} else {
tracing::info!("WebRTC audio enabled");
}
}
// Set up frame source from video streamer (if capturer is available)
// The frame source allows WebRTC sessions to receive live video frames
if let Some(frame_tx) = streamer.frame_sender().await {
// Synchronize WebRTC config with actual capture format before connecting
let (actual_format, actual_resolution, actual_fps) = streamer.current_video_config().await;
tracing::info!(
"Initial video config from capturer: {}x{} {:?} @ {}fps",
actual_resolution.width, actual_resolution.height, actual_format, actual_fps
);
webrtc_streamer.update_video_config(actual_resolution, actual_format, actual_fps).await;
webrtc_streamer.set_video_source(frame_tx).await;
tracing::info!("WebRTC streamer connected to video frame source");
} else {
tracing::warn!("Video capturer not ready, WebRTC will connect to frame source when available");
}
// Create video stream manager (unified MJPEG/WebRTC management)
// Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance
let stream_manager = VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone());
stream_manager.set_event_bus(events.clone()).await;
stream_manager.set_config_store(config_store.clone()).await;
// Initialize stream manager with configured mode
let initial_mode = config.stream.mode.clone();
if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await {
tracing::warn!("Failed to initialize stream manager with mode {:?}: {}", initial_mode, e);
} else {
tracing::info!("Video stream manager initialized with mode: {:?}", initial_mode);
}
// Create application state
let state = AppState::new(
config_store.clone(),
session_store,
user_store,
otg_service,
stream_manager,
hid,
msd,
atx,
audio,
extensions.clone(),
events.clone(),
shutdown_tx.clone(),
data_dir.clone(),
);
// Start enabled extensions
{
let ext_config = config_store.get();
extensions.start_enabled(&ext_config.extensions).await;
}
// Start extension health check task (every 30 seconds)
{
let extensions_clone = extensions.clone();
let config_store_clone = config_store.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
loop {
interval.tick().await;
let config = config_store_clone.get();
extensions_clone.health_check(&config.extensions).await;
}
});
tracing::info!("Extension health check task started");
}
// Start device info broadcast task
// This monitors state change events and broadcasts DeviceInfo to all clients
spawn_device_info_broadcaster(state.clone(), events);
// Create router
let app = web::create_router(state.clone());
// Determine bind address based on HTTPS setting
let bind_addr: SocketAddr = if config.web.https_enabled {
format!("{}:{}", config.web.bind_address, config.web.https_port).parse()?
} else {
format!("{}:{}", config.web.bind_address, config.web.http_port).parse()?
};
// Setup graceful shutdown
let shutdown_signal = async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C handler");
tracing::info!("Shutdown signal received");
let _ = shutdown_tx.send(());
};
// Start server
if config.web.https_enabled {
// Generate self-signed certificate if no custom cert provided
let tls_config = if let (Some(cert_path), Some(key_path)) =
(&config.web.ssl_cert_path, &config.web.ssl_key_path)
{
RustlsConfig::from_pem_file(cert_path, key_path).await?
} else {
let cert_dir = data_dir.join("certs");
let cert_path = cert_dir.join("server.crt");
let key_path = cert_dir.join("server.key");
// Check if certificate already exists, only generate if missing
if !cert_path.exists() || !key_path.exists() {
tracing::info!("Generating new self-signed TLS certificate");
let cert = generate_self_signed_cert()?;
tokio::fs::create_dir_all(&cert_dir).await?;
tokio::fs::write(&cert_path, cert.cert.pem()).await?;
tokio::fs::write(&key_path, cert.key_pair.serialize_pem()).await?;
} else {
tracing::info!("Using existing TLS certificate from {}", cert_dir.display());
}
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
};
tracing::info!("Starting HTTPS server on {}", bind_addr);
let server = axum_server::bind_rustls(bind_addr, tls_config)
.serve(app.into_make_service());
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = server => {
if let Err(e) = result {
tracing::error!("HTTPS server error: {}", e);
}
cleanup(&state).await;
}
}
} else {
tracing::info!("Starting HTTP server on {}", bind_addr);
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
let server = axum::serve(listener, app);
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = server => {
if let Err(e) = result {
tracing::error!("HTTP server error: {}", e);
}
cleanup(&state).await;
}
}
}
tracing::info!("Server shutdown complete");
Ok(())
}
/// Initialize logging with tracing
fn init_logging(level: LogLevel, verbose_count: u8) {
// Verbose count overrides log level
let effective_level = match verbose_count {
0 => level,
1 => LogLevel::Verbose,
2 => LogLevel::Debug,
_ => LogLevel::Trace,
};
// Build filter string based on effective level
let filter = match effective_level {
LogLevel::Error => "one_kvm=error,tower_http=error",
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
LogLevel::Info => "one_kvm=info,tower_http=info",
LogLevel::Verbose => "one_kvm=debug,tower_http=info",
LogLevel::Debug => "one_kvm=debug,tower_http=debug",
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
};
// Environment variable takes highest priority
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| filter.into());
tracing_subscriber::registry()
.with(env_filter)
.with(tracing_subscriber::fmt::layer())
.init();
}
/// Get the application data directory
fn get_data_dir() -> PathBuf {
// Check environment variable first
if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") {
return PathBuf::from(path);
}
// Default to system configuration directory
PathBuf::from("/etc/one-kvm")
}
/// Parse video format and resolution from config (avoids code duplication)
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
let format = config
.video
.format
.as_ref()
.and_then(|f: &String| f.parse::<PixelFormat>().ok())
.unwrap_or(PixelFormat::Mjpeg);
let resolution = Resolution::new(config.video.width, config.video.height);
(format, resolution)
}
/// Generate a self-signed TLS certificate
fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey> {
use rcgen::generate_simple_self_signed;
let subject_alt_names = vec![
"localhost".to_string(),
"127.0.0.1".to_string(),
"::1".to_string(),
];
let certified_key = generate_simple_self_signed(subject_alt_names)?;
Ok(certified_key)
}
/// Spawn a background task that monitors state change events
/// and broadcasts DeviceInfo to all WebSocket clients with debouncing
fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
use one_kvm::events::SystemEvent;
use std::time::{Duration, Instant};
let mut rx = events.subscribe();
const DEBOUNCE_MS: u64 = 100;
tokio::spawn(async move {
let mut last_broadcast = Instant::now() - Duration::from_millis(DEBOUNCE_MS);
let mut pending_broadcast = false;
loop {
// Use timeout to handle pending broadcasts
let recv_result = if pending_broadcast {
let remaining = DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64);
tokio::time::timeout(Duration::from_millis(remaining), rx.recv()).await
} else {
Ok(rx.recv().await)
};
match recv_result {
Ok(Ok(event)) => {
let should_broadcast = matches!(
event,
SystemEvent::StreamStateChanged { .. }
| SystemEvent::StreamConfigApplied { .. }
| SystemEvent::HidStateChanged { .. }
| SystemEvent::MsdStateChanged { .. }
| SystemEvent::AtxStateChanged { .. }
| SystemEvent::AudioStateChanged { .. }
);
if should_broadcast {
pending_broadcast = true;
}
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(n))) => {
tracing::warn!("DeviceInfo broadcaster lagged by {} events", n);
pending_broadcast = true;
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
tracing::info!("Event bus closed, stopping DeviceInfo broadcaster");
break;
}
Err(_timeout) => {
// Debounce timeout reached, broadcast now
}
}
// Broadcast if pending and debounce time has passed
if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) {
state.publish_device_info().await;
tracing::trace!("Broadcasted DeviceInfo (debounced)");
last_broadcast = Instant::now();
pending_broadcast = false;
}
}
});
tracing::info!("DeviceInfo broadcaster task started (debounce: {}ms)", DEBOUNCE_MS);
}
/// Clean up subsystems on shutdown
async fn cleanup(state: &Arc<AppState>) {
// Stop all extensions
state.extensions.stop_all().await;
tracing::info!("Extensions stopped");
// Stop video
if let Err(e) = state.stream_manager.stop().await {
tracing::warn!("Failed to stop streamer: {}", e);
}
// Shutdown HID
if let Err(e) = state.hid.shutdown().await {
tracing::warn!("Failed to shutdown HID: {}", e);
}
// Shutdown MSD
if let Some(msd) = state.msd.write().await.as_mut() {
if let Err(e) = msd.shutdown().await {
tracing::warn!("Failed to shutdown MSD: {}", e);
}
}
// Shutdown ATX
if let Some(atx) = state.atx.write().await.as_mut() {
if let Err(e) = atx.shutdown().await {
tracing::warn!("Failed to shutdown ATX: {}", e);
}
}
// Shutdown Audio
if let Err(e) = state.audio.shutdown().await {
tracing::warn!("Failed to shutdown audio: {}", e);
}
}

49
src/modules/mod.rs Normal file
View File

@@ -0,0 +1,49 @@
//! Module management for One-KVM
//!
//! This module provides infrastructure for managing feature modules
//! (video streaming, HID control, MSD, ATX) as independent async tasks.
use std::future::Future;
use std::pin::Pin;
use tokio::sync::broadcast;
/// Module status
#[derive(Debug, Clone, PartialEq)]
pub enum ModuleStatus {
Stopped,
Starting,
Running,
Stopping,
Error(String),
}
/// Trait for feature modules
pub trait Module: Send + Sync {
/// Module name
fn name(&self) -> &'static str;
/// Current status
fn status(&self) -> ModuleStatus;
/// Start the module
fn start(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
/// Stop the module
fn stop(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
}
/// Module manager for coordinating feature modules
pub struct ModuleManager {
shutdown_rx: broadcast::Receiver<()>,
}
impl ModuleManager {
pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self {
Self { shutdown_rx }
}
/// Wait for shutdown signal
pub async fn wait_for_shutdown(&mut self) {
let _ = self.shutdown_rx.recv().await;
}
}

597
src/msd/controller.rs Normal file
View File

@@ -0,0 +1,597 @@
//! MSD Controller
//!
//! Manages the mass storage device lifecycle including:
//! - Image mounting and unmounting
//! - Virtual drive management
//! - State tracking
//! - Image downloads from URL
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::image::ImageManager;
use super::monitor::{MsdHealthMonitor, MsdHealthStatus};
use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState};
use crate::error::{AppError, Result};
use crate::otg::{MsdFunction, MsdLunConfig, OtgService};
/// USB Gadget path (system constant)
const GADGET_PATH: &str = "/sys/kernel/config/usb_gadget/one-kvm";
/// MSD Controller
pub struct MsdController {
/// OTG Service reference
otg_service: Arc<OtgService>,
/// MSD function manager (provided by OtgService)
msd_function: RwLock<Option<MsdFunction>>,
/// Current state
state: RwLock<MsdState>,
/// Images storage path
images_path: PathBuf,
/// Virtual drive path
drive_path: PathBuf,
/// Event bus for broadcasting state changes (optional)
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
/// Active downloads (download_id -> CancellationToken)
downloads: Arc<RwLock<HashMap<String, CancellationToken>>>,
/// Operation mutex lock (prevents concurrent operations)
operation_lock: Arc<RwLock<()>>,
/// Health monitor for error tracking and recovery
monitor: Arc<MsdHealthMonitor>,
}
impl MsdController {
/// Create new MSD controller
///
/// # Parameters
/// * `otg_service` - OTG service for gadget management
/// * `images_path` - Directory path for storing ISO/IMG files
/// * `drive_path` - File path for the virtual FAT32 drive
pub fn new(
otg_service: Arc<OtgService>,
images_path: impl Into<PathBuf>,
drive_path: impl Into<PathBuf>,
) -> Self {
Self {
otg_service,
msd_function: RwLock::new(None),
state: RwLock::new(MsdState::default()),
images_path: images_path.into(),
drive_path: drive_path.into(),
events: tokio::sync::RwLock::new(None),
downloads: Arc::new(RwLock::new(HashMap::new())),
operation_lock: Arc::new(RwLock::new(())),
monitor: Arc::new(MsdHealthMonitor::with_defaults()),
}
}
/// Initialize the MSD controller
pub async fn init(&self) -> Result<()> {
info!("Initializing MSD controller");
// 1. Ensure images directory exists
if let Err(e) = std::fs::create_dir_all(&self.images_path) {
warn!("Failed to create images directory: {}", e);
}
// 2. Request MSD function from OtgService
info!("Requesting MSD function from OtgService");
let msd_func = self.otg_service.enable_msd().await?;
// 3. Store function handle
*self.msd_function.write().await = Some(msd_func);
// 4. Update state
let mut state = self.state.write().await;
state.available = true;
// 5. Check for existing virtual drive
if self.drive_path.exists() {
if let Ok(metadata) = std::fs::metadata(&self.drive_path) {
state.drive_info = Some(DriveInfo {
size: metadata.len(),
used: 0,
free: metadata.len(),
initialized: true,
path: self.drive_path.clone(),
});
debug!("Found existing virtual drive: {}", self.drive_path.display());
}
}
info!("MSD controller initialized");
Ok(())
}
/// Get current state as SystemEvent
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
let state = self.state.read().await;
crate::events::SystemEvent::MsdStateChanged {
mode: state.mode.clone(),
connected: state.connected,
}
}
/// Get current MSD state
pub async fn state(&self) -> MsdState {
self.state.read().await.clone()
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: std::sync::Arc<crate::events::EventBus>) {
*self.events.write().await = Some(events.clone());
// Also set event bus on the monitor for health notifications
self.monitor.set_event_bus(events).await;
}
/// Publish an event to the event bus
async fn publish_event(&self, event: crate::events::SystemEvent) {
if let Some(ref bus) = *self.events.read().await {
bus.publish(event);
}
}
/// Check if MSD is available
pub async fn is_available(&self) -> bool {
self.state.read().await.available
}
/// Connect an image file
///
/// # Parameters
/// * `image` - Image info to mount
/// * `cdrom` - Mount as CD-ROM (read-only, removable)
/// * `read_only` - Mount as read-only
pub async fn connect_image(&self, image: &ImageInfo, cdrom: bool, read_only: bool) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor.report_error("MSD not available", "not_available").await;
return Err(err);
}
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Verify image exists
if !image.path.exists() {
let error_msg = format!("Image file not found: {}", image.path.display());
self.monitor.report_error(&error_msg, "image_not_found").await;
return Err(AppError::Internal(error_msg));
}
// Configure LUN
let config = if cdrom {
MsdLunConfig::cdrom(image.path.clone())
} else {
MsdLunConfig::disk(image.path.clone(), read_only)
};
let gadget_path = PathBuf::from(GADGET_PATH);
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor.report_error(&error_msg, "configfs_error").await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor.report_error("MSD function not initialized", "not_initialized").await;
return Err(err);
}
state.connected = true;
state.mode = MsdMode::Image;
state.current_image = Some(image.clone());
info!(
"Connected image: {} (cdrom={}, ro={})",
image.name, cdrom, read_only
);
// Release the lock before publishing events
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
// Publish events
self.publish_event(crate::events::SystemEvent::MsdImageMounted {
image_id: image.id.clone(),
image_name: image.name.clone(),
size: image.size,
cdrom,
})
.await;
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
mode: MsdMode::Image,
connected: true,
})
.await;
Ok(())
}
/// Connect the virtual drive
pub async fn connect_drive(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor.report_error("MSD not available", "not_available").await;
return Err(err);
}
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Check drive exists
if !self.drive_path.exists() {
let err = AppError::Internal(
"Virtual drive not initialized. Call init first.".to_string(),
);
self.monitor.report_error("Virtual drive not initialized", "drive_not_found").await;
return Err(err);
}
// Configure LUN as read-write disk
let config = MsdLunConfig::disk(self.drive_path.clone(), false);
let gadget_path = PathBuf::from(GADGET_PATH);
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor.report_error(&error_msg, "configfs_error").await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor.report_error("MSD function not initialized", "not_initialized").await;
return Err(err);
}
state.connected = true;
state.mode = MsdMode::Drive;
state.current_image = None;
info!("Connected virtual drive: {}", self.drive_path.display());
// Release the lock before publishing event
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
// Publish event
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
mode: MsdMode::Drive,
connected: true,
})
.await;
Ok(())
}
/// Disconnect current storage
pub async fn disconnect(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.connected {
debug!("Nothing connected, skipping disconnect");
return Ok(());
}
let gadget_path = PathBuf::from(GADGET_PATH);
if let Some(ref msd) = *self.msd_function.read().await {
msd.disconnect_lun_async(&gadget_path, 0).await?;
}
state.connected = false;
state.mode = MsdMode::None;
state.current_image = None;
info!("Disconnected storage");
// Release the lock before publishing events
drop(state);
drop(_op_guard);
// Publish events
self.publish_event(crate::events::SystemEvent::MsdImageUnmounted)
.await;
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
mode: MsdMode::None,
connected: false,
})
.await;
Ok(())
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
/// Get virtual drive path
pub fn drive_path(&self) -> &PathBuf {
&self.drive_path
}
/// Check if currently connected
pub async fn is_connected(&self) -> bool {
self.state.read().await.connected
}
/// Get current mode
pub async fn mode(&self) -> MsdMode {
self.state.read().await.mode.clone()
}
/// Update drive info
pub async fn update_drive_info(&self, info: DriveInfo) {
let mut state = self.state.write().await;
state.drive_info = Some(info);
}
/// Start downloading an image from URL
///
/// Returns the download_id that can be used to track or cancel the download.
/// Progress is reported via MsdDownloadProgress events.
pub async fn download_image(
&self,
url: String,
filename: Option<String>,
) -> Result<DownloadProgress> {
let download_id = uuid::Uuid::new_v4().to_string();
let cancel_token = CancellationToken::new();
// Register download
{
let mut downloads = self.downloads.write().await;
downloads.insert(download_id.clone(), cancel_token.clone());
}
// Extract filename for initial response
let display_filename = filename.clone().unwrap_or_else(|| {
url.rsplit('/')
.next()
.unwrap_or("download")
.to_string()
});
// Create initial progress
let initial_progress = DownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
filename: display_filename.clone(),
bytes_downloaded: 0,
total_bytes: None,
progress_pct: None,
status: DownloadStatus::Started,
error: None,
};
// Publish started event
self.publish_event(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
filename: display_filename.clone(),
bytes_downloaded: 0,
total_bytes: None,
progress_pct: None,
status: "started".to_string(),
})
.await;
// Clone what we need for the spawned task
let images_path = self.images_path.clone();
let events = self.events.read().await.clone();
let downloads = self.downloads.clone();
let download_id_clone = download_id.clone();
let url_clone = url.clone();
// Spawn download task
tokio::spawn(async move {
let manager = ImageManager::new(images_path);
// Create progress callback
let events_for_callback = events.clone();
let download_id_for_callback = download_id_clone.clone();
let url_for_callback = url_clone.clone();
let filename_for_callback = display_filename.clone();
let progress_callback = move |downloaded: u64, total: Option<u64>| {
let progress_pct = total.map(|t| (downloaded as f32 / t as f32) * 100.0);
if let Some(ref bus) = events_for_callback {
bus.publish(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id_for_callback.clone(),
url: url_for_callback.clone(),
filename: filename_for_callback.clone(),
bytes_downloaded: downloaded,
total_bytes: total,
progress_pct,
status: "in_progress".to_string(),
});
}
};
// Run download
let result = manager
.download_from_url(&url_clone, filename, progress_callback)
.await;
// Remove from active downloads
{
let mut downloads_guard = downloads.write().await;
downloads_guard.remove(&download_id_clone);
}
// Publish completion event
match result {
Ok(image_info) => {
if let Some(ref bus) = events {
bus.publish(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id_clone,
url: url_clone,
filename: image_info.name,
bytes_downloaded: image_info.size,
total_bytes: Some(image_info.size),
progress_pct: Some(100.0),
status: "completed".to_string(),
});
}
}
Err(e) => {
warn!("Download failed: {}", e);
if let Some(ref bus) = events {
bus.publish(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id_clone,
url: url_clone,
filename: display_filename,
bytes_downloaded: 0,
total_bytes: None,
progress_pct: None,
status: format!("failed: {}", e),
});
}
}
}
});
Ok(initial_progress)
}
/// Cancel an active download
pub async fn cancel_download(&self, download_id: &str) -> Result<()> {
let mut downloads = self.downloads.write().await;
if let Some(token) = downloads.remove(download_id) {
token.cancel();
info!("Download cancelled: {}", download_id);
Ok(())
} else {
Err(AppError::NotFound(format!(
"Download not found: {}",
download_id
)))
}
}
/// Get list of active download IDs
pub async fn active_downloads(&self) -> Vec<String> {
let downloads = self.downloads.read().await;
downloads.keys().cloned().collect()
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down MSD controller");
// 1. Disconnect if connected
if let Err(e) = self.disconnect().await {
warn!("Error disconnecting during shutdown: {}", e);
}
// 2. Notify OtgService to disable MSD
info!("Disabling MSD function in OtgService");
self.otg_service.disable_msd().await?;
// 3. Clear local state
*self.msd_function.write().await = None;
let mut state = self.state.write().await;
state.available = false;
info!("MSD controller shutdown complete");
Ok(())
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<MsdHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> MsdHealthStatus {
self.monitor.status().await
}
/// Check if the MSD is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Drop for MsdController {
fn drop(&mut self) {
// Cleanup is handled by OtgGadgetManager when the gadget is torn down
// Individual controllers don't need to cleanup the ConfigFS
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_controller_creation() {
let temp_dir = TempDir::new().unwrap();
let otg_service = Arc::new(OtgService::new());
let images_path = temp_dir.path().join("images");
let drive_path = temp_dir.path().join("ventoy.img");
let controller = MsdController::new(otg_service, &images_path, &drive_path);
// Check that MSD is not initialized (msd_function is None)
let state = controller.state().await;
assert!(!state.available);
assert!(controller.images_path.ends_with("images"));
assert!(controller.drive_path.ends_with("ventoy.img"));
}
#[tokio::test]
async fn test_state_default() {
let temp_dir = TempDir::new().unwrap();
let otg_service = Arc::new(OtgService::new());
let images_path = temp_dir.path().join("images");
let drive_path = temp_dir.path().join("ventoy.img");
let controller = MsdController::new(otg_service, &images_path, &drive_path);
let state = controller.state().await;
assert!(!state.available);
assert!(!state.connected);
assert_eq!(state.mode, MsdMode::None);
}
}

654
src/msd/image.rs Normal file
View File

@@ -0,0 +1,654 @@
//! Image file manager
//!
//! Handles ISO/IMG image file operations:
//! - List available images
//! - Upload new images
//! - Delete images
//! - Metadata management
//! - Download from URL
use chrono::Utc;
use futures::StreamExt;
use std::fs::{self, File};
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use tokio::io::AsyncWriteExt;
use tracing::info;
use super::types::ImageInfo;
use crate::error::{AppError, Result};
/// Maximum image size (32 GB)
const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024;
/// Progress report throttle interval (milliseconds)
const PROGRESS_THROTTLE_MS: u64 = 200;
/// Progress report throttle bytes threshold (512 KB)
const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024;
/// Image Manager
pub struct ImageManager {
/// Images storage directory
images_path: PathBuf,
}
impl ImageManager {
/// Create a new image manager
pub fn new(images_path: PathBuf) -> Self {
Self { images_path }
}
/// Ensure images directory exists
pub fn ensure_dir(&self) -> Result<()> {
fs::create_dir_all(&self.images_path).map_err(|e| {
AppError::Internal(format!("Failed to create images directory: {}", e))
})?;
Ok(())
}
/// List all available images
pub fn list(&self) -> Result<Vec<ImageInfo>> {
self.ensure_dir()?;
let mut images = Vec::new();
for entry in fs::read_dir(&self.images_path).map_err(|e| {
AppError::Internal(format!("Failed to read images directory: {}", e))
})? {
let entry = entry.map_err(|e| {
AppError::Internal(format!("Failed to read directory entry: {}", e))
})?;
let path = entry.path();
if path.is_file() {
if let Some(info) = self.get_image_info(&path) {
images.push(info);
}
}
}
// Sort by creation time (newest first)
images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(images)
}
/// Get image info from path
fn get_image_info(&self, path: &Path) -> Option<ImageInfo> {
let metadata = fs::metadata(path).ok()?;
let name = path.file_name()?.to_string_lossy().to_string();
// Use filename hash as ID (stable across restarts)
let id = format!("{:x}", md5_hash(&name));
let created_at = metadata
.created()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| {
chrono::DateTime::from_timestamp(d.as_secs() as i64, 0)
.unwrap_or_else(|| Utc::now().into())
})
.unwrap_or_else(Utc::now);
Some(ImageInfo {
id,
name,
path: path.to_path_buf(),
size: metadata.len(),
created_at,
})
}
/// Get image by ID
pub fn get(&self, id: &str) -> Result<ImageInfo> {
for image in self.list()? {
if image.id == id {
return Ok(image);
}
}
Err(AppError::NotFound(format!("Image not found: {}", id)))
}
/// Get image by name
pub fn get_by_name(&self, name: &str) -> Result<ImageInfo> {
let path = self.images_path.join(name);
self.get_image_info(&path)
.ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name)))
}
/// Create a new image from bytes
pub fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
self.ensure_dir()?;
// Validate name
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Check size
if data.len() as u64 > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
// Write file
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
name
)));
}
let mut file = File::create(&path).map_err(|e| {
AppError::Internal(format!("Failed to create image file: {}", e))
})?;
file.write_all(data).map_err(|e| {
// Try to clean up on error
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
info!("Created image: {} ({} bytes)", name, data.len());
self.get_by_name(&name)
}
/// Create a new image from a file stream (for chunked uploads)
pub fn create_from_stream<R: Read>(
&self,
name: &str,
reader: &mut R,
expected_size: Option<u64>,
) -> Result<ImageInfo> {
self.ensure_dir()?;
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
if let Some(size) = expected_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
}
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
name
)));
}
// Create file and copy data
let mut file = File::create(&path).map_err(|e| {
AppError::Internal(format!("Failed to create image file: {}", e))
})?;
let bytes_written = io::copy(reader, &mut file).map_err(|e| {
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
info!("Created image: {} ({} bytes)", name, bytes_written);
self.get_by_name(&name)
}
/// Create a new image from an async multipart field (streaming, memory-efficient)
///
/// This method streams data directly to disk without buffering the entire file in memory,
/// making it suitable for large files (multi-GB ISOs).
pub async fn create_from_multipart_field(
&self,
name: &str,
mut field: axum::extract::multipart::Field<'_>,
) -> Result<ImageInfo> {
self.ensure_dir()?;
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Use a temporary file during upload
let temp_name = format!(".upload_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_name);
let final_path = self.images_path.join(&name);
// Check if final file already exists
if final_path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
name
)));
}
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
let mut bytes_written: u64 = 0;
// Stream chunks directly to disk
while let Some(chunk) = field.chunk().await.map_err(|e| {
AppError::Internal(format!("Failed to read upload chunk: {}", e))
})? {
// Check size limit
bytes_written += chunk.len() as u64;
if bytes_written > MAX_IMAGE_SIZE {
// Cleanup and return error
drop(file);
let _ = tokio::fs::remove_file(&temp_path).await;
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
// Write chunk to file
file.write_all(&chunk).await.map_err(|e| {
AppError::Internal(format!("Failed to write chunk: {}", e))
})?;
}
// Flush and close file
file.flush().await.map_err(|e| {
AppError::Internal(format!("Failed to flush file: {}", e))
})?;
drop(file);
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
let _ = std::fs::remove_file(&temp_path);
AppError::Internal(format!("Failed to rename temp file: {}", e))
})?;
info!("Created image (streaming): {} ({} bytes)", name, bytes_written);
self.get_by_name(&name)
}
/// Delete an image by ID
pub fn delete(&self, id: &str) -> Result<()> {
let image = self.get(id)?;
fs::remove_file(&image.path).map_err(|e| {
AppError::Internal(format!("Failed to delete image: {}", e))
})?;
info!("Deleted image: {}", image.name);
Ok(())
}
/// Delete an image by name
pub fn delete_by_name(&self, name: &str) -> Result<()> {
let path = self.images_path.join(name);
if !path.exists() {
return Err(AppError::NotFound(format!("Image not found: {}", name)));
}
fs::remove_file(&path).map_err(|e| {
AppError::Internal(format!("Failed to delete image: {}", e))
})?;
info!("Deleted image: {}", name);
Ok(())
}
/// Get total storage used
pub fn used_space(&self) -> u64 {
self.list()
.map(|images| images.iter().map(|i| i.size).sum())
.unwrap_or(0)
}
/// Check if storage has space for new image
pub fn has_space(&self, size: u64) -> bool {
// For now, just check against max size
// In the future, could check disk space
size <= MAX_IMAGE_SIZE
}
/// Download image from URL with progress callback
///
/// # Arguments
/// * `url` - The URL to download from
/// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided)
/// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes)
///
/// # Returns
/// * `Ok(ImageInfo)` - The downloaded image info
/// * `Err(AppError)` - If download fails
pub async fn download_from_url<F>(
&self,
url: &str,
filename: Option<String>,
progress_callback: F,
) -> Result<ImageInfo>
where
F: Fn(u64, Option<u64>) + Send + 'static,
{
self.ensure_dir()?;
// Validate URL
let parsed_url = reqwest::Url::parse(url)
.map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?;
info!("Starting download from: {}", url);
// Create HTTP client with timeout
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files
.connect_timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?;
// Send HEAD request first to get content info
let head_response = client
.head(url)
.send()
.await
.map_err(|e| AppError::Internal(format!("Failed to connect: {}", e)))?;
if !head_response.status().is_success() {
return Err(AppError::Internal(format!(
"Server returned error: {}",
head_response.status()
)));
}
let total_size = head_response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
// Check file size
if let Some(size) = total_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::BadRequest(format!(
"File too large: {} bytes (max {} GB)",
size,
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
}
// Determine filename
let final_filename = if let Some(name) = filename {
sanitize_filename(&name)
} else {
// Try Content-Disposition header first
let from_header = head_response
.headers()
.get(reqwest::header::CONTENT_DISPOSITION)
.and_then(|v| v.to_str().ok())
.and_then(|s| extract_filename_from_content_disposition(s));
if let Some(name) = from_header {
sanitize_filename(&name)
} else {
// Fall back to URL path
let path = parsed_url.path();
let name = path.rsplit('/').next().unwrap_or("download");
let name = urlencoding::decode(name).unwrap_or_else(|_| name.into());
sanitize_filename(&name)
}
};
if final_filename.is_empty() {
return Err(AppError::BadRequest("Could not determine filename".to_string()));
}
// Check if file already exists
let final_path = self.images_path.join(&final_filename);
if final_path.exists() {
return Err(AppError::BadRequest(format!(
"Image already exists: {}",
final_filename
)));
}
// Create temporary file for download
let temp_filename = format!(".download_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_filename);
// Start actual download
let response = client
.get(url)
.send()
.await
.map_err(|e| AppError::Internal(format!("Download failed: {}", e)))?;
if !response.status().is_success() {
return Err(AppError::Internal(format!(
"Download failed: HTTP {}",
response.status()
)));
}
// Get actual content length from response (may differ from HEAD)
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.or(total_size);
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
// Stream download with progress (throttled)
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
let mut last_report_time = Instant::now();
let mut last_reported_bytes: u64 = 0;
let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS);
// Report initial progress
progress_callback(0, content_length);
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result
.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?;
file.write_all(&chunk)
.await
.map_err(|e| {
// Cleanup on error
let _ = std::fs::remove_file(&temp_path);
AppError::Internal(format!("Failed to write data: {}", e))
})?;
downloaded += chunk.len() as u64;
// Throttled progress reporting: report if enough time or bytes have passed
let now = Instant::now();
let time_elapsed = now.duration_since(last_report_time) >= throttle_interval;
let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES;
if time_elapsed || bytes_elapsed {
progress_callback(downloaded, content_length);
last_report_time = now;
last_reported_bytes = downloaded;
}
}
// Always report final progress
if downloaded != last_reported_bytes {
progress_callback(downloaded, content_length);
}
// Ensure all data is flushed
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Verify downloaded size
let metadata = tokio::fs::metadata(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?;
if let Some(expected) = content_length {
if metadata.len() != expected {
let _ = tokio::fs::remove_file(&temp_path).await;
return Err(AppError::Internal(format!(
"Download incomplete: got {} bytes, expected {}",
metadata.len(),
expected
)));
}
}
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
let _ = std::fs::remove_file(&temp_path);
AppError::Internal(format!("Failed to move file: {}", e))
})?;
info!(
"Download complete: {} ({} bytes)",
final_filename,
metadata.len()
);
// Return image info
self.get_by_name(&final_filename)
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
}
/// Simple hash function for generating stable IDs
fn md5_hash(s: &str) -> u64 {
let mut hash: u64 = 0;
for (i, byte) in s.bytes().enumerate() {
hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1)));
hash = hash.wrapping_mul(31);
}
hash
}
/// Sanitize filename to prevent path traversal
fn sanitize_filename(name: &str) -> String {
let name = name.trim();
let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_");
// Remove leading dots (hidden files)
let name = name.trim_start_matches('.');
// Limit length
if name.len() > 255 {
name[..255].to_string()
} else {
name.to_string()
}
}
/// Extract filename from Content-Disposition header
fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
// Handle both:
// Content-Disposition: attachment; filename="example.iso"
// Content-Disposition: attachment; filename*=UTF-8''example.iso
// Try filename* first (RFC 5987)
if let Some(pos) = header.find("filename*=") {
let start = pos + 10;
let value = &header[start..];
// Format: charset'language'value
if let Some(quote_start) = value.find("''") {
let encoded = value[quote_start + 2..].split(';').next()?;
let decoded = urlencoding::decode(encoded.trim()).ok()?;
let name = decoded.trim_matches('"').to_string();
if !name.is_empty() {
return Some(name);
}
}
}
// Try filename next
if let Some(pos) = header.find("filename=") {
let start = pos + 9;
let value = &header[start..];
let name = value.split(';').next()?;
let name = name.trim().trim_matches('"').to_string();
if !name.is_empty() {
return Some(name);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("test.iso"), "test.iso");
assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.')
assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso");
assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso");
}
#[test]
fn test_image_manager_list_empty() {
let temp_dir = TempDir::new().unwrap();
let manager = ImageManager::new(temp_dir.path().to_path_buf());
let images = manager.list().unwrap();
assert!(images.is_empty());
}
#[test]
fn test_image_manager_create() {
let temp_dir = TempDir::new().unwrap();
let manager = ImageManager::new(temp_dir.path().to_path_buf());
let data = vec![0u8; 1024];
let image = manager.create("test.iso", &data).unwrap();
assert_eq!(image.name, "test.iso");
assert_eq!(image.size, 1024);
}
#[test]
fn test_image_manager_delete() {
let temp_dir = TempDir::new().unwrap();
let manager = ImageManager::new(temp_dir.path().to_path_buf());
let data = vec![0u8; 1024];
let image = manager.create("test.iso", &data).unwrap();
manager.delete(&image.id).unwrap();
assert!(manager.list().unwrap().is_empty());
}
}

33
src/msd/mod.rs Normal file
View File

@@ -0,0 +1,33 @@
//! MSD (Mass Storage Device) module
//!
//! Provides virtual USB storage functionality with two modes:
//! - Image mounting: Mount ISO/IMG files for system installation
//! - Ventoy drive: Bootable exFAT drive for multiple ISO files
//!
//! Architecture:
//! ```text
//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC
//! |
//! ┌──────┴──────┐
//! │ │
//! Image Manager Ventoy Drive
//! (ISO/IMG) (Bootable exFAT)
//! ```
pub mod controller;
pub mod ventoy_drive;
pub mod image;
pub mod monitor;
pub mod types;
pub use controller::MsdController;
pub use ventoy_drive::VentoyDrive;
pub use image::ImageManager;
pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig};
pub use types::{
DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest,
ImageInfo, MsdConnectRequest, MsdMode, MsdState,
};
// Re-export from otg module for backward compatibility
pub use crate::otg::{MsdFunction, MsdLunConfig};

284
src/msd/monitor.rs Normal file
View File

@@ -0,0 +1,284 @@
//! MSD (Mass Storage Device) health monitoring
//!
//! This module provides health monitoring for MSD operations, including:
//! - ConfigFS operation error tracking
//! - Image mount/unmount error tracking
//! - Error notification
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::events::{EventBus, SystemEvent};
use crate::utils::LogThrottler;
/// MSD health status
#[derive(Debug, Clone, PartialEq)]
pub enum MsdHealthStatus {
/// Device is healthy and operational
Healthy,
/// Device has an error
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
},
}
impl Default for MsdHealthStatus {
fn default() -> Self {
Self::Healthy
}
}
/// MSD health monitor configuration
#[derive(Debug, Clone)]
pub struct MsdMonitorConfig {
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for MsdMonitorConfig {
fn default() -> Self {
Self {
log_throttle_secs: 5,
}
}
}
/// MSD health monitor
///
/// Monitors MSD operation health and manages error notifications.
/// Publishes WebSocket events when operation status changes.
pub struct MsdHealthMonitor {
/// Current health status
status: RwLock<MsdHealthStatus>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
#[allow(dead_code)]
config: MsdMonitorConfig,
/// Whether monitoring is active (reserved for future use)
#[allow(dead_code)]
running: AtomicBool,
/// Error count (for tracking)
error_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
}
impl MsdHealthMonitor {
/// Create a new MSD health monitor with the specified configuration
pub fn new(config: MsdMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
Self {
status: RwLock::new(MsdHealthStatus::Healthy),
events: RwLock::new(None),
throttler: LogThrottler::with_secs(throttle_secs),
config,
running: AtomicBool::new(false),
error_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
}
}
/// Create a new MSD health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(MsdMonitorConfig::default())
}
/// Set the event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Report an error from MSD operations
///
/// This method is called when an MSD operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Publishes a WebSocket event if the error is new or changed
///
/// # Arguments
///
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, reason: &str, error_code: &str) {
let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("msd_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!("MSD error: {} (code: {}, count: {})", reason, error_code, count);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = MsdHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
};
// Publish event (only if error changed or first occurrence)
if error_changed || count == 1 {
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::MsdError {
reason: reason.to_string(),
error_code: error_code.to_string(),
});
}
}
}
/// Report that the MSD has recovered from error
///
/// This method is called when an MSD operation succeeds after errors.
/// It resets the error state and publishes a recovery event.
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != MsdHealthStatus::Healthy {
let error_count = self.error_count.load(Ordering::Relaxed);
info!("MSD recovered after {} errors", error_count);
// Reset state
self.error_count.store(0, Ordering::Relaxed);
self.throttler.clear_all();
*self.last_error_code.write().await = None;
*self.status.write().await = MsdHealthStatus::Healthy;
// Publish recovery event
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::MsdRecovered);
}
}
}
/// Get the current health status
pub async fn status(&self) -> MsdHealthStatus {
self.status.read().await.clone()
}
/// Get the current error count
pub fn error_count(&self) -> u32 {
self.error_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.error_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = MsdHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
match &*self.status.read().await {
MsdHealthStatus::Error { reason, .. } => Some(reason.clone()),
_ => None,
}
}
}
impl Default for MsdHealthMonitor {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_status() {
let monitor = MsdHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.error_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = MsdHealthMonitor::with_defaults();
monitor
.report_error("ConfigFS write failed", "configfs_error")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.error_count(), 1);
if let MsdHealthStatus::Error { reason, error_code } = monitor.status().await {
assert_eq!(reason, "ConfigFS write failed");
assert_eq!(error_code, "configfs_error");
} else {
panic!("Expected Error status");
}
}
#[tokio::test]
async fn test_report_recovered() {
let monitor = MsdHealthMonitor::with_defaults();
// First report an error
monitor
.report_error("Image not found", "image_not_found")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.error_count(), 0);
}
#[tokio::test]
async fn test_error_count_increments() {
let monitor = MsdHealthMonitor::with_defaults();
for i in 1..=5 {
monitor.report_error("Error", "io_error").await;
assert_eq!(monitor.error_count(), i);
}
}
#[tokio::test]
async fn test_reset() {
let monitor = MsdHealthMonitor::with_defaults();
monitor.report_error("Error", "io_error").await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.error_count(), 0);
}
}

229
src/msd/types.rs Normal file
View File

@@ -0,0 +1,229 @@
//! MSD data types and structures
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// MSD operating mode
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MsdMode {
/// No storage connected
None,
/// Image file mounted (ISO/IMG)
Image,
/// Virtual drive (FAT32) connected
Drive,
}
impl Default for MsdMode {
fn default() -> Self {
Self::None
}
}
/// Image file metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageInfo {
/// Unique image ID
pub id: String,
/// Display name
pub name: String,
/// File path on disk
#[serde(skip_serializing)]
pub path: PathBuf,
/// File size in bytes
pub size: u64,
/// Creation timestamp
pub created_at: DateTime<Utc>,
}
impl ImageInfo {
/// Create new image info
pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self {
Self {
id,
name,
path,
size,
created_at: Utc::now(),
}
}
/// Format size for display
pub fn size_display(&self) -> String {
const KB: u64 = 1024;
const MB: u64 = KB * 1024;
const GB: u64 = MB * 1024;
if self.size >= GB {
format!("{:.2} GB", self.size as f64 / GB as f64)
} else if self.size >= MB {
format!("{:.2} MB", self.size as f64 / MB as f64)
} else if self.size >= KB {
format!("{:.2} KB", self.size as f64 / KB as f64)
} else {
format!("{} B", self.size)
}
}
}
/// MSD state information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdState {
/// Whether MSD feature is available
pub available: bool,
/// Current mode
pub mode: MsdMode,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image (if mode is Image)
pub current_image: Option<ImageInfo>,
/// Virtual drive info (if mode is Drive)
pub drive_info: Option<DriveInfo>,
}
impl Default for MsdState {
fn default() -> Self {
Self {
available: false,
mode: MsdMode::None,
connected: false,
current_image: None,
drive_info: None,
}
}
}
/// Virtual drive information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveInfo {
/// Drive size in bytes
pub size: u64,
/// Used space in bytes
pub used: u64,
/// Free space in bytes
pub free: u64,
/// Whether drive is initialized
pub initialized: bool,
/// Drive file path
#[serde(skip_serializing)]
pub path: PathBuf,
}
impl DriveInfo {
/// Create new drive info
pub fn new(path: PathBuf, size: u64) -> Self {
Self {
size,
used: 0,
free: size,
initialized: false,
path,
}
}
}
/// File entry in virtual drive
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveFile {
/// File name
pub name: String,
/// Relative path from drive root
pub path: String,
/// File size in bytes (0 for directories)
pub size: u64,
/// Whether this is a directory
pub is_dir: bool,
/// Last modified timestamp
pub modified: Option<DateTime<Utc>>,
}
/// MSD connect request
#[derive(Debug, Clone, Deserialize)]
pub struct MsdConnectRequest {
/// Connection mode: "image" or "drive"
pub mode: MsdMode,
/// Image ID to mount (required for image mode)
pub image_id: Option<String>,
/// Mount as CD-ROM (optional, defaults based on image type)
#[serde(default)]
pub cdrom: Option<bool>,
/// Mount as read-only
#[serde(default)]
pub read_only: Option<bool>,
}
/// Virtual drive init request
#[derive(Debug, Clone, Deserialize)]
pub struct DriveInitRequest {
/// Drive size in megabytes (defaults to 16GB)
#[serde(default = "default_drive_size")]
pub size_mb: u32,
/// Optional custom path for Ventoy installation
pub ventoy_path: Option<String>,
}
fn default_drive_size() -> u32 {
16 * 1024 // 16GB
}
/// Image download request
#[derive(Debug, Clone, Deserialize)]
pub struct ImageDownloadRequest {
/// URL to download from
pub url: String,
/// Optional custom filename
pub filename: Option<String>,
}
/// Download status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DownloadStatus {
/// Download has started
Started,
/// Download is in progress
InProgress,
/// Download completed successfully
Completed,
/// Download failed
Failed,
}
/// Download progress information
#[derive(Debug, Clone, Serialize)]
pub struct DownloadProgress {
/// Unique download ID
pub download_id: String,
/// Source URL
pub url: String,
/// Target filename
pub filename: String,
/// Bytes downloaded so far
pub bytes_downloaded: u64,
/// Total file size (None if unknown)
pub total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
pub progress_pct: Option<f32>,
/// Download status
pub status: DownloadStatus,
/// Error message if failed
pub error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_size_display() {
let info = ImageInfo::new(
"test".into(),
"test.iso".into(),
PathBuf::from("/tmp/test.iso"),
1024 * 1024 * 1024 * 2, // 2 GB
);
assert!(info.size_display().contains("GB"));
}
}

690
src/msd/ventoy_drive.rs Normal file
View File

@@ -0,0 +1,690 @@
//! Ventoy Virtual Drive
//!
//! Replaces FAT32 VirtualDrive with a Ventoy bootable image.
//! Provides a bootable USB with exFAT data partition for ISO files.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::info;
use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage};
use super::types::{DriveFile, DriveInfo};
use crate::error::{AppError, Result};
/// Chunk size for streaming reads (64 KB)
const STREAM_CHUNK_SIZE: usize = 64 * 1024;
/// Minimum drive size (1 GB) - Ventoy requires space for boot partition
const MIN_DRIVE_SIZE_MB: u32 = 1024;
/// Maximum drive size (128 GB)
const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024;
/// Default drive label
const DEFAULT_LABEL: &str = "ONE-KVM";
/// Ventoy Drive Manager
///
/// Thread-safe wrapper around VentoyImage providing async file operations.
/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous.
/// Uses RwLock to allow concurrent read operations while serializing writes.
pub struct VentoyDrive {
/// Drive image path
path: PathBuf,
/// RwLock for concurrent reads, exclusive writes
/// (ventoy-img-rs operations are synchronous and not thread-safe)
lock: Arc<RwLock<()>>,
}
impl VentoyDrive {
/// Create new Ventoy drive manager
pub fn new(path: PathBuf) -> Self {
Self {
path,
lock: Arc::new(RwLock::new(())),
}
}
/// Check if drive image exists
pub fn exists(&self) -> bool {
self.path.exists()
}
/// Get drive path
pub fn path(&self) -> &PathBuf {
&self.path
}
/// Initialize a new Ventoy drive image
///
/// Creates a bootable Ventoy image with the specified size.
/// The image includes boot partitions and an exFAT data partition.
pub async fn init(&self, size_mb: u32) -> Result<DriveInfo> {
let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB);
let size_str = format!("{}M", size_mb);
let path = self.path.clone();
let _lock = self.lock.write().await; // Write lock for initialization
info!("Creating {} MB Ventoy drive at {}", size_mb, path.display());
// Run Ventoy creation in blocking task
let info = tokio::task::spawn_blocking(move || {
VentoyImage::create(&path, &size_str, DEFAULT_LABEL)
.map_err(ventoy_to_app_error)?;
// Get file metadata for DriveInfo
let metadata = std::fs::metadata(&path).map_err(|e| {
AppError::Internal(format!("Failed to read drive metadata: {}", e))
})?;
Ok::<DriveInfo, AppError>(DriveInfo {
size: metadata.len(),
used: 0,
free: metadata.len(), // Approximate - exFAT overhead not calculated
initialized: true,
path,
})
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))??;
info!("Ventoy drive created successfully");
Ok(info)
}
/// Get drive information
pub async fn info(&self) -> Result<DriveInfo> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let _lock = self.lock.read().await; // Read lock for info query
tokio::task::spawn_blocking(move || {
let metadata = std::fs::metadata(&path).map_err(|e| {
AppError::Internal(format!("Failed to read drive metadata: {}", e))
})?;
// Open image to get file list and calculate used space
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
let files = image
.list_files_recursive()
.map_err(ventoy_to_app_error)?;
let used: u64 = files
.iter()
.filter(|f| !f.is_directory)
.map(|f| f.size)
.sum();
// Note: This is approximate since we don't have exact exFAT overhead
let size = metadata.len();
let free = size.saturating_sub(used);
Ok(DriveInfo {
size,
used,
free,
initialized: true,
path,
})
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// List files at a given path (or root if empty/"/")
pub async fn list_files(&self, dir_path: &str) -> Result<Vec<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.read().await; // Read lock for listing
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
let files = if dir_path.is_empty() || dir_path == "/" {
image.list_files()
} else {
image.list_files_at(&dir_path)
}
.map_err(ventoy_to_app_error)?;
Ok(files
.into_iter()
.map(|f| ventoy_file_to_drive_file(f, &dir_path))
.collect())
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Write a file to the drive from multipart upload (streaming)
///
/// Streams the file directly into the Ventoy image's exFAT partition.
pub async fn write_file_from_multipart_field(
&self,
file_path: &str,
mut field: axum::extract::multipart::Field<'_>,
) -> Result<u64> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, stream to a temporary file (to get the size)
let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp"));
let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4());
let temp_path = temp_dir.join(&temp_name);
// Stream upload to temp file
let mut temp_file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
let mut bytes_written: u64 = 0;
while let Some(chunk) = field.chunk().await.map_err(|e| {
AppError::Internal(format!("Failed to read upload chunk: {}", e))
})? {
bytes_written += chunk.len() as u64;
tokio::io::AsyncWriteExt::write_all(&mut temp_file, &chunk)
.await
.map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?;
}
tokio::io::AsyncWriteExt::flush(&mut temp_file)
.await
.map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?;
drop(temp_file);
// Now copy from temp file to Ventoy image
let path = self.path.clone();
let file_path = file_path.to_string();
let temp_path_clone = temp_path.clone();
let _lock = self.lock.write().await; // Write lock for file write
let result = tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use add_file_to_path which handles streaming internally
image
.add_file_to_path(
&temp_path_clone,
&file_path,
true, // create_parents
true, // overwrite
)
.map_err(ventoy_to_app_error)?;
Ok::<(), AppError>(())
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?;
// Cleanup temp file
let _ = tokio::fs::remove_file(&temp_path).await;
result?;
Ok(bytes_written)
}
/// Read a file from the drive (for download)
pub async fn read_file(&self, file_path: &str) -> Result<Vec<u8>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let file_path = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file read
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
image
.read_file(&file_path)
.map_err(ventoy_to_app_error)
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Get file information without reading content
///
/// Returns file size, name, and other metadata.
/// Returns None if the file doesn't exist.
pub async fn get_file_info(&self, file_path: &str) -> Result<Option<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let file_path_owned = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file info
let info = tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
image
.get_file_info(&file_path_owned)
.map_err(ventoy_to_app_error)
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))??;
Ok(info.map(|f| DriveFile {
name: f.name,
path: f.path,
size: f.size,
is_dir: f.is_directory,
modified: None,
}))
}
/// Read a file from the drive as a stream (for large file downloads)
///
/// Returns an async channel receiver that yields chunks of file data.
/// This avoids loading the entire file into memory.
pub async fn read_file_stream(
&self,
file_path: &str,
) -> Result<(
u64,
tokio::sync::mpsc::Receiver<std::result::Result<bytes::Bytes, std::io::Error>>,
)> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, get the file size
let file_info = self
.get_file_info(file_path)
.await?
.ok_or_else(|| AppError::NotFound(format!("File not found: {}", file_path)))?;
if file_info.is_dir {
return Err(AppError::BadRequest(format!(
"'{}' is a directory",
file_path
)));
}
let file_size = file_info.size;
let path = self.path.clone();
let file_path_owned = file_path.to_string();
let lock = self.lock.clone();
// Create a channel for streaming data
let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, std::io::Error>>(8);
// Spawn blocking task to read and send chunks
tokio::task::spawn_blocking(move || {
// Hold read lock for the entire read operation
let rt = tokio::runtime::Handle::current();
let _lock = rt.block_on(lock.read()); // Read lock for streaming
let image = match VentoyImage::open(&path) {
Ok(img) => img,
Err(e) => {
let _ = rt.block_on(tx.send(Err(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
))));
return;
}
};
// Create a channel writer that sends chunks
let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone());
// Stream the file through the writer
if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) {
let _ = rt.block_on(tx.send(Err(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
))));
}
});
Ok((file_size, rx))
}
/// Create a directory
pub async fn mkdir(&self, dir_path: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.write().await; // Write lock for mkdir
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
image
.create_directory(&dir_path, true)
.map_err(ventoy_to_app_error)
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Delete a file or directory
pub async fn delete(&self, path_to_delete: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let path_to_delete = path_to_delete.to_string();
let _lock = self.lock.write().await; // Write lock for delete
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use recursive delete to handle directories
image
.remove_recursive(&path_to_delete)
.map_err(ventoy_to_app_error)
})
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
}
/// Convert VentoyError to AppError
fn ventoy_to_app_error(err: VentoyError) -> AppError {
match err {
VentoyError::Io(e) => AppError::Io(e),
VentoyError::InvalidSize(s) => AppError::BadRequest(format!("Invalid size: {}", s)),
VentoyError::SizeParseError(s) => {
AppError::BadRequest(format!("Size parse error: {}", s))
}
VentoyError::FilesystemError(s) => {
AppError::Internal(format!("Filesystem error: {}", s))
}
VentoyError::ImageError(s) => AppError::Internal(format!("Image error: {}", s)),
VentoyError::FileNotFound(s) => AppError::NotFound(format!("File not found: {}", s)),
VentoyError::ResourceNotFound(s) => {
AppError::Internal(format!("Resource not found: {}", s))
}
VentoyError::PartitionError(s) => {
AppError::Internal(format!("Partition error: {}", s))
}
}
}
/// Convert VentoyFileInfo to DriveFile
fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile {
let full_path = if parent_path.is_empty() || parent_path == "/" {
format!("/{}", info.name)
} else {
format!("{}/{}", parent_path.trim_end_matches('/'), info.name)
};
DriveFile {
name: info.name,
path: full_path,
size: info.size,
is_dir: info.is_directory,
modified: None, // Ventoy FileInfo doesn't include timestamps
}
}
/// A writer that sends chunks to an async channel
///
/// This bridges the sync Write trait with async channels for streaming.
struct ChannelWriter {
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
rt: tokio::runtime::Handle,
buffer: Vec<u8>,
}
impl ChannelWriter {
fn new(
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
rt: tokio::runtime::Handle,
) -> Self {
Self {
tx,
rt,
buffer: Vec::with_capacity(STREAM_CHUNK_SIZE),
}
}
fn flush_buffer(&mut self) -> std::io::Result<()> {
if self.buffer.is_empty() {
return Ok(());
}
let chunk = bytes::Bytes::copy_from_slice(&self.buffer);
self.buffer.clear();
self.rt
.block_on(self.tx.send(Ok(chunk)))
.map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel closed"))
}
}
impl std::io::Write for ChannelWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut written = 0;
while written < buf.len() {
let space = STREAM_CHUNK_SIZE - self.buffer.len();
let to_copy = std::cmp::min(space, buf.len() - written);
self.buffer.extend_from_slice(&buf[written..written + to_copy]);
written += to_copy;
if self.buffer.len() >= STREAM_CHUNK_SIZE {
self.flush_buffer()?;
}
}
Ok(written)
}
fn flush(&mut self) -> std::io::Result<()> {
self.flush_buffer()
}
}
impl Drop for ChannelWriter {
fn drop(&mut self) {
// Flush any remaining data when the writer is dropped
let _ = self.flush_buffer();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_drive_init() {
let temp_dir = TempDir::new().unwrap();
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path);
let info = drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
assert!(info.initialized);
assert!(drive.exists());
}
#[tokio::test]
async fn test_drive_mkdir() {
let temp_dir = TempDir::new().unwrap();
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path);
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
drive.mkdir("/isos").await.unwrap();
let files = drive.list_files("/").await.unwrap();
assert_eq!(files.len(), 1);
assert!(files[0].is_dir);
assert_eq!(files[0].name, "isos");
}
#[tokio::test]
async fn test_drive_file_write_and_read() {
let temp_dir = TempDir::new().unwrap();
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Write a test file
let test_content = b"Hello, Ventoy!";
let test_file_path = temp_dir.path().join("test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive using ventoy-img directly
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
image.add_file(&test_file_path).unwrap();
})
.await
.unwrap();
// Read file from drive
let read_data = drive.read_file("/test.txt").await.unwrap();
assert_eq!(read_data, test_content);
}
#[tokio::test]
async fn test_drive_get_file_info() {
let temp_dir = TempDir::new().unwrap();
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a directory
drive.mkdir("/mydir").await.unwrap();
// Write a test file
let test_content = b"Test file content for info check";
let test_file_path = temp_dir.path().join("info_test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
image.add_file(&test_file_path).unwrap();
})
.await
.unwrap();
// Test get_file_info for file
let file_info = drive.get_file_info("/info_test.txt").await.unwrap();
assert!(file_info.is_some());
let file_info = file_info.unwrap();
assert_eq!(file_info.name, "info_test.txt");
assert_eq!(file_info.size, test_content.len() as u64);
assert!(!file_info.is_dir);
// Test get_file_info for directory
let dir_info = drive.get_file_info("/mydir").await.unwrap();
assert!(dir_info.is_some());
let dir_info = dir_info.unwrap();
assert_eq!(dir_info.name, "mydir");
assert!(dir_info.is_dir);
// Test get_file_info for non-existent file
let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap();
assert!(not_found.is_none());
}
#[tokio::test]
async fn test_drive_stream_read() {
let temp_dir = TempDir::new().unwrap();
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create test data that spans multiple chunks (>64KB)
let test_size = 200 * 1024; // 200 KB
let test_content: Vec<u8> = (0..test_size).map(|i| (i % 256) as u8).collect();
let test_file_path = temp_dir.path().join("large_file.bin");
std::fs::write(&test_file_path, &test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
let file_path_clone = test_file_path.clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
image.add_file(&file_path_clone).unwrap();
})
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap();
assert_eq!(file_size, test_size as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.len(), test_content.len());
assert_eq!(received_data, test_content);
}
#[tokio::test]
async fn test_drive_stream_read_small_file() {
let temp_dir = TempDir::new().unwrap();
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a small test file
let test_content = b"Small file for streaming test";
let test_file_path = temp_dir.path().join("small.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
image.add_file(&test_file_path).unwrap();
})
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap();
assert_eq!(file_size, test_content.len() as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.as_slice(), test_content);
}
}

138
src/otg/configfs.rs Normal file
View File

@@ -0,0 +1,138 @@
//! ConfigFS file operations for USB Gadget
use std::fs::{self, File, OpenOptions};
use std::io::Write;
use std::path::Path;
use crate::error::{AppError, Result};
/// ConfigFS base path for USB gadgets
pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget";
/// Default gadget name
pub const DEFAULT_GADGET_NAME: &str = "one-kvm";
/// USB Vendor ID (Linux Foundation)
pub const USB_VENDOR_ID: u16 = 0x1d6b;
/// USB Product ID (Multifunction Composite Gadget)
pub const USB_PRODUCT_ID: u16 = 0x0104;
/// USB device version
pub const USB_BCD_DEVICE: u16 = 0x0100;
/// USB spec version (USB 2.0)
pub const USB_BCD_USB: u16 = 0x0200;
/// Check if ConfigFS is available
pub fn is_configfs_available() -> bool {
Path::new(CONFIGFS_PATH).exists()
}
/// Find available UDC (USB Device Controller)
pub fn find_udc() -> Option<String> {
let udc_path = Path::new("/sys/class/udc");
if !udc_path.exists() {
return None;
}
fs::read_dir(udc_path)
.ok()?
.filter_map(|e| e.ok())
.map(|e| e.file_name().to_string_lossy().to_string())
.next()
}
/// Write string content to a file
///
/// For sysfs files, this function appends a newline and flushes
/// to ensure the kernel processes the write immediately.
///
/// IMPORTANT: sysfs attributes require a single atomic write() syscall.
/// The kernel processes the value on the first write(), so we must
/// build the complete buffer (including newline) before writing.
pub fn write_file(path: &Path, content: &str) -> Result<()> {
// For sysfs files (especially write-only ones like forced_eject),
// we need to use simple O_WRONLY without O_TRUNC
// O_TRUNC may fail on special files or require read permission
let mut file = OpenOptions::new()
.write(true)
.open(path)
.or_else(|e| {
// If open fails, try create (for regular files)
if path.exists() {
Err(e)
} else {
File::create(path)
}
})
.map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?;
// Build complete buffer with newline, then write in single syscall.
// This is critical for sysfs - multiple write() calls may cause
// the kernel to only process partial data or return EINVAL.
let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') {
content.as_bytes().into()
} else {
let mut buf = content.as_bytes().to_vec();
buf.push(b'\n');
buf.into()
};
file.write_all(&data)
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
// Explicitly flush to ensure sysfs processes the write
file.flush()
.map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?;
Ok(())
}
/// Write binary content to a file
pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
let mut file = File::create(path)
.map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?;
file.write_all(data)
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
Ok(())
}
/// Read string content from a file
pub fn read_file(path: &Path) -> Result<String> {
fs::read_to_string(path)
.map(|s| s.trim().to_string())
.map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e)))
}
/// Create directory if not exists
pub fn create_dir(path: &Path) -> Result<()> {
fs::create_dir_all(path)
.map_err(|e| AppError::Internal(format!("Failed to create directory {}: {}", path.display(), e)))
}
/// Remove directory
pub fn remove_dir(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_dir(path)
.map_err(|e| AppError::Internal(format!("Failed to remove directory {}: {}", path.display(), e)))?;
}
Ok(())
}
/// Remove file
pub fn remove_file(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_file(path)
.map_err(|e| AppError::Internal(format!("Failed to remove file {}: {}", path.display(), e)))?;
}
Ok(())
}
/// Create symlink
pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> {
std::os::unix::fs::symlink(src, dest)
.map_err(|e| AppError::Internal(format!("Failed to create symlink {} -> {}: {}", dest.display(), src.display(), e)))
}

91
src/otg/endpoint.rs Normal file
View File

@@ -0,0 +1,91 @@
//! USB Endpoint allocation management
use crate::error::{AppError, Result};
/// Default maximum endpoints for typical UDC
pub const DEFAULT_MAX_ENDPOINTS: u8 = 16;
/// Endpoint allocator - manages UDC endpoint resources
#[derive(Debug, Clone)]
pub struct EndpointAllocator {
max_endpoints: u8,
used_endpoints: u8,
}
impl EndpointAllocator {
/// Create a new endpoint allocator
pub fn new(max_endpoints: u8) -> Self {
Self {
max_endpoints,
used_endpoints: 0,
}
}
/// Allocate endpoints for a function
pub fn allocate(&mut self, count: u8) -> Result<()> {
if self.used_endpoints + count > self.max_endpoints {
return Err(AppError::Internal(format!(
"Not enough endpoints: need {}, available {}",
count,
self.available()
)));
}
self.used_endpoints += count;
Ok(())
}
/// Release endpoints
pub fn release(&mut self, count: u8) {
self.used_endpoints = self.used_endpoints.saturating_sub(count);
}
/// Get available endpoint count
pub fn available(&self) -> u8 {
self.max_endpoints.saturating_sub(self.used_endpoints)
}
/// Get used endpoint count
pub fn used(&self) -> u8 {
self.used_endpoints
}
/// Get maximum endpoint count
pub fn max(&self) -> u8 {
self.max_endpoints
}
/// Check if can allocate
pub fn can_allocate(&self, count: u8) -> bool {
self.available() >= count
}
}
impl Default for EndpointAllocator {
fn default() -> Self {
Self::new(DEFAULT_MAX_ENDPOINTS)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocator() {
let mut alloc = EndpointAllocator::new(8);
assert_eq!(alloc.available(), 8);
alloc.allocate(2).unwrap();
assert_eq!(alloc.available(), 6);
assert_eq!(alloc.used(), 2);
alloc.allocate(4).unwrap();
assert_eq!(alloc.available(), 2);
// Should fail - not enough endpoints
assert!(alloc.allocate(3).is_err());
alloc.release(2);
assert_eq!(alloc.available(), 4);
}
}

42
src/otg/function.rs Normal file
View File

@@ -0,0 +1,42 @@
//! USB Gadget Function trait definition
use std::path::Path;
use crate::error::Result;
/// Function metadata
#[derive(Debug, Clone)]
pub struct FunctionMeta {
/// Function name (e.g., "hid.usb0")
pub name: String,
/// Human-readable description
pub description: String,
/// Number of endpoints used
pub endpoints: u8,
/// Whether the function is enabled
pub enabled: bool,
}
/// USB Gadget Function trait
pub trait GadgetFunction: Send + Sync {
/// Get function name (e.g., "hid.usb0", "mass_storage.usb0")
fn name(&self) -> &str;
/// Get number of endpoints required
fn endpoints_required(&self) -> u8;
/// Get function metadata
fn meta(&self) -> FunctionMeta;
/// Create function directory and configuration in ConfigFS
fn create(&self, gadget_path: &Path) -> Result<()>;
/// Link function to configuration
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()>;
/// Unlink function from configuration
fn unlink(&self, config_path: &Path) -> Result<()>;
/// Cleanup function directory
fn cleanup(&self, gadget_path: &Path) -> Result<()>;
}

226
src/otg/hid.rs Normal file
View File

@@ -0,0 +1,226 @@
//! HID Function implementation for USB Gadget
use std::path::{Path, PathBuf};
use tracing::debug;
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file};
use super::function::{FunctionMeta, GadgetFunction};
use super::report_desc::{KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
use crate::error::Result;
/// HID function type
#[derive(Debug, Clone)]
pub enum HidFunctionType {
/// Keyboard with LED feedback support
/// Uses 2 endpoints: IN (reports) + OUT (LED status)
Keyboard,
/// Relative mouse (traditional mouse movement)
/// Uses 1 endpoint: IN
MouseRelative,
/// Absolute mouse (touchscreen-like positioning)
/// Uses 1 endpoint: IN
MouseAbsolute,
}
impl HidFunctionType {
/// Get endpoints required for this function type
pub fn endpoints(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 2, // IN + OUT for LED
HidFunctionType::MouseRelative => 1,
HidFunctionType::MouseAbsolute => 1,
}
}
/// Get HID protocol
pub fn protocol(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1, // Keyboard
HidFunctionType::MouseRelative => 2, // Mouse
HidFunctionType::MouseAbsolute => 2, // Mouse
}
}
/// Get HID subclass
pub fn subclass(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1, // Boot interface
HidFunctionType::MouseRelative => 1, // Boot interface
HidFunctionType::MouseAbsolute => 0, // No boot interface (absolute not in boot protocol)
}
}
/// Get report length in bytes
pub fn report_length(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 8,
HidFunctionType::MouseRelative => 4,
HidFunctionType::MouseAbsolute => 6,
}
}
/// Get report descriptor
pub fn report_desc(&self) -> &'static [u8] {
match self {
HidFunctionType::Keyboard => KEYBOARD_WITH_LED,
HidFunctionType::MouseRelative => MOUSE_RELATIVE,
HidFunctionType::MouseAbsolute => MOUSE_ABSOLUTE,
}
}
/// Get description
pub fn description(&self) -> &'static str {
match self {
HidFunctionType::Keyboard => "Keyboard",
HidFunctionType::MouseRelative => "Relative Mouse",
HidFunctionType::MouseAbsolute => "Absolute Mouse",
}
}
}
/// HID Function for USB Gadget
#[derive(Debug, Clone)]
pub struct HidFunction {
/// Instance number (usb0, usb1, ...)
instance: u8,
/// Function type
func_type: HidFunctionType,
/// Cached function name (avoids repeated allocation)
name: String,
}
impl HidFunction {
/// Create a keyboard function
pub fn keyboard(instance: u8) -> Self {
Self {
instance,
func_type: HidFunctionType::Keyboard,
name: format!("hid.usb{}", instance),
}
}
/// Create a relative mouse function
pub fn mouse_relative(instance: u8) -> Self {
Self {
instance,
func_type: HidFunctionType::MouseRelative,
name: format!("hid.usb{}", instance),
}
}
/// Create an absolute mouse function
pub fn mouse_absolute(instance: u8) -> Self {
Self {
instance,
func_type: HidFunctionType::MouseAbsolute,
name: format!("hid.usb{}", instance),
}
}
/// Get function path in gadget
fn function_path(&self, gadget_path: &Path) -> PathBuf {
gadget_path.join("functions").join(self.name())
}
/// Get expected device path (e.g., /dev/hidg0)
pub fn device_path(&self) -> PathBuf {
PathBuf::from(format!("/dev/hidg{}", self.instance))
}
}
impl GadgetFunction for HidFunction {
fn name(&self) -> &str {
&self.name
}
fn endpoints_required(&self) -> u8 {
self.func_type.endpoints()
}
fn meta(&self) -> FunctionMeta {
FunctionMeta {
name: self.name().to_string(),
description: self.func_type.description().to_string(),
endpoints: self.endpoints_required(),
enabled: true,
}
}
fn create(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
create_dir(&func_path)?;
// Set HID parameters
write_file(&func_path.join("protocol"), &self.func_type.protocol().to_string())?;
write_file(&func_path.join("subclass"), &self.func_type.subclass().to_string())?;
write_file(&func_path.join("report_length"), &self.func_type.report_length().to_string())?;
// For keyboard, enable OUT endpoint for LED feedback
// no_out_endpoint: 0 = enable OUT endpoint, 1 = disable
if matches!(self.func_type, HidFunctionType::Keyboard) {
let no_out_path = func_path.join("no_out_endpoint");
if no_out_path.exists() || func_path.exists() {
// Try to write, ignore error if file doesn't exist yet
let _ = write_file(&no_out_path, "0");
}
}
// Write report descriptor
write_bytes(&func_path.join("report_desc"), self.func_type.report_desc())?;
debug!("Created HID function: {} at {}", self.name(), func_path.display());
Ok(())
}
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
let link_path = config_path.join(self.name());
if !link_path.exists() {
create_symlink(&func_path, &link_path)?;
debug!("Linked HID function {} to config", self.name());
}
Ok(())
}
fn unlink(&self, config_path: &Path) -> Result<()> {
let link_path = config_path.join(self.name());
remove_file(&link_path)?;
debug!("Unlinked HID function {}", self.name());
Ok(())
}
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
remove_dir(&func_path)?;
debug!("Cleaned up HID function {}", self.name());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hid_function_types() {
assert_eq!(HidFunctionType::Keyboard.endpoints(), 2);
assert_eq!(HidFunctionType::MouseRelative.endpoints(), 1);
assert_eq!(HidFunctionType::MouseAbsolute.endpoints(), 1);
assert_eq!(HidFunctionType::Keyboard.report_length(), 8);
assert_eq!(HidFunctionType::MouseRelative.report_length(), 4);
assert_eq!(HidFunctionType::MouseAbsolute.report_length(), 6);
}
#[test]
fn test_hid_function_names() {
let kb = HidFunction::keyboard(0);
assert_eq!(kb.name(), "hid.usb0");
assert_eq!(kb.device_path(), PathBuf::from("/dev/hidg0"));
let mouse = HidFunction::mouse_relative(1);
assert_eq!(mouse.name(), "hid.usb1");
}
}

394
src/otg/manager.rs Normal file
View File

@@ -0,0 +1,394 @@
//! OTG Gadget Manager - unified management for USB Gadget functions
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tracing::{debug, error, info, warn};
use super::configfs::{
create_dir, find_udc, is_configfs_available, remove_dir, write_file, CONFIGFS_PATH,
DEFAULT_GADGET_NAME, USB_BCD_DEVICE, USB_BCD_USB, USB_PRODUCT_ID, USB_VENDOR_ID,
};
use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS};
use super::function::{FunctionMeta, GadgetFunction};
use super::hid::HidFunction;
use super::msd::MsdFunction;
use crate::error::{AppError, Result};
/// OTG Gadget Manager - unified management for HID and MSD
pub struct OtgGadgetManager {
/// Gadget name
gadget_name: String,
/// Gadget path in ConfigFS
gadget_path: PathBuf,
/// Configuration path
config_path: PathBuf,
/// Endpoint allocator
endpoint_allocator: EndpointAllocator,
/// HID instance counter
hid_instance: u8,
/// MSD instance counter
msd_instance: u8,
/// Registered functions
functions: Vec<Box<dyn GadgetFunction>>,
/// Function metadata
meta: HashMap<String, FunctionMeta>,
/// Bound UDC name
bound_udc: Option<String>,
/// Whether gadget was created by us
created_by_us: bool,
}
impl OtgGadgetManager {
/// Create a new gadget manager with default settings
pub fn new() -> Self {
Self::with_config(DEFAULT_GADGET_NAME, DEFAULT_MAX_ENDPOINTS)
}
/// Create a new gadget manager with custom configuration
pub fn with_config(gadget_name: &str, max_endpoints: u8) -> Self {
let gadget_path = PathBuf::from(CONFIGFS_PATH).join(gadget_name);
let config_path = gadget_path.join("configs/c.1");
Self {
gadget_name: gadget_name.to_string(),
gadget_path,
config_path,
endpoint_allocator: EndpointAllocator::new(max_endpoints),
hid_instance: 0,
msd_instance: 0,
// Pre-allocate for typical use: 3 HID (keyboard, rel mouse, abs mouse) + 1 MSD
functions: Vec::with_capacity(4),
meta: HashMap::with_capacity(4),
bound_udc: None,
created_by_us: false,
}
}
/// Check if ConfigFS is available
pub fn is_available() -> bool {
is_configfs_available()
}
/// Find available UDC
pub fn find_udc() -> Option<String> {
find_udc()
}
/// Check if gadget exists
pub fn gadget_exists(&self) -> bool {
self.gadget_path.exists()
}
/// Check if gadget is bound to UDC
pub fn is_bound(&self) -> bool {
let udc_file = self.gadget_path.join("UDC");
if let Ok(content) = fs::read_to_string(&udc_file) {
!content.trim().is_empty()
} else {
false
}
}
/// Add keyboard function
/// Returns the expected device path (e.g., /dev/hidg0)
pub fn add_keyboard(&mut self) -> Result<PathBuf> {
let func = HidFunction::keyboard(self.hid_instance);
let device_path = func.device_path();
self.add_function(Box::new(func))?;
self.hid_instance += 1;
Ok(device_path)
}
/// Add relative mouse function
pub fn add_mouse_relative(&mut self) -> Result<PathBuf> {
let func = HidFunction::mouse_relative(self.hid_instance);
let device_path = func.device_path();
self.add_function(Box::new(func))?;
self.hid_instance += 1;
Ok(device_path)
}
/// Add absolute mouse function
pub fn add_mouse_absolute(&mut self) -> Result<PathBuf> {
let func = HidFunction::mouse_absolute(self.hid_instance);
let device_path = func.device_path();
self.add_function(Box::new(func))?;
self.hid_instance += 1;
Ok(device_path)
}
/// Add MSD function (returns MsdFunction handle for LUN configuration)
pub fn add_msd(&mut self) -> Result<MsdFunction> {
let func = MsdFunction::new(self.msd_instance);
let func_clone = func.clone();
self.add_function(Box::new(func))?;
self.msd_instance += 1;
Ok(func_clone)
}
/// Add a generic function
fn add_function(&mut self, func: Box<dyn GadgetFunction>) -> Result<()> {
let endpoints = func.endpoints_required();
// Check endpoint availability
if !self.endpoint_allocator.can_allocate(endpoints) {
return Err(AppError::Internal(format!(
"Not enough endpoints for function {}: need {}, available {}",
func.name(),
endpoints,
self.endpoint_allocator.available()
)));
}
// Allocate endpoints
self.endpoint_allocator.allocate(endpoints)?;
// Store metadata
self.meta.insert(func.name().to_string(), func.meta());
// Store function
self.functions.push(func);
Ok(())
}
/// Setup the gadget (create directories and configure)
pub fn setup(&mut self) -> Result<()> {
info!("Setting up OTG USB Gadget: {}", self.gadget_name);
// Check ConfigFS availability
if !Self::is_available() {
return Err(AppError::Internal(
"ConfigFS not available. Is it mounted at /sys/kernel/config?".to_string(),
));
}
// Check if gadget already exists and is bound
if self.gadget_exists() {
if self.is_bound() {
info!("Gadget already exists and is bound, skipping setup");
return Ok(());
}
warn!("Gadget exists but not bound, will reconfigure");
self.cleanup()?;
}
// Create gadget directory
create_dir(&self.gadget_path)?;
self.created_by_us = true;
// Set device descriptors
self.set_device_descriptors()?;
// Create strings
self.create_strings()?;
// Create configuration
self.create_configuration()?;
// Create and link all functions
for func in &self.functions {
func.create(&self.gadget_path)?;
func.link(&self.config_path, &self.gadget_path)?;
}
info!("OTG USB Gadget setup complete");
Ok(())
}
/// Bind gadget to UDC
pub fn bind(&mut self) -> Result<()> {
let udc = Self::find_udc().ok_or_else(|| {
AppError::Internal("No USB Device Controller (UDC) found".to_string())
})?;
info!("Binding gadget to UDC: {}", udc);
write_file(&self.gadget_path.join("UDC"), &udc)?;
self.bound_udc = Some(udc);
Ok(())
}
/// Unbind gadget from UDC
pub fn unbind(&mut self) -> Result<()> {
if self.is_bound() {
write_file(&self.gadget_path.join("UDC"), "")?;
self.bound_udc = None;
info!("Unbound gadget from UDC");
}
Ok(())
}
/// Cleanup all resources
pub fn cleanup(&mut self) -> Result<()> {
if !self.gadget_exists() {
return Ok(());
}
info!("Cleaning up OTG USB Gadget: {}", self.gadget_name);
// Unbind from UDC first
let _ = self.unbind();
// Unlink and cleanup functions
for func in self.functions.iter().rev() {
let _ = func.unlink(&self.config_path);
}
// Remove config strings
let config_strings = self.config_path.join("strings/0x409");
let _ = remove_dir(&config_strings);
let _ = remove_dir(&self.config_path);
// Cleanup functions
for func in self.functions.iter().rev() {
let _ = func.cleanup(&self.gadget_path);
}
// Remove gadget strings
let gadget_strings = self.gadget_path.join("strings/0x409");
let _ = remove_dir(&gadget_strings);
// Remove gadget directory
if let Err(e) = remove_dir(&self.gadget_path) {
warn!("Could not remove gadget directory: {}", e);
}
self.created_by_us = false;
info!("OTG USB Gadget cleanup complete");
Ok(())
}
/// Set USB device descriptors
fn set_device_descriptors(&self) -> Result<()> {
write_file(&self.gadget_path.join("idVendor"), &format!("0x{:04x}", USB_VENDOR_ID))?;
write_file(&self.gadget_path.join("idProduct"), &format!("0x{:04x}", USB_PRODUCT_ID))?;
write_file(&self.gadget_path.join("bcdDevice"), &format!("0x{:04x}", USB_BCD_DEVICE))?;
write_file(&self.gadget_path.join("bcdUSB"), &format!("0x{:04x}", USB_BCD_USB))?;
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device
write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?;
write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?;
debug!("Set device descriptors");
Ok(())
}
/// Create USB strings
fn create_strings(&self) -> Result<()> {
let strings_path = self.gadget_path.join("strings/0x409");
create_dir(&strings_path)?;
write_file(&strings_path.join("serialnumber"), "0123456789")?;
write_file(&strings_path.join("manufacturer"), "One-KVM")?;
write_file(&strings_path.join("product"), "One-KVM HID Device")?;
debug!("Created USB strings");
Ok(())
}
/// Create configuration
fn create_configuration(&self) -> Result<()> {
create_dir(&self.config_path)?;
// Create config strings
let strings_path = self.config_path.join("strings/0x409");
create_dir(&strings_path)?;
write_file(&strings_path.join("configuration"), "Config 1: HID + MSD")?;
// Set max power (500mA)
write_file(&self.config_path.join("MaxPower"), "500")?;
debug!("Created configuration c.1");
Ok(())
}
/// Get function metadata
pub fn get_meta(&self) -> &HashMap<String, FunctionMeta> {
&self.meta
}
/// Get endpoint usage info
pub fn endpoint_info(&self) -> (u8, u8) {
(self.endpoint_allocator.used(), self.endpoint_allocator.max())
}
/// Get gadget path
pub fn gadget_path(&self) -> &PathBuf {
&self.gadget_path
}
}
impl Default for OtgGadgetManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for OtgGadgetManager {
fn drop(&mut self) {
if self.created_by_us {
if let Err(e) = self.cleanup() {
error!("Failed to cleanup OTG gadget on drop: {}", e);
}
}
}
}
/// Wait for HID devices to become available
///
/// Uses exponential backoff starting from 10ms, capped at 100ms,
/// to reduce CPU usage while still providing fast response.
pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
// Exponential backoff: start at 10ms, double each time, cap at 100ms
let mut delay_ms = 10u64;
const MAX_DELAY_MS: u64 = 100;
while start.elapsed() < timeout {
if device_paths.iter().all(|p| p.exists()) {
return true;
}
// Calculate remaining time to avoid overshooting timeout
let remaining = timeout.saturating_sub(start.elapsed());
let sleep_duration = std::time::Duration::from_millis(delay_ms).min(remaining);
if sleep_duration.is_zero() {
break;
}
tokio::time::sleep(sleep_duration).await;
// Exponential backoff with cap
delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manager_creation() {
let manager = OtgGadgetManager::new();
assert_eq!(manager.gadget_name, DEFAULT_GADGET_NAME);
assert!(!manager.gadget_exists()); // Won't exist in test environment
}
#[test]
fn test_endpoint_tracking() {
let mut manager = OtgGadgetManager::with_config("test", 8);
// Keyboard uses 2 endpoints
let _ = manager.add_keyboard();
assert_eq!(manager.endpoint_allocator.used(), 2);
// Mouse uses 1 endpoint each
let _ = manager.add_mouse_relative();
let _ = manager.add_mouse_absolute();
assert_eq!(manager.endpoint_allocator.used(), 4);
}
}

35
src/otg/mod.rs Normal file
View File

@@ -0,0 +1,35 @@
//! OTG USB Gadget unified management module
//!
//! This module provides unified management for USB Gadget functions:
//! - HID (Keyboard, Mouse)
//! - MSD (Mass Storage Device)
//!
//! Architecture:
//! ```text
//! OtgService (high-level coordination)
//! └── OtgGadgetManager (gadget lifecycle)
//! ├── EndpointAllocator (manages UDC endpoints)
//! ├── HidFunction (keyboard, mouse_rel, mouse_abs)
//! └── MsdFunction (mass storage)
//! ```
//!
//! The recommended way to use this module is through `OtgService`, which provides
//! a high-level interface for enabling/disabling HID and MSD functions independently.
//! Both `HidController` and `MsdController` should share the same `OtgService` instance.
pub mod configfs;
pub mod endpoint;
pub mod function;
pub mod hid;
pub mod manager;
pub mod msd;
pub mod report_desc;
pub mod service;
pub use endpoint::EndpointAllocator;
pub use function::{FunctionMeta, GadgetFunction};
pub use hid::{HidFunction, HidFunctionType};
pub use manager::{wait_for_hid_devices, OtgGadgetManager};
pub use msd::{MsdFunction, MsdLunConfig};
pub use report_desc::{KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
pub use service::{HidDevicePaths, OtgService, OtgServiceState};

411
src/otg/msd.rs Normal file
View File

@@ -0,0 +1,411 @@
//! MSD (Mass Storage Device) Function implementation for USB Gadget
use std::fs;
use std::path::{Path, PathBuf};
use tracing::{debug, info, warn};
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_file};
use super::function::{FunctionMeta, GadgetFunction};
use crate::error::{AppError, Result};
/// MSD LUN configuration
#[derive(Debug, Clone)]
pub struct MsdLunConfig {
/// File/image path to expose
pub file: PathBuf,
/// Mount as CD-ROM
pub cdrom: bool,
/// Read-only mode
pub ro: bool,
/// Removable media
pub removable: bool,
/// Disable Force Unit Access
pub nofua: bool,
}
impl Default for MsdLunConfig {
fn default() -> Self {
Self {
file: PathBuf::new(),
cdrom: false,
ro: false,
removable: true,
nofua: true,
}
}
}
impl MsdLunConfig {
/// Create CD-ROM configuration
pub fn cdrom(file: PathBuf) -> Self {
Self {
file,
cdrom: true,
ro: true,
removable: true,
nofua: true,
}
}
/// Create disk configuration
pub fn disk(file: PathBuf, read_only: bool) -> Self {
Self {
file,
cdrom: false,
ro: read_only,
removable: true,
nofua: true,
}
}
}
/// MSD Function for USB Gadget
#[derive(Debug, Clone)]
pub struct MsdFunction {
/// Instance number (usb0, usb1, ...)
instance: u8,
/// Cached function name (avoids repeated allocation)
name: String,
}
impl MsdFunction {
/// Create a new MSD function
pub fn new(instance: u8) -> Self {
Self {
instance,
name: format!("mass_storage.usb{}", instance),
}
}
/// Get function path in gadget
fn function_path(&self, gadget_path: &Path) -> PathBuf {
gadget_path.join("functions").join(self.name())
}
/// Get LUN path
fn lun_path(&self, gadget_path: &Path, lun: u8) -> PathBuf {
self.function_path(gadget_path).join(format!("lun.{}", lun))
}
/// Configure a LUN with specified settings (async version)
///
/// This is the preferred method for async contexts. It runs the blocking
/// file I/O and USB timing operations in a separate thread pool.
pub async fn configure_lun_async(
&self,
gadget_path: &Path,
lun: u8,
config: &MsdLunConfig,
) -> Result<()> {
let gadget_path = gadget_path.to_path_buf();
let config = config.clone();
let this = self.clone();
tokio::task::spawn_blocking(move || this.configure_lun(&gadget_path, lun, &config))
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Configure a LUN with specified settings
/// Note: This should be called after the gadget is set up
///
/// This implementation is based on PiKVM's MSD drive configuration.
/// Key improvements:
/// - Uses forced_eject when available (safer than clearing file directly)
/// - Reduced sleep times to minimize HID interference
/// - Better retry logic for EBUSY errors
///
/// **Note**: This is a blocking function. In async contexts, prefer
/// `configure_lun_async` to avoid blocking the runtime.
pub fn configure_lun(&self, gadget_path: &Path, lun: u8, config: &MsdLunConfig) -> Result<()> {
let lun_path = self.lun_path(gadget_path, lun);
if !lun_path.exists() {
create_dir(&lun_path)?;
}
// Batch read all current values to minimize syscalls
let read_attr = |attr: &str| -> String {
fs::read_to_string(lun_path.join(attr))
.unwrap_or_default()
.trim()
.to_string()
};
let current_cdrom = read_attr("cdrom");
let current_ro = read_attr("ro");
let current_removable = read_attr("removable");
let current_nofua = read_attr("nofua");
// Prepare new values
let new_cdrom = if config.cdrom { "1" } else { "0" };
let new_ro = if config.ro { "1" } else { "0" };
let new_removable = if config.removable { "1" } else { "0" };
let new_nofua = if config.nofua { "1" } else { "0" };
// Disconnect current file first using forced_eject if available (PiKVM approach)
let forced_eject_path = lun_path.join("forced_eject");
if forced_eject_path.exists() {
// forced_eject is safer - it forcibly detaches regardless of host state
debug!("Using forced_eject to clear LUN {}", lun);
let _ = write_file(&forced_eject_path, "1");
} else {
// Fallback to clearing file directly
let _ = write_file(&lun_path.join("file"), "");
}
// Brief yield to allow USB stack to process the disconnect
// Reduced from 200ms to 50ms - let USB protocol handle timing
std::thread::sleep(std::time::Duration::from_millis(50));
// Write only changed attributes
let cdrom_changed = current_cdrom != new_cdrom;
if cdrom_changed {
debug!("Updating LUN {} cdrom: {} -> {}", lun, current_cdrom, new_cdrom);
write_file(&lun_path.join("cdrom"), new_cdrom)?;
}
if current_ro != new_ro {
debug!("Updating LUN {} ro: {} -> {}", lun, current_ro, new_ro);
write_file(&lun_path.join("ro"), new_ro)?;
}
if current_removable != new_removable {
debug!("Updating LUN {} removable: {} -> {}", lun, current_removable, new_removable);
write_file(&lun_path.join("removable"), new_removable)?;
}
if current_nofua != new_nofua {
debug!("Updating LUN {} nofua: {} -> {}", lun, current_nofua, new_nofua);
write_file(&lun_path.join("nofua"), new_nofua)?;
}
// If cdrom mode changed, brief yield for USB host
if cdrom_changed {
debug!("CDROM mode changed, brief yield for USB host");
std::thread::sleep(std::time::Duration::from_millis(50));
}
// Set file path (this triggers the actual mount) - with retry on EBUSY
if config.file.exists() {
let file_path = config.file.to_string_lossy();
let mut last_error = None;
for attempt in 0..5 {
match write_file(&lun_path.join("file"), file_path.as_ref()) {
Ok(_) => {
info!(
"LUN {} configured with file: {} (cdrom={}, ro={})",
lun,
config.file.display(),
config.cdrom,
config.ro
);
return Ok(());
}
Err(e) => {
// Check if it's EBUSY (error code 16)
let is_busy = e.to_string().contains("Device or resource busy")
|| e.to_string().contains("os error 16");
if is_busy && attempt < 4 {
warn!(
"LUN {} file write busy, retrying (attempt {}/5)",
lun,
attempt + 1
);
// Exponential backoff: 50, 100, 200, 400ms
std::thread::sleep(std::time::Duration::from_millis(50 << attempt));
last_error = Some(e);
continue;
}
return Err(e);
}
}
}
// If we get here, all retries failed
if let Some(e) = last_error {
return Err(e);
}
} else if !config.file.as_os_str().is_empty() {
warn!("LUN {} file does not exist: {}", lun, config.file.display());
}
Ok(())
}
/// Disconnect LUN (async version)
///
/// Preferred for async contexts.
pub async fn disconnect_lun_async(&self, gadget_path: &Path, lun: u8) -> Result<()> {
let gadget_path = gadget_path.to_path_buf();
let this = self.clone();
tokio::task::spawn_blocking(move || this.disconnect_lun(&gadget_path, lun))
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Disconnect LUN (clear file)
///
/// This method uses forced_eject when available, which is safer than
/// directly clearing the file path. Based on PiKVM's implementation.
/// See: https://docs.kernel.org/usb/mass-storage.html
pub fn disconnect_lun(&self, gadget_path: &Path, lun: u8) -> Result<()> {
let lun_path = self.lun_path(gadget_path, lun);
if lun_path.exists() {
// Prefer forced_eject if available (PiKVM approach)
// forced_eject forcibly detaches the backing file regardless of host state
let forced_eject_path = lun_path.join("forced_eject");
if forced_eject_path.exists() {
debug!("Using forced_eject to disconnect LUN {} at {:?}", lun, forced_eject_path);
match write_file(&forced_eject_path, "1") {
Ok(_) => debug!("forced_eject write succeeded"),
Err(e) => {
warn!("forced_eject write failed: {}, falling back to clearing file", e);
write_file(&lun_path.join("file"), "")?;
}
}
} else {
// Fallback to clearing file directly
write_file(&lun_path.join("file"), "")?;
}
info!("LUN {} disconnected", lun);
}
Ok(())
}
/// Get current LUN file path
pub fn get_lun_file(&self, gadget_path: &Path, lun: u8) -> Option<PathBuf> {
let lun_path = self.lun_path(gadget_path, lun);
let file_path = lun_path.join("file");
if let Ok(content) = fs::read_to_string(&file_path) {
let content = content.trim();
if !content.is_empty() {
return Some(PathBuf::from(content));
}
}
None
}
/// Check if LUN is connected
pub fn is_lun_connected(&self, gadget_path: &Path, lun: u8) -> bool {
self.get_lun_file(gadget_path, lun).is_some()
}
}
impl GadgetFunction for MsdFunction {
fn name(&self) -> &str {
&self.name
}
fn endpoints_required(&self) -> u8 {
2 // IN + OUT for bulk transfers
}
fn meta(&self) -> FunctionMeta {
FunctionMeta {
name: self.name().to_string(),
description: if self.instance == 0 {
"Mass Storage Drive".to_string()
} else {
format!("Extra Drive #{}", self.instance)
},
endpoints: self.endpoints_required(),
enabled: true,
}
}
fn create(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
create_dir(&func_path)?;
// Set stall to 0 (workaround for some hosts)
let stall_path = func_path.join("stall");
if stall_path.exists() {
let _ = write_file(&stall_path, "0");
}
// LUN 0 is created automatically, but ensure it exists
let lun0_path = func_path.join("lun.0");
if !lun0_path.exists() {
create_dir(&lun0_path)?;
}
// Set default LUN 0 parameters
let _ = write_file(&lun0_path.join("cdrom"), "0");
let _ = write_file(&lun0_path.join("ro"), "0");
let _ = write_file(&lun0_path.join("removable"), "1");
let _ = write_file(&lun0_path.join("nofua"), "1");
debug!("Created MSD function: {}", self.name());
Ok(())
}
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
let link_path = config_path.join(self.name());
if !link_path.exists() {
create_symlink(&func_path, &link_path)?;
debug!("Linked MSD function {} to config", self.name());
}
Ok(())
}
fn unlink(&self, config_path: &Path) -> Result<()> {
let link_path = config_path.join(self.name());
remove_file(&link_path)?;
debug!("Unlinked MSD function {}", self.name());
Ok(())
}
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
// Disconnect all LUNs first
for lun in 0..8 {
let _ = self.disconnect_lun(gadget_path, lun);
}
// Remove function directory
if let Err(e) = remove_dir(&func_path) {
warn!("Could not remove MSD function directory: {}", e);
}
debug!("Cleaned up MSD function {}", self.name());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lun_config_cdrom() {
let config = MsdLunConfig::cdrom(PathBuf::from("/tmp/test.iso"));
assert!(config.cdrom);
assert!(config.ro);
assert!(config.removable);
}
#[test]
fn test_lun_config_disk() {
let config = MsdLunConfig::disk(PathBuf::from("/tmp/test.img"), false);
assert!(!config.cdrom);
assert!(!config.ro);
assert!(config.removable);
}
#[test]
fn test_msd_function_name() {
let msd = MsdFunction::new(0);
assert_eq!(msd.name(), "mass_storage.usb0");
assert_eq!(msd.endpoints_required(), 2);
}
}

160
src/otg/report_desc.rs Normal file
View File

@@ -0,0 +1,160 @@
//! HID Report Descriptors
/// Keyboard HID Report Descriptor with LED output support
/// Report format (8 bytes input):
/// [0] Modifier keys (8 bits)
/// [1] Reserved
/// [2-7] Key codes (6 keys)
/// LED output (1 byte):
/// Bit 0: Num Lock
/// Bit 1: Caps Lock
/// Bit 2: Scroll Lock
/// Bit 3: Compose
/// Bit 4: Kana
pub const KEYBOARD_WITH_LED: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x06, // Usage (Keyboard)
0xA1, 0x01, // Collection (Application)
// Modifier keys input (8 bits)
0x05, 0x07, // Usage Page (Key Codes)
0x19, 0xE0, // Usage Minimum (224) - Left Control
0x29, 0xE7, // Usage Maximum (231) - Right GUI
0x15, 0x00, // Logical Minimum (0)
0x25, 0x01, // Logical Maximum (1)
0x75, 0x01, // Report Size (1)
0x95, 0x08, // Report Count (8)
0x81, 0x02, // Input (Data, Variable, Absolute) - Modifier byte
// Reserved byte
0x95, 0x01, // Report Count (1)
0x75, 0x08, // Report Size (8)
0x81, 0x01, // Input (Constant) - Reserved byte
// LED output (5 bits)
0x95, 0x05, // Report Count (5)
0x75, 0x01, // Report Size (1)
0x05, 0x08, // Usage Page (LEDs)
0x19, 0x01, // Usage Minimum (1) - Num Lock
0x29, 0x05, // Usage Maximum (5) - Kana
0x91, 0x02, // Output (Data, Variable, Absolute) - LED bits
// LED padding (3 bits)
0x95, 0x01, // Report Count (1)
0x75, 0x03, // Report Size (3)
0x91, 0x01, // Output (Constant) - Padding
// Key array (6 bytes)
0x95, 0x06, // Report Count (6)
0x75, 0x08, // Report Size (8)
0x15, 0x00, // Logical Minimum (0)
0x26, 0xFF, 0x00, // Logical Maximum (255)
0x05, 0x07, // Usage Page (Key Codes)
0x19, 0x00, // Usage Minimum (0)
0x2A, 0xFF, 0x00, // Usage Maximum (255)
0x81, 0x00, // Input (Data, Array) - Key array (6 keys)
0xC0, // End Collection
];
/// Relative Mouse HID Report Descriptor (4 bytes report)
/// Report format:
/// [0] Buttons (5 bits) + padding (3 bits)
/// [1] X movement (signed 8-bit)
/// [2] Y movement (signed 8-bit)
/// [3] Wheel (signed 8-bit)
pub const MOUSE_RELATIVE: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
0xA1, 0x01, // Collection (Application)
0x09, 0x01, // Usage (Pointer)
0xA1, 0x00, // Collection (Physical)
// Buttons (5 bits)
0x05, 0x09, // Usage Page (Button)
0x19, 0x01, // Usage Minimum (1)
0x29, 0x05, // Usage Maximum (5) - 5 buttons
0x15, 0x00, // Logical Minimum (0)
0x25, 0x01, // Logical Maximum (1)
0x95, 0x05, // Report Count (5)
0x75, 0x01, // Report Size (1)
0x81, 0x02, // Input (Data, Variable, Absolute) - Button bits
// Padding (3 bits)
0x95, 0x01, // Report Count (1)
0x75, 0x03, // Report Size (3)
0x81, 0x01, // Input (Constant) - Padding
// X, Y movement
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x30, // Usage (X)
0x09, 0x31, // Usage (Y)
0x15, 0x81, // Logical Minimum (-127)
0x25, 0x7F, // Logical Maximum (127)
0x75, 0x08, // Report Size (8)
0x95, 0x02, // Report Count (2)
0x81, 0x06, // Input (Data, Variable, Relative) - X, Y
// Wheel
0x09, 0x38, // Usage (Wheel)
0x15, 0x81, // Logical Minimum (-127)
0x25, 0x7F, // Logical Maximum (127)
0x75, 0x08, // Report Size (8)
0x95, 0x01, // Report Count (1)
0x81, 0x06, // Input (Data, Variable, Relative) - Wheel
0xC0, // End Collection
0xC0, // End Collection
];
/// Absolute Mouse HID Report Descriptor (6 bytes report)
/// Report format:
/// [0] Buttons (5 bits) + padding (3 bits)
/// [1-2] X position (16-bit, 0-32767)
/// [3-4] Y position (16-bit, 0-32767)
/// [5] Wheel (signed 8-bit)
pub const MOUSE_ABSOLUTE: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
0xA1, 0x01, // Collection (Application)
0x09, 0x01, // Usage (Pointer)
0xA1, 0x00, // Collection (Physical)
// Buttons (5 bits)
0x05, 0x09, // Usage Page (Button)
0x19, 0x01, // Usage Minimum (1)
0x29, 0x05, // Usage Maximum (5) - 5 buttons
0x15, 0x00, // Logical Minimum (0)
0x25, 0x01, // Logical Maximum (1)
0x95, 0x05, // Report Count (5)
0x75, 0x01, // Report Size (1)
0x81, 0x02, // Input (Data, Variable, Absolute) - Button bits
// Padding (3 bits)
0x95, 0x01, // Report Count (1)
0x75, 0x03, // Report Size (3)
0x81, 0x01, // Input (Constant) - Padding
// X position (16-bit absolute)
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x30, // Usage (X)
0x16, 0x00, 0x00, // Logical Minimum (0)
0x26, 0xFF, 0x7F, // Logical Maximum (32767)
0x75, 0x10, // Report Size (16)
0x95, 0x01, // Report Count (1)
0x81, 0x02, // Input (Data, Variable, Absolute) - X
// Y position (16-bit absolute)
0x09, 0x31, // Usage (Y)
0x16, 0x00, 0x00, // Logical Minimum (0)
0x26, 0xFF, 0x7F, // Logical Maximum (32767)
0x75, 0x10, // Report Size (16)
0x95, 0x01, // Report Count (1)
0x81, 0x02, // Input (Data, Variable, Absolute) - Y
// Wheel
0x09, 0x38, // Usage (Wheel)
0x15, 0x81, // Logical Minimum (-127)
0x25, 0x7F, // Logical Maximum (127)
0x75, 0x08, // Report Size (8)
0x95, 0x01, // Report Count (1)
0x81, 0x06, // Input (Data, Variable, Relative) - Wheel
0xC0, // End Collection
0xC0, // End Collection
];
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_report_descriptor_sizes() {
assert!(!KEYBOARD_WITH_LED.is_empty());
assert!(!MOUSE_RELATIVE.is_empty());
assert!(!MOUSE_ABSOLUTE.is_empty());
}
}

503
src/otg/service.rs Normal file
View File

@@ -0,0 +1,503 @@
//! OTG Service - unified gadget lifecycle management
//!
//! This module provides centralized management for USB OTG gadget functions.
//! It solves the ownership problem where both HID and MSD need access to the
//! same USB gadget but should be independently configurable.
//!
//! Architecture:
//! ```text
//! ┌─────────────────────────┐
//! │ OtgService │
//! │ ┌───────────────────┐ │
//! │ │ OtgGadgetManager │ │
//! │ └───────────────────┘ │
//! │ ↓ ↓ │
//! │ ┌─────┐ ┌─────┐ │
//! │ │ HID │ │ MSD │ │
//! │ └─────┘ └─────┘ │
//! └─────────────────────────┘
//! ↑ ↑
//! HidController MsdController
//! ```
use std::path::PathBuf;
use std::sync::atomic::{AtomicU8, Ordering};
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, info, warn};
use super::manager::{wait_for_hid_devices, OtgGadgetManager};
use super::msd::MsdFunction;
use crate::error::{AppError, Result};
/// Bitflags for requested functions (lock-free)
const FLAG_HID: u8 = 0b01;
const FLAG_MSD: u8 = 0b10;
/// HID device paths
#[derive(Debug, Clone)]
pub struct HidDevicePaths {
pub keyboard: PathBuf,
pub mouse_relative: PathBuf,
pub mouse_absolute: PathBuf,
}
impl Default for HidDevicePaths {
fn default() -> Self {
Self {
keyboard: PathBuf::from("/dev/hidg0"),
mouse_relative: PathBuf::from("/dev/hidg1"),
mouse_absolute: PathBuf::from("/dev/hidg2"),
}
}
}
/// OTG Service state
#[derive(Debug, Clone, Default)]
pub struct OtgServiceState {
/// Whether the gadget is created and bound
pub gadget_active: bool,
/// Whether HID functions are enabled
pub hid_enabled: bool,
/// Whether MSD function is enabled
pub msd_enabled: bool,
/// HID device paths (set after gadget setup)
pub hid_paths: Option<HidDevicePaths>,
/// Error message if setup failed
pub error: Option<String>,
}
/// OTG Service - unified gadget lifecycle management
///
/// This service owns the OtgGadgetManager and provides a high-level interface
/// for enabling/disabling HID and MSD functions. It ensures proper coordination
/// between the two subsystems and handles gadget lifecycle management.
pub struct OtgService {
/// The underlying gadget manager
manager: Mutex<Option<OtgGadgetManager>>,
/// Current state
state: RwLock<OtgServiceState>,
/// MSD function handle (for runtime LUN configuration)
msd_function: RwLock<Option<MsdFunction>>,
/// Requested functions flags (atomic, lock-free read/write)
requested_flags: AtomicU8,
}
impl OtgService {
/// Create a new OTG service
pub fn new() -> Self {
Self {
manager: Mutex::new(None),
state: RwLock::new(OtgServiceState::default()),
msd_function: RwLock::new(None),
requested_flags: AtomicU8::new(0),
}
}
/// Check if HID is requested (lock-free)
#[inline]
fn is_hid_requested(&self) -> bool {
self.requested_flags.load(Ordering::Acquire) & FLAG_HID != 0
}
/// Check if MSD is requested (lock-free)
#[inline]
fn is_msd_requested(&self) -> bool {
self.requested_flags.load(Ordering::Acquire) & FLAG_MSD != 0
}
/// Set HID requested flag (lock-free)
#[inline]
fn set_hid_requested(&self, requested: bool) {
if requested {
self.requested_flags.fetch_or(FLAG_HID, Ordering::Release);
} else {
self.requested_flags.fetch_and(!FLAG_HID, Ordering::Release);
}
}
/// Set MSD requested flag (lock-free)
#[inline]
fn set_msd_requested(&self, requested: bool) {
if requested {
self.requested_flags.fetch_or(FLAG_MSD, Ordering::Release);
} else {
self.requested_flags.fetch_and(!FLAG_MSD, Ordering::Release);
}
}
/// Check if OTG is available on this system
pub fn is_available() -> bool {
OtgGadgetManager::is_available() && OtgGadgetManager::find_udc().is_some()
}
/// Get current service state
pub async fn state(&self) -> OtgServiceState {
self.state.read().await.clone()
}
/// Check if gadget is active
pub async fn is_gadget_active(&self) -> bool {
self.state.read().await.gadget_active
}
/// Check if HID is enabled
pub async fn is_hid_enabled(&self) -> bool {
self.state.read().await.hid_enabled
}
/// Check if MSD is enabled
pub async fn is_msd_enabled(&self) -> bool {
self.state.read().await.msd_enabled
}
/// Get gadget path (for MSD LUN configuration)
pub async fn gadget_path(&self) -> Option<PathBuf> {
let manager = self.manager.lock().await;
manager.as_ref().map(|m| m.gadget_path().clone())
}
/// Get HID device paths
pub async fn hid_device_paths(&self) -> Option<HidDevicePaths> {
self.state.read().await.hid_paths.clone()
}
/// Get MSD function handle (for LUN configuration)
pub async fn msd_function(&self) -> Option<MsdFunction> {
self.msd_function.read().await.clone()
}
/// Enable HID functions
///
/// This will create the gadget if not already created, add HID functions,
/// and bind the gadget to UDC.
pub async fn enable_hid(&self) -> Result<HidDevicePaths> {
info!("Enabling HID functions via OtgService");
// Mark HID as requested (lock-free)
self.set_hid_requested(true);
// Check if already enabled
{
let state = self.state.read().await;
if state.hid_enabled {
if let Some(ref paths) = state.hid_paths {
info!("HID already enabled, returning existing paths");
return Ok(paths.clone());
}
}
}
// Recreate gadget with both HID and MSD if needed
self.recreate_gadget().await?;
// Get HID paths from state
let state = self.state.read().await;
state
.hid_paths
.clone()
.ok_or_else(|| AppError::Internal("HID paths not set after gadget setup".to_string()))
}
/// Disable HID functions
///
/// This will unbind the gadget, remove HID functions, and optionally
/// recreate the gadget with only MSD if MSD is still enabled.
pub async fn disable_hid(&self) -> Result<()> {
info!("Disabling HID functions via OtgService");
// Mark HID as not requested (lock-free)
self.set_hid_requested(false);
// Check if HID is enabled
{
let state = self.state.read().await;
if !state.hid_enabled {
info!("HID already disabled");
return Ok(());
}
}
// Recreate gadget without HID (or destroy if MSD also disabled)
self.recreate_gadget().await
}
/// Enable MSD function
///
/// This will create the gadget if not already created, add MSD function,
/// and bind the gadget to UDC.
pub async fn enable_msd(&self) -> Result<MsdFunction> {
info!("Enabling MSD function via OtgService");
// Mark MSD as requested (lock-free)
self.set_msd_requested(true);
// Check if already enabled
{
let state = self.state.read().await;
if state.msd_enabled {
let msd = self.msd_function.read().await;
if let Some(ref func) = *msd {
info!("MSD already enabled, returning existing function");
return Ok(func.clone());
}
}
}
// Recreate gadget with both HID and MSD if needed
self.recreate_gadget().await?;
// Get MSD function
let msd = self.msd_function.read().await;
msd.clone()
.ok_or_else(|| AppError::Internal("MSD function not set after gadget setup".to_string()))
}
/// Disable MSD function
///
/// This will unbind the gadget, remove MSD function, and optionally
/// recreate the gadget with only HID if HID is still enabled.
pub async fn disable_msd(&self) -> Result<()> {
info!("Disabling MSD function via OtgService");
// Mark MSD as not requested (lock-free)
self.set_msd_requested(false);
// Check if MSD is enabled
{
let state = self.state.read().await;
if !state.msd_enabled {
info!("MSD already disabled");
return Ok(());
}
}
// Recreate gadget without MSD (or destroy if HID also disabled)
self.recreate_gadget().await
}
/// Recreate the gadget with currently requested functions
///
/// This is called whenever the set of enabled functions changes.
/// It will:
/// 1. Check if recreation is needed (function set changed)
/// 2. If needed: cleanup existing gadget
/// 3. Create new gadget with requested functions
/// 4. Setup and bind
async fn recreate_gadget(&self) -> Result<()> {
// Read requested flags atomically (lock-free)
let hid_requested = self.is_hid_requested();
let msd_requested = self.is_msd_requested();
info!(
"Recreating gadget with: HID={}, MSD={}",
hid_requested, msd_requested
);
// Check if gadget already matches requested state
{
let state = self.state.read().await;
if state.gadget_active
&& state.hid_enabled == hid_requested
&& state.msd_enabled == msd_requested
{
info!("Gadget already has requested functions, skipping recreate");
return Ok(());
}
}
// Cleanup existing gadget
{
let mut manager = self.manager.lock().await;
if let Some(mut m) = manager.take() {
info!("Cleaning up existing gadget before recreate");
if let Err(e) = m.cleanup() {
warn!("Error cleaning up existing gadget: {}", e);
}
}
}
// Clear MSD function
*self.msd_function.write().await = None;
// Update state to inactive
{
let mut state = self.state.write().await;
state.gadget_active = false;
state.hid_enabled = false;
state.msd_enabled = false;
state.hid_paths = None;
state.error = None;
}
// If nothing requested, we're done
if !hid_requested && !msd_requested {
info!("No functions requested, gadget destroyed");
return Ok(());
}
// Check if OTG is available
if !Self::is_available() {
let error = "OTG not available: ConfigFS not mounted or no UDC found".to_string();
let mut state = self.state.write().await;
state.error = Some(error.clone());
return Err(AppError::Internal(error));
}
// Create new gadget manager
let mut manager = OtgGadgetManager::new();
let mut hid_paths = None;
// Add HID functions if requested
if hid_requested {
match (
manager.add_keyboard(),
manager.add_mouse_relative(),
manager.add_mouse_absolute(),
) {
(Ok(kb), Ok(rel), Ok(abs)) => {
hid_paths = Some(HidDevicePaths {
keyboard: kb,
mouse_relative: rel,
mouse_absolute: abs,
});
debug!("HID functions added to gadget");
}
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
let error = format!("Failed to add HID functions: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
}
// Add MSD function if requested
let msd_func = if msd_requested {
match manager.add_msd() {
Ok(func) => {
debug!("MSD function added to gadget");
Some(func)
}
Err(e) => {
let error = format!("Failed to add MSD function: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
} else {
None
};
// Setup gadget
if let Err(e) = manager.setup() {
let error = format!("Failed to setup gadget: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
return Err(AppError::Internal(error));
}
// Bind to UDC
if let Err(e) = manager.bind() {
let error = format!("Failed to bind gadget to UDC: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
// Cleanup on failure
let _ = manager.cleanup();
return Err(AppError::Internal(error));
}
// Wait for HID devices to appear
if let Some(ref paths) = hid_paths {
let device_paths = vec![
paths.keyboard.clone(),
paths.mouse_relative.clone(),
paths.mouse_absolute.clone(),
];
if !wait_for_hid_devices(&device_paths, 2000).await {
warn!("HID devices did not appear after gadget setup");
}
}
// Store manager and update state
{
*self.manager.lock().await = Some(manager);
}
{
*self.msd_function.write().await = msd_func;
}
{
let mut state = self.state.write().await;
state.gadget_active = true;
state.hid_enabled = hid_requested;
state.msd_enabled = msd_requested;
state.hid_paths = hid_paths;
state.error = None;
}
info!("Gadget created successfully");
Ok(())
}
/// Shutdown the OTG service and cleanup all resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down OTG service");
// Mark nothing as requested (lock-free)
self.requested_flags.store(0, Ordering::Release);
// Cleanup gadget
let mut manager = self.manager.lock().await;
if let Some(mut m) = manager.take() {
if let Err(e) = m.cleanup() {
warn!("Error cleaning up gadget during shutdown: {}", e);
}
}
// Clear state
*self.msd_function.write().await = None;
{
let mut state = self.state.write().await;
*state = OtgServiceState::default();
}
info!("OTG service shutdown complete");
Ok(())
}
}
impl Default for OtgService {
fn default() -> Self {
Self::new()
}
}
impl Drop for OtgService {
fn drop(&mut self) {
// Gadget cleanup is handled by OtgGadgetManager's Drop
debug!("OtgService dropping");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_creation() {
let service = OtgService::new();
// Just test that creation doesn't panic
assert!(!OtgService::is_available() || true); // Depends on environment
}
#[tokio::test]
async fn test_initial_state() {
let service = OtgService::new();
let state = service.state().await;
assert!(!state.gadget_active);
assert!(!state.hid_enabled);
assert!(!state.msd_enabled);
}
}

227
src/state.rs Normal file
View File

@@ -0,0 +1,227 @@
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use crate::atx::AtxController;
use crate::audio::AudioController;
use crate::auth::{SessionStore, UserStore};
use crate::config::ConfigStore;
use crate::events::{AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent, VideoDeviceInfo};
use crate::extensions::ExtensionManager;
use crate::hid::HidController;
use crate::msd::MsdController;
use crate::otg::OtgService;
use crate::video::VideoStreamManager;
/// Application-wide state shared across handlers
///
/// # Video Streaming
///
/// All video operations should go through `stream_manager`:
/// - `stream_manager.webrtc_streamer()` - WebRTC streaming (H264, extensible to VP8/VP9/H265)
/// - `stream_manager.mjpeg_handler()` - MJPEG stream handler
/// - `stream_manager.streamer()` - Low-level video capture
/// - `stream_manager.start()` / `stream_manager.stop()` - Stream control
/// - `stream_manager.stats()` - Stream statistics
/// - `stream_manager.list_devices()` - List video devices
pub struct AppState {
/// Configuration store
pub config: ConfigStore,
/// Session store
pub sessions: SessionStore,
/// User store
pub users: UserStore,
/// OTG Service - centralized USB gadget lifecycle management
/// This is the single owner of OtgGadgetManager, coordinating HID and MSD functions
pub otg_service: Arc<OtgService>,
/// Video stream manager (unified MJPEG/WebRTC management)
/// This is the single entry point for all video operations.
pub stream_manager: Arc<VideoStreamManager>,
/// HID controller
pub hid: Arc<HidController>,
/// MSD controller (optional, may not be initialized)
pub msd: Arc<RwLock<Option<MsdController>>>,
/// ATX controller (optional, may not be initialized)
pub atx: Arc<RwLock<Option<AtxController>>>,
/// Audio controller
pub audio: Arc<AudioController>,
/// Extension manager (ttyd, gostc, easytier)
pub extensions: Arc<ExtensionManager>,
/// Event bus for real-time notifications
pub events: Arc<EventBus>,
/// Shutdown signal sender
pub shutdown_tx: broadcast::Sender<()>,
/// Data directory path
data_dir: std::path::PathBuf,
}
impl AppState {
/// Create new application state
pub fn new(
config: ConfigStore,
sessions: SessionStore,
users: UserStore,
otg_service: Arc<OtgService>,
stream_manager: Arc<VideoStreamManager>,
hid: Arc<HidController>,
msd: Option<MsdController>,
atx: Option<AtxController>,
audio: Arc<AudioController>,
extensions: Arc<ExtensionManager>,
events: Arc<EventBus>,
shutdown_tx: broadcast::Sender<()>,
data_dir: std::path::PathBuf,
) -> Arc<Self> {
Arc::new(Self {
config,
sessions,
users,
otg_service,
stream_manager,
hid,
msd: Arc::new(RwLock::new(msd)),
atx: Arc::new(RwLock::new(atx)),
audio,
extensions,
events,
shutdown_tx,
data_dir,
})
}
/// Get data directory path
pub fn data_dir(&self) -> &std::path::PathBuf {
&self.data_dir
}
/// Subscribe to shutdown signal
pub fn shutdown_signal(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
/// Get complete device information for WebSocket clients
///
/// This method collects the current state of all devices (video, HID, MSD, ATX, Audio)
/// and returns a DeviceInfo event that can be sent to clients.
/// Uses tokio::join! to collect all device info in parallel for better performance.
pub async fn get_device_info(&self) -> SystemEvent {
// Collect all device info in parallel
let (video, hid, msd, atx, audio) = tokio::join!(
self.collect_video_info(),
self.collect_hid_info(),
self.collect_msd_info(),
self.collect_atx_info(),
self.collect_audio_info(),
);
SystemEvent::DeviceInfo {
video,
hid,
msd,
atx,
audio,
}
}
/// Publish DeviceInfo event to all connected WebSocket clients
pub async fn publish_device_info(&self) {
let device_info = self.get_device_info().await;
self.events.publish(device_info);
}
/// Collect video device information
async fn collect_video_info(&self) -> VideoDeviceInfo {
// Use stream_manager to get video info (includes stream_mode)
self.stream_manager.get_video_info().await
}
/// Collect HID device information
async fn collect_hid_info(&self) -> HidDeviceInfo {
let info = self.hid.info().await;
let backend_type = self.hid.backend_type().await;
match info {
Some(hid_info) => HidDeviceInfo {
available: true,
backend: hid_info.name.to_string(),
initialized: hid_info.initialized,
supports_absolute_mouse: hid_info.supports_absolute_mouse,
device: match backend_type {
crate::hid::HidBackendType::Ch9329 { ref port, .. } => Some(port.clone()),
_ => None,
},
error: None,
},
None => HidDeviceInfo {
available: false,
backend: backend_type.name_str().to_string(),
initialized: false,
supports_absolute_mouse: false,
device: match backend_type {
crate::hid::HidBackendType::Ch9329 { ref port, .. } => Some(port.clone()),
_ => None,
},
error: Some("HID backend not available".to_string()),
},
}
}
/// Collect MSD device information (optional)
async fn collect_msd_info(&self) -> Option<MsdDeviceInfo> {
let msd_guard = self.msd.read().await;
let msd = msd_guard.as_ref()?;
let state = msd.state().await;
Some(MsdDeviceInfo {
available: state.available,
mode: match state.mode {
crate::msd::MsdMode::None => "none",
crate::msd::MsdMode::Image => "image",
crate::msd::MsdMode::Drive => "drive",
}
.to_string(),
connected: state.connected,
image_id: state.current_image.map(|img| img.id),
error: None,
})
}
/// Collect ATX device information (optional)
async fn collect_atx_info(&self) -> Option<AtxDeviceInfo> {
// Predefined backend strings to avoid repeated allocations
const BACKEND_POWER_ONLY: &str = "power: configured, reset: none";
const BACKEND_RESET_ONLY: &str = "power: none, reset: configured";
const BACKEND_BOTH: &str = "power: configured, reset: configured";
const BACKEND_NONE: &str = "none";
let atx_guard = self.atx.read().await;
let atx = atx_guard.as_ref()?;
let state = atx.state().await;
Some(AtxDeviceInfo {
available: state.available,
backend: match (state.power_configured, state.reset_configured) {
(true, true) => BACKEND_BOTH,
(true, false) => BACKEND_POWER_ONLY,
(false, true) => BACKEND_RESET_ONLY,
(false, false) => BACKEND_NONE,
}
.to_string(),
initialized: state.power_configured || state.reset_configured,
power_on: state.power_status == crate::atx::PowerStatus::On,
error: None,
})
}
/// Collect Audio device information (optional)
async fn collect_audio_info(&self) -> Option<AudioDeviceInfo> {
let status = self.audio.status().await;
Some(AudioDeviceInfo {
available: status.enabled,
streaming: status.streaming,
device: status.device,
quality: status.quality.to_string(),
error: status.error,
})
}
}

564
src/stream/mjpeg.rs Normal file
View File

@@ -0,0 +1,564 @@
//! MJPEG stream handler
//!
//! Manages video frame distribution and per-client statistics.
use arc_swap::ArcSwap;
use parking_lot::Mutex as ParkingMutex;
use parking_lot::RwLock as ParkingRwLock;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
use tracing::{debug, info, warn};
use crate::video::encoder::JpegEncoder;
use crate::video::encoder::traits::{Encoder, EncoderConfig};
use crate::video::format::PixelFormat;
use crate::video::VideoFrame;
/// Client ID type (UUID string)
pub type ClientId = String;
/// Per-client session information
#[derive(Debug, Clone)]
pub struct ClientSession {
/// Unique client ID
pub id: ClientId,
/// Connection timestamp
pub connected_at: Instant,
/// Last activity timestamp (frame sent)
pub last_activity: Instant,
/// Frames sent to this client
pub frames_sent: u64,
/// FPS calculator (1-second rolling window)
pub fps_calculator: FpsCalculator,
}
impl ClientSession {
/// Create a new client session
pub fn new(id: ClientId) -> Self {
let now = Instant::now();
Self {
id,
connected_at: now,
last_activity: now,
frames_sent: 0,
fps_calculator: FpsCalculator::new(),
}
}
/// Get connection duration
pub fn connected_duration(&self) -> Duration {
self.last_activity.duration_since(self.connected_at)
}
/// Get idle duration
pub fn idle_duration(&self) -> Duration {
Instant::now().duration_since(self.last_activity)
}
}
/// Rolling window FPS calculator
#[derive(Debug, Clone)]
pub struct FpsCalculator {
/// Frame timestamps in last window
frame_times: VecDeque<Instant>,
/// Window duration (default 1 second)
window: Duration,
/// Cached count of frames in current window (optimization to avoid O(n) filtering)
count_in_window: usize,
}
impl FpsCalculator {
/// Create a new FPS calculator with 1-second window
pub fn new() -> Self {
Self {
frame_times: VecDeque::with_capacity(120), // Max 120fps tracking
window: Duration::from_secs(1),
count_in_window: 0,
}
}
/// Record a frame sent
pub fn record_frame(&mut self) {
let now = Instant::now();
self.frame_times.push_back(now);
// Remove frames outside window and maintain count
let cutoff = now - self.window;
while let Some(&oldest) = self.frame_times.front() {
if oldest < cutoff {
self.frame_times.pop_front();
} else {
break;
}
}
// Update cached count
self.count_in_window = self.frame_times.len();
}
/// Calculate current FPS (frames in last 1 second window)
pub fn current_fps(&self) -> u32 {
// Return cached count instead of filtering entire deque (O(1) instead of O(n))
self.count_in_window as u32
}
}
impl Default for FpsCalculator {
fn default() -> Self {
Self::new()
}
}
/// Auto-pause configuration
#[derive(Debug, Clone)]
pub struct AutoPauseConfig {
/// Enable auto-pause when no clients
pub enabled: bool,
/// Delay before pausing (default 10s)
pub shutdown_delay_secs: u64,
/// Client timeout for cleanup (default 30s)
pub client_timeout_secs: u64,
}
impl Default for AutoPauseConfig {
fn default() -> Self {
Self {
enabled: false,
shutdown_delay_secs: 10,
client_timeout_secs: 30,
}
}
}
/// MJPEG stream handler
/// Manages video frame distribution to HTTP clients
pub struct MjpegStreamHandler {
/// Current frame (latest) - using ArcSwap for lock-free reads
current_frame: ArcSwap<Option<VideoFrame>>,
/// Frame update notification
frame_notify: broadcast::Sender<()>,
/// Whether stream is online
online: AtomicBool,
/// Frame sequence counter
sequence: AtomicU64,
/// Per-client sessions (ClientId -> ClientSession)
/// Use parking_lot::RwLock for better performance
clients: ParkingRwLock<HashMap<ClientId, ClientSession>>,
/// Auto-pause configuration
auto_pause_config: ParkingRwLock<AutoPauseConfig>,
/// Last frame timestamp
last_frame_ts: ParkingRwLock<Option<Instant>>,
/// Dropped same frames count
dropped_same_frames: AtomicU64,
/// Maximum consecutive same frames to drop (0 = disabled)
max_drop_same_frames: AtomicU64,
/// JPEG encoder for non-JPEG input formats
jpeg_encoder: ParkingMutex<Option<JpegEncoder>>,
}
impl MjpegStreamHandler {
/// Create a new MJPEG stream handler
pub fn new() -> Self {
Self::with_drop_limit(100) // Default: drop up to 100 same frames
}
/// Create handler with custom drop limit
pub fn with_drop_limit(max_drop: u64) -> Self {
let (frame_notify, _) = broadcast::channel(4); // Reduced from 16 for lower latency
Self {
current_frame: ArcSwap::from_pointee(None),
frame_notify,
online: AtomicBool::new(false),
sequence: AtomicU64::new(0),
clients: ParkingRwLock::new(HashMap::new()),
jpeg_encoder: ParkingMutex::new(None),
auto_pause_config: ParkingRwLock::new(AutoPauseConfig::default()),
last_frame_ts: ParkingRwLock::new(None),
dropped_same_frames: AtomicU64::new(0),
max_drop_same_frames: AtomicU64::new(max_drop),
}
}
/// Update current frame
pub fn update_frame(&self, frame: VideoFrame) {
// If frame is not JPEG, encode it
let frame = if !frame.format.is_compressed() {
match self.encode_to_jpeg(&frame) {
Ok(jpeg_frame) => jpeg_frame,
Err(e) => {
warn!("Failed to encode frame to JPEG: {}", e);
return;
}
}
} else {
frame
};
// Frame deduplication (ustreamer-style)
// Check if this frame is identical to the previous one
let max_drop = self.max_drop_same_frames.load(Ordering::Relaxed);
if max_drop > 0 && frame.online {
let current = self.current_frame.load();
if let Some(ref prev_frame) = **current {
let dropped_count = self.dropped_same_frames.load(Ordering::Relaxed);
// Check if we should drop this frame
if dropped_count < max_drop && frames_are_identical(prev_frame, &frame) {
// Check last frame timestamp to ensure minimum 1fps
let last_ts = *self.last_frame_ts.read();
let should_force_send = if let Some(ts) = last_ts {
ts.elapsed() >= Duration::from_secs(1)
} else {
false
};
if !should_force_send {
// Drop this duplicate frame
self.dropped_same_frames.fetch_add(1, Ordering::Relaxed);
return;
}
// If more than 1 second since last frame, force send even if identical
}
}
}
// Frame is different or limit reached or forced by 1fps guarantee, update
self.dropped_same_frames.store(0, Ordering::Relaxed);
self.sequence.fetch_add(1, Ordering::Relaxed);
self.online.store(true, Ordering::SeqCst);
*self.last_frame_ts.write() = Some(Instant::now());
self.current_frame.store(Arc::new(Some(frame)));
// Notify waiting clients
let _ = self.frame_notify.send(());
}
/// Encode a non-JPEG frame to JPEG
fn encode_to_jpeg(&self, frame: &VideoFrame) -> Result<VideoFrame, String> {
let resolution = frame.resolution;
let sequence = self.sequence.load(Ordering::Relaxed);
// Get or create encoder
let mut encoder_guard = self.jpeg_encoder.lock();
let encoder = encoder_guard.get_or_insert_with(|| {
let config = EncoderConfig::jpeg(resolution, 85);
match JpegEncoder::new(config) {
Ok(enc) => {
debug!("Created JPEG encoder for MJPEG stream: {}x{}", resolution.width, resolution.height);
enc
}
Err(e) => {
warn!("Failed to create JPEG encoder: {}, using default", e);
// Try with default config
JpegEncoder::new(EncoderConfig::jpeg(resolution, 85))
.expect("Failed to create default JPEG encoder")
}
}
});
// Check if resolution changed
if encoder.config().resolution != resolution {
debug!("Resolution changed, recreating JPEG encoder: {}x{}", resolution.width, resolution.height);
let config = EncoderConfig::jpeg(resolution, 85);
*encoder = JpegEncoder::new(config).map_err(|e| format!("Failed to create encoder: {}", e))?;
}
// Encode based on input format
let encoded = match frame.format {
PixelFormat::Yuyv => {
encoder.encode_yuyv(frame.data(), sequence)
.map_err(|e| format!("YUYV encode failed: {}", e))?
}
PixelFormat::Nv12 => {
encoder.encode_nv12(frame.data(), sequence)
.map_err(|e| format!("NV12 encode failed: {}", e))?
}
PixelFormat::Rgb24 => {
encoder.encode_rgb(frame.data(), sequence)
.map_err(|e| format!("RGB encode failed: {}", e))?
}
PixelFormat::Bgr24 => {
encoder.encode_bgr(frame.data(), sequence)
.map_err(|e| format!("BGR encode failed: {}", e))?
}
_ => {
return Err(format!("Unsupported format for JPEG encoding: {}", frame.format));
}
};
// Create new VideoFrame with JPEG data
Ok(VideoFrame::from_vec(
encoded.data.to_vec(),
resolution,
PixelFormat::Mjpeg,
0, // stride not relevant for JPEG
sequence,
))
}
/// Set stream offline
pub fn set_offline(&self) {
self.online.store(false, Ordering::SeqCst);
let _ = self.frame_notify.send(());
}
/// Set stream online (called when streaming starts)
pub fn set_online(&self) {
self.online.store(true, Ordering::SeqCst);
}
/// Check if stream is online
pub fn is_online(&self) -> bool {
self.online.load(Ordering::SeqCst)
}
/// Get current client count
pub fn client_count(&self) -> u64 {
self.clients.read().len() as u64
}
/// Register a new client
pub fn register_client(&self, client_id: ClientId) {
let session = ClientSession::new(client_id.clone());
self.clients.write().insert(client_id.clone(), session);
info!("Client {} connected (total: {})", client_id, self.client_count());
}
/// Unregister a client
pub fn unregister_client(&self, client_id: &str) {
if let Some(session) = self.clients.write().remove(client_id) {
let duration = session.connected_duration();
let duration_secs = duration.as_secs_f32();
let avg_fps = if duration_secs > 0.1 {
session.frames_sent as f32 / duration_secs
} else {
0.0
};
info!(
"Client {} disconnected after {:.1}s ({} frames, {:.1} avg FPS)",
client_id, duration_secs, session.frames_sent, avg_fps
);
}
}
/// Record frame sent to a specific client
pub fn record_frame_sent(&self, client_id: &str) {
if let Some(session) = self.clients.write().get_mut(client_id) {
session.last_activity = Instant::now();
session.frames_sent += 1;
session.fps_calculator.record_frame();
}
}
/// Get per-client statistics
pub fn get_clients_stat(&self) -> HashMap<String, crate::events::types::ClientStats> {
self.clients
.read()
.iter()
.map(|(id, session)| {
(
id.clone(),
crate::events::types::ClientStats {
id: id.clone(),
fps: session.fps_calculator.current_fps(),
connected_secs: session.connected_duration().as_secs(),
},
)
})
.collect()
}
/// Get auto-pause configuration
pub fn auto_pause_config(&self) -> AutoPauseConfig {
self.auto_pause_config.read().clone()
}
/// Update auto-pause configuration
pub fn set_auto_pause_config(&self, config: AutoPauseConfig) {
let config_clone = config.clone();
*self.auto_pause_config.write() = config;
info!(
"Auto-pause config updated: enabled={}, delay={}s, timeout={}s",
config_clone.enabled, config_clone.shutdown_delay_secs, config_clone.client_timeout_secs
);
}
/// Get current frame (if any)
pub fn current_frame(&self) -> Option<VideoFrame> {
(**self.current_frame.load()).clone()
}
/// Subscribe to frame updates
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.frame_notify.subscribe()
}
/// Disconnect all clients (used during config changes)
/// This clears the client list and sets the stream offline,
/// which will cause all active MJPEG streams to terminate.
pub fn disconnect_all_clients(&self) {
let count = {
let mut clients = self.clients.write();
let count = clients.len();
clients.clear();
count
};
if count > 0 {
info!("Disconnected all {} MJPEG clients for config change", count);
}
// Set offline to signal all streaming tasks to stop
self.set_offline();
}
}
impl Default for MjpegStreamHandler {
fn default() -> Self {
Self::new()
}
}
/// RAII guard for client lifecycle management
/// Ensures cleanup even on panic or abrupt disconnection
pub struct ClientGuard {
client_id: ClientId,
handler: Arc<MjpegStreamHandler>,
}
impl ClientGuard {
/// Create a new client guard
pub fn new(client_id: ClientId, handler: Arc<MjpegStreamHandler>) -> Self {
handler.register_client(client_id.clone());
Self {
client_id,
handler,
}
}
/// Get client ID
pub fn id(&self) -> &ClientId {
&self.client_id
}
}
impl Drop for ClientGuard {
fn drop(&mut self) {
self.handler.unregister_client(&self.client_id);
}
}
impl MjpegStreamHandler {
/// Start stale client cleanup task
/// Should be called once when handler is created
pub fn start_cleanup_task(self: Arc<Self>) {
let handler = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
let timeout_secs = handler.auto_pause_config().client_timeout_secs;
let timeout = Duration::from_secs(timeout_secs);
let now = Instant::now();
let mut stale = Vec::new();
// Find stale clients
{
let clients = handler.clients.read();
for (id, session) in clients.iter() {
if now.duration_since(session.last_activity) > timeout {
stale.push(id.clone());
}
}
}
// Remove stale clients
if !stale.is_empty() {
let mut clients = handler.clients.write();
for id in stale {
if let Some(session) = clients.remove(&id) {
warn!(
"Removed stale client {} (inactive for {:.1}s)",
id,
now.duration_since(session.last_activity).as_secs_f32()
);
}
}
}
}
});
}
}
/// Compare two frames for equality (hash-based, ustreamer-style)
/// Returns true if frames are identical in geometry and content
fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
// Quick checks first (geometry)
if a.len() != b.len() {
return false;
}
if a.resolution.width != b.resolution.width || a.resolution.height != b.resolution.height {
return false;
}
if a.format != b.format {
return false;
}
if a.stride != b.stride {
return false;
}
if a.online != b.online {
return false;
}
// Compare hashes instead of full binary data
// Hash is computed once and cached in OnceLock for efficiency
// This is much faster than binary comparison for large frames (1080p MJPEG)
a.get_hash() == b.get_hash()
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use crate::video::{format::Resolution, PixelFormat};
#[tokio::test]
async fn test_stream_handler() {
let handler = MjpegStreamHandler::new();
assert!(!handler.is_online());
assert_eq!(handler.client_count(), 0);
// Create a frame
let _frame = VideoFrame::new(
Bytes::from(vec![0xFF, 0xD8, 0x00, 0x00, 0xFF, 0xD9]),
Resolution::VGA,
PixelFormat::Mjpeg,
0,
1,
);
}
#[test]
fn test_fps_calculator() {
let mut calc = FpsCalculator::new();
// Initially empty
assert_eq!(calc.current_fps(), 0);
// Record some frames
calc.record_frame();
calc.record_frame();
calc.record_frame();
// Should have 3 frames in window
assert!(calc.frame_times.len() == 3);
}
}

View File

@@ -0,0 +1,487 @@
//! MJPEG Streamer - High-level MJPEG/HTTP streaming manager
//!
//! This module provides a unified interface for MJPEG streaming mode,
//! integrating video capture, MJPEG distribution, and WebSocket HID.
//!
//! # Architecture
//!
//! ```text
//! MjpegStreamer
//! |
//! +-- VideoCapturer (V4L2 video capture)
//! +-- MjpegStreamHandler (HTTP multipart video)
//! +-- WsHidHandler (WebSocket HID)
//! ```
//!
//! Note: Audio WebSocket is handled separately by audio_ws.rs (/api/ws/audio)
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::info;
use crate::audio::AudioController;
use crate::error::{AppError, Result};
use crate::events::{EventBus, SystemEvent};
use crate::hid::HidController;
use crate::video::capture::{CaptureConfig, VideoCapturer};
use crate::video::device::{enumerate_devices, find_best_device, VideoDeviceInfo};
use crate::video::format::{PixelFormat, Resolution};
use crate::video::frame::VideoFrame;
use super::mjpeg::MjpegStreamHandler;
use super::ws_hid::WsHidHandler;
/// MJPEG streamer configuration
#[derive(Debug, Clone)]
pub struct MjpegStreamerConfig {
/// Device path (None = auto-detect)
pub device_path: Option<PathBuf>,
/// Desired resolution
pub resolution: Resolution,
/// Desired format
pub format: PixelFormat,
/// Desired FPS
pub fps: u32,
/// JPEG quality (1-100)
pub jpeg_quality: u8,
}
impl Default for MjpegStreamerConfig {
fn default() -> Self {
Self {
device_path: None,
resolution: Resolution::HD1080,
format: PixelFormat::Mjpeg,
fps: 30,
jpeg_quality: 80,
}
}
}
/// MJPEG streamer state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MjpegStreamerState {
/// Not initialized
Uninitialized,
/// Ready but not streaming
Ready,
/// Actively streaming
Streaming,
/// No video signal
NoSignal,
/// Error occurred
Error,
}
impl std::fmt::Display for MjpegStreamerState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MjpegStreamerState::Uninitialized => write!(f, "uninitialized"),
MjpegStreamerState::Ready => write!(f, "ready"),
MjpegStreamerState::Streaming => write!(f, "streaming"),
MjpegStreamerState::NoSignal => write!(f, "no_signal"),
MjpegStreamerState::Error => write!(f, "error"),
}
}
}
/// MJPEG streamer statistics
#[derive(Debug, Clone, Default)]
pub struct MjpegStreamerStats {
/// Current state
pub state: String,
/// Current device path
pub device: Option<String>,
/// Video resolution
pub resolution: Option<(u32, u32)>,
/// Video format
pub format: Option<String>,
/// Current FPS
pub fps: u32,
/// MJPEG client count
pub mjpeg_clients: u64,
/// WebSocket HID client count
pub ws_hid_clients: usize,
/// Total frames captured
pub frames_captured: u64,
}
/// MJPEG Streamer
///
/// High-level manager for MJPEG/HTTP streaming mode.
/// Integrates video capture, MJPEG distribution, and WebSocket HID.
pub struct MjpegStreamer {
// === Video ===
config: RwLock<MjpegStreamerConfig>,
capturer: RwLock<Option<Arc<VideoCapturer>>>,
mjpeg_handler: Arc<MjpegStreamHandler>,
current_device: RwLock<Option<VideoDeviceInfo>>,
state: RwLock<MjpegStreamerState>,
// === Audio (controller reference only, WS handled by audio_ws.rs) ===
audio_controller: RwLock<Option<Arc<AudioController>>>,
audio_enabled: AtomicBool,
// === HID ===
ws_hid_handler: Arc<WsHidHandler>,
hid_controller: RwLock<Option<Arc<HidController>>>,
// === Control ===
start_lock: tokio::sync::Mutex<()>,
events: RwLock<Option<Arc<EventBus>>>,
config_changing: AtomicBool,
}
impl MjpegStreamer {
/// Create a new MJPEG streamer
pub fn new() -> Arc<Self> {
Arc::new(Self {
config: RwLock::new(MjpegStreamerConfig::default()),
capturer: RwLock::new(None),
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
current_device: RwLock::new(None),
state: RwLock::new(MjpegStreamerState::Uninitialized),
audio_controller: RwLock::new(None),
audio_enabled: AtomicBool::new(false),
ws_hid_handler: WsHidHandler::new(),
hid_controller: RwLock::new(None),
start_lock: tokio::sync::Mutex::new(()),
events: RwLock::new(None),
config_changing: AtomicBool::new(false),
})
}
/// Create with specific config
pub fn with_config(config: MjpegStreamerConfig) -> Arc<Self> {
Arc::new(Self {
config: RwLock::new(config),
capturer: RwLock::new(None),
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
current_device: RwLock::new(None),
state: RwLock::new(MjpegStreamerState::Uninitialized),
audio_controller: RwLock::new(None),
audio_enabled: AtomicBool::new(false),
ws_hid_handler: WsHidHandler::new(),
hid_controller: RwLock::new(None),
start_lock: tokio::sync::Mutex::new(()),
events: RwLock::new(None),
config_changing: AtomicBool::new(false),
})
}
// ========================================================================
// Configuration and Setup
// ========================================================================
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Set audio controller (for reference, WebSocket handled by audio_ws.rs)
pub async fn set_audio_controller(&self, audio: Arc<AudioController>) {
*self.audio_controller.write().await = Some(audio);
info!("MjpegStreamer: Audio controller set");
}
/// Set HID controller
pub async fn set_hid_controller(&self, hid: Arc<HidController>) {
*self.hid_controller.write().await = Some(hid.clone());
self.ws_hid_handler.set_hid_controller(hid);
info!("MjpegStreamer: HID controller set");
}
/// Enable or disable audio
pub fn set_audio_enabled(&self, enabled: bool) {
self.audio_enabled.store(enabled, Ordering::SeqCst);
}
/// Check if audio is enabled
pub fn is_audio_enabled(&self) -> bool {
self.audio_enabled.load(Ordering::SeqCst)
}
// ========================================================================
// State and Status
// ========================================================================
/// Get current state
pub async fn state(&self) -> MjpegStreamerState {
*self.state.read().await
}
/// Check if config is currently being changed
pub fn is_config_changing(&self) -> bool {
self.config_changing.load(Ordering::SeqCst)
}
/// Get current device info
pub async fn current_device(&self) -> Option<VideoDeviceInfo> {
self.current_device.read().await.clone()
}
/// Get statistics
pub async fn stats(&self) -> MjpegStreamerStats {
let state = *self.state.read().await;
let device = self.current_device.read().await;
let config = self.config.read().await;
let (resolution, format, frames_captured) = if let Some(ref cap) = *self.capturer.read().await {
let stats = cap.stats().await;
(
Some((config.resolution.width, config.resolution.height)),
Some(config.format.to_string()),
stats.frames_captured,
)
} else {
(None, None, 0)
};
MjpegStreamerStats {
state: state.to_string(),
device: device.as_ref().map(|d| d.path.display().to_string()),
resolution,
format,
fps: config.fps,
mjpeg_clients: self.mjpeg_handler.client_count(),
ws_hid_clients: self.ws_hid_handler.client_count(),
frames_captured,
}
}
// ========================================================================
// Handler Access
// ========================================================================
/// Get MJPEG handler for HTTP streaming
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
self.mjpeg_handler.clone()
}
/// Get WebSocket HID handler
pub fn ws_hid_handler(&self) -> Arc<WsHidHandler> {
self.ws_hid_handler.clone()
}
/// Get frame sender for WebRTC integration
pub async fn frame_sender(&self) -> Option<broadcast::Sender<VideoFrame>> {
if let Some(ref cap) = *self.capturer.read().await {
Some(cap.frame_sender())
} else {
None
}
}
// ========================================================================
// Initialization
// ========================================================================
/// Initialize with auto-detected device
pub async fn init_auto(self: &Arc<Self>) -> Result<()> {
let best = find_best_device()?;
self.init_with_device(best).await
}
/// Initialize with specific device
pub async fn init_with_device(self: &Arc<Self>, device: VideoDeviceInfo) -> Result<()> {
info!("MjpegStreamer: Initializing with device: {}", device.path.display());
let config = self.config.read().await.clone();
// Create capture config
let capture_config = CaptureConfig {
device_path: device.path.clone(),
resolution: config.resolution,
format: config.format,
fps: config.fps,
buffer_count: 4,
timeout: std::time::Duration::from_secs(5),
jpeg_quality: config.jpeg_quality,
};
// Create capturer
let capturer = Arc::new(VideoCapturer::new(capture_config));
// Store device and capturer
*self.current_device.write().await = Some(device);
*self.capturer.write().await = Some(capturer);
*self.state.write().await = MjpegStreamerState::Ready;
self.publish_state_change().await;
Ok(())
}
// ========================================================================
// Streaming Control
// ========================================================================
/// Start streaming
pub async fn start(self: &Arc<Self>) -> Result<()> {
let _lock = self.start_lock.lock().await;
if self.config_changing.load(Ordering::SeqCst) {
return Err(AppError::VideoError("Config change in progress".to_string()));
}
let state = *self.state.read().await;
if state == MjpegStreamerState::Streaming {
return Ok(());
}
// Get capturer
let capturer = self.capturer.read().await.clone();
let capturer = capturer.ok_or_else(|| AppError::VideoError("Not initialized".to_string()))?;
// Start capture
capturer.start().await?;
// Start frame forwarding task
let handler = self.mjpeg_handler.clone();
let mut frame_rx = capturer.frame_sender().subscribe();
tokio::spawn(async move {
while let Ok(frame) = frame_rx.recv().await {
handler.update_frame(frame);
}
});
// Note: Audio WebSocket is handled separately by audio_ws.rs (/api/ws/audio)
*self.state.write().await = MjpegStreamerState::Streaming;
self.mjpeg_handler.set_online();
self.publish_state_change().await;
info!("MjpegStreamer: Streaming started");
Ok(())
}
/// Stop streaming
pub async fn stop(&self) -> Result<()> {
let state = *self.state.read().await;
if state != MjpegStreamerState::Streaming {
return Ok(());
}
// Stop capturer
if let Some(ref cap) = *self.capturer.read().await {
let _ = cap.stop().await;
}
// Set offline
self.mjpeg_handler.set_offline();
*self.state.write().await = MjpegStreamerState::Ready;
self.publish_state_change().await;
info!("MjpegStreamer: Streaming stopped");
Ok(())
}
/// Check if streaming
pub async fn is_streaming(&self) -> bool {
*self.state.read().await == MjpegStreamerState::Streaming
}
// ========================================================================
// Configuration Updates
// ========================================================================
/// Apply video configuration
///
/// This stops the current stream, reconfigures the capturer, and restarts.
pub async fn apply_config(self: &Arc<Self>, config: MjpegStreamerConfig) -> Result<()> {
info!("MjpegStreamer: Applying config: {:?}", config);
self.config_changing.store(true, Ordering::SeqCst);
// Stop current stream
self.stop().await?;
// Disconnect all MJPEG clients
self.mjpeg_handler.disconnect_all_clients();
// Release capturer
*self.capturer.write().await = None;
// Update config
*self.config.write().await = config.clone();
// Re-initialize if device path is set
if let Some(ref path) = config.device_path {
let devices = enumerate_devices()?;
let device = devices
.into_iter()
.find(|d| d.path == *path)
.ok_or_else(|| AppError::VideoError(format!("Device not found: {}", path.display())))?;
self.init_with_device(device).await?;
}
self.config_changing.store(false, Ordering::SeqCst);
self.publish_state_change().await;
Ok(())
}
// ========================================================================
// Internal
// ========================================================================
/// Publish state change event
async fn publish_state_change(&self) {
if let Some(ref events) = *self.events.read().await {
let state = *self.state.read().await;
let device = self.current_device.read().await;
events.publish(SystemEvent::StreamStateChanged {
state: state.to_string(),
device: device.as_ref().map(|d| d.path.display().to_string()),
});
}
}
}
impl Default for MjpegStreamer {
fn default() -> Self {
Self {
config: RwLock::new(MjpegStreamerConfig::default()),
capturer: RwLock::new(None),
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
current_device: RwLock::new(None),
state: RwLock::new(MjpegStreamerState::Uninitialized),
audio_controller: RwLock::new(None),
audio_enabled: AtomicBool::new(false),
ws_hid_handler: WsHidHandler::new(),
hid_controller: RwLock::new(None),
start_lock: tokio::sync::Mutex::new(()),
events: RwLock::new(None),
config_changing: AtomicBool::new(false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mjpeg_streamer_creation() {
let streamer = MjpegStreamer::new();
assert!(!streamer.is_config_changing());
assert!(!streamer.is_audio_enabled());
}
#[test]
fn test_mjpeg_streamer_config_default() {
let config = MjpegStreamerConfig::default();
assert_eq!(config.resolution, Resolution::HD1080);
assert_eq!(config.format, PixelFormat::Mjpeg);
assert_eq!(config.fps, 30);
}
#[test]
fn test_mjpeg_streamer_state_display() {
assert_eq!(MjpegStreamerState::Streaming.to_string(), "streaming");
assert_eq!(MjpegStreamerState::Ready.to_string(), "ready");
}
}

17
src/stream/mod.rs Normal file
View File

@@ -0,0 +1,17 @@
//! Video streaming module
//!
//! Provides MJPEG streaming and WebSocket handlers for MJPEG mode.
//!
//! # Components
//!
//! - `MjpegStreamer` - High-level MJPEG streaming manager
//! - `MjpegStreamHandler` - HTTP multipart MJPEG video streaming
//! - `WsHidHandler` - WebSocket HID input handler
pub mod mjpeg;
pub mod mjpeg_streamer;
pub mod ws_hid;
pub use mjpeg::{ClientGuard, MjpegStreamHandler};
pub use mjpeg_streamer::{MjpegStreamer, MjpegStreamerConfig, MjpegStreamerState, MjpegStreamerStats};
pub use ws_hid::WsHidHandler;

280
src/stream/ws_hid.rs Normal file
View File

@@ -0,0 +1,280 @@
//! WebSocket HID Handler for MJPEG mode
//!
//! This module provides a standalone WebSocket HID handler that can be used
//! independently of the application state. It manages multiple WebSocket
//! connections and forwards HID events to the HID controller.
//!
//! # Protocol
//!
//! Only binary protocol is supported for optimal performance.
//! See `crate::hid::datachannel` for message format details.
//!
//! # Architecture
//!
//! ```text
//! WsHidHandler
//! |
//! +-- clients: HashMap<ClientId, WsHidClient>
//! +-- hid_controller: Arc<HidController>
//! |
//! +-- add_client() -> spawns client handler task
//! +-- remove_client()
//! ```
use axum::extract::ws::{Message, WebSocket};
use futures::{SinkExt, StreamExt};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::hid::datachannel::{parse_hid_message, HidChannelEvent};
use crate::hid::HidController;
/// Client ID type
pub type ClientId = String;
/// WebSocket HID client information
#[derive(Debug)]
pub struct WsHidClient {
/// Client ID
pub id: ClientId,
/// Connection timestamp
pub connected_at: Instant,
/// Events processed
pub events_processed: AtomicU64,
/// Shutdown signal sender
shutdown_tx: mpsc::Sender<()>,
}
impl WsHidClient {
/// Get events processed count
pub fn events_count(&self) -> u64 {
self.events_processed.load(Ordering::Relaxed)
}
/// Get connection duration in seconds
pub fn connected_secs(&self) -> u64 {
self.connected_at.elapsed().as_secs()
}
}
/// WebSocket HID Handler
///
/// Manages WebSocket connections for HID input in MJPEG mode.
/// Only binary protocol is supported for optimal performance.
pub struct WsHidHandler {
/// HID controller reference
hid_controller: RwLock<Option<Arc<HidController>>>,
/// Active clients
clients: RwLock<HashMap<ClientId, Arc<WsHidClient>>>,
/// Running state
running: AtomicBool,
/// Total events processed
total_events: AtomicU64,
}
impl WsHidHandler {
/// Create a new WebSocket HID handler
pub fn new() -> Arc<Self> {
Arc::new(Self {
hid_controller: RwLock::new(None),
clients: RwLock::new(HashMap::new()),
running: AtomicBool::new(true),
total_events: AtomicU64::new(0),
})
}
/// Set HID controller
pub fn set_hid_controller(&self, hid: Arc<HidController>) {
*self.hid_controller.write() = Some(hid);
info!("WsHidHandler: HID controller set");
}
/// Get HID controller
pub fn hid_controller(&self) -> Option<Arc<HidController>> {
self.hid_controller.read().clone()
}
/// Check if HID controller is available
pub fn is_hid_available(&self) -> bool {
self.hid_controller.read().is_some()
}
/// Get client count
pub fn client_count(&self) -> usize {
self.clients.read().len()
}
/// Check if running
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
/// Stop the handler
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
// Signal all clients to disconnect
let clients = self.clients.read();
for client in clients.values() {
let _ = client.shutdown_tx.try_send(());
}
}
/// Get total events processed
pub fn total_events(&self) -> u64 {
self.total_events.load(Ordering::Relaxed)
}
/// Add a new WebSocket client
///
/// This spawns a background task to handle the WebSocket connection.
pub async fn add_client(self: &Arc<Self>, client_id: ClientId, socket: WebSocket) {
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let client = Arc::new(WsHidClient {
id: client_id.clone(),
connected_at: Instant::now(),
events_processed: AtomicU64::new(0),
shutdown_tx,
});
self.clients.write().insert(client_id.clone(), client.clone());
info!(
"WsHidHandler: Client {} connected (total: {})",
client_id,
self.client_count()
);
// Spawn handler task
let handler = self.clone();
tokio::spawn(async move {
handler
.handle_client(client_id.clone(), socket, client, shutdown_rx)
.await;
handler.remove_client(&client_id);
});
}
/// Remove a client
pub fn remove_client(&self, client_id: &str) {
if let Some(client) = self.clients.write().remove(client_id) {
info!(
"WsHidHandler: Client {} disconnected after {}s ({} events)",
client_id,
client.connected_secs(),
client.events_count()
);
}
}
/// Handle a WebSocket client connection
async fn handle_client(
&self,
client_id: ClientId,
socket: WebSocket,
client: Arc<WsHidClient>,
mut shutdown_rx: mpsc::Receiver<()>,
) {
let (mut sender, mut receiver) = socket.split();
// Send initial status as binary: 0x00 = ok, 0x01 = error
let status_byte = if self.is_hid_available() { 0x00u8 } else { 0x01u8 };
let _ = sender.send(Message::Binary(vec![status_byte])).await;
loop {
tokio::select! {
biased;
_ = shutdown_rx.recv() => {
debug!("WsHidHandler: Client {} received shutdown signal", client_id);
break;
}
msg = receiver.next() => {
match msg {
Some(Ok(Message::Binary(data))) => {
if let Err(e) = self.handle_binary_message(&data, &client).await {
warn!("WsHidHandler: Failed to handle binary message: {}", e);
}
}
Some(Ok(Message::Ping(data))) => {
let _ = sender.send(Message::Pong(data)).await;
}
Some(Ok(Message::Close(_))) => {
debug!("WsHidHandler: Client {} closed connection", client_id);
break;
}
Some(Err(e)) => {
error!("WsHidHandler: WebSocket error for client {}: {}", client_id, e);
break;
}
None => {
debug!("WsHidHandler: Client {} stream ended", client_id);
break;
}
// Ignore text messages - binary protocol only
Some(Ok(Message::Text(_))) => {
warn!("WsHidHandler: Ignoring text message from client {} (binary protocol only)", client_id);
}
_ => {}
}
}
}
}
}
/// Handle binary HID message
async fn handle_binary_message(&self, data: &[u8], client: &WsHidClient) -> Result<(), String> {
let hid = self
.hid_controller
.read()
.clone()
.ok_or("HID controller not available")?;
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
match event {
HidChannelEvent::Keyboard(kb_event) => {
hid.send_keyboard(kb_event)
.await
.map_err(|e| e.to_string())?;
}
HidChannelEvent::Mouse(ms_event) => {
hid.send_mouse(ms_event).await.map_err(|e| e.to_string())?;
}
}
client.events_processed.fetch_add(1, Ordering::Relaxed);
self.total_events.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
impl Default for WsHidHandler {
fn default() -> Self {
Self {
hid_controller: RwLock::new(None),
clients: RwLock::new(HashMap::new()),
running: AtomicBool::new(true),
total_events: AtomicU64::new(0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_hid_handler_creation() {
let handler = WsHidHandler::new();
assert!(handler.is_running());
assert_eq!(handler.client_count(), 0);
assert!(!handler.is_hid_available());
}
}

7
src/utils/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
//! Utility modules for One-KVM
//!
//! This module contains common utilities used across the codebase.
pub mod throttle;
pub use throttle::LogThrottler;

247
src/utils/throttle.rs Normal file
View File

@@ -0,0 +1,247 @@
//! Log throttling utility
//!
//! Provides a mechanism to limit how often the same log message is recorded,
//! preventing log flooding when errors occur repeatedly.
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
/// Log throttler that limits how often the same message is logged
///
/// This is useful for preventing log flooding when errors occur repeatedly,
/// such as when a device is disconnected and reconnection attempts fail.
///
/// # Example
///
/// ```rust
/// use one_kvm::utils::LogThrottler;
///
/// let throttler = LogThrottler::new(Duration::from_secs(5));
///
/// // First call returns true
/// assert!(throttler.should_log("device_error"));
///
/// // Subsequent calls within 5 seconds return false
/// assert!(!throttler.should_log("device_error"));
/// ```
pub struct LogThrottler {
/// Map of message key to last log time
last_logged: RwLock<HashMap<String, Instant>>,
/// Throttle interval
interval: Duration,
}
impl LogThrottler {
/// Create a new log throttler with the specified interval
///
/// # Arguments
///
/// * `interval` - The minimum time between log messages for the same key
pub fn new(interval: Duration) -> Self {
Self {
last_logged: RwLock::new(HashMap::new()),
interval,
}
}
/// Create a new log throttler with interval specified in seconds
pub fn with_secs(secs: u64) -> Self {
Self::new(Duration::from_secs(secs))
}
/// Check if a message should be logged (not throttled)
///
/// Returns `true` if the message should be logged, `false` if it should be throttled.
/// If `true` is returned, the internal timestamp is updated.
///
/// # Arguments
///
/// * `key` - A unique identifier for the message type
pub fn should_log(&self, key: &str) -> bool {
let now = Instant::now();
// First check with read lock (fast path)
{
let map = self.last_logged.read().unwrap();
if let Some(last) = map.get(key) {
if now.duration_since(*last) < self.interval {
return false;
}
}
}
// Update with write lock
let mut map = self.last_logged.write().unwrap();
// Double-check after acquiring write lock
if let Some(last) = map.get(key) {
if now.duration_since(*last) < self.interval {
return false;
}
}
map.insert(key.to_string(), now);
true
}
/// Clear throttle state for a specific key
///
/// This should be called when an error condition recovers,
/// so the next error will be logged immediately.
///
/// # Arguments
///
/// * `key` - The key to clear
pub fn clear(&self, key: &str) {
self.last_logged.write().unwrap().remove(key);
}
/// Clear all throttle state
pub fn clear_all(&self) {
self.last_logged.write().unwrap().clear();
}
/// Get the number of tracked keys
pub fn len(&self) -> usize {
self.last_logged.read().unwrap().len()
}
/// Check if the throttler is empty
pub fn is_empty(&self) -> bool {
self.last_logged.read().unwrap().is_empty()
}
}
impl Default for LogThrottler {
/// Create a default log throttler with 5 second interval
fn default() -> Self {
Self::with_secs(5)
}
}
/// Macro for throttled warning logging
///
/// # Example
///
/// ```rust
/// use one_kvm::utils::LogThrottler;
/// use one_kvm::warn_throttled;
///
/// let throttler = LogThrottler::default();
/// warn_throttled!(throttler, "my_error", "Error occurred: {}", "details");
/// ```
#[macro_export]
macro_rules! warn_throttled {
($throttler:expr, $key:expr, $($arg:tt)*) => {
if $throttler.should_log($key) {
tracing::warn!($($arg)*);
}
};
}
/// Macro for throttled error logging
#[macro_export]
macro_rules! error_throttled {
($throttler:expr, $key:expr, $($arg:tt)*) => {
if $throttler.should_log($key) {
tracing::error!($($arg)*);
}
};
}
/// Macro for throttled info logging
#[macro_export]
macro_rules! info_throttled {
($throttler:expr, $key:expr, $($arg:tt)*) => {
if $throttler.should_log($key) {
tracing::info!($($arg)*);
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_should_log_first_call() {
let throttler = LogThrottler::with_secs(1);
assert!(throttler.should_log("test_key"));
}
#[test]
fn test_throttling() {
let throttler = LogThrottler::new(Duration::from_millis(100));
// First call should succeed
assert!(throttler.should_log("test_key"));
// Immediate second call should be throttled
assert!(!throttler.should_log("test_key"));
// Wait for throttle to expire
thread::sleep(Duration::from_millis(150));
// Should succeed again
assert!(throttler.should_log("test_key"));
}
#[test]
fn test_different_keys() {
let throttler = LogThrottler::with_secs(10);
// Different keys should be independent
assert!(throttler.should_log("key1"));
assert!(throttler.should_log("key2"));
assert!(!throttler.should_log("key1"));
assert!(!throttler.should_log("key2"));
}
#[test]
fn test_clear() {
let throttler = LogThrottler::with_secs(10);
assert!(throttler.should_log("test_key"));
assert!(!throttler.should_log("test_key"));
// Clear the key
throttler.clear("test_key");
// Should be able to log again
assert!(throttler.should_log("test_key"));
}
#[test]
fn test_clear_all() {
let throttler = LogThrottler::with_secs(10);
assert!(throttler.should_log("key1"));
assert!(throttler.should_log("key2"));
throttler.clear_all();
assert!(throttler.should_log("key1"));
assert!(throttler.should_log("key2"));
}
#[test]
fn test_default() {
let throttler = LogThrottler::default();
assert!(throttler.should_log("test"));
}
#[test]
fn test_len_and_is_empty() {
let throttler = LogThrottler::with_secs(10);
assert!(throttler.is_empty());
assert_eq!(throttler.len(), 0);
throttler.should_log("key1");
assert!(!throttler.is_empty());
assert_eq!(throttler.len(), 1);
throttler.should_log("key2");
assert_eq!(throttler.len(), 2);
}
}

693
src/video/capture.rs Normal file
View File

@@ -0,0 +1,693 @@
//! V4L2 video capture implementation
//!
//! Provides async video capture using memory-mapped buffers.
use bytes::Bytes;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, error, info, warn};
use v4l::buffer::Type as BufferType;
use v4l::io::traits::CaptureStream;
use v4l::prelude::*;
use v4l::video::capture::Parameters;
use v4l::video::Capture;
use v4l::Format;
use super::format::{PixelFormat, Resolution};
use super::frame::VideoFrame;
use crate::error::{AppError, Result};
/// Default number of capture buffers (reduced from 4 to 2 for lower latency)
const DEFAULT_BUFFER_COUNT: u32 = 2;
/// Default capture timeout in seconds
const DEFAULT_TIMEOUT: u64 = 2;
/// Minimum valid frame size (bytes)
const MIN_FRAME_SIZE: usize = 128;
/// Video capturer configuration
#[derive(Debug, Clone)]
pub struct CaptureConfig {
/// Device path
pub device_path: PathBuf,
/// Desired resolution
pub resolution: Resolution,
/// Desired pixel format
pub format: PixelFormat,
/// Desired frame rate (0 = max available)
pub fps: u32,
/// Number of capture buffers
pub buffer_count: u32,
/// Capture timeout
pub timeout: Duration,
/// JPEG quality (1-100, for MJPEG sources with hardware quality control)
pub jpeg_quality: u8,
}
impl Default for CaptureConfig {
fn default() -> Self {
Self {
device_path: PathBuf::from("/dev/video0"),
resolution: Resolution::HD1080,
format: PixelFormat::Mjpeg,
fps: 30,
buffer_count: DEFAULT_BUFFER_COUNT,
timeout: Duration::from_secs(DEFAULT_TIMEOUT),
jpeg_quality: 80,
}
}
}
impl CaptureConfig {
/// Create config for a specific device
pub fn for_device(path: impl AsRef<Path>) -> Self {
Self {
device_path: path.as_ref().to_path_buf(),
..Default::default()
}
}
/// Set resolution
pub fn with_resolution(mut self, width: u32, height: u32) -> Self {
self.resolution = Resolution::new(width, height);
self
}
/// Set format
pub fn with_format(mut self, format: PixelFormat) -> Self {
self.format = format;
self
}
/// Set frame rate
pub fn with_fps(mut self, fps: u32) -> Self {
self.fps = fps;
self
}
}
/// Capture statistics
#[derive(Debug, Clone, Default)]
pub struct CaptureStats {
/// Total frames captured
pub frames_captured: u64,
/// Frames dropped (invalid/too small)
pub frames_dropped: u64,
/// Current FPS (calculated)
pub current_fps: f32,
/// Average frame size in bytes
pub avg_frame_size: usize,
/// Capture errors
pub errors: u64,
/// Last frame timestamp
pub last_frame_ts: Option<Instant>,
/// Whether signal is present
pub signal_present: bool,
}
/// Video capturer state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
/// Not started
Stopped,
/// Starting (initializing device)
Starting,
/// Running and capturing
Running,
/// No signal from source
NoSignal,
/// Error occurred
Error,
/// Device was lost (disconnected)
DeviceLost,
}
/// Async video capturer
pub struct VideoCapturer {
config: CaptureConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
stats: Arc<Mutex<CaptureStats>>,
frame_tx: broadcast::Sender<VideoFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
/// Last error that occurred (device path, reason)
last_error: Arc<parking_lot::RwLock<Option<(String, String)>>>,
}
impl VideoCapturer {
/// Create a new video capturer
pub fn new(config: CaptureConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16); // Buffer up to 16 frames
Self {
config,
state: Arc::new(state_tx),
state_rx,
stats: Arc::new(Mutex::new(CaptureStats::default())),
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
last_error: Arc::new(parking_lot::RwLock::new(None)),
}
}
/// Get current capture state
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
/// Get last error (device path, reason)
pub fn last_error(&self) -> Option<(String, String)> {
self.last_error.read().clone()
}
/// Clear last error
pub fn clear_error(&self) {
*self.last_error.write() = None;
}
/// Subscribe to frames
pub fn subscribe(&self) -> broadcast::Receiver<VideoFrame> {
self.frame_tx.subscribe()
}
/// Get frame sender (for sharing with other components like WebRTC)
pub fn frame_sender(&self) -> broadcast::Sender<VideoFrame> {
self.frame_tx.clone()
}
/// Get capture statistics
pub async fn stats(&self) -> CaptureStats {
self.stats.lock().await.clone()
}
/// Get config
pub fn config(&self) -> &CaptureConfig {
&self.config
}
/// Start capturing in background
pub async fn start(&self) -> Result<()> {
let current_state = self.state();
// Already running or starting - nothing to do
if current_state == CaptureState::Running || current_state == CaptureState::Starting {
return Ok(());
}
info!(
"Starting capture on {:?} at {}x{} {}",
self.config.device_path,
self.config.resolution.width,
self.config.resolution.height,
self.config.format
);
// Set Starting state immediately to prevent concurrent start attempts
let _ = self.state.send(CaptureState::Starting);
// Clear any previous error
*self.last_error.write() = None;
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let stats = self.stats.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let last_error = self.last_error.clone();
let handle = tokio::task::spawn_blocking(move || {
capture_loop(config, state, stats, frame_tx, stop_flag, sequence, last_error);
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
/// Stop capturing
pub async fn stop(&self) -> Result<()> {
info!("Stopping capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
/// Check if capturing
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
/// Get the latest frame (if any receivers would get it)
pub fn latest_frame(&self) -> Option<VideoFrame> {
// This is a bit tricky with broadcast - we'd need to track internally
// For now, callers should use subscribe()
None
}
}
/// Main capture loop (runs in blocking thread)
fn capture_loop(
config: CaptureConfig,
state: Arc<watch::Sender<CaptureState>>,
stats: Arc<Mutex<CaptureStats>>,
frame_tx: broadcast::Sender<VideoFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
error_holder: Arc<parking_lot::RwLock<Option<(String, String)>>>,
) {
let result = run_capture(
&config,
&state,
&stats,
&frame_tx,
&stop_flag,
&sequence,
);
match result {
Ok(_) => {
let _ = state.send(CaptureState::Stopped);
}
Err(AppError::VideoDeviceLost { device, reason }) => {
error!("Video device lost: {} - {}", device, reason);
// Store the error for recovery handling
*error_holder.write() = Some((device, reason));
let _ = state.send(CaptureState::DeviceLost);
}
Err(e) => {
error!("Capture error: {}", e);
let _ = state.send(CaptureState::Error);
}
}
}
fn run_capture(
config: &CaptureConfig,
state: &watch::Sender<CaptureState>,
stats: &Arc<Mutex<CaptureStats>>,
frame_tx: &broadcast::Sender<VideoFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
) -> Result<()> {
// Retry logic for device busy errors
const MAX_RETRIES: u32 = 5;
const RETRY_DELAY_MS: u64 = 200;
let mut last_error = None;
for attempt in 0..MAX_RETRIES {
if stop_flag.load(Ordering::Relaxed) {
return Ok(());
}
// Open device
let device = match Device::with_path(&config.device_path) {
Ok(d) => d,
Err(e) => {
let err_str = e.to_string();
if err_str.contains("busy") || err_str.contains("resource") {
warn!(
"Device busy on attempt {}/{}, retrying in {}ms...",
attempt + 1,
MAX_RETRIES,
RETRY_DELAY_MS
);
std::thread::sleep(Duration::from_millis(RETRY_DELAY_MS));
last_error = Some(AppError::VideoError(format!(
"Failed to open device {:?}: {}",
config.device_path, e
)));
continue;
}
return Err(AppError::VideoError(format!(
"Failed to open device {:?}: {}",
config.device_path, e
)));
}
};
// Set format
let format = Format::new(
config.resolution.width,
config.resolution.height,
config.format.to_fourcc(),
);
let actual_format = match device.set_format(&format) {
Ok(f) => f,
Err(e) => {
let err_str = e.to_string();
if err_str.contains("busy") || err_str.contains("resource") {
warn!(
"Device busy on set_format attempt {}/{}, retrying in {}ms...",
attempt + 1,
MAX_RETRIES,
RETRY_DELAY_MS
);
std::thread::sleep(Duration::from_millis(RETRY_DELAY_MS));
last_error = Some(AppError::VideoError(format!("Failed to set format: {}", e)));
continue;
}
return Err(AppError::VideoError(format!("Failed to set format: {}", e)));
}
};
// Device opened and format set successfully - proceed with capture
return run_capture_inner(
config,
state,
stats,
frame_tx,
stop_flag,
sequence,
device,
actual_format,
);
}
// All retries exhausted
Err(last_error.unwrap_or_else(|| {
AppError::VideoError("Failed to open device after all retries".to_string())
}))
}
/// Inner capture function after device is successfully opened
fn run_capture_inner(
config: &CaptureConfig,
state: &watch::Sender<CaptureState>,
stats: &Arc<Mutex<CaptureStats>>,
frame_tx: &broadcast::Sender<VideoFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
device: Device,
actual_format: Format,
) -> Result<()> {
info!(
"Capture format: {}x{} {:?} stride={}",
actual_format.width, actual_format.height, actual_format.fourcc, actual_format.stride
);
let resolution = Resolution::new(actual_format.width, actual_format.height);
let pixel_format = PixelFormat::from_fourcc(actual_format.fourcc).unwrap_or(config.format);
// Try to set hardware FPS (V4L2 VIDIOC_S_PARM)
if config.fps > 0 {
match device.set_params(&Parameters::with_fps(config.fps)) {
Ok(actual_params) => {
// Extract actual FPS from returned interval (numerator/denominator)
let actual_hw_fps = if actual_params.interval.numerator > 0 {
actual_params.interval.denominator / actual_params.interval.numerator
} else {
0
};
if actual_hw_fps == config.fps {
info!("Hardware FPS set successfully: {} fps", actual_hw_fps);
} else if actual_hw_fps > 0 {
info!(
"Hardware FPS coerced: requested {} fps, got {} fps",
config.fps, actual_hw_fps
);
} else {
warn!("Hardware FPS setting returned invalid interval");
}
}
Err(e) => {
warn!("Failed to set hardware FPS: {}", e);
}
}
}
// Create stream with mmap buffers
let mut stream =
MmapStream::with_buffers(&device, BufferType::VideoCapture, config.buffer_count)
.map_err(|e| AppError::VideoError(format!("Failed to create stream: {}", e)))?;
let _ = state.send(CaptureState::Running);
info!("Capture started");
// FPS calculation variables
let mut fps_frame_count = 0u64;
let mut fps_window_start = Instant::now();
let fps_window_duration = Duration::from_secs(1);
// Main capture loop
while !stop_flag.load(Ordering::Relaxed) {
// Try to capture a frame
let (buf, meta) = match stream.next() {
Ok(frame_data) => frame_data,
Err(e) => {
if e.kind() == io::ErrorKind::TimedOut {
warn!("Capture timeout - no signal?");
let _ = state.send(CaptureState::NoSignal);
// Update stats
if let Ok(mut s) = stats.try_lock() {
s.signal_present = false;
}
// Wait a bit before retrying
std::thread::sleep(Duration::from_millis(100));
continue;
}
// Check for device loss errors
let is_device_lost = match e.raw_os_error() {
Some(6) => true, // ENXIO - No such device or address
Some(19) => true, // ENODEV - No such device
Some(5) => true, // EIO - I/O error (device removed)
Some(32) => true, // EPIPE - Broken pipe
Some(108) => true, // ESHUTDOWN - Transport endpoint shutdown
_ => false,
};
if is_device_lost {
let device_path = config.device_path.display().to_string();
error!("Video device lost: {} - {}", device_path, e);
return Err(AppError::VideoDeviceLost {
device: device_path,
reason: e.to_string(),
});
}
error!("Capture error: {}", e);
if let Ok(mut s) = stats.try_lock() {
s.errors += 1;
}
continue;
}
};
// Use actual bytes used, not buffer size
let frame_size = meta.bytesused as usize;
// Validate frame
if frame_size < MIN_FRAME_SIZE {
debug!("Dropping small frame: {} bytes (bytesused={})", frame_size, meta.bytesused);
if let Ok(mut s) = stats.try_lock() {
s.frames_dropped += 1;
}
continue;
}
// For JPEG formats, validate header
if pixel_format.is_compressed() && !is_valid_jpeg(&buf[..frame_size]) {
debug!("Dropping invalid JPEG frame (size={})", frame_size);
if let Ok(mut s) = stats.try_lock() {
s.frames_dropped += 1;
}
continue;
}
// Create frame with actual data size
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = VideoFrame::new(
Bytes::copy_from_slice(&buf[..frame_size]),
resolution,
pixel_format,
actual_format.stride,
seq,
);
// Update state if was no signal
if *state.borrow() == CaptureState::NoSignal {
let _ = state.send(CaptureState::Running);
}
// Send frame to subscribers
let receiver_count = frame_tx.receiver_count();
if receiver_count > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No active receivers for frame: {}", e);
}
} else if seq % 60 == 0 {
// Log every 60 frames (about 1 second at 60fps) when no receivers
debug!("No receivers for video frames (receiver_count=0)");
}
// Update stats
if let Ok(mut s) = stats.try_lock() {
s.frames_captured += 1;
s.signal_present = true;
s.last_frame_ts = Some(Instant::now());
// Update FPS calculation
fps_frame_count += 1;
let elapsed = fps_window_start.elapsed();
if elapsed >= fps_window_duration {
// Calculate FPS from the completed window
s.current_fps = (fps_frame_count as f32 / elapsed.as_secs_f32()).max(0.0);
// Reset for next window
fps_frame_count = 0;
fps_window_start = Instant::now();
} else if elapsed.as_millis() > 100 && fps_frame_count > 0 {
// Provide partial estimate if we have at least 100ms of data
s.current_fps = (fps_frame_count as f32 / elapsed.as_secs_f32()).max(0.0);
}
}
}
info!("Capture stopped");
Ok(())
}
/// Validate JPEG frame data
fn is_valid_jpeg(data: &[u8]) -> bool {
if data.len() < 125 {
return false;
}
// Check start marker (0xFFD8)
let start_marker = ((data[0] as u16) << 8) | data[1] as u16;
if start_marker != 0xFFD8 {
return false;
}
// Check end marker
let end = data.len();
let end_marker = ((data[end - 2] as u16) << 8) | data[end - 1] as u16;
// Valid end markers: 0xFFD9, 0xD900, 0x0000 (padded)
matches!(end_marker, 0xFFD9 | 0xD900 | 0x0000)
}
/// Frame grabber for one-shot capture
pub struct FrameGrabber {
device_path: PathBuf,
}
impl FrameGrabber {
/// Create a new frame grabber
pub fn new(device_path: impl AsRef<Path>) -> Self {
Self {
device_path: device_path.as_ref().to_path_buf(),
}
}
/// Capture a single frame
pub async fn grab(
&self,
resolution: Resolution,
format: PixelFormat,
) -> Result<VideoFrame> {
let device_path = self.device_path.clone();
tokio::task::spawn_blocking(move || {
grab_single_frame(&device_path, resolution, format)
})
.await
.map_err(|e| AppError::VideoError(format!("Grab task failed: {}", e)))?
}
}
fn grab_single_frame(
device_path: &Path,
resolution: Resolution,
format: PixelFormat,
) -> Result<VideoFrame> {
let device = Device::with_path(device_path).map_err(|e| {
AppError::VideoError(format!("Failed to open device: {}", e))
})?;
let fmt = Format::new(resolution.width, resolution.height, format.to_fourcc());
let actual = device.set_format(&fmt).map_err(|e| {
AppError::VideoError(format!("Failed to set format: {}", e))
})?;
let mut stream = MmapStream::with_buffers(&device, BufferType::VideoCapture, 2)
.map_err(|e| AppError::VideoError(format!("Failed to create stream: {}", e)))?;
// Try to get a valid frame (skip first few which might be bad)
for attempt in 0..5 {
match stream.next() {
Ok((buf, _meta)) => {
if buf.len() >= MIN_FRAME_SIZE {
let actual_format =
PixelFormat::from_fourcc(actual.fourcc).unwrap_or(format);
return Ok(VideoFrame::new(
Bytes::copy_from_slice(buf),
Resolution::new(actual.width, actual.height),
actual_format,
actual.stride,
0,
));
}
}
Err(e) => {
if attempt == 4 {
return Err(AppError::VideoError(format!(
"Failed to grab frame: {}",
e
)));
}
}
}
}
Err(AppError::VideoError("Failed to capture valid frame".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_jpeg() {
// Valid JPEG header and footer
let mut data = vec![0xFF, 0xD8]; // SOI
data.extend(vec![0u8; 200]); // Content
data.extend([0xFF, 0xD9]); // EOI
assert!(is_valid_jpeg(&data));
// Invalid - too small
assert!(!is_valid_jpeg(&[0xFF, 0xD8, 0xFF, 0xD9]));
// Invalid - wrong header
let mut bad = vec![0x00, 0x00];
bad.extend(vec![0u8; 200]);
assert!(!is_valid_jpeg(&bad));
}
}

640
src/video/convert.rs Normal file
View File

@@ -0,0 +1,640 @@
//! Pixel format conversion utilities
//!
//! This module provides SIMD-accelerated color space conversion using libyuv.
//! Primary use case: YUYV (from V4L2 capture) → YUV420P/NV12 (for H264 encoding)
use crate::error::{AppError, Result};
use crate::video::format::{PixelFormat, Resolution};
/// YUV420P buffer with separate Y, U, V planes
pub struct Yuv420pBuffer {
/// Raw buffer containing all planes
data: Vec<u8>,
/// Width of the frame
width: u32,
/// Height of the frame
height: u32,
/// Y plane offset (always 0)
y_offset: usize,
/// U plane offset
u_offset: usize,
/// V plane offset
v_offset: usize,
}
impl Yuv420pBuffer {
/// Create a new YUV420P buffer for the given resolution
pub fn new(resolution: Resolution) -> Self {
let width = resolution.width;
let height = resolution.height;
// YUV420P: Y = width*height, U = width*height/4, V = width*height/4
let y_size = (width * height) as usize;
let uv_size = y_size / 4;
let total_size = y_size + uv_size * 2;
Self {
data: vec![0u8; total_size],
width,
height,
y_offset: 0,
u_offset: y_size,
v_offset: y_size + uv_size,
}
}
/// Get the raw buffer as bytes
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
/// Get the raw buffer as mutable bytes
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.data
}
/// Get Y plane
pub fn y_plane(&self) -> &[u8] {
&self.data[self.y_offset..self.u_offset]
}
/// Get Y plane mutable
pub fn y_plane_mut(&mut self) -> &mut [u8] {
let u_offset = self.u_offset;
&mut self.data[self.y_offset..u_offset]
}
/// Get U plane
pub fn u_plane(&self) -> &[u8] {
&self.data[self.u_offset..self.v_offset]
}
/// Get U plane mutable
pub fn u_plane_mut(&mut self) -> &mut [u8] {
let v_offset = self.v_offset;
let u_offset = self.u_offset;
&mut self.data[u_offset..v_offset]
}
/// Get V plane
pub fn v_plane(&self) -> &[u8] {
&self.data[self.v_offset..]
}
/// Get V plane mutable
pub fn v_plane_mut(&mut self) -> &mut [u8] {
let v_offset = self.v_offset;
&mut self.data[v_offset..]
}
/// Get buffer length
pub fn len(&self) -> usize {
self.data.len()
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Get resolution
pub fn resolution(&self) -> Resolution {
Resolution::new(self.width, self.height)
}
}
/// NV12 buffer with Y plane and interleaved UV plane
pub struct Nv12Buffer {
/// Raw buffer containing Y plane followed by interleaved UV plane
data: Vec<u8>,
/// Width of the frame
width: u32,
/// Height of the frame
height: u32,
}
impl Nv12Buffer {
/// Create a new NV12 buffer for the given resolution
pub fn new(resolution: Resolution) -> Self {
let width = resolution.width;
let height = resolution.height;
// NV12: Y = width*height, UV = width*height/2 (interleaved)
let y_size = (width * height) as usize;
let uv_size = y_size / 2;
let total_size = y_size + uv_size;
Self {
data: vec![0u8; total_size],
width,
height,
}
}
/// Get the raw buffer as bytes
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
/// Get the raw buffer as mutable bytes
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.data
}
/// Get Y plane
pub fn y_plane(&self) -> &[u8] {
let y_size = (self.width * self.height) as usize;
&self.data[..y_size]
}
/// Get Y plane mutable
pub fn y_plane_mut(&mut self) -> &mut [u8] {
let y_size = (self.width * self.height) as usize;
&mut self.data[..y_size]
}
/// Get UV plane (interleaved)
pub fn uv_plane(&self) -> &[u8] {
let y_size = (self.width * self.height) as usize;
&self.data[y_size..]
}
/// Get UV plane mutable
pub fn uv_plane_mut(&mut self) -> &mut [u8] {
let y_size = (self.width * self.height) as usize;
&mut self.data[y_size..]
}
/// Get buffer length
pub fn len(&self) -> usize {
self.data.len()
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Get resolution
pub fn resolution(&self) -> Resolution {
Resolution::new(self.width, self.height)
}
}
/// Pixel format converter using libyuv (SIMD accelerated)
pub struct PixelConverter {
/// Source format
src_format: PixelFormat,
/// Destination format
dst_format: PixelFormat,
/// Frame resolution
resolution: Resolution,
/// Output buffer (reused across conversions)
output_buffer: Yuv420pBuffer,
}
impl PixelConverter {
/// Create a new converter for YUYV → YUV420P
pub fn yuyv_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Yuyv,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Create a new converter for UYVY → YUV420P
pub fn uyvy_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Uyvy,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Create a new converter for YVYU → YUV420P
pub fn yvyu_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Yvyu,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Create a new converter for NV12 → YUV420P
pub fn nv12_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Nv12,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Create a new converter for YVU420 → YUV420P (swap U and V planes)
pub fn yvu420_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Yvu420,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Create a new converter for RGB24 → YUV420P
pub fn rgb24_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Rgb24,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Create a new converter for BGR24 → YUV420P
pub fn bgr24_to_yuv420p(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Bgr24,
dst_format: PixelFormat::Yuv420,
resolution,
output_buffer: Yuv420pBuffer::new(resolution),
}
}
/// Convert a frame and return reference to the output buffer
pub fn convert(&mut self, input: &[u8]) -> Result<&[u8]> {
let width = self.resolution.width as i32;
let height = self.resolution.height as i32;
let expected_size = self.output_buffer.len();
match (self.src_format, self.dst_format) {
(PixelFormat::Yuyv, PixelFormat::Yuv420) => {
libyuv::yuy2_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
}
(PixelFormat::Uyvy, PixelFormat::Yuv420) => {
libyuv::uyvy_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
}
(PixelFormat::Nv12, PixelFormat::Yuv420) => {
libyuv::nv12_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
}
(PixelFormat::Rgb24, PixelFormat::Yuv420) => {
libyuv::rgb24_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
}
(PixelFormat::Bgr24, PixelFormat::Yuv420) => {
libyuv::bgr24_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
}
(PixelFormat::Yvyu, PixelFormat::Yuv420) => {
// YVYU is not directly supported by libyuv, use software conversion
self.convert_yvyu_to_yuv420p_sw(input)?;
}
(PixelFormat::Yvu420, PixelFormat::Yuv420) => {
// YVU420 just swaps U and V planes
self.convert_yvu420_to_yuv420p_sw(input)?;
}
(PixelFormat::Yuv420, PixelFormat::Yuv420) => {
// No conversion needed, just copy
if input.len() < expected_size {
return Err(AppError::VideoError(format!(
"Input buffer too small: {} < {}",
input.len(),
expected_size
)));
}
self.output_buffer.as_bytes_mut().copy_from_slice(&input[..expected_size]);
}
_ => {
return Err(AppError::VideoError(format!(
"Unsupported conversion: {}{}",
self.src_format, self.dst_format
)));
}
};
Ok(self.output_buffer.as_bytes())
}
/// Get output buffer length
pub fn output_len(&self) -> usize {
self.output_buffer.len()
}
/// Get resolution
pub fn resolution(&self) -> Resolution {
self.resolution
}
/// Software conversion for YVYU (not supported by libyuv)
fn convert_yvyu_to_yuv420p_sw(&mut self, yvyu: &[u8]) -> Result<()> {
let width = self.resolution.width as usize;
let height = self.resolution.height as usize;
let y_size = width * height;
let uv_size = y_size / 4;
let half_width = width / 2;
let data = self.output_buffer.as_bytes_mut();
let (y_plane, uv_planes) = data.split_at_mut(y_size);
let (u_plane, v_plane) = uv_planes.split_at_mut(uv_size);
for row in (0..height).step_by(2) {
let yvyu_row0_offset = row * width * 2;
let yvyu_row1_offset = (row + 1) * width * 2;
let y_row0_offset = row * width;
let y_row1_offset = (row + 1) * width;
let uv_row_offset = (row / 2) * half_width;
for col in (0..width).step_by(2) {
let yvyu_offset0 = yvyu_row0_offset + col * 2;
let yvyu_offset1 = yvyu_row1_offset + col * 2;
// YVYU: Y0, V0, Y1, U0
let y0_0 = yvyu[yvyu_offset0];
let v0 = yvyu[yvyu_offset0 + 1];
let y0_1 = yvyu[yvyu_offset0 + 2];
let u0 = yvyu[yvyu_offset0 + 3];
let y1_0 = yvyu[yvyu_offset1];
let v1 = yvyu[yvyu_offset1 + 1];
let y1_1 = yvyu[yvyu_offset1 + 2];
let u1 = yvyu[yvyu_offset1 + 3];
y_plane[y_row0_offset + col] = y0_0;
y_plane[y_row0_offset + col + 1] = y0_1;
y_plane[y_row1_offset + col] = y1_0;
y_plane[y_row1_offset + col + 1] = y1_1;
let uv_idx = uv_row_offset + col / 2;
u_plane[uv_idx] = ((u0 as u16 + u1 as u16) / 2) as u8;
v_plane[uv_idx] = ((v0 as u16 + v1 as u16) / 2) as u8;
}
}
Ok(())
}
/// Software conversion for YVU420 (just swap U and V)
fn convert_yvu420_to_yuv420p_sw(&mut self, yvu420: &[u8]) -> Result<()> {
let width = self.resolution.width as usize;
let height = self.resolution.height as usize;
let y_size = width * height;
let uv_size = y_size / 4;
let data = self.output_buffer.as_bytes_mut();
let (y_plane, uv_planes) = data.split_at_mut(y_size);
let (u_plane, v_plane) = uv_planes.split_at_mut(uv_size);
// Copy Y plane directly
y_plane.copy_from_slice(&yvu420[..y_size]);
// In YVU420, V comes before U
let v_src = &yvu420[y_size..y_size + uv_size];
let u_src = &yvu420[y_size + uv_size..];
// Swap U and V
u_plane.copy_from_slice(u_src);
v_plane.copy_from_slice(v_src);
Ok(())
}
}
/// Calculate YUV420P buffer size for a given resolution
pub fn yuv420p_buffer_size(resolution: Resolution) -> usize {
let pixels = (resolution.width * resolution.height) as usize;
pixels + pixels / 2
}
/// Calculate YUYV buffer size for a given resolution
pub fn yuyv_buffer_size(resolution: Resolution) -> usize {
(resolution.width * resolution.height * 2) as usize
}
// ============================================================================
// MJPEG Decoder - Decodes JPEG to YUV420P using libyuv
// ============================================================================
/// MJPEG/JPEG decoder that outputs YUV420P using libyuv
pub struct MjpegDecoder {
/// Resolution hint (can be updated from decoded frame)
resolution: Resolution,
/// YUV420P output buffer
yuv_buffer: Yuv420pBuffer,
}
impl MjpegDecoder {
/// Create a new MJPEG decoder with expected resolution
pub fn new(resolution: Resolution) -> Result<Self> {
Ok(Self {
resolution,
yuv_buffer: Yuv420pBuffer::new(resolution),
})
}
/// Decode MJPEG/JPEG data to YUV420P using libyuv
pub fn decode(&mut self, jpeg_data: &[u8]) -> Result<&[u8]> {
// Get MJPEG dimensions
let (width, height) = libyuv::mjpeg_size(jpeg_data)
.map_err(|e| AppError::VideoError(format!("Failed to get MJPEG size: {}", e)))?;
// Check if resolution changed
if width != self.resolution.width as i32 || height != self.resolution.height as i32 {
tracing::debug!(
"MJPEG resolution changed: {}x{} -> {}x{}",
self.resolution.width,
self.resolution.height,
width,
height
);
self.resolution = Resolution::new(width as u32, height as u32);
self.yuv_buffer = Yuv420pBuffer::new(self.resolution);
}
// Decode MJPEG directly to I420 using libyuv
libyuv::mjpeg_to_i420(jpeg_data, self.yuv_buffer.as_bytes_mut(), width, height)
.map_err(|e| AppError::VideoError(format!("MJPEG decode failed: {}", e)))?;
Ok(self.yuv_buffer.as_bytes())
}
/// Get current resolution
pub fn resolution(&self) -> Resolution {
self.resolution
}
/// Get YUV420P buffer size
pub fn yuv_buffer_size(&self) -> usize {
self.yuv_buffer.len()
}
}
// ============================================================================
// NV12 Converter for VAAPI encoder (using libyuv)
// ============================================================================
/// Pixel format converter that outputs NV12 (for VAAPI encoders)
pub struct Nv12Converter {
/// Source format
src_format: PixelFormat,
/// Frame resolution
resolution: Resolution,
/// Output buffer (reused across conversions)
output_buffer: Nv12Buffer,
}
impl Nv12Converter {
/// Create a new converter for BGR24 → NV12
pub fn bgr24_to_nv12(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Bgr24,
resolution,
output_buffer: Nv12Buffer::new(resolution),
}
}
/// Create a new converter for RGB24 → NV12
pub fn rgb24_to_nv12(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Rgb24,
resolution,
output_buffer: Nv12Buffer::new(resolution),
}
}
/// Create a new converter for YUYV → NV12
pub fn yuyv_to_nv12(resolution: Resolution) -> Self {
Self {
src_format: PixelFormat::Yuyv,
resolution,
output_buffer: Nv12Buffer::new(resolution),
}
}
/// Convert a frame and return reference to the output buffer
pub fn convert(&mut self, input: &[u8]) -> Result<&[u8]> {
let width = self.resolution.width as i32;
let height = self.resolution.height as i32;
let dst = self.output_buffer.as_bytes_mut();
let result = match self.src_format {
PixelFormat::Bgr24 => libyuv::bgr24_to_nv12(input, dst, width, height),
PixelFormat::Rgb24 => libyuv::rgb24_to_nv12(input, dst, width, height),
PixelFormat::Yuyv => libyuv::yuy2_to_nv12(input, dst, width, height),
_ => {
return Err(AppError::VideoError(format!(
"Unsupported conversion to NV12: {}",
self.src_format
)));
}
};
result.map_err(|e| AppError::VideoError(format!("libyuv NV12 conversion failed: {}", e)))?;
Ok(self.output_buffer.as_bytes())
}
/// Get output buffer length
pub fn output_len(&self) -> usize {
self.output_buffer.len()
}
/// Get resolution
pub fn resolution(&self) -> Resolution {
self.resolution
}
}
// ============================================================================
// Standalone conversion functions (using libyuv)
// ============================================================================
/// Convert BGR24 to NV12 using libyuv
pub fn bgr_to_nv12(bgr: &[u8], nv12: &mut [u8], width: usize, height: usize) {
if let Err(e) = libyuv::bgr24_to_nv12(bgr, nv12, width as i32, height as i32) {
tracing::error!("libyuv BGR24→NV12 conversion failed: {}", e);
}
}
/// Convert RGB24 to NV12 using libyuv
pub fn rgb_to_nv12(rgb: &[u8], nv12: &mut [u8], width: usize, height: usize) {
if let Err(e) = libyuv::rgb24_to_nv12(rgb, nv12, width as i32, height as i32) {
tracing::error!("libyuv RGB24→NV12 conversion failed: {}", e);
}
}
/// Convert YUYV to NV12 using libyuv
pub fn yuyv_to_nv12(yuyv: &[u8], nv12: &mut [u8], width: usize, height: usize) {
if let Err(e) = libyuv::yuy2_to_nv12(yuyv, nv12, width as i32, height as i32) {
tracing::error!("libyuv YUYV→NV12 conversion failed: {}", e);
}
}
// ============================================================================
// Extended PixelConverter for MJPEG support
// ============================================================================
/// MJPEG to YUV420P converter (wraps MjpegDecoder)
pub struct MjpegToYuv420Converter {
decoder: MjpegDecoder,
}
impl MjpegToYuv420Converter {
/// Create a new MJPEG to YUV420P converter
pub fn new(resolution: Resolution) -> Result<Self> {
Ok(Self {
decoder: MjpegDecoder::new(resolution)?,
})
}
/// Convert MJPEG data to YUV420P
pub fn convert(&mut self, mjpeg_data: &[u8]) -> Result<&[u8]> {
self.decoder.decode(mjpeg_data)
}
/// Get current resolution
pub fn resolution(&self) -> Resolution {
self.decoder.resolution()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_yuv420p_buffer_creation() {
let buffer = Yuv420pBuffer::new(Resolution::HD720);
assert_eq!(buffer.len(), 1280 * 720 * 3 / 2);
assert_eq!(buffer.y_plane().len(), 1280 * 720);
assert_eq!(buffer.u_plane().len(), 1280 * 720 / 4);
assert_eq!(buffer.v_plane().len(), 1280 * 720 / 4);
}
#[test]
fn test_nv12_buffer_creation() {
let buffer = Nv12Buffer::new(Resolution::HD720);
assert_eq!(buffer.len(), 1280 * 720 * 3 / 2);
assert_eq!(buffer.y_plane().len(), 1280 * 720);
assert_eq!(buffer.uv_plane().len(), 1280 * 720 / 2);
}
#[test]
fn test_yuyv_to_yuv420p_conversion() {
let resolution = Resolution::new(4, 4);
let mut converter = PixelConverter::yuyv_to_yuv420p(resolution);
// Create YUYV data (4x4 = 32 bytes)
let yuyv = vec![
16, 128, 17, 129, 18, 130, 19, 131,
20, 132, 21, 133, 22, 134, 23, 135,
24, 136, 25, 137, 26, 138, 27, 139,
28, 140, 29, 141, 30, 142, 31, 143,
];
let result = converter.convert(&yuyv).unwrap();
assert_eq!(result.len(), 24); // 4*4 + 2*2 + 2*2 = 24 bytes
}
}

645
src/video/decoder/mjpeg.rs Normal file
View File

@@ -0,0 +1,645 @@
//! MJPEG VAAPI hardware decoder
//!
//! Uses hwcodec's FFmpeg VAAPI backend to decode MJPEG to NV12.
//! This provides hardware-accelerated JPEG decoding with direct NV12 output,
//! which is the optimal format for VAAPI H264 encoding.
use std::sync::Once;
use tracing::{debug, info, warn};
use hwcodec::ffmpeg::AVHWDeviceType;
use hwcodec::ffmpeg::AVPixelFormat;
use hwcodec::ffmpeg_ram::decode::{DecodeContext, DecodeFrame, Decoder};
use crate::error::{AppError, Result};
use crate::video::format::Resolution;
// libyuv for SIMD-accelerated YUV conversion
static INIT_LOGGING: Once = Once::new();
/// Initialize hwcodec logging (only once)
fn init_hwcodec_logging() {
INIT_LOGGING.call_once(|| {
debug!("hwcodec MJPEG decoder logging initialized");
});
}
/// MJPEG VAAPI decoder configuration
#[derive(Debug, Clone)]
pub struct MjpegVaapiDecoderConfig {
/// Expected resolution (can be updated from decoded frame)
pub resolution: Resolution,
/// Use hardware acceleration (VAAPI)
pub use_hwaccel: bool,
}
impl Default for MjpegVaapiDecoderConfig {
fn default() -> Self {
Self {
resolution: Resolution::HD1080,
use_hwaccel: true,
}
}
}
/// Decoded frame data in NV12 format
#[derive(Debug, Clone)]
pub struct DecodedNv12Frame {
/// Y plane data
pub y_plane: Vec<u8>,
/// UV interleaved plane data
pub uv_plane: Vec<u8>,
/// Y plane linesize (stride)
pub y_linesize: i32,
/// UV plane linesize (stride)
pub uv_linesize: i32,
/// Frame width
pub width: i32,
/// Frame height
pub height: i32,
}
/// Decoded frame data in YUV420P (I420) format
#[derive(Debug, Clone)]
pub struct DecodedYuv420pFrame {
/// Y plane data
pub y_plane: Vec<u8>,
/// U plane data
pub u_plane: Vec<u8>,
/// V plane data
pub v_plane: Vec<u8>,
/// Y plane linesize (stride)
pub y_linesize: i32,
/// U plane linesize (stride)
pub u_linesize: i32,
/// V plane linesize (stride)
pub v_linesize: i32,
/// Frame width
pub width: i32,
/// Frame height
pub height: i32,
}
impl DecodedYuv420pFrame {
/// Get packed YUV420P data (Y plane followed by U and V planes, with stride removed)
pub fn to_packed_yuv420p(&self) -> Vec<u8> {
let width = self.width as usize;
let height = self.height as usize;
let y_size = width * height;
let uv_size = width * height / 4;
let mut packed = Vec::with_capacity(y_size + uv_size * 2);
// Copy Y plane, removing stride padding if any
if self.y_linesize as usize == width {
packed.extend_from_slice(&self.y_plane[..y_size]);
} else {
for row in 0..height {
let src_offset = row * self.y_linesize as usize;
packed.extend_from_slice(&self.y_plane[src_offset..src_offset + width]);
}
}
// Copy U plane
let uv_width = width / 2;
let uv_height = height / 2;
if self.u_linesize as usize == uv_width {
packed.extend_from_slice(&self.u_plane[..uv_size]);
} else {
for row in 0..uv_height {
let src_offset = row * self.u_linesize as usize;
packed.extend_from_slice(&self.u_plane[src_offset..src_offset + uv_width]);
}
}
// Copy V plane
if self.v_linesize as usize == uv_width {
packed.extend_from_slice(&self.v_plane[..uv_size]);
} else {
for row in 0..uv_height {
let src_offset = row * self.v_linesize as usize;
packed.extend_from_slice(&self.v_plane[src_offset..src_offset + uv_width]);
}
}
packed
}
/// Copy packed YUV420P data to external buffer (zero allocation)
/// Returns the number of bytes written, or None if buffer too small
pub fn copy_to_packed_yuv420p(&self, dst: &mut [u8]) -> Option<usize> {
let width = self.width as usize;
let height = self.height as usize;
let y_size = width * height;
let uv_size = width * height / 4;
let total_size = y_size + uv_size * 2;
if dst.len() < total_size {
return None;
}
// Copy Y plane
if self.y_linesize as usize == width {
dst[..y_size].copy_from_slice(&self.y_plane[..y_size]);
} else {
for row in 0..height {
let src_offset = row * self.y_linesize as usize;
let dst_offset = row * width;
dst[dst_offset..dst_offset + width]
.copy_from_slice(&self.y_plane[src_offset..src_offset + width]);
}
}
// Copy U plane
let uv_width = width / 2;
let uv_height = height / 2;
if self.u_linesize as usize == uv_width {
dst[y_size..y_size + uv_size].copy_from_slice(&self.u_plane[..uv_size]);
} else {
for row in 0..uv_height {
let src_offset = row * self.u_linesize as usize;
let dst_offset = y_size + row * uv_width;
dst[dst_offset..dst_offset + uv_width]
.copy_from_slice(&self.u_plane[src_offset..src_offset + uv_width]);
}
}
// Copy V plane
let v_offset = y_size + uv_size;
if self.v_linesize as usize == uv_width {
dst[v_offset..v_offset + uv_size].copy_from_slice(&self.v_plane[..uv_size]);
} else {
for row in 0..uv_height {
let src_offset = row * self.v_linesize as usize;
let dst_offset = v_offset + row * uv_width;
dst[dst_offset..dst_offset + uv_width]
.copy_from_slice(&self.v_plane[src_offset..src_offset + uv_width]);
}
}
Some(total_size)
}
}
impl DecodedNv12Frame {
/// Get packed NV12 data (Y plane followed by UV plane, with stride removed)
pub fn to_packed_nv12(&self) -> Vec<u8> {
let width = self.width as usize;
let height = self.height as usize;
let y_size = width * height;
let uv_size = width * height / 2;
let mut packed = Vec::with_capacity(y_size + uv_size);
// Copy Y plane, removing stride padding if any
if self.y_linesize as usize == width {
// No padding, direct copy
packed.extend_from_slice(&self.y_plane[..y_size]);
} else {
// Has padding, copy row by row
for row in 0..height {
let src_offset = row * self.y_linesize as usize;
packed.extend_from_slice(&self.y_plane[src_offset..src_offset + width]);
}
}
// Copy UV plane, removing stride padding if any
let uv_height = height / 2;
if self.uv_linesize as usize == width {
// No padding, direct copy
packed.extend_from_slice(&self.uv_plane[..uv_size]);
} else {
// Has padding, copy row by row
for row in 0..uv_height {
let src_offset = row * self.uv_linesize as usize;
packed.extend_from_slice(&self.uv_plane[src_offset..src_offset + width]);
}
}
packed
}
/// Copy packed NV12 data to external buffer (zero allocation)
/// Returns the number of bytes written, or None if buffer too small
pub fn copy_to_packed_nv12(&self, dst: &mut [u8]) -> Option<usize> {
let width = self.width as usize;
let height = self.height as usize;
let y_size = width * height;
let uv_size = width * height / 2;
let total_size = y_size + uv_size;
if dst.len() < total_size {
return None;
}
// Copy Y plane, removing stride padding if any
if self.y_linesize as usize == width {
// No padding, direct copy
dst[..y_size].copy_from_slice(&self.y_plane[..y_size]);
} else {
// Has padding, copy row by row
for row in 0..height {
let src_offset = row * self.y_linesize as usize;
let dst_offset = row * width;
dst[dst_offset..dst_offset + width]
.copy_from_slice(&self.y_plane[src_offset..src_offset + width]);
}
}
// Copy UV plane, removing stride padding if any
let uv_height = height / 2;
if self.uv_linesize as usize == width {
// No padding, direct copy
dst[y_size..total_size].copy_from_slice(&self.uv_plane[..uv_size]);
} else {
// Has padding, copy row by row
for row in 0..uv_height {
let src_offset = row * self.uv_linesize as usize;
let dst_offset = y_size + row * width;
dst[dst_offset..dst_offset + width]
.copy_from_slice(&self.uv_plane[src_offset..src_offset + width]);
}
}
Some(total_size)
}
}
/// MJPEG VAAPI hardware decoder
///
/// Decodes MJPEG frames to NV12 format using VAAPI hardware acceleration.
/// This is optimal for feeding into VAAPI H264 encoder.
pub struct MjpegVaapiDecoder {
/// hwcodec decoder instance
decoder: Decoder,
/// Configuration
config: MjpegVaapiDecoderConfig,
/// Frame counter
frame_count: u64,
/// Whether hardware acceleration is active
hwaccel_active: bool,
}
impl MjpegVaapiDecoder {
/// Create a new MJPEG decoder
/// Note: VAAPI does not support MJPEG decoding on most hardware,
/// so we use software decoding and convert to NV12 for VAAPI encoding.
pub fn new(config: MjpegVaapiDecoderConfig) -> Result<Self> {
init_hwcodec_logging();
// VAAPI doesn't support MJPEG decoding, always use software decoder
// The output will be converted to NV12 for VAAPI H264 encoding
let device_type = AVHWDeviceType::AV_HWDEVICE_TYPE_NONE;
info!(
"Creating MJPEG decoder with software decoding (VAAPI doesn't support MJPEG decode)"
);
let ctx = DecodeContext {
name: "mjpeg".to_string(),
device_type,
thread_count: 4, // Use multiple threads for software decoding
};
let decoder = Decoder::new(ctx).map_err(|_| {
AppError::VideoError("Failed to create MJPEG software decoder".to_string())
})?;
// hwaccel is not actually active for MJPEG decoding
let hwaccel_active = false;
info!(
"MJPEG decoder created successfully (software decode, will convert to NV12)"
);
Ok(Self {
decoder,
config,
frame_count: 0,
hwaccel_active,
})
}
/// Create with default config (VAAPI enabled)
pub fn with_vaapi(resolution: Resolution) -> Result<Self> {
Self::new(MjpegVaapiDecoderConfig {
resolution,
use_hwaccel: true,
})
}
/// Create with software decoding (fallback)
pub fn with_software(resolution: Resolution) -> Result<Self> {
Self::new(MjpegVaapiDecoderConfig {
resolution,
use_hwaccel: false,
})
}
/// Check if hardware acceleration is active
pub fn is_hwaccel_active(&self) -> bool {
self.hwaccel_active
}
/// Decode MJPEG frame to NV12
///
/// Returns the decoded frame in NV12 format, or an error if decoding fails.
pub fn decode(&mut self, jpeg_data: &[u8]) -> Result<DecodedNv12Frame> {
if jpeg_data.len() < 2 {
return Err(AppError::VideoError("JPEG data too small".to_string()));
}
// Verify JPEG signature (FFD8)
if jpeg_data[0] != 0xFF || jpeg_data[1] != 0xD8 {
return Err(AppError::VideoError("Invalid JPEG signature".to_string()));
}
self.frame_count += 1;
let frames = self.decoder.decode(jpeg_data).map_err(|e| {
AppError::VideoError(format!("MJPEG decode failed: error code {}", e))
})?;
if frames.is_empty() {
return Err(AppError::VideoError("Decoder returned no frames".to_string()));
}
let frame = &frames[0];
// Handle different output formats
// VAAPI MJPEG decoder may output NV12, YUV420P, or YUVJ420P (JPEG full-range)
if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_NV12
|| frame.pixfmt == AVPixelFormat::AV_PIX_FMT_NV21
{
// NV12/NV21 format: Y plane + UV interleaved plane
if frame.data.len() < 2 {
return Err(AppError::VideoError("Invalid NV12 frame data".to_string()));
}
return Ok(DecodedNv12Frame {
y_plane: frame.data[0].clone(),
uv_plane: frame.data[1].clone(),
y_linesize: frame.linesize[0],
uv_linesize: frame.linesize[1],
width: frame.width,
height: frame.height,
});
}
// YUV420P or YUVJ420P (JPEG full-range) - need to convert to NV12
if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUV420P
|| frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUVJ420P
{
return Self::convert_yuv420p_to_nv12_static(frame);
}
// YUV422P or YUVJ422P (JPEG full-range 4:2:2) - need to convert to NV12
if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUV422P
|| frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUVJ422P
{
return Self::convert_yuv422p_to_nv12_static(frame);
}
Err(AppError::VideoError(format!(
"Unexpected output format: {:?} (expected NV12, YUV420P, YUV422P, or YUVJ variants)",
frame.pixfmt
)))
}
/// Convert YUV420P frame to NV12 format using libyuv (SIMD accelerated)
fn convert_yuv420p_to_nv12_static(frame: &DecodeFrame) -> Result<DecodedNv12Frame> {
if frame.data.len() < 3 {
return Err(AppError::VideoError("Invalid YUV420P frame data".to_string()));
}
let width = frame.width as i32;
let height = frame.height as i32;
let y_linesize = frame.linesize[0];
let u_linesize = frame.linesize[1];
let v_linesize = frame.linesize[2];
// Allocate packed NV12 output buffer
let nv12_size = (width * height * 3 / 2) as usize;
let mut nv12_data = vec![0u8; nv12_size];
// Use libyuv for SIMD-accelerated I420 → NV12 conversion
libyuv::i420_to_nv12_planar(
&frame.data[0], y_linesize,
&frame.data[1], u_linesize,
&frame.data[2], v_linesize,
&mut nv12_data,
width, height,
).map_err(|e| AppError::VideoError(format!("libyuv I420→NV12 failed: {}", e)))?;
// Split into Y and UV planes for DecodedNv12Frame
let y_size = (width * height) as usize;
let y_plane = nv12_data[..y_size].to_vec();
let uv_plane = nv12_data[y_size..].to_vec();
Ok(DecodedNv12Frame {
y_plane,
uv_plane,
y_linesize: width, // Output is packed, no padding
uv_linesize: width,
width: frame.width,
height: frame.height,
})
}
/// Convert YUV422P frame to NV12 format using libyuv (SIMD accelerated)
/// Pipeline: I422 (YUV422P) → I420 → NV12
fn convert_yuv422p_to_nv12_static(frame: &DecodeFrame) -> Result<DecodedNv12Frame> {
if frame.data.len() < 3 {
return Err(AppError::VideoError("Invalid YUV422P frame data".to_string()));
}
let width = frame.width as i32;
let height = frame.height as i32;
let y_linesize = frame.linesize[0];
let u_linesize = frame.linesize[1];
let v_linesize = frame.linesize[2];
// Step 1: I422 → I420 (vertical chroma downsampling via SIMD)
let i420_size = (width * height * 3 / 2) as usize;
let mut i420_data = vec![0u8; i420_size];
libyuv::i422_to_i420_planar(
&frame.data[0], y_linesize,
&frame.data[1], u_linesize,
&frame.data[2], v_linesize,
&mut i420_data,
width, height,
).map_err(|e| AppError::VideoError(format!("libyuv I422→I420 failed: {}", e)))?;
// Step 2: I420 → NV12 (UV interleaving via SIMD)
let nv12_size = (width * height * 3 / 2) as usize;
let mut nv12_data = vec![0u8; nv12_size];
libyuv::i420_to_nv12(&i420_data, &mut nv12_data, width, height)
.map_err(|e| AppError::VideoError(format!("libyuv I420→NV12 failed: {}", e)))?;
// Split into Y and UV planes for DecodedNv12Frame
let y_size = (width * height) as usize;
let y_plane = nv12_data[..y_size].to_vec();
let uv_plane = nv12_data[y_size..].to_vec();
Ok(DecodedNv12Frame {
y_plane,
uv_plane,
y_linesize: width,
uv_linesize: width,
width: frame.width,
height: frame.height,
})
}
/// Get frame count
pub fn frame_count(&self) -> u64 {
self.frame_count
}
/// Get current resolution from config
pub fn resolution(&self) -> Resolution {
self.config.resolution
}
}
/// Libyuv-based MJPEG decoder for direct YUV420P output
///
/// This decoder is optimized for software encoders (libvpx, libx265) that need YUV420P input.
/// It uses libyuv's MJPGToI420 to decode directly to I420/YUV420P format.
pub struct MjpegTurboDecoder {
/// Frame counter
frame_count: u64,
}
impl MjpegTurboDecoder {
/// Create a new libyuv-based MJPEG decoder
pub fn new(resolution: Resolution) -> Result<Self> {
info!(
"Created libyuv MJPEG decoder for {}x{} (direct YUV420P output)",
resolution.width, resolution.height
);
Ok(Self {
frame_count: 0,
})
}
/// Decode MJPEG frame directly to YUV420P using libyuv
///
/// This is the optimal path for software encoders that need YUV420P input.
/// libyuv handles all JPEG subsampling formats internally.
pub fn decode_to_yuv420p(&mut self, jpeg_data: &[u8]) -> Result<DecodedYuv420pFrame> {
if jpeg_data.len() < 2 || jpeg_data[0] != 0xFF || jpeg_data[1] != 0xD8 {
return Err(AppError::VideoError("Invalid JPEG data".to_string()));
}
self.frame_count += 1;
// Get JPEG dimensions
let (width, height) = libyuv::mjpeg_size(jpeg_data)
.map_err(|e| AppError::VideoError(format!("Failed to read MJPEG size: {}", e)))?;
let y_size = (width * height) as usize;
let uv_size = y_size / 4;
let yuv420_size = y_size + uv_size * 2;
let mut yuv_data = vec![0u8; yuv420_size];
libyuv::mjpeg_to_i420(jpeg_data, &mut yuv_data, width, height)
.map_err(|e| AppError::VideoError(format!("libyuv MJPEG→I420 failed: {}", e)))?;
Ok(DecodedYuv420pFrame {
y_plane: yuv_data[..y_size].to_vec(),
u_plane: yuv_data[y_size..y_size + uv_size].to_vec(),
v_plane: yuv_data[y_size + uv_size..].to_vec(),
y_linesize: width,
u_linesize: width / 2,
v_linesize: width / 2,
width,
height,
})
}
/// Decode directly to packed YUV420P buffer using libyuv
///
/// This uses libyuv's MJPGToI420 which handles all JPEG subsampling formats
/// and converts to I420 directly.
pub fn decode_to_yuv420p_buffer(&mut self, jpeg_data: &[u8], dst: &mut [u8]) -> Result<usize> {
if jpeg_data.len() < 2 || jpeg_data[0] != 0xFF || jpeg_data[1] != 0xD8 {
return Err(AppError::VideoError("Invalid JPEG data".to_string()));
}
self.frame_count += 1;
// Get JPEG dimensions from libyuv
let (width, height) = libyuv::mjpeg_size(jpeg_data)
.map_err(|e| AppError::VideoError(format!("Failed to read MJPEG size: {}", e)))?;
let yuv420_size = (width * height * 3 / 2) as usize;
if dst.len() < yuv420_size {
return Err(AppError::VideoError(format!(
"Buffer too small: {} < {}", dst.len(), yuv420_size
)));
}
// Decode MJPEG directly to I420 using libyuv
// libyuv handles all JPEG subsampling formats (4:2:0, 4:2:2, 4:4:4) internally
libyuv::mjpeg_to_i420(jpeg_data, &mut dst[..yuv420_size], width, height)
.map_err(|e| AppError::VideoError(format!("libyuv MJPEG→I420 failed: {}", e)))?;
Ok(yuv420_size)
}
/// Get frame count
pub fn frame_count(&self) -> u64 {
self.frame_count
}
}
/// Check if MJPEG VAAPI decoder is available
pub fn is_mjpeg_vaapi_available() -> bool {
let ctx = DecodeContext {
name: "mjpeg".to_string(),
device_type: AVHWDeviceType::AV_HWDEVICE_TYPE_VAAPI,
thread_count: 1,
};
match Decoder::new(ctx) {
Ok(_) => {
info!("MJPEG VAAPI decoder is available");
true
}
Err(_) => {
warn!("MJPEG VAAPI decoder is not available");
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mjpeg_vaapi_availability() {
let available = is_mjpeg_vaapi_available();
println!("MJPEG VAAPI available: {}", available);
}
#[test]
fn test_decoder_creation() {
let config = MjpegVaapiDecoderConfig::default();
match MjpegVaapiDecoder::new(config) {
Ok(decoder) => {
println!("Decoder created, hwaccel: {}", decoder.is_hwaccel_active());
}
Err(e) => {
println!("Failed to create decoder: {}", e);
}
}
}
}

11
src/video/decoder/mod.rs Normal file
View File

@@ -0,0 +1,11 @@
//! Video decoder implementations
//!
//! This module provides video decoding capabilities including:
//! - MJPEG VAAPI hardware decoding (outputs NV12)
//! - MJPEG turbojpeg decoding (outputs YUV420P directly)
pub mod mjpeg;
pub use mjpeg::{
DecodedYuv420pFrame, MjpegTurboDecoder, MjpegVaapiDecoder, MjpegVaapiDecoderConfig,
};

459
src/video/device.rs Normal file
View File

@@ -0,0 +1,459 @@
//! V4L2 device enumeration and capability query
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tracing::{debug, info, warn};
use v4l::capability::Flags;
use v4l::prelude::*;
use v4l::video::Capture;
use v4l::Format;
use v4l::FourCC;
use super::format::{PixelFormat, Resolution};
use crate::error::{AppError, Result};
/// Information about a video device
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoDeviceInfo {
/// Device path (e.g., /dev/video0)
pub path: PathBuf,
/// Device name from driver
pub name: String,
/// Driver name
pub driver: String,
/// Bus info
pub bus_info: String,
/// Card name
pub card: String,
/// Supported pixel formats
pub formats: Vec<FormatInfo>,
/// Device capabilities
pub capabilities: DeviceCapabilities,
/// Whether this is likely an HDMI capture card
pub is_capture_card: bool,
/// Priority score for device selection (higher is better)
pub priority: u32,
}
/// Information about a supported format
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FormatInfo {
/// Pixel format
pub format: PixelFormat,
/// Supported resolutions
pub resolutions: Vec<ResolutionInfo>,
/// Description from driver
pub description: String,
}
/// Information about a supported resolution and frame rates
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResolutionInfo {
pub width: u32,
pub height: u32,
pub fps: Vec<u32>,
}
impl ResolutionInfo {
pub fn new(width: u32, height: u32, fps: Vec<u32>) -> Self {
Self { width, height, fps }
}
pub fn resolution(&self) -> Resolution {
Resolution::new(self.width, self.height)
}
}
/// Device capabilities
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DeviceCapabilities {
pub video_capture: bool,
pub video_capture_mplane: bool,
pub video_output: bool,
pub streaming: bool,
pub read_write: bool,
}
/// Wrapper around a V4L2 video device
pub struct VideoDevice {
pub path: PathBuf,
device: Device,
}
impl VideoDevice {
/// Open a video device by path
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
debug!("Opening video device: {:?}", path);
let device = Device::with_path(&path).map_err(|e| {
AppError::VideoError(format!("Failed to open device {:?}: {}", path, e))
})?;
Ok(Self { path, device })
}
/// Get device capabilities
pub fn capabilities(&self) -> Result<DeviceCapabilities> {
let caps = self.device.query_caps().map_err(|e| {
AppError::VideoError(format!("Failed to query capabilities: {}", e))
})?;
Ok(DeviceCapabilities {
video_capture: caps.capabilities.contains(Flags::VIDEO_CAPTURE),
video_capture_mplane: caps.capabilities.contains(Flags::VIDEO_CAPTURE_MPLANE),
video_output: caps.capabilities.contains(Flags::VIDEO_OUTPUT),
streaming: caps.capabilities.contains(Flags::STREAMING),
read_write: caps.capabilities.contains(Flags::READ_WRITE),
})
}
/// Get detailed device information
pub fn info(&self) -> Result<VideoDeviceInfo> {
let caps = self.device.query_caps().map_err(|e| {
AppError::VideoError(format!("Failed to query capabilities: {}", e))
})?;
let capabilities = DeviceCapabilities {
video_capture: caps.capabilities.contains(Flags::VIDEO_CAPTURE),
video_capture_mplane: caps.capabilities.contains(Flags::VIDEO_CAPTURE_MPLANE),
video_output: caps.capabilities.contains(Flags::VIDEO_OUTPUT),
streaming: caps.capabilities.contains(Flags::STREAMING),
read_write: caps.capabilities.contains(Flags::READ_WRITE),
};
let formats = self.enumerate_formats()?;
// Determine if this is likely an HDMI capture card
let is_capture_card = Self::detect_capture_card(&caps.card, &caps.driver, &formats);
// Calculate priority score
let priority = Self::calculate_priority(&caps.card, &caps.driver, &formats, is_capture_card);
Ok(VideoDeviceInfo {
path: self.path.clone(),
name: caps.card.clone(),
driver: caps.driver.clone(),
bus_info: caps.bus.clone(),
card: caps.card,
formats,
capabilities,
is_capture_card,
priority,
})
}
/// Enumerate supported formats
pub fn enumerate_formats(&self) -> Result<Vec<FormatInfo>> {
let mut formats = Vec::new();
// Get supported formats
let format_descs = self.device.enum_formats().map_err(|e| {
AppError::VideoError(format!("Failed to enumerate formats: {}", e))
})?;
for desc in format_descs {
// Try to convert FourCC to our PixelFormat
if let Some(format) = PixelFormat::from_fourcc(desc.fourcc) {
let resolutions = self.enumerate_resolutions(desc.fourcc)?;
formats.push(FormatInfo {
format,
resolutions,
description: desc.description.clone(),
});
} else {
debug!(
"Skipping unsupported format: {:?} ({})",
desc.fourcc, desc.description
);
}
}
// Sort by format priority (MJPEG first)
formats.sort_by(|a, b| b.format.priority().cmp(&a.format.priority()));
Ok(formats)
}
/// Enumerate resolutions for a specific format
fn enumerate_resolutions(&self, fourcc: FourCC) -> Result<Vec<ResolutionInfo>> {
let mut resolutions = Vec::new();
// Try to enumerate frame sizes
match self.device.enum_framesizes(fourcc) {
Ok(sizes) => {
for size in sizes {
match size.size {
v4l::framesize::FrameSizeEnum::Discrete(d) => {
let fps = self.enumerate_fps(fourcc, d.width, d.height).unwrap_or_default();
resolutions.push(ResolutionInfo::new(d.width, d.height, fps));
}
v4l::framesize::FrameSizeEnum::Stepwise(s) => {
// For stepwise, add some common resolutions
for res in [
Resolution::VGA,
Resolution::HD720,
Resolution::HD1080,
Resolution::UHD4K,
] {
if res.width >= s.min_width
&& res.width <= s.max_width
&& res.height >= s.min_height
&& res.height <= s.max_height
{
let fps = self.enumerate_fps(fourcc, res.width, res.height).unwrap_or_default();
resolutions.push(ResolutionInfo::new(res.width, res.height, fps));
}
}
}
}
}
}
Err(e) => {
debug!("Failed to enumerate frame sizes for {:?}: {}", fourcc, e);
}
}
// Sort by resolution (largest first)
resolutions.sort_by(|a, b| (b.width * b.height).cmp(&(a.width * a.height)));
resolutions.dedup_by(|a, b| a.width == b.width && a.height == b.height);
Ok(resolutions)
}
/// Enumerate FPS for a specific resolution
fn enumerate_fps(&self, fourcc: FourCC, width: u32, height: u32) -> Result<Vec<u32>> {
let mut fps_list = Vec::new();
match self.device.enum_frameintervals(fourcc, width, height) {
Ok(intervals) => {
for interval in intervals {
match interval.interval {
v4l::frameinterval::FrameIntervalEnum::Discrete(fraction) => {
if fraction.numerator > 0 {
let fps = fraction.denominator / fraction.numerator;
fps_list.push(fps);
}
}
v4l::frameinterval::FrameIntervalEnum::Stepwise(step) => {
// Just pick max/min/step
if step.max.numerator > 0 {
let min_fps = step.max.denominator / step.max.numerator;
let max_fps = step.min.denominator / step.min.numerator;
fps_list.push(min_fps);
if max_fps != min_fps {
fps_list.push(max_fps);
}
}
}
}
}
}
Err(_) => {
// If enumeration fails, assume 30fps
fps_list.push(30);
}
}
fps_list.sort_by(|a, b| b.cmp(a));
fps_list.dedup();
Ok(fps_list)
}
/// Get current format
pub fn get_format(&self) -> Result<Format> {
self.device.format().map_err(|e| {
AppError::VideoError(format!("Failed to get format: {}", e))
})
}
/// Set capture format
pub fn set_format(&self, width: u32, height: u32, format: PixelFormat) -> Result<Format> {
let fmt = Format::new(width, height, format.to_fourcc());
// Request the format
let actual = self.device.set_format(&fmt).map_err(|e| {
AppError::VideoError(format!("Failed to set format: {}", e))
})?;
if actual.width != width || actual.height != height {
warn!(
"Requested {}x{}, got {}x{}",
width, height, actual.width, actual.height
);
}
Ok(actual)
}
/// Detect if device is likely an HDMI capture card
fn detect_capture_card(card: &str, driver: &str, formats: &[FormatInfo]) -> bool {
let card_lower = card.to_lowercase();
let driver_lower = driver.to_lowercase();
// Known capture card patterns
let capture_patterns = [
"hdmi",
"capture",
"grabber",
"usb3",
"ms2109",
"ms2130",
"macrosilicon",
"tc358743",
"uvc",
];
// Check card/driver names
for pattern in capture_patterns {
if card_lower.contains(pattern) || driver_lower.contains(pattern) {
return true;
}
}
// Capture cards usually support MJPEG and high resolutions
let has_mjpeg = formats.iter().any(|f| f.format == PixelFormat::Mjpeg);
let has_1080p = formats.iter().any(|f| {
f.resolutions
.iter()
.any(|r| r.width >= 1920 && r.height >= 1080)
});
has_mjpeg && has_1080p
}
/// Calculate device priority for selection
fn calculate_priority(
_card: &str,
driver: &str,
formats: &[FormatInfo],
is_capture_card: bool,
) -> u32 {
let mut priority = 0u32;
// Capture cards get highest priority
if is_capture_card {
priority += 1000;
}
// MJPEG support is valuable
if formats.iter().any(|f| f.format == PixelFormat::Mjpeg) {
priority += 100;
}
// High resolution support
let max_resolution = formats
.iter()
.flat_map(|f| &f.resolutions)
.map(|r| r.width * r.height)
.max()
.unwrap_or(0);
priority += (max_resolution / 100000) as u32;
// Known good drivers get bonus
let good_drivers = ["uvcvideo", "tc358743"];
if good_drivers.iter().any(|d| driver.contains(d)) {
priority += 50;
}
priority
}
/// Get the inner device reference (for advanced usage)
pub fn inner(&self) -> &Device {
&self.device
}
}
/// Enumerate all video capture devices
pub fn enumerate_devices() -> Result<Vec<VideoDeviceInfo>> {
info!("Enumerating video devices...");
let mut devices = Vec::new();
// Scan /dev/video* devices
for entry in std::fs::read_dir("/dev").map_err(|e| {
AppError::VideoError(format!("Failed to read /dev: {}", e))
})? {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let path = entry.path();
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if !name.starts_with("video") {
continue;
}
debug!("Found video device: {:?}", path);
// Try to open and query the device
match VideoDevice::open(&path) {
Ok(device) => {
match device.info() {
Ok(info) => {
// Only include devices with video capture capability
if info.capabilities.video_capture || info.capabilities.video_capture_mplane
{
info!(
"Found capture device: {} ({}) - {} formats",
info.name,
info.driver,
info.formats.len()
);
devices.push(info);
} else {
debug!("Skipping non-capture device: {:?}", path);
}
}
Err(e) => {
debug!("Failed to get info for {:?}: {}", path, e);
}
}
}
Err(e) => {
debug!("Failed to open {:?}: {}", path, e);
}
}
}
// Sort by priority (highest first)
devices.sort_by(|a, b| b.priority.cmp(&a.priority));
info!("Found {} video capture devices", devices.len());
Ok(devices)
}
/// Find the best video device for KVM use
pub fn find_best_device() -> Result<VideoDeviceInfo> {
let devices = enumerate_devices()?;
devices.into_iter().next().ok_or_else(|| {
AppError::VideoError("No video capture devices found".to_string())
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pixel_format_conversion() {
let format = PixelFormat::Mjpeg;
let fourcc = format.to_fourcc();
let back = PixelFormat::from_fourcc(fourcc);
assert_eq!(back, Some(format));
}
#[test]
fn test_resolution() {
let res = Resolution::HD1080;
assert_eq!(res.width, 1920);
assert_eq!(res.height, 1080);
assert!(res.is_valid());
}
}

370
src/video/encoder/codec.rs Normal file
View File

@@ -0,0 +1,370 @@
//! WebRTC Video Codec abstraction layer
//!
//! This module provides a unified interface for video codecs used in WebRTC streaming.
//! It supports multiple codec types (H264, VP8, VP9, H265) with a common API.
//!
//! # Architecture
//!
//! ```text
//! VideoCodec (trait)
//! |
//! +-- H264Codec (current implementation)
//! +-- VP8Codec (reserved)
//! +-- VP9Codec (reserved)
//! +-- H265Codec (reserved)
//! ```
use bytes::Bytes;
use std::time::Duration;
use crate::error::Result;
use crate::video::format::Resolution;
/// Supported video codec types for WebRTC
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum VideoCodecType {
/// H.264/AVC - widely supported, good compression
H264,
/// VP8 - royalty-free, good browser support
VP8,
/// VP9 - better compression than VP8
VP9,
/// H.265/HEVC - best compression, limited browser support
H265,
}
impl VideoCodecType {
/// Get the codec name for SDP
pub fn sdp_name(&self) -> &'static str {
match self {
VideoCodecType::H264 => "H264",
VideoCodecType::VP8 => "VP8",
VideoCodecType::VP9 => "VP9",
VideoCodecType::H265 => "H265",
}
}
/// Get the default RTP payload type
pub fn default_payload_type(&self) -> u8 {
match self {
VideoCodecType::H264 => 96,
VideoCodecType::VP8 => 97,
VideoCodecType::VP9 => 98,
VideoCodecType::H265 => 99,
}
}
/// Get the RTP clock rate (always 90000 for video)
pub fn clock_rate(&self) -> u32 {
90000
}
/// Get the MIME type
pub fn mime_type(&self) -> &'static str {
match self {
VideoCodecType::H264 => "video/H264",
VideoCodecType::VP8 => "video/VP8",
VideoCodecType::VP9 => "video/VP9",
VideoCodecType::H265 => "video/H265",
}
}
}
impl std::fmt::Display for VideoCodecType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.sdp_name())
}
}
/// Encoded video frame for WebRTC transmission
#[derive(Debug, Clone)]
pub struct CodecFrame {
/// Encoded data (Annex B format for H264/H265, raw for VP8/VP9)
pub data: Bytes,
/// Presentation timestamp in milliseconds
pub pts_ms: i64,
/// Whether this is a keyframe (IDR for H264, key frame for VP8/VP9)
pub is_keyframe: bool,
/// Codec type
pub codec: VideoCodecType,
/// Frame sequence number
pub sequence: u64,
/// Frame duration
pub duration: Duration,
}
impl CodecFrame {
/// Create a new H264 frame
pub fn h264(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
Self {
data,
pts_ms,
is_keyframe,
codec: VideoCodecType::H264,
sequence,
duration: Duration::from_millis(1000 / fps as u64),
}
}
/// Create a new VP8 frame
pub fn vp8(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
Self {
data,
pts_ms,
is_keyframe,
codec: VideoCodecType::VP8,
sequence,
duration: Duration::from_millis(1000 / fps as u64),
}
}
/// Create a new VP9 frame
pub fn vp9(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
Self {
data,
pts_ms,
is_keyframe,
codec: VideoCodecType::VP9,
sequence,
duration: Duration::from_millis(1000 / fps as u64),
}
}
/// Create a new H265 frame
pub fn h265(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
Self {
data,
pts_ms,
is_keyframe,
codec: VideoCodecType::H265,
sequence,
duration: Duration::from_millis(1000 / fps as u64),
}
}
/// Get frame size in bytes
pub fn len(&self) -> usize {
self.data.len()
}
/// Check if frame is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
/// Video codec configuration
#[derive(Debug, Clone)]
pub struct VideoCodecConfig {
/// Codec type
pub codec: VideoCodecType,
/// Target resolution
pub resolution: Resolution,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// Target FPS
pub fps: u32,
/// GOP size (keyframe interval in frames)
pub gop_size: u32,
/// Profile (codec-specific)
pub profile: Option<String>,
/// Level (codec-specific)
pub level: Option<String>,
}
impl Default for VideoCodecConfig {
fn default() -> Self {
Self {
codec: VideoCodecType::H264,
resolution: Resolution::HD720,
bitrate_kbps: 8000,
fps: 30,
gop_size: 30,
profile: None,
level: None,
}
}
}
impl VideoCodecConfig {
/// Create H264 config with common settings
pub fn h264(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
Self {
codec: VideoCodecType::H264,
resolution,
bitrate_kbps,
fps,
gop_size: fps, // 1 second GOP
profile: Some("baseline".to_string()),
level: Some("3.1".to_string()),
}
}
/// Create VP8 config
pub fn vp8(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
Self {
codec: VideoCodecType::VP8,
resolution,
bitrate_kbps,
fps,
gop_size: fps,
profile: None,
level: None,
}
}
/// Create VP9 config
pub fn vp9(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
Self {
codec: VideoCodecType::VP9,
resolution,
bitrate_kbps,
fps,
gop_size: fps,
profile: None,
level: None,
}
}
/// Create H265 config
pub fn h265(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
Self {
codec: VideoCodecType::H265,
resolution,
bitrate_kbps,
fps,
gop_size: fps,
profile: Some("main".to_string()),
level: Some("4.0".to_string()),
}
}
}
/// WebRTC video codec trait
///
/// This trait defines the interface for video codecs used in WebRTC streaming.
/// Implementations should handle format conversion internally if needed.
pub trait VideoCodec: Send {
/// Get codec type
fn codec_type(&self) -> VideoCodecType;
/// Get codec name for display
fn codec_name(&self) -> &'static str;
/// Get RTP payload type
fn payload_type(&self) -> u8 {
self.codec_type().default_payload_type()
}
/// Get SDP fmtp parameters (codec-specific)
///
/// For H264: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f"
/// For VP8/VP9: None or empty
fn sdp_fmtp(&self) -> Option<String>;
/// Encode a raw frame (NV12 format expected)
///
/// # Arguments
/// * `frame` - Raw frame data in NV12 format
/// * `pts_ms` - Presentation timestamp in milliseconds
///
/// # Returns
/// * `Ok(Some(frame))` - Encoded frame
/// * `Ok(None)` - Encoder is buffering (no output yet)
/// * `Err(e)` - Encoding error
fn encode(&mut self, frame: &[u8], pts_ms: i64) -> Result<Option<CodecFrame>>;
/// Set target bitrate dynamically
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()>;
/// Request a keyframe on next encode
fn request_keyframe(&mut self);
/// Get current configuration
fn config(&self) -> &VideoCodecConfig;
/// Flush any pending frames
fn flush(&mut self) -> Result<Vec<CodecFrame>> {
Ok(vec![])
}
/// Reset encoder state
fn reset(&mut self) -> Result<()> {
Ok(())
}
}
/// Video codec factory trait
///
/// Used to create codec instances and query available codecs.
pub trait VideoCodecFactory: Send + Sync {
/// Create a codec with the given configuration
fn create(&self, config: VideoCodecConfig) -> Result<Box<dyn VideoCodec>>;
/// Get supported codec types
fn supported_codecs(&self) -> Vec<VideoCodecType>;
/// Check if a specific codec is available
fn is_codec_available(&self, codec: VideoCodecType) -> bool {
self.supported_codecs().contains(&codec)
}
/// Get the best available codec (based on priority)
fn best_codec(&self) -> Option<VideoCodecType> {
// Priority: H264 > VP8 > VP9 > H265
let supported = self.supported_codecs();
if supported.contains(&VideoCodecType::H264) {
Some(VideoCodecType::H264)
} else if supported.contains(&VideoCodecType::VP8) {
Some(VideoCodecType::VP8)
} else if supported.contains(&VideoCodecType::VP9) {
Some(VideoCodecType::VP9)
} else if supported.contains(&VideoCodecType::H265) {
Some(VideoCodecType::H265)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_codec_type_properties() {
assert_eq!(VideoCodecType::H264.sdp_name(), "H264");
assert_eq!(VideoCodecType::H264.default_payload_type(), 96);
assert_eq!(VideoCodecType::H264.clock_rate(), 90000);
assert_eq!(VideoCodecType::H264.mime_type(), "video/H264");
}
#[test]
fn test_codec_frame_creation() {
let data = Bytes::from(vec![0x00, 0x00, 0x00, 0x01, 0x65]);
let frame = CodecFrame::h264(data.clone(), 1000, true, 1, 30);
assert_eq!(frame.codec, VideoCodecType::H264);
assert!(frame.is_keyframe);
assert_eq!(frame.pts_ms, 1000);
assert_eq!(frame.sequence, 1);
assert_eq!(frame.len(), 5);
}
#[test]
fn test_codec_config_default() {
let config = VideoCodecConfig::default();
assert_eq!(config.codec, VideoCodecType::H264);
assert_eq!(config.bitrate_kbps, 2000);
assert_eq!(config.fps, 30);
}
#[test]
fn test_codec_config_h264() {
let config = VideoCodecConfig::h264(Resolution::HD1080, 4000, 60);
assert_eq!(config.codec, VideoCodecType::H264);
assert_eq!(config.bitrate_kbps, 4000);
assert_eq!(config.fps, 60);
assert_eq!(config.gop_size, 60);
}
}

532
src/video/encoder/h264.rs Normal file
View File

@@ -0,0 +1,532 @@
//! H.264 encoder using hwcodec (rustdesk's FFmpeg wrapper)
//!
//! Supports multiple encoder backends via FFmpeg:
//! - VAAPI (Intel/AMD/NVIDIA on Linux)
//! - NVENC (NVIDIA)
//! - AMF (AMD)
//! - Software (libx264)
//!
//! The encoder is selected automatically based on availability.
use bytes::Bytes;
use std::sync::Once;
use tracing::{debug, error, info, warn};
use hwcodec::common::{Quality, RateControl};
use hwcodec::ffmpeg::AVPixelFormat;
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
use hwcodec::ffmpeg_ram::CodecInfo;
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
use crate::error::{AppError, Result};
use crate::video::format::{PixelFormat, Resolution};
static INIT_LOGGING: Once = Once::new();
/// Initialize hwcodec logging (only once)
fn init_hwcodec_logging() {
INIT_LOGGING.call_once(|| {
// hwcodec uses the `log` crate, which will work with our tracing subscriber
debug!("hwcodec logging initialized");
});
}
/// H.264 encoder type (detected from hwcodec)
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum H264EncoderType {
/// NVIDIA NVENC
Nvenc,
/// Intel Quick Sync (QSV)
Qsv,
/// AMD AMF
Amf,
/// VAAPI (Linux generic)
Vaapi,
/// RKMPP (Rockchip) - requires hwcodec extension
Rkmpp,
/// V4L2 M2M (ARM generic) - requires hwcodec extension
V4l2M2m,
/// Software encoding (libx264/openh264)
Software,
/// No encoder available
None,
}
impl std::fmt::Display for H264EncoderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
H264EncoderType::Nvenc => write!(f, "NVENC"),
H264EncoderType::Qsv => write!(f, "QSV"),
H264EncoderType::Amf => write!(f, "AMF"),
H264EncoderType::Vaapi => write!(f, "VAAPI"),
H264EncoderType::Rkmpp => write!(f, "RKMPP"),
H264EncoderType::V4l2M2m => write!(f, "V4L2 M2M"),
H264EncoderType::Software => write!(f, "Software"),
H264EncoderType::None => write!(f, "None"),
}
}
}
impl Default for H264EncoderType {
fn default() -> Self {
Self::None
}
}
/// Map codec name to encoder type
fn codec_name_to_type(name: &str) -> H264EncoderType {
if name.contains("nvenc") {
H264EncoderType::Nvenc
} else if name.contains("qsv") {
H264EncoderType::Qsv
} else if name.contains("amf") {
H264EncoderType::Amf
} else if name.contains("vaapi") {
H264EncoderType::Vaapi
} else if name.contains("rkmpp") {
H264EncoderType::Rkmpp
} else if name.contains("v4l2m2m") {
H264EncoderType::V4l2M2m
} else {
H264EncoderType::Software
}
}
/// Input pixel format for H264 encoder
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum H264InputFormat {
/// YUV420P (I420) - planar Y, U, V
Yuv420p,
/// NV12 - Y plane + interleaved UV plane (optimal for VAAPI)
Nv12,
}
impl Default for H264InputFormat {
fn default() -> Self {
Self::Nv12 // Default to NV12 for VAAPI compatibility
}
}
/// H.264 encoder configuration
#[derive(Debug, Clone)]
pub struct H264Config {
/// Base encoder config
pub base: EncoderConfig,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// GOP size (keyframe interval)
pub gop_size: u32,
/// Frame rate
pub fps: u32,
/// Input pixel format
pub input_format: H264InputFormat,
}
impl Default for H264Config {
fn default() -> Self {
Self {
base: EncoderConfig::default(),
bitrate_kbps: 8000,
gop_size: 30,
fps: 30,
input_format: H264InputFormat::Nv12,
}
}
}
impl H264Config {
/// Create config for low latency streaming with NV12 input (optimal for VAAPI)
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig::h264(resolution, bitrate_kbps),
bitrate_kbps,
gop_size: 30,
fps: 30,
input_format: H264InputFormat::Nv12,
}
}
/// Create config for low latency streaming with YUV420P input
pub fn low_latency_yuv420p(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig::h264(resolution, bitrate_kbps),
bitrate_kbps,
gop_size: 30,
fps: 30,
input_format: H264InputFormat::Yuv420p,
}
}
/// Create config for quality streaming
pub fn quality(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig::h264(resolution, bitrate_kbps),
bitrate_kbps,
gop_size: 60,
fps: 30,
input_format: H264InputFormat::Nv12,
}
}
/// Set input format
pub fn with_input_format(mut self, format: H264InputFormat) -> Self {
self.input_format = format;
self
}
}
/// Get available H264 encoders from hwcodec
pub fn get_available_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
init_hwcodec_logging();
let ctx = EncodeContext {
name: String::new(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt: AVPixelFormat::AV_PIX_FMT_YUV420P,
align: 1,
fps: 30,
gop: 30,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (ultrafast)
kbs: 2000,
q: 23,
thread_count: 4,
};
HwEncoder::available_encoders(ctx, None)
}
/// Detect best available H.264 encoder
pub fn detect_best_encoder(width: u32, height: u32) -> (H264EncoderType, Option<String>) {
let encoders = get_available_encoders(width, height);
if encoders.is_empty() {
warn!("No H.264 encoders available from hwcodec");
return (H264EncoderType::None, None);
}
// Find H264 encoder (not H265)
for codec in &encoders {
if codec.format == hwcodec::common::DataFormat::H264 {
let encoder_type = codec_name_to_type(&codec.name);
info!("Best H.264 encoder: {} ({})", codec.name, encoder_type);
return (encoder_type, Some(codec.name.clone()));
}
}
(H264EncoderType::None, None)
}
/// Encoded frame from hwcodec (cloned for ownership)
#[derive(Debug, Clone)]
pub struct HwEncodeFrame {
pub data: Vec<u8>,
pub pts: i64,
pub key: i32,
}
/// H.264 encoder using hwcodec
pub struct H264Encoder {
/// hwcodec encoder instance
inner: HwEncoder,
/// Encoder configuration
config: H264Config,
/// Detected encoder type
encoder_type: H264EncoderType,
/// Codec name
codec_name: String,
/// Frame counter
frame_count: u64,
/// YUV420P buffer for input (reserved for future use)
#[allow(dead_code)]
yuv_buffer: Vec<u8>,
/// Required YUV buffer length from hwcodec
yuv_length: i32,
}
impl H264Encoder {
/// Create a new H.264 encoder with automatic codec detection
pub fn new(config: H264Config) -> Result<Self> {
init_hwcodec_logging();
let width = config.base.resolution.width;
let height = config.base.resolution.height;
// Detect best encoder
let (_encoder_type, codec_name) = detect_best_encoder(width, height);
let codec_name = codec_name.ok_or_else(|| {
AppError::VideoError("No H.264 encoder available".to_string())
})?;
Self::with_codec(config, &codec_name)
}
/// Create encoder with specific codec name
pub fn with_codec(config: H264Config, codec_name: &str) -> Result<Self> {
init_hwcodec_logging();
let width = config.base.resolution.width;
let height = config.base.resolution.height;
// Select pixel format based on config
let pixfmt = match config.input_format {
H264InputFormat::Nv12 => AVPixelFormat::AV_PIX_FMT_NV12,
H264InputFormat::Yuv420p => AVPixelFormat::AV_PIX_FMT_YUV420P,
};
info!(
"Creating H.264 encoder: {} at {}x{} @ {} kbps (input: {:?})",
codec_name, width, height, config.bitrate_kbps, config.input_format
);
let ctx = EncodeContext {
name: codec_name.to_string(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt,
align: 1,
fps: config.fps as i32,
gop: config.gop_size as i32,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (lowest latency)
kbs: config.bitrate_kbps as i32,
q: 23,
thread_count: 4, // Use 4 threads for better performance
};
let inner = HwEncoder::new(ctx).map_err(|_| {
AppError::VideoError(format!("Failed to create encoder: {}", codec_name))
})?;
let yuv_length = inner.length;
let encoder_type = codec_name_to_type(codec_name);
info!(
"H.264 encoder created: {} (type: {}, buffer_length: {}, input_format: {:?})",
codec_name, encoder_type, yuv_length, config.input_format
);
Ok(Self {
inner,
config,
encoder_type,
codec_name: codec_name.to_string(),
frame_count: 0,
yuv_buffer: vec![0u8; yuv_length as usize],
yuv_length,
})
}
/// Create with auto-detected encoder
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
let config = H264Config::low_latency(resolution, bitrate_kbps);
Self::new(config)
}
/// Get encoder type
pub fn encoder_type(&self) -> &H264EncoderType {
&self.encoder_type
}
/// Get codec name
pub fn codec_name(&self) -> &str {
&self.codec_name
}
/// Update bitrate dynamically
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
AppError::VideoError("Failed to set bitrate".to_string())
})?;
self.config.bitrate_kbps = bitrate_kbps;
debug!("Bitrate updated to {} kbps", bitrate_kbps);
Ok(())
}
/// Request next frame to be a keyframe (IDR)
pub fn request_keyframe(&mut self) {
self.inner.request_keyframe();
debug!("H264 keyframe requested");
}
/// Encode raw frame data (YUV420P or NV12 depending on config)
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
if data.len() < self.yuv_length as usize {
return Err(AppError::VideoError(format!(
"Frame data too small: {} < {}",
data.len(),
self.yuv_length
)));
}
self.frame_count += 1;
match self.inner.encode(data, pts_ms) {
Ok(frames) => {
// Copy frame data to owned HwEncodeFrame
let owned_frames: Vec<HwEncodeFrame> = frames
.iter()
.map(|f| HwEncodeFrame {
data: f.data.clone(),
pts: f.pts,
key: f.key,
})
.collect();
Ok(owned_frames)
}
Err(e) => {
// For the first ~30 frames, x264 may fail due to initialization
// Log as warning instead of error to avoid alarming users
if self.frame_count <= 30 {
warn!(
"Encode failed during initialization (frame {}): {} - this is normal for x264",
self.frame_count, e
);
} else {
error!("Encode failed: {}", e);
}
Err(AppError::VideoError(format!("Encode failed: {}", e)))
}
}
}
/// Encode YUV420P data (legacy method, use encode_raw for new code)
pub fn encode_yuv420p(&mut self, yuv_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
self.encode_raw(yuv_data, pts_ms)
}
/// Encode NV12 data
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
self.encode_raw(nv12_data, pts_ms)
}
/// Get input format
pub fn input_format(&self) -> H264InputFormat {
self.config.input_format
}
/// Get YUV buffer info (linesize, offset, length)
pub fn yuv_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
(
self.inner.linesize.clone(),
self.inner.offset.clone(),
self.inner.length,
)
}
}
// SAFETY: H264Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
// that are not Send by default. However, we ensure that H264Encoder is only used from
// a single task/thread at a time (encoding is sequential), so this is safe.
// The raw pointers are internal FFmpeg context that doesn't escape the encoder.
unsafe impl Send for H264Encoder {}
impl Encoder for H264Encoder {
fn name(&self) -> &str {
&self.codec_name
}
fn output_format(&self) -> EncodedFormat {
EncodedFormat::H264
}
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
// Assume input is YUV420P
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
let frames = self.encode_yuv420p(data, pts_ms)?;
if frames.is_empty() {
// Encoder needs more frames (shouldn't happen with our config)
warn!("Encoder returned no frames");
return Err(AppError::VideoError("Encoder returned no frames".to_string()));
}
// Take the first frame
let frame = &frames[0];
let key_frame = frame.key == 1;
Ok(EncodedFrame::h264(
Bytes::from(frame.data.clone()),
self.config.base.resolution,
key_frame,
sequence,
frame.pts as u64,
frame.pts as u64,
))
}
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
// hwcodec doesn't have explicit flush, return empty
Ok(vec![])
}
fn reset(&mut self) -> Result<()> {
self.frame_count = 0;
Ok(())
}
fn config(&self) -> &EncoderConfig {
&self.config.base
}
fn supports_format(&self, format: PixelFormat) -> bool {
// Check if the format matches our configured input format
match self.config.input_format {
H264InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
H264InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
}
}
}
/// Encoder statistics
#[derive(Debug, Clone, Default)]
pub struct EncoderStats {
/// Total frames encoded
pub frames_encoded: u64,
/// Total bytes output
pub bytes_output: u64,
/// Current encoding FPS
pub fps: f32,
/// Average encoding time per frame (ms)
pub avg_encode_time_ms: f32,
/// Keyframes encoded
pub keyframes: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_encoder() {
let (encoder_type, codec_name) = detect_best_encoder(1280, 720);
println!("Detected encoder: {:?} ({:?})", encoder_type, codec_name);
}
#[test]
fn test_available_encoders() {
let encoders = get_available_encoders(1280, 720);
println!("Available encoders:");
for enc in &encoders {
println!(" - {} ({:?})", enc.name, enc.format);
}
}
#[test]
fn test_create_encoder() {
let config = H264Config::low_latency(Resolution::HD720, 2000);
match H264Encoder::new(config) {
Ok(encoder) => {
println!("Created encoder: {} ({})", encoder.codec_name(), encoder.encoder_type());
}
Err(e) => {
println!("Failed to create encoder: {}", e);
}
}
}
}

577
src/video/encoder/h265.rs Normal file
View File

@@ -0,0 +1,577 @@
//! H.265/HEVC encoder using hwcodec (FFmpeg wrapper)
//!
//! Supports both hardware and software encoding:
//! - Hardware: VAAPI, NVENC, QSV, AMF, RKMPP, V4L2 M2M
//! - Software: libx265 (CPU-based, high CPU usage)
//!
//! Hardware encoding is preferred when available for better performance.
use bytes::Bytes;
use std::sync::Once;
use tracing::{debug, error, info, warn};
use hwcodec::common::{DataFormat, Quality, RateControl};
use hwcodec::ffmpeg::AVPixelFormat;
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
use hwcodec::ffmpeg_ram::CodecInfo;
use super::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
use crate::error::{AppError, Result};
use crate::video::format::{PixelFormat, Resolution};
static INIT_LOGGING: Once = Once::new();
/// Initialize hwcodec logging (only once)
fn init_hwcodec_logging() {
INIT_LOGGING.call_once(|| {
debug!("hwcodec logging initialized for H265");
});
}
/// H.265 encoder type (detected from hwcodec)
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum H265EncoderType {
/// NVIDIA NVENC
Nvenc,
/// Intel Quick Sync (QSV)
Qsv,
/// AMD AMF
Amf,
/// VAAPI (Linux generic)
Vaapi,
/// RKMPP (Rockchip)
Rkmpp,
/// V4L2 M2M (ARM generic)
V4l2M2m,
/// Software encoder (libx265)
Software,
/// No encoder available
None,
}
impl std::fmt::Display for H265EncoderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
H265EncoderType::Nvenc => write!(f, "NVENC"),
H265EncoderType::Qsv => write!(f, "QSV"),
H265EncoderType::Amf => write!(f, "AMF"),
H265EncoderType::Vaapi => write!(f, "VAAPI"),
H265EncoderType::Rkmpp => write!(f, "RKMPP"),
H265EncoderType::V4l2M2m => write!(f, "V4L2 M2M"),
H265EncoderType::Software => write!(f, "Software"),
H265EncoderType::None => write!(f, "None"),
}
}
}
impl Default for H265EncoderType {
fn default() -> Self {
Self::None
}
}
impl From<EncoderBackend> for H265EncoderType {
fn from(backend: EncoderBackend) -> Self {
match backend {
EncoderBackend::Nvenc => H265EncoderType::Nvenc,
EncoderBackend::Qsv => H265EncoderType::Qsv,
EncoderBackend::Amf => H265EncoderType::Amf,
EncoderBackend::Vaapi => H265EncoderType::Vaapi,
EncoderBackend::Rkmpp => H265EncoderType::Rkmpp,
EncoderBackend::V4l2m2m => H265EncoderType::V4l2M2m,
EncoderBackend::Software => H265EncoderType::Software,
}
}
}
/// Input pixel format for H265 encoder
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum H265InputFormat {
/// YUV420P (I420) - planar Y, U, V
Yuv420p,
/// NV12 - Y plane + interleaved UV plane (optimal for hardware encoders)
Nv12,
}
impl Default for H265InputFormat {
fn default() -> Self {
Self::Nv12 // Default to NV12 for hardware encoder compatibility
}
}
/// H.265 encoder configuration
#[derive(Debug, Clone)]
pub struct H265Config {
/// Base encoder config
pub base: EncoderConfig,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// GOP size (keyframe interval)
pub gop_size: u32,
/// Frame rate
pub fps: u32,
/// Input pixel format
pub input_format: H265InputFormat,
}
impl Default for H265Config {
fn default() -> Self {
Self {
base: EncoderConfig::default(),
bitrate_kbps: 8000,
gop_size: 30,
fps: 30,
input_format: H265InputFormat::Nv12,
}
}
}
impl H265Config {
/// Create config for low latency streaming with NV12 input
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig {
resolution,
input_format: PixelFormat::Nv12,
quality: bitrate_kbps,
fps: 30,
gop_size: 30,
},
bitrate_kbps,
gop_size: 30,
fps: 30,
input_format: H265InputFormat::Nv12,
}
}
/// Create config for quality streaming
pub fn quality(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig {
resolution,
input_format: PixelFormat::Nv12,
quality: bitrate_kbps,
fps: 30,
gop_size: 60,
},
bitrate_kbps,
gop_size: 60,
fps: 30,
input_format: H265InputFormat::Nv12,
}
}
/// Set input format
pub fn with_input_format(mut self, format: H265InputFormat) -> Self {
self.input_format = format;
self
}
}
/// Get available H265 hardware encoders from hwcodec
pub fn get_available_h265_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
init_hwcodec_logging();
let ctx = EncodeContext {
name: String::new(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
align: 1,
fps: 30,
gop: 30,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: 2000,
q: 23,
thread_count: 1,
};
let all_encoders = HwEncoder::available_encoders(ctx, None);
// Include both hardware and software H265 encoders
all_encoders
.into_iter()
.filter(|e| e.format == DataFormat::H265)
.collect()
}
/// Detect best available H.265 encoder (hardware preferred, software fallback)
pub fn detect_best_h265_encoder(width: u32, height: u32) -> (H265EncoderType, Option<String>) {
let encoders = get_available_h265_encoders(width, height);
if encoders.is_empty() {
warn!("No H.265 encoders available");
return (H265EncoderType::None, None);
}
// Prefer hardware encoders over software (libx265)
// Hardware priority: NVENC > QSV > AMF > VAAPI > RKMPP > V4L2 M2M > Software
let codec = encoders
.iter()
.find(|e| !e.name.contains("libx265"))
.or_else(|| encoders.first())
.unwrap();
let encoder_type = if codec.name.contains("nvenc") {
H265EncoderType::Nvenc
} else if codec.name.contains("qsv") {
H265EncoderType::Qsv
} else if codec.name.contains("amf") {
H265EncoderType::Amf
} else if codec.name.contains("vaapi") {
H265EncoderType::Vaapi
} else if codec.name.contains("rkmpp") {
H265EncoderType::Rkmpp
} else if codec.name.contains("v4l2m2m") {
H265EncoderType::V4l2M2m
} else if codec.name.contains("libx265") {
H265EncoderType::Software
} else {
H265EncoderType::Software // Default to software for unknown
};
info!(
"Selected H.265 encoder: {} ({})",
codec.name, encoder_type
);
(encoder_type, Some(codec.name.clone()))
}
/// Check if H265 hardware encoding is available
pub fn is_h265_available() -> bool {
let registry = EncoderRegistry::global();
registry.is_format_available(VideoEncoderType::H265, true)
}
/// Encoded frame from hwcodec (cloned for ownership)
#[derive(Debug, Clone)]
pub struct HwEncodeFrame {
pub data: Vec<u8>,
pub pts: i64,
pub key: i32,
}
/// H.265 encoder using hwcodec (hardware only)
pub struct H265Encoder {
/// hwcodec encoder instance
inner: HwEncoder,
/// Encoder configuration
config: H265Config,
/// Detected encoder type
encoder_type: H265EncoderType,
/// Codec name
codec_name: String,
/// Frame counter
frame_count: u64,
/// Required buffer length from hwcodec
buffer_length: i32,
}
impl H265Encoder {
/// Create a new H.265 encoder with automatic hardware codec detection
///
/// Returns an error if no hardware encoder is available.
pub fn new(config: H265Config) -> Result<Self> {
init_hwcodec_logging();
let width = config.base.resolution.width;
let height = config.base.resolution.height;
// Detect best hardware encoder
let (encoder_type, codec_name) = detect_best_h265_encoder(width, height);
if encoder_type == H265EncoderType::None {
return Err(AppError::VideoError(
"No H.265 encoder available. Please ensure FFmpeg is built with libx265 support.".to_string(),
));
}
let codec_name = codec_name.unwrap();
Self::with_codec(config, &codec_name)
}
/// Create encoder with specific codec name
pub fn with_codec(config: H265Config, codec_name: &str) -> Result<Self> {
init_hwcodec_logging();
// Determine if this is a software encoder
let is_software = codec_name.contains("libx265");
// Warn about software encoder performance
if is_software {
warn!(
"Using software H.265 encoder (libx265) - high CPU usage expected. \
Hardware encoder is recommended for better performance."
);
}
let width = config.base.resolution.width;
let height = config.base.resolution.height;
// Software encoders (libx265) require YUV420P, hardware encoders use NV12
let (pixfmt, actual_input_format) = if is_software {
(AVPixelFormat::AV_PIX_FMT_YUV420P, H265InputFormat::Yuv420p)
} else {
match config.input_format {
H265InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, H265InputFormat::Nv12),
H265InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, H265InputFormat::Yuv420p),
}
};
info!(
"Creating H.265 encoder: {} at {}x{} @ {} kbps (input: {:?})",
codec_name, width, height, config.bitrate_kbps, actual_input_format
);
let ctx = EncodeContext {
name: codec_name.to_string(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt,
align: 1,
fps: config.fps as i32,
gop: config.gop_size as i32,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: config.bitrate_kbps as i32,
q: 23,
thread_count: 1,
};
let inner = HwEncoder::new(ctx).map_err(|_| {
AppError::VideoError(format!("Failed to create H.265 encoder: {}", codec_name))
})?;
let buffer_length = inner.length;
let backend = EncoderBackend::from_codec_name(codec_name);
let encoder_type = H265EncoderType::from(backend);
// Update config to reflect actual input format used
let mut config = config;
config.input_format = actual_input_format;
info!(
"H.265 encoder created: {} (type: {}, buffer_length: {})",
codec_name, encoder_type, buffer_length
);
Ok(Self {
inner,
config,
encoder_type,
codec_name: codec_name.to_string(),
frame_count: 0,
buffer_length,
})
}
/// Create with auto-detected encoder
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
let config = H265Config::low_latency(resolution, bitrate_kbps);
Self::new(config)
}
/// Get encoder type
pub fn encoder_type(&self) -> &H265EncoderType {
&self.encoder_type
}
/// Get codec name
pub fn codec_name(&self) -> &str {
&self.codec_name
}
/// Update bitrate dynamically
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
AppError::VideoError("Failed to set H.265 bitrate".to_string())
})?;
self.config.bitrate_kbps = bitrate_kbps;
debug!("H.265 bitrate updated to {} kbps", bitrate_kbps);
Ok(())
}
/// Request next frame to be a keyframe (IDR)
pub fn request_keyframe(&mut self) {
self.inner.request_keyframe();
debug!("H265 keyframe requested");
}
/// Encode raw frame data (NV12 or YUV420P depending on config)
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
if data.len() < self.buffer_length as usize {
return Err(AppError::VideoError(format!(
"Frame data too small: {} < {}",
data.len(),
self.buffer_length
)));
}
self.frame_count += 1;
// Debug log every 30 frames (1 second at 30fps)
if self.frame_count % 30 == 1 {
debug!(
"[H265] Encoding frame #{}: input_size={}, pts_ms={}, codec={}",
self.frame_count,
data.len(),
pts_ms,
self.codec_name
);
}
match self.inner.encode(data, pts_ms) {
Ok(frames) => {
let owned_frames: Vec<HwEncodeFrame> = frames
.iter()
.map(|f| HwEncodeFrame {
data: f.data.clone(),
pts: f.pts,
key: f.key,
})
.collect();
// Log encoded output
if !owned_frames.is_empty() {
let total_size: usize = owned_frames.iter().map(|f| f.data.len()).sum();
let keyframe = owned_frames.iter().any(|f| f.key == 1);
if keyframe || self.frame_count % 30 == 1 {
debug!(
"[H265] Encoded frame #{}: output_size={}, keyframe={}, frame_count={}",
self.frame_count, total_size, keyframe, owned_frames.len()
);
// Log first few bytes of keyframe for debugging
if keyframe && !owned_frames[0].data.is_empty() {
let preview_len = owned_frames[0].data.len().min(32);
debug!(
"[H265] Keyframe data preview: {:02x?}",
&owned_frames[0].data[..preview_len]
);
}
}
} else {
warn!("[H265] Encoder returned empty frame list for frame #{}", self.frame_count);
}
Ok(owned_frames)
}
Err(e) => {
error!("[H265] Encode failed at frame #{}: {}", self.frame_count, e);
Err(AppError::VideoError(format!("H.265 encode failed: {}", e)))
}
}
}
/// Encode NV12 data
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
self.encode_raw(nv12_data, pts_ms)
}
/// Get input format
pub fn input_format(&self) -> H265InputFormat {
self.config.input_format
}
/// Get buffer info (linesize, offset, length)
pub fn buffer_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
(
self.inner.linesize.clone(),
self.inner.offset.clone(),
self.inner.length,
)
}
}
// SAFETY: H265Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
// that are not Send by default. However, we ensure that H265Encoder is only used from
// a single task/thread at a time (encoding is sequential), so this is safe.
unsafe impl Send for H265Encoder {}
impl Encoder for H265Encoder {
fn name(&self) -> &str {
&self.codec_name
}
fn output_format(&self) -> EncodedFormat {
EncodedFormat::H265
}
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
let frames = self.encode_raw(data, pts_ms)?;
if frames.is_empty() {
warn!("H.265 encoder returned no frames");
return Err(AppError::VideoError(
"H.265 encoder returned no frames".to_string(),
));
}
let frame = &frames[0];
let key_frame = frame.key == 1;
Ok(EncodedFrame {
data: Bytes::from(frame.data.clone()),
format: EncodedFormat::H265,
resolution: self.config.base.resolution,
key_frame,
sequence,
timestamp: std::time::Instant::now(),
pts: frame.pts as u64,
dts: frame.pts as u64,
})
}
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
Ok(vec![])
}
fn reset(&mut self) -> Result<()> {
self.frame_count = 0;
Ok(())
}
fn config(&self) -> &EncoderConfig {
&self.config.base
}
fn supports_format(&self, format: PixelFormat) -> bool {
match self.config.input_format {
H265InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
H265InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_h265_encoder() {
let (encoder_type, codec_name) = detect_best_h265_encoder(1280, 720);
println!("Detected H.265 encoder: {:?} ({:?})", encoder_type, codec_name);
}
#[test]
fn test_available_h265_encoders() {
let encoders = get_available_h265_encoders(1280, 720);
println!("Available H.265 hardware encoders:");
for enc in &encoders {
println!(" - {} ({:?})", enc.name, enc.format);
}
}
#[test]
fn test_h265_availability() {
let available = is_h265_available();
println!("H.265 hardware encoding available: {}", available);
}
}

226
src/video/encoder/jpeg.rs Normal file
View File

@@ -0,0 +1,226 @@
//! JPEG encoder implementation
//!
//! Provides JPEG encoding for raw video frames (YUYV, NV12, RGB, BGR)
//! Uses libyuv for SIMD-accelerated color space conversion to I420,
//! then turbojpeg for direct YUV encoding (skips internal color conversion).
use bytes::Bytes;
use super::traits::{EncodedFormat, EncodedFrame, EncoderConfig};
use crate::error::{AppError, Result};
use crate::video::format::{PixelFormat, Resolution};
/// JPEG encoder using libyuv + turbojpeg
///
/// Encoding pipeline (all SIMD accelerated):
/// ```text
/// YUYV/NV12/BGR24/RGB24 ──libyuv──> I420 ──turbojpeg──> JPEG
/// ```
///
/// Note: This encoder is NOT thread-safe due to turbojpeg limitations.
/// Use it from a single thread or wrap in a Mutex.
pub struct JpegEncoder {
config: EncoderConfig,
compressor: turbojpeg::Compressor,
/// I420 buffer for YUV encoding (Y + U + V planes)
i420_buffer: Vec<u8>,
}
impl JpegEncoder {
/// Create a new JPEG encoder
pub fn new(config: EncoderConfig) -> Result<Self> {
let resolution = config.resolution;
let width = resolution.width as usize;
let height = resolution.height as usize;
// I420: Y = width*height, U = width*height/4, V = width*height/4
let i420_size = width * height * 3 / 2;
let mut compressor = turbojpeg::Compressor::new()
.map_err(|e| AppError::VideoError(format!("Failed to create turbojpeg compressor: {}", e)))?;
compressor.set_quality(config.quality.min(100) as i32)
.map_err(|e| AppError::VideoError(format!("Failed to set JPEG quality: {}", e)))?;
Ok(Self {
config,
compressor,
i420_buffer: vec![0u8; i420_size],
})
}
/// Create with specific quality
pub fn with_quality(resolution: Resolution, quality: u32) -> Result<Self> {
let config = EncoderConfig::jpeg(resolution, quality);
Self::new(config)
}
/// Set JPEG quality (1-100)
pub fn set_quality(&mut self, quality: u32) -> Result<()> {
self.compressor.set_quality(quality.min(100) as i32)
.map_err(|e| AppError::VideoError(format!("Failed to set JPEG quality: {}", e)))?;
self.config.quality = quality;
Ok(())
}
/// Encode I420 buffer to JPEG using turbojpeg's YUV encoder
#[inline]
fn encode_i420_to_jpeg(&mut self, sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
let height = self.config.resolution.height as usize;
// Create YuvImage for turbojpeg (I420 = YUV420 = Sub2x2)
let yuv_image = turbojpeg::YuvImage {
pixels: self.i420_buffer.as_slice(),
width,
height,
align: 1, // No padding between rows
subsamp: turbojpeg::Subsamp::Sub2x2, // YUV 4:2:0
};
// Compress YUV directly to JPEG (skips color space conversion!)
let jpeg_data = self.compressor.compress_yuv_to_vec(yuv_image)
.map_err(|e| AppError::VideoError(format!("JPEG compression failed: {}", e)))?;
Ok(EncodedFrame::jpeg(
Bytes::from(jpeg_data),
self.config.resolution,
sequence,
))
}
/// Encode YUYV (YUV422) frame to JPEG
pub fn encode_yuyv(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
let height = self.config.resolution.height as usize;
let expected_size = width * height * 2;
if data.len() < expected_size {
return Err(AppError::VideoError(format!(
"YUYV data too small: {} < {}",
data.len(),
expected_size
)));
}
// Convert YUYV to I420 using libyuv (SIMD accelerated)
libyuv::yuy2_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
.map_err(|e| AppError::VideoError(format!("libyuv YUYV→I420 failed: {}", e)))?;
self.encode_i420_to_jpeg(sequence)
}
/// Encode NV12 frame to JPEG
pub fn encode_nv12(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
let height = self.config.resolution.height as usize;
let expected_size = width * height * 3 / 2;
if data.len() < expected_size {
return Err(AppError::VideoError(format!(
"NV12 data too small: {} < {}",
data.len(),
expected_size
)));
}
// Convert NV12 to I420 using libyuv (SIMD accelerated)
libyuv::nv12_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
.map_err(|e| AppError::VideoError(format!("libyuv NV12→I420 failed: {}", e)))?;
self.encode_i420_to_jpeg(sequence)
}
/// Encode RGB24 frame to JPEG
pub fn encode_rgb(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
let height = self.config.resolution.height as usize;
let expected_size = width * height * 3;
if data.len() < expected_size {
return Err(AppError::VideoError(format!(
"RGB data too small: {} < {}",
data.len(),
expected_size
)));
}
// Convert RGB24 to I420 using libyuv (SIMD accelerated)
libyuv::rgb24_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
.map_err(|e| AppError::VideoError(format!("libyuv RGB24→I420 failed: {}", e)))?;
self.encode_i420_to_jpeg(sequence)
}
/// Encode BGR24 frame to JPEG
pub fn encode_bgr(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
let height = self.config.resolution.height as usize;
let expected_size = width * height * 3;
if data.len() < expected_size {
return Err(AppError::VideoError(format!(
"BGR data too small: {} < {}",
data.len(),
expected_size
)));
}
// Convert BGR24 to I420 using libyuv (SIMD accelerated)
// Note: libyuv's RAWToI420 is BGR24 → I420
libyuv::bgr24_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
.map_err(|e| AppError::VideoError(format!("libyuv BGR24→I420 failed: {}", e)))?;
self.encode_i420_to_jpeg(sequence)
}
}
impl crate::video::encoder::traits::Encoder for JpegEncoder {
fn name(&self) -> &str {
"JPEG (libyuv+turbojpeg)"
}
fn output_format(&self) -> EncodedFormat {
EncodedFormat::Jpeg
}
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
match self.config.input_format {
PixelFormat::Yuyv | PixelFormat::Yvyu => self.encode_yuyv(data, sequence),
PixelFormat::Nv12 => self.encode_nv12(data, sequence),
PixelFormat::Rgb24 => self.encode_rgb(data, sequence),
PixelFormat::Bgr24 => self.encode_bgr(data, sequence),
_ => Err(AppError::VideoError(format!(
"Unsupported input format for JPEG: {}",
self.config.input_format
))),
}
}
fn config(&self) -> &EncoderConfig {
&self.config
}
fn supports_format(&self, format: PixelFormat) -> bool {
matches!(
format,
PixelFormat::Yuyv
| PixelFormat::Yvyu
| PixelFormat::Nv12
| PixelFormat::Rgb24
| PixelFormat::Bgr24
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_i420_buffer_size() {
// 1920x1080 I420 = 1920*1080 + 960*540 + 960*540 = 3110400 bytes
let config = EncoderConfig::jpeg(Resolution::HD1080, 80);
let encoder = JpegEncoder::new(config).unwrap();
assert_eq!(encoder.i420_buffer.len(), 1920 * 1080 * 3 / 2);
}
}

43
src/video/encoder/mod.rs Normal file
View File

@@ -0,0 +1,43 @@
//! Video encoder implementations
//!
//! This module provides video encoding capabilities including:
//! - JPEG encoding for raw frames (YUYV, NV12, etc.)
//! - H264 encoding (hardware + software)
//! - H265 encoding (hardware only)
//! - VP8 encoding (hardware only - VAAPI)
//! - VP9 encoding (hardware only - VAAPI)
//! - WebRTC video codec abstraction
//! - Encoder registry for automatic detection
pub mod codec;
pub mod h264;
pub mod h265;
pub mod jpeg;
pub mod registry;
pub mod traits;
pub mod vp8;
pub mod vp9;
// Core traits and types
pub use traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig, EncoderFactory};
// WebRTC codec abstraction
pub use codec::{CodecFrame, VideoCodec, VideoCodecConfig, VideoCodecFactory, VideoCodecType};
// Encoder registry
pub use registry::{AvailableEncoder, EncoderBackend, EncoderRegistry, VideoEncoderType};
// H264 encoder
pub use h264::{H264Config, H264Encoder, H264EncoderType, H264InputFormat};
// H265 encoder (hardware only)
pub use h265::{H265Config, H265Encoder, H265EncoderType, H265InputFormat};
// VP8 encoder (hardware only)
pub use vp8::{VP8Config, VP8Encoder, VP8EncoderType, VP8InputFormat};
// VP9 encoder (hardware only)
pub use vp9::{VP9Config, VP9Encoder, VP9EncoderType, VP9InputFormat};
// JPEG encoder
pub use jpeg::JpegEncoder;

View File

@@ -0,0 +1,531 @@
//! Encoder registry - Detection and management of available video encoders
//!
//! This module provides:
//! - Automatic detection of available hardware/software encoders
//! - Encoder selection based on format and priority
//! - Global registry for encoder availability queries
use std::collections::HashMap;
use std::sync::OnceLock;
use tracing::{debug, info};
use hwcodec::common::{DataFormat, Quality, RateControl};
use hwcodec::ffmpeg::AVPixelFormat;
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
use hwcodec::ffmpeg_ram::CodecInfo;
/// Video encoder format type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum VideoEncoderType {
/// H.264/AVC
H264,
/// H.265/HEVC
H265,
/// VP8
VP8,
/// VP9
VP9,
}
impl VideoEncoderType {
/// Convert to hwcodec DataFormat
pub fn to_data_format(&self) -> DataFormat {
match self {
VideoEncoderType::H264 => DataFormat::H264,
VideoEncoderType::H265 => DataFormat::H265,
VideoEncoderType::VP8 => DataFormat::VP8,
VideoEncoderType::VP9 => DataFormat::VP9,
}
}
/// Create from hwcodec DataFormat
pub fn from_data_format(format: DataFormat) -> Option<Self> {
match format {
DataFormat::H264 => Some(VideoEncoderType::H264),
DataFormat::H265 => Some(VideoEncoderType::H265),
DataFormat::VP8 => Some(VideoEncoderType::VP8),
DataFormat::VP9 => Some(VideoEncoderType::VP9),
_ => None,
}
}
/// Get codec name prefix for FFmpeg
pub fn codec_prefix(&self) -> &'static str {
match self {
VideoEncoderType::H264 => "h264",
VideoEncoderType::H265 => "hevc",
VideoEncoderType::VP8 => "vp8",
VideoEncoderType::VP9 => "vp9",
}
}
/// Get display name
pub fn display_name(&self) -> &'static str {
match self {
VideoEncoderType::H264 => "H.264",
VideoEncoderType::H265 => "H.265/HEVC",
VideoEncoderType::VP8 => "VP8",
VideoEncoderType::VP9 => "VP9",
}
}
/// Check if this format requires hardware-only encoding
/// H264 supports software fallback, others require hardware
pub fn hardware_only(&self) -> bool {
match self {
VideoEncoderType::H264 => false,
VideoEncoderType::H265 => true,
VideoEncoderType::VP8 => true,
VideoEncoderType::VP9 => true,
}
}
}
impl std::fmt::Display for VideoEncoderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.display_name())
}
}
/// Encoder backend type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EncoderBackend {
/// Intel/AMD/NVIDIA VAAPI (Linux)
Vaapi,
/// NVIDIA NVENC
Nvenc,
/// Intel Quick Sync Video
Qsv,
/// AMD AMF
Amf,
/// Rockchip MPP
Rkmpp,
/// V4L2 Memory-to-Memory (ARM)
V4l2m2m,
/// Software encoding (libx264, libx265, libvpx)
Software,
}
impl EncoderBackend {
/// Detect backend from codec name
pub fn from_codec_name(name: &str) -> Self {
if name.contains("vaapi") {
EncoderBackend::Vaapi
} else if name.contains("nvenc") {
EncoderBackend::Nvenc
} else if name.contains("qsv") {
EncoderBackend::Qsv
} else if name.contains("amf") {
EncoderBackend::Amf
} else if name.contains("rkmpp") {
EncoderBackend::Rkmpp
} else if name.contains("v4l2m2m") {
EncoderBackend::V4l2m2m
} else {
EncoderBackend::Software
}
}
/// Check if this is a hardware backend
pub fn is_hardware(&self) -> bool {
!matches!(self, EncoderBackend::Software)
}
/// Get display name
pub fn display_name(&self) -> &'static str {
match self {
EncoderBackend::Vaapi => "VAAPI",
EncoderBackend::Nvenc => "NVENC",
EncoderBackend::Qsv => "QSV",
EncoderBackend::Amf => "AMF",
EncoderBackend::Rkmpp => "RKMPP",
EncoderBackend::V4l2m2m => "V4L2 M2M",
EncoderBackend::Software => "Software",
}
}
/// Parse from string (case-insensitive)
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"vaapi" => Some(EncoderBackend::Vaapi),
"nvenc" => Some(EncoderBackend::Nvenc),
"qsv" => Some(EncoderBackend::Qsv),
"amf" => Some(EncoderBackend::Amf),
"rkmpp" => Some(EncoderBackend::Rkmpp),
"v4l2m2m" | "v4l2" => Some(EncoderBackend::V4l2m2m),
"software" | "cpu" => Some(EncoderBackend::Software),
_ => None,
}
}
}
impl std::fmt::Display for EncoderBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.display_name())
}
}
/// Information about an available encoder
#[derive(Debug, Clone)]
pub struct AvailableEncoder {
/// Encoder format type
pub format: VideoEncoderType,
/// FFmpeg codec name (e.g., "h264_vaapi", "hevc_nvenc")
pub codec_name: String,
/// Backend type
pub backend: EncoderBackend,
/// Priority (lower is better)
pub priority: i32,
/// Whether this is a hardware encoder
pub is_hardware: bool,
}
impl AvailableEncoder {
/// Create from hwcodec CodecInfo
pub fn from_codec_info(info: &CodecInfo) -> Option<Self> {
let format = VideoEncoderType::from_data_format(info.format)?;
let backend = EncoderBackend::from_codec_name(&info.name);
let is_hardware = backend.is_hardware();
Some(Self {
format,
codec_name: info.name.clone(),
backend,
priority: info.priority,
is_hardware,
})
}
}
/// Global encoder registry
///
/// Detects and caches available encoders at startup.
/// Use `EncoderRegistry::global()` to access the singleton instance.
pub struct EncoderRegistry {
/// Available encoders grouped by format
encoders: HashMap<VideoEncoderType, Vec<AvailableEncoder>>,
/// Detection resolution (used for testing)
detection_resolution: (u32, u32),
}
impl EncoderRegistry {
/// Get the global registry instance
///
/// The registry is initialized lazily on first access with 1920x1080 detection.
pub fn global() -> &'static Self {
static INSTANCE: OnceLock<EncoderRegistry> = OnceLock::new();
INSTANCE.get_or_init(|| {
let mut registry = EncoderRegistry::new();
registry.detect_encoders(1920, 1080);
registry
})
}
/// Create a new empty registry
pub fn new() -> Self {
Self {
encoders: HashMap::new(),
detection_resolution: (0, 0),
}
}
/// Detect all available encoders
///
/// This queries hwcodec/FFmpeg for available encoders and populates the registry.
pub fn detect_encoders(&mut self, width: u32, height: u32) {
info!("Detecting available video encoders at {}x{}", width, height);
self.encoders.clear();
self.detection_resolution = (width, height);
// Create test context for encoder detection
let ctx = EncodeContext {
name: String::new(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
align: 1,
fps: 30,
gop: 30,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: 2000,
q: 23,
thread_count: 1,
};
// Get all available encoders from hwcodec
let all_encoders = HwEncoder::available_encoders(ctx, None);
info!("Found {} encoders from hwcodec", all_encoders.len());
for codec_info in &all_encoders {
if let Some(encoder) = AvailableEncoder::from_codec_info(codec_info) {
debug!(
"Detected encoder: {} ({}) - {} priority={}",
encoder.codec_name,
encoder.format,
encoder.backend,
encoder.priority
);
self.encoders
.entry(encoder.format)
.or_default()
.push(encoder);
}
}
// Sort encoders by priority (lower is better)
for encoders in self.encoders.values_mut() {
encoders.sort_by_key(|e| e.priority);
}
// Register software encoders as fallback
info!("Registering software encoders...");
let software_encoders = [
(VideoEncoderType::H264, "libx264", 100),
(VideoEncoderType::H265, "libx265", 100),
(VideoEncoderType::VP8, "libvpx", 100),
(VideoEncoderType::VP9, "libvpx-vp9", 100),
];
for (format, codec_name, priority) in software_encoders {
self.encoders
.entry(format)
.or_default()
.push(AvailableEncoder {
format,
codec_name: codec_name.to_string(),
backend: EncoderBackend::Software,
priority,
is_hardware: false,
});
debug!(
"Registered software encoder: {} for {} (priority: {})",
codec_name, format, priority
);
}
// Log summary
for (format, encoders) in &self.encoders {
let hw_count = encoders.iter().filter(|e| e.is_hardware).count();
let sw_count = encoders.len() - hw_count;
info!(
"{}: {} encoders ({} hardware, {} software)",
format,
encoders.len(),
hw_count,
sw_count
);
}
}
/// Get the best encoder for a format
///
/// # Arguments
/// * `format` - The video format to encode
/// * `hardware_only` - If true, only return hardware encoders
///
/// # Returns
/// The best available encoder, or None if no suitable encoder is found
pub fn best_encoder(
&self,
format: VideoEncoderType,
hardware_only: bool,
) -> Option<&AvailableEncoder> {
self.encoders.get(&format)?.iter().find(|e| {
if hardware_only {
e.is_hardware
} else {
true
}
})
}
/// Get all encoders for a format
pub fn encoders_for_format(&self, format: VideoEncoderType) -> &[AvailableEncoder] {
self.encoders
.get(&format)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
/// Get all available formats
///
/// # Arguments
/// * `hardware_only` - If true, only return formats with hardware encoders
pub fn available_formats(&self, hardware_only: bool) -> Vec<VideoEncoderType> {
self.encoders
.iter()
.filter(|(_, encoders)| {
if hardware_only {
encoders.iter().any(|e| e.is_hardware)
} else {
!encoders.is_empty()
}
})
.map(|(format, _)| *format)
.collect()
}
/// Check if a format is available
///
/// # Arguments
/// * `format` - The video format to check
/// * `hardware_only` - If true, only check for hardware encoders
pub fn is_format_available(&self, format: VideoEncoderType, hardware_only: bool) -> bool {
self.best_encoder(format, hardware_only).is_some()
}
/// Get available formats for user selection
///
/// Returns formats that are actually usable based on their requirements:
/// - H264: Available if any encoder exists (hardware or software)
/// - H265/VP8/VP9: Available only if hardware encoder exists
pub fn selectable_formats(&self) -> Vec<VideoEncoderType> {
let mut formats = Vec::new();
// H264 - supports software fallback
if self.is_format_available(VideoEncoderType::H264, false) {
formats.push(VideoEncoderType::H264);
}
// H265/VP8/VP9 - hardware only
for format in [
VideoEncoderType::H265,
VideoEncoderType::VP8,
VideoEncoderType::VP9,
] {
if self.is_format_available(format, true) {
formats.push(format);
}
}
formats
}
/// Get detection resolution
pub fn detection_resolution(&self) -> (u32, u32) {
self.detection_resolution
}
/// Get all available backend types
pub fn available_backends(&self) -> Vec<EncoderBackend> {
use std::collections::HashSet;
let mut backends = HashSet::new();
for encoders in self.encoders.values() {
for encoder in encoders {
backends.insert(encoder.backend);
}
}
let mut result: Vec<_> = backends.into_iter().collect();
// Sort: hardware backends first, software last
result.sort_by_key(|b| if b.is_hardware() { 0 } else { 1 });
result
}
/// Get formats supported by a specific backend
pub fn formats_for_backend(&self, backend: EncoderBackend) -> Vec<VideoEncoderType> {
let mut formats = Vec::new();
for (format, encoders) in &self.encoders {
if encoders.iter().any(|e| e.backend == backend) {
formats.push(*format);
}
}
formats
}
/// Get encoder for a format with specific backend
pub fn encoder_with_backend(
&self,
format: VideoEncoderType,
backend: EncoderBackend,
) -> Option<&AvailableEncoder> {
self.encoders
.get(&format)?
.iter()
.find(|e| e.backend == backend)
}
/// Get encoders grouped by backend for a format
pub fn encoders_by_backend(
&self,
format: VideoEncoderType,
) -> HashMap<EncoderBackend, Vec<&AvailableEncoder>> {
let mut grouped = HashMap::new();
if let Some(encoders) = self.encoders.get(&format) {
for encoder in encoders {
grouped
.entry(encoder.backend)
.or_insert_with(Vec::new)
.push(encoder);
}
}
grouped
}
}
impl Default for EncoderRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_video_encoder_type_display() {
assert_eq!(VideoEncoderType::H264.display_name(), "H.264");
assert_eq!(VideoEncoderType::H265.display_name(), "H.265/HEVC");
assert_eq!(VideoEncoderType::VP8.display_name(), "VP8");
assert_eq!(VideoEncoderType::VP9.display_name(), "VP9");
}
#[test]
fn test_encoder_backend_detection() {
assert_eq!(
EncoderBackend::from_codec_name("h264_vaapi"),
EncoderBackend::Vaapi
);
assert_eq!(
EncoderBackend::from_codec_name("hevc_nvenc"),
EncoderBackend::Nvenc
);
assert_eq!(
EncoderBackend::from_codec_name("h264_qsv"),
EncoderBackend::Qsv
);
assert_eq!(
EncoderBackend::from_codec_name("libx264"),
EncoderBackend::Software
);
}
#[test]
fn test_hardware_only_requirement() {
assert!(!VideoEncoderType::H264.hardware_only());
assert!(VideoEncoderType::H265.hardware_only());
assert!(VideoEncoderType::VP8.hardware_only());
assert!(VideoEncoderType::VP9.hardware_only());
}
#[test]
fn test_registry_detection() {
let mut registry = EncoderRegistry::new();
registry.detect_encoders(1280, 720);
// Should have detected at least H264 (software fallback available)
println!("Available formats: {:?}", registry.available_formats(false));
println!(
"Selectable formats: {:?}",
registry.selectable_formats()
);
}
}

188
src/video/encoder/traits.rs Normal file
View File

@@ -0,0 +1,188 @@
//! Encoder traits and common types
use bytes::Bytes;
use std::time::Instant;
use crate::video::format::{PixelFormat, Resolution};
use crate::error::Result;
/// Encoder configuration
#[derive(Debug, Clone)]
pub struct EncoderConfig {
/// Target resolution
pub resolution: Resolution,
/// Input pixel format
pub input_format: PixelFormat,
/// Output quality (1-100 for JPEG, bitrate kbps for H264)
pub quality: u32,
/// Target frame rate
pub fps: u32,
/// Keyframe interval (for H264)
pub gop_size: u32,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self {
resolution: Resolution::HD1080,
input_format: PixelFormat::Yuyv,
quality: 80,
fps: 30,
gop_size: 30,
}
}
}
impl EncoderConfig {
pub fn jpeg(resolution: Resolution, quality: u32) -> Self {
Self {
resolution,
input_format: PixelFormat::Yuyv,
quality,
fps: 30,
gop_size: 1,
}
}
pub fn h264(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
resolution,
input_format: PixelFormat::Yuyv,
quality: bitrate_kbps,
fps: 30,
gop_size: 30,
}
}
}
/// Encoded frame output
#[derive(Debug, Clone)]
pub struct EncodedFrame {
/// Encoded data
pub data: Bytes,
/// Output format (JPEG, H264, etc.)
pub format: EncodedFormat,
/// Resolution
pub resolution: Resolution,
/// Whether this is a key frame
pub key_frame: bool,
/// Frame sequence number
pub sequence: u64,
/// Encoding timestamp
pub timestamp: Instant,
/// Presentation timestamp (for video sync)
pub pts: u64,
/// Decode timestamp (for B-frames)
pub dts: u64,
}
impl EncodedFrame {
pub fn jpeg(data: Bytes, resolution: Resolution, sequence: u64) -> Self {
Self {
data,
format: EncodedFormat::Jpeg,
resolution,
key_frame: true,
sequence,
timestamp: Instant::now(),
pts: sequence,
dts: sequence,
}
}
pub fn h264(
data: Bytes,
resolution: Resolution,
key_frame: bool,
sequence: u64,
pts: u64,
dts: u64,
) -> Self {
Self {
data,
format: EncodedFormat::H264,
resolution,
key_frame,
sequence,
timestamp: Instant::now(),
pts,
dts,
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
/// Encoded output format
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncodedFormat {
Jpeg,
H264,
H265,
Vp8,
Vp9,
Av1,
}
impl std::fmt::Display for EncodedFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EncodedFormat::Jpeg => write!(f, "JPEG"),
EncodedFormat::H264 => write!(f, "H.264"),
EncodedFormat::H265 => write!(f, "H.265"),
EncodedFormat::Vp8 => write!(f, "VP8"),
EncodedFormat::Vp9 => write!(f, "VP9"),
EncodedFormat::Av1 => write!(f, "AV1"),
}
}
}
/// Generic encoder trait
/// Note: Not Sync because some encoders (like turbojpeg) are not thread-safe
pub trait Encoder: Send {
/// Get encoder name
fn name(&self) -> &str;
/// Get output format
fn output_format(&self) -> EncodedFormat;
/// Encode a raw frame
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame>;
/// Flush any pending frames
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
Ok(vec![])
}
/// Reset encoder state
fn reset(&mut self) -> Result<()> {
Ok(())
}
/// Get current configuration
fn config(&self) -> &EncoderConfig;
/// Check if encoder supports the given input format
fn supports_format(&self, format: PixelFormat) -> bool;
}
/// Encoder factory for creating encoders
pub trait EncoderFactory: Send + Sync {
/// Create an encoder with the given configuration
fn create(&self, config: EncoderConfig) -> Result<Box<dyn Encoder>>;
/// Get encoder type name
fn encoder_type(&self) -> &str;
/// Check if this encoder is available on the system
fn is_available(&self) -> bool;
/// Get encoder priority (higher = preferred)
fn priority(&self) -> u32;
}

488
src/video/encoder/vp8.rs Normal file
View File

@@ -0,0 +1,488 @@
//! VP8 encoder using hwcodec (FFmpeg wrapper)
//!
//! Supports both hardware and software encoding:
//! - Hardware: VAAPI (Intel on Linux)
//! - Software: libvpx (CPU-based, high CPU usage)
//!
//! Hardware encoding is preferred when available for better performance.
use bytes::Bytes;
use std::sync::Once;
use tracing::{debug, error, info, warn};
use hwcodec::common::{DataFormat, Quality, RateControl};
use hwcodec::ffmpeg::AVPixelFormat;
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
use hwcodec::ffmpeg_ram::CodecInfo;
use super::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
use crate::error::{AppError, Result};
use crate::video::format::{PixelFormat, Resolution};
static INIT_LOGGING: Once = Once::new();
/// Initialize hwcodec logging (only once)
fn init_hwcodec_logging() {
INIT_LOGGING.call_once(|| {
debug!("hwcodec logging initialized for VP8");
});
}
/// VP8 encoder type (detected from hwcodec)
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VP8EncoderType {
/// VAAPI (Intel on Linux)
Vaapi,
/// Software encoder (libvpx)
Software,
/// No encoder available
None,
}
impl std::fmt::Display for VP8EncoderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VP8EncoderType::Vaapi => write!(f, "VAAPI"),
VP8EncoderType::Software => write!(f, "Software"),
VP8EncoderType::None => write!(f, "None"),
}
}
}
impl Default for VP8EncoderType {
fn default() -> Self {
Self::None
}
}
impl From<EncoderBackend> for VP8EncoderType {
fn from(backend: EncoderBackend) -> Self {
match backend {
EncoderBackend::Vaapi => VP8EncoderType::Vaapi,
EncoderBackend::Software => VP8EncoderType::Software,
_ => VP8EncoderType::None,
}
}
}
/// Input pixel format for VP8 encoder
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VP8InputFormat {
/// YUV420P (I420) - planar Y, U, V
Yuv420p,
/// NV12 - Y plane + interleaved UV plane
Nv12,
}
impl Default for VP8InputFormat {
fn default() -> Self {
Self::Nv12 // Default to NV12 for VAAPI compatibility
}
}
/// VP8 encoder configuration
#[derive(Debug, Clone)]
pub struct VP8Config {
/// Base encoder config
pub base: EncoderConfig,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// GOP size (keyframe interval)
pub gop_size: u32,
/// Frame rate
pub fps: u32,
/// Input pixel format
pub input_format: VP8InputFormat,
}
impl Default for VP8Config {
fn default() -> Self {
Self {
base: EncoderConfig::default(),
bitrate_kbps: 8000,
gop_size: 30,
fps: 30,
input_format: VP8InputFormat::Nv12,
}
}
}
impl VP8Config {
/// Create config for low latency streaming with NV12 input
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig {
resolution,
input_format: PixelFormat::Nv12,
quality: bitrate_kbps,
fps: 30,
gop_size: 30,
},
bitrate_kbps,
gop_size: 30,
fps: 30,
input_format: VP8InputFormat::Nv12,
}
}
/// Set input format
pub fn with_input_format(mut self, format: VP8InputFormat) -> Self {
self.input_format = format;
self
}
}
/// Get available VP8 hardware encoders from hwcodec
pub fn get_available_vp8_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
init_hwcodec_logging();
let ctx = EncodeContext {
name: String::new(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
align: 1,
fps: 30,
gop: 30,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: 2000,
q: 23,
thread_count: 1,
};
let all_encoders = HwEncoder::available_encoders(ctx, None);
// Include both hardware and software VP8 encoders
all_encoders
.into_iter()
.filter(|e| e.format == DataFormat::VP8)
.collect()
}
/// Detect best available VP8 encoder (hardware preferred, software fallback)
pub fn detect_best_vp8_encoder(width: u32, height: u32) -> (VP8EncoderType, Option<String>) {
let encoders = get_available_vp8_encoders(width, height);
if encoders.is_empty() {
warn!("No VP8 encoders available");
return (VP8EncoderType::None, None);
}
// Prefer hardware encoders (VAAPI) over software (libvpx)
let codec = encoders
.iter()
.find(|e| e.name.contains("vaapi"))
.or_else(|| encoders.first())
.unwrap();
let encoder_type = if codec.name.contains("vaapi") {
VP8EncoderType::Vaapi
} else if codec.name.contains("libvpx") {
VP8EncoderType::Software
} else {
VP8EncoderType::Software // Default to software for unknown
};
info!(
"Selected VP8 encoder: {} ({})",
codec.name, encoder_type
);
(encoder_type, Some(codec.name.clone()))
}
/// Check if VP8 hardware encoding is available
pub fn is_vp8_available() -> bool {
let registry = EncoderRegistry::global();
registry.is_format_available(VideoEncoderType::VP8, true)
}
/// Encoded frame from hwcodec (cloned for ownership)
#[derive(Debug, Clone)]
pub struct HwEncodeFrame {
pub data: Vec<u8>,
pub pts: i64,
pub key: i32,
}
/// VP8 encoder using hwcodec (hardware only - VAAPI)
pub struct VP8Encoder {
/// hwcodec encoder instance
inner: HwEncoder,
/// Encoder configuration
config: VP8Config,
/// Detected encoder type
encoder_type: VP8EncoderType,
/// Codec name
codec_name: String,
/// Frame counter
frame_count: u64,
/// Required buffer length from hwcodec
buffer_length: i32,
}
impl VP8Encoder {
/// Create a new VP8 encoder with automatic hardware codec detection
///
/// Returns an error if no hardware encoder is available.
/// VP8 hardware encoding requires Intel VAAPI support.
pub fn new(config: VP8Config) -> Result<Self> {
init_hwcodec_logging();
let width = config.base.resolution.width;
let height = config.base.resolution.height;
let (encoder_type, codec_name) = detect_best_vp8_encoder(width, height);
if encoder_type == VP8EncoderType::None {
return Err(AppError::VideoError(
"No VP8 encoder available. Please ensure FFmpeg is built with libvpx support.".to_string(),
));
}
let codec_name = codec_name.unwrap();
Self::with_codec(config, &codec_name)
}
/// Create encoder with specific codec name
pub fn with_codec(config: VP8Config, codec_name: &str) -> Result<Self> {
init_hwcodec_logging();
// Determine if this is a software encoder
let is_software = codec_name.contains("libvpx");
// Warn about software encoder performance
if is_software {
warn!(
"Using software VP8 encoder (libvpx) - high CPU usage expected. \
Hardware encoder is recommended for better performance."
);
}
let width = config.base.resolution.width;
let height = config.base.resolution.height;
// Software encoders (libvpx) require YUV420P, hardware (VAAPI) uses NV12
let (pixfmt, actual_input_format) = if is_software {
(AVPixelFormat::AV_PIX_FMT_YUV420P, VP8InputFormat::Yuv420p)
} else {
match config.input_format {
VP8InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, VP8InputFormat::Nv12),
VP8InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, VP8InputFormat::Yuv420p),
}
};
info!(
"Creating VP8 encoder: {} at {}x{} @ {} kbps (input: {:?})",
codec_name, width, height, config.bitrate_kbps, actual_input_format
);
let ctx = EncodeContext {
name: codec_name.to_string(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt,
align: 1,
fps: config.fps as i32,
gop: config.gop_size as i32,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: config.bitrate_kbps as i32,
q: 23,
thread_count: 1,
};
let inner = HwEncoder::new(ctx).map_err(|_| {
AppError::VideoError(format!("Failed to create VP8 encoder: {}", codec_name))
})?;
let buffer_length = inner.length;
let backend = EncoderBackend::from_codec_name(codec_name);
let encoder_type = VP8EncoderType::from(backend);
// Update config to reflect actual input format used
let mut config = config;
config.input_format = actual_input_format;
info!(
"VP8 encoder created: {} (type: {}, buffer_length: {})",
codec_name, encoder_type, buffer_length
);
Ok(Self {
inner,
config,
encoder_type,
codec_name: codec_name.to_string(),
frame_count: 0,
buffer_length,
})
}
/// Create with auto-detected encoder
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
let config = VP8Config::low_latency(resolution, bitrate_kbps);
Self::new(config)
}
/// Get encoder type
pub fn encoder_type(&self) -> &VP8EncoderType {
&self.encoder_type
}
/// Get codec name
pub fn codec_name(&self) -> &str {
&self.codec_name
}
/// Update bitrate dynamically
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
AppError::VideoError("Failed to set VP8 bitrate".to_string())
})?;
self.config.bitrate_kbps = bitrate_kbps;
debug!("VP8 bitrate updated to {} kbps", bitrate_kbps);
Ok(())
}
/// Encode raw frame data
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
if data.len() < self.buffer_length as usize {
return Err(AppError::VideoError(format!(
"Frame data too small: {} < {}",
data.len(),
self.buffer_length
)));
}
self.frame_count += 1;
match self.inner.encode(data, pts_ms) {
Ok(frames) => {
let owned_frames: Vec<HwEncodeFrame> = frames
.iter()
.map(|f| HwEncodeFrame {
data: f.data.clone(),
pts: f.pts,
key: f.key,
})
.collect();
Ok(owned_frames)
}
Err(e) => {
error!("VP8 encode failed: {}", e);
Err(AppError::VideoError(format!("VP8 encode failed: {}", e)))
}
}
}
/// Encode NV12 data
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
self.encode_raw(nv12_data, pts_ms)
}
/// Get input format
pub fn input_format(&self) -> VP8InputFormat {
self.config.input_format
}
/// Get buffer info
pub fn buffer_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
(
self.inner.linesize.clone(),
self.inner.offset.clone(),
self.inner.length,
)
}
}
// SAFETY: VP8Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
// that are not Send by default. However, we ensure that VP8Encoder is only used from
// a single task/thread at a time (encoding is sequential), so this is safe.
unsafe impl Send for VP8Encoder {}
impl Encoder for VP8Encoder {
fn name(&self) -> &str {
&self.codec_name
}
fn output_format(&self) -> EncodedFormat {
EncodedFormat::Vp8
}
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
let frames = self.encode_raw(data, pts_ms)?;
if frames.is_empty() {
warn!("VP8 encoder returned no frames");
return Err(AppError::VideoError(
"VP8 encoder returned no frames".to_string(),
));
}
let frame = &frames[0];
let key_frame = frame.key == 1;
Ok(EncodedFrame {
data: Bytes::from(frame.data.clone()),
format: EncodedFormat::Vp8,
resolution: self.config.base.resolution,
key_frame,
sequence,
timestamp: std::time::Instant::now(),
pts: frame.pts as u64,
dts: frame.pts as u64,
})
}
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
Ok(vec![])
}
fn reset(&mut self) -> Result<()> {
self.frame_count = 0;
Ok(())
}
fn config(&self) -> &EncoderConfig {
&self.config.base
}
fn supports_format(&self, format: PixelFormat) -> bool {
match self.config.input_format {
VP8InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
VP8InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_vp8_encoder() {
let (encoder_type, codec_name) = detect_best_vp8_encoder(1280, 720);
println!("Detected VP8 encoder: {:?} ({:?})", encoder_type, codec_name);
}
#[test]
fn test_available_vp8_encoders() {
let encoders = get_available_vp8_encoders(1280, 720);
println!("Available VP8 hardware encoders:");
for enc in &encoders {
println!(" - {} ({:?})", enc.name, enc.format);
}
}
#[test]
fn test_vp8_availability() {
let available = is_vp8_available();
println!("VP8 hardware encoding available: {}", available);
}
}

488
src/video/encoder/vp9.rs Normal file
View File

@@ -0,0 +1,488 @@
//! VP9 encoder using hwcodec (FFmpeg wrapper)
//!
//! Supports both hardware and software encoding:
//! - Hardware: VAAPI (Intel on Linux)
//! - Software: libvpx-vp9 (CPU-based, high CPU usage)
//!
//! Hardware encoding is preferred when available for better performance.
use bytes::Bytes;
use std::sync::Once;
use tracing::{debug, error, info, warn};
use hwcodec::common::{DataFormat, Quality, RateControl};
use hwcodec::ffmpeg::AVPixelFormat;
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
use hwcodec::ffmpeg_ram::CodecInfo;
use super::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
use crate::error::{AppError, Result};
use crate::video::format::{PixelFormat, Resolution};
static INIT_LOGGING: Once = Once::new();
/// Initialize hwcodec logging (only once)
fn init_hwcodec_logging() {
INIT_LOGGING.call_once(|| {
debug!("hwcodec logging initialized for VP9");
});
}
/// VP9 encoder type (detected from hwcodec)
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VP9EncoderType {
/// VAAPI (Intel on Linux)
Vaapi,
/// Software encoder (libvpx-vp9)
Software,
/// No encoder available
None,
}
impl std::fmt::Display for VP9EncoderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VP9EncoderType::Vaapi => write!(f, "VAAPI"),
VP9EncoderType::Software => write!(f, "Software"),
VP9EncoderType::None => write!(f, "None"),
}
}
}
impl Default for VP9EncoderType {
fn default() -> Self {
Self::None
}
}
impl From<EncoderBackend> for VP9EncoderType {
fn from(backend: EncoderBackend) -> Self {
match backend {
EncoderBackend::Vaapi => VP9EncoderType::Vaapi,
EncoderBackend::Software => VP9EncoderType::Software,
_ => VP9EncoderType::None,
}
}
}
/// Input pixel format for VP9 encoder
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VP9InputFormat {
/// YUV420P (I420) - planar Y, U, V
Yuv420p,
/// NV12 - Y plane + interleaved UV plane
Nv12,
}
impl Default for VP9InputFormat {
fn default() -> Self {
Self::Nv12 // Default to NV12 for VAAPI compatibility
}
}
/// VP9 encoder configuration
#[derive(Debug, Clone)]
pub struct VP9Config {
/// Base encoder config
pub base: EncoderConfig,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// GOP size (keyframe interval)
pub gop_size: u32,
/// Frame rate
pub fps: u32,
/// Input pixel format
pub input_format: VP9InputFormat,
}
impl Default for VP9Config {
fn default() -> Self {
Self {
base: EncoderConfig::default(),
bitrate_kbps: 8000,
gop_size: 30,
fps: 30,
input_format: VP9InputFormat::Nv12,
}
}
}
impl VP9Config {
/// Create config for low latency streaming with NV12 input
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
base: EncoderConfig {
resolution,
input_format: PixelFormat::Nv12,
quality: bitrate_kbps,
fps: 30,
gop_size: 30,
},
bitrate_kbps,
gop_size: 30,
fps: 30,
input_format: VP9InputFormat::Nv12,
}
}
/// Set input format
pub fn with_input_format(mut self, format: VP9InputFormat) -> Self {
self.input_format = format;
self
}
}
/// Get available VP9 hardware encoders from hwcodec
pub fn get_available_vp9_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
init_hwcodec_logging();
let ctx = EncodeContext {
name: String::new(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
align: 1,
fps: 30,
gop: 30,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: 2000,
q: 23,
thread_count: 1,
};
let all_encoders = HwEncoder::available_encoders(ctx, None);
// Include both hardware and software VP9 encoders
all_encoders
.into_iter()
.filter(|e| e.format == DataFormat::VP9)
.collect()
}
/// Detect best available VP9 encoder (hardware preferred, software fallback)
pub fn detect_best_vp9_encoder(width: u32, height: u32) -> (VP9EncoderType, Option<String>) {
let encoders = get_available_vp9_encoders(width, height);
if encoders.is_empty() {
warn!("No VP9 encoders available");
return (VP9EncoderType::None, None);
}
// Prefer hardware encoders (VAAPI) over software (libvpx-vp9)
let codec = encoders
.iter()
.find(|e| e.name.contains("vaapi"))
.or_else(|| encoders.first())
.unwrap();
let encoder_type = if codec.name.contains("vaapi") {
VP9EncoderType::Vaapi
} else if codec.name.contains("libvpx") {
VP9EncoderType::Software
} else {
VP9EncoderType::Software // Default to software for unknown
};
info!(
"Selected VP9 encoder: {} ({})",
codec.name, encoder_type
);
(encoder_type, Some(codec.name.clone()))
}
/// Check if VP9 hardware encoding is available
pub fn is_vp9_available() -> bool {
let registry = EncoderRegistry::global();
registry.is_format_available(VideoEncoderType::VP9, true)
}
/// Encoded frame from hwcodec (cloned for ownership)
#[derive(Debug, Clone)]
pub struct HwEncodeFrame {
pub data: Vec<u8>,
pub pts: i64,
pub key: i32,
}
/// VP9 encoder using hwcodec (hardware only - VAAPI)
pub struct VP9Encoder {
/// hwcodec encoder instance
inner: HwEncoder,
/// Encoder configuration
config: VP9Config,
/// Detected encoder type
encoder_type: VP9EncoderType,
/// Codec name
codec_name: String,
/// Frame counter
frame_count: u64,
/// Required buffer length from hwcodec
buffer_length: i32,
}
impl VP9Encoder {
/// Create a new VP9 encoder with automatic hardware codec detection
///
/// Returns an error if no hardware encoder is available.
/// VP9 hardware encoding requires Intel VAAPI support.
pub fn new(config: VP9Config) -> Result<Self> {
init_hwcodec_logging();
let width = config.base.resolution.width;
let height = config.base.resolution.height;
let (encoder_type, codec_name) = detect_best_vp9_encoder(width, height);
if encoder_type == VP9EncoderType::None {
return Err(AppError::VideoError(
"No VP9 encoder available. Please ensure FFmpeg is built with libvpx support.".to_string(),
));
}
let codec_name = codec_name.unwrap();
Self::with_codec(config, &codec_name)
}
/// Create encoder with specific codec name
pub fn with_codec(config: VP9Config, codec_name: &str) -> Result<Self> {
init_hwcodec_logging();
// Determine if this is a software encoder
let is_software = codec_name.contains("libvpx");
// Warn about software encoder performance
if is_software {
warn!(
"Using software VP9 encoder (libvpx-vp9) - high CPU usage expected. \
Hardware encoder is recommended for better performance."
);
}
let width = config.base.resolution.width;
let height = config.base.resolution.height;
// Software encoders (libvpx-vp9) require YUV420P, hardware (VAAPI) uses NV12
let (pixfmt, actual_input_format) = if is_software {
(AVPixelFormat::AV_PIX_FMT_YUV420P, VP9InputFormat::Yuv420p)
} else {
match config.input_format {
VP9InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, VP9InputFormat::Nv12),
VP9InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, VP9InputFormat::Yuv420p),
}
};
info!(
"Creating VP9 encoder: {} at {}x{} @ {} kbps (input: {:?})",
codec_name, width, height, config.bitrate_kbps, actual_input_format
);
let ctx = EncodeContext {
name: codec_name.to_string(),
mc_name: None,
width: width as i32,
height: height as i32,
pixfmt,
align: 1,
fps: config.fps as i32,
gop: config.gop_size as i32,
rc: RateControl::RC_CBR,
quality: Quality::Quality_Default,
kbs: config.bitrate_kbps as i32,
q: 31,
thread_count: 4, // VP9 benefits from multi-threading
};
let inner = HwEncoder::new(ctx).map_err(|_| {
AppError::VideoError(format!("Failed to create VP9 encoder: {}", codec_name))
})?;
let buffer_length = inner.length;
let backend = EncoderBackend::from_codec_name(codec_name);
let encoder_type = VP9EncoderType::from(backend);
// Update config to reflect actual input format used
let mut config = config;
config.input_format = actual_input_format;
info!(
"VP9 encoder created: {} (type: {}, buffer_length: {})",
codec_name, encoder_type, buffer_length
);
Ok(Self {
inner,
config,
encoder_type,
codec_name: codec_name.to_string(),
frame_count: 0,
buffer_length,
})
}
/// Create with auto-detected encoder
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
let config = VP9Config::low_latency(resolution, bitrate_kbps);
Self::new(config)
}
/// Get encoder type
pub fn encoder_type(&self) -> &VP9EncoderType {
&self.encoder_type
}
/// Get codec name
pub fn codec_name(&self) -> &str {
&self.codec_name
}
/// Update bitrate dynamically
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
AppError::VideoError("Failed to set VP9 bitrate".to_string())
})?;
self.config.bitrate_kbps = bitrate_kbps;
debug!("VP9 bitrate updated to {} kbps", bitrate_kbps);
Ok(())
}
/// Encode raw frame data
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
if data.len() < self.buffer_length as usize {
return Err(AppError::VideoError(format!(
"Frame data too small: {} < {}",
data.len(),
self.buffer_length
)));
}
self.frame_count += 1;
match self.inner.encode(data, pts_ms) {
Ok(frames) => {
let owned_frames: Vec<HwEncodeFrame> = frames
.iter()
.map(|f| HwEncodeFrame {
data: f.data.clone(),
pts: f.pts,
key: f.key,
})
.collect();
Ok(owned_frames)
}
Err(e) => {
error!("VP9 encode failed: {}", e);
Err(AppError::VideoError(format!("VP9 encode failed: {}", e)))
}
}
}
/// Encode NV12 data
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
self.encode_raw(nv12_data, pts_ms)
}
/// Get input format
pub fn input_format(&self) -> VP9InputFormat {
self.config.input_format
}
/// Get buffer info
pub fn buffer_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
(
self.inner.linesize.clone(),
self.inner.offset.clone(),
self.inner.length,
)
}
}
// SAFETY: VP9Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
// that are not Send by default. However, we ensure that VP9Encoder is only used from
// a single task/thread at a time (encoding is sequential), so this is safe.
unsafe impl Send for VP9Encoder {}
impl Encoder for VP9Encoder {
fn name(&self) -> &str {
&self.codec_name
}
fn output_format(&self) -> EncodedFormat {
EncodedFormat::Vp9
}
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
let frames = self.encode_raw(data, pts_ms)?;
if frames.is_empty() {
warn!("VP9 encoder returned no frames");
return Err(AppError::VideoError(
"VP9 encoder returned no frames".to_string(),
));
}
let frame = &frames[0];
let key_frame = frame.key == 1;
Ok(EncodedFrame {
data: Bytes::from(frame.data.clone()),
format: EncodedFormat::Vp9,
resolution: self.config.base.resolution,
key_frame,
sequence,
timestamp: std::time::Instant::now(),
pts: frame.pts as u64,
dts: frame.pts as u64,
})
}
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
Ok(vec![])
}
fn reset(&mut self) -> Result<()> {
self.frame_count = 0;
Ok(())
}
fn config(&self) -> &EncoderConfig {
&self.config.base
}
fn supports_format(&self, format: PixelFormat) -> bool {
match self.config.input_format {
VP9InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
VP9InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_vp9_encoder() {
let (encoder_type, codec_name) = detect_best_vp9_encoder(1280, 720);
println!("Detected VP9 encoder: {:?} ({:?})", encoder_type, codec_name);
}
#[test]
fn test_available_vp9_encoders() {
let encoders = get_available_vp9_encoders(1280, 720);
println!("Available VP9 hardware encoders:");
for enc in &encoders {
println!(" - {} ({:?})", enc.name, enc.format);
}
}
#[test]
fn test_vp9_availability() {
let available = is_vp9_available();
println!("VP9 hardware encoding available: {}", available);
}
}

259
src/video/format.rs Normal file
View File

@@ -0,0 +1,259 @@
//! Pixel format definitions and conversions
use serde::{Deserialize, Serialize};
use std::fmt;
use v4l::format::fourcc;
/// Supported pixel formats
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum PixelFormat {
/// MJPEG compressed format (preferred for capture cards)
Mjpeg,
/// JPEG compressed format
Jpeg,
/// YUYV 4:2:2 packed format
Yuyv,
/// YVYU 4:2:2 packed format
Yvyu,
/// UYVY 4:2:2 packed format
Uyvy,
/// NV12 semi-planar format (Y plane + interleaved UV)
Nv12,
/// NV16 semi-planar format
Nv16,
/// NV24 semi-planar format
Nv24,
/// YUV420 planar format
Yuv420,
/// YVU420 planar format
Yvu420,
/// RGB565 format
Rgb565,
/// RGB24 format (3 bytes per pixel)
Rgb24,
/// BGR24 format (3 bytes per pixel)
Bgr24,
/// Grayscale format
Grey,
}
impl PixelFormat {
/// Convert to V4L2 FourCC
pub fn to_fourcc(&self) -> fourcc::FourCC {
match self {
PixelFormat::Mjpeg => fourcc::FourCC::new(b"MJPG"),
PixelFormat::Jpeg => fourcc::FourCC::new(b"JPEG"),
PixelFormat::Yuyv => fourcc::FourCC::new(b"YUYV"),
PixelFormat::Yvyu => fourcc::FourCC::new(b"YVYU"),
PixelFormat::Uyvy => fourcc::FourCC::new(b"UYVY"),
PixelFormat::Nv12 => fourcc::FourCC::new(b"NV12"),
PixelFormat::Nv16 => fourcc::FourCC::new(b"NV16"),
PixelFormat::Nv24 => fourcc::FourCC::new(b"NV24"),
PixelFormat::Yuv420 => fourcc::FourCC::new(b"YU12"),
PixelFormat::Yvu420 => fourcc::FourCC::new(b"YV12"),
PixelFormat::Rgb565 => fourcc::FourCC::new(b"RGBP"),
PixelFormat::Rgb24 => fourcc::FourCC::new(b"RGB3"),
PixelFormat::Bgr24 => fourcc::FourCC::new(b"BGR3"),
PixelFormat::Grey => fourcc::FourCC::new(b"GREY"),
}
}
/// Try to convert from V4L2 FourCC
pub fn from_fourcc(fourcc: fourcc::FourCC) -> Option<Self> {
let repr = fourcc.repr;
match &repr {
b"MJPG" => Some(PixelFormat::Mjpeg),
b"JPEG" => Some(PixelFormat::Jpeg),
b"YUYV" => Some(PixelFormat::Yuyv),
b"YVYU" => Some(PixelFormat::Yvyu),
b"UYVY" => Some(PixelFormat::Uyvy),
b"NV12" => Some(PixelFormat::Nv12),
b"NV16" => Some(PixelFormat::Nv16),
b"NV24" => Some(PixelFormat::Nv24),
b"YU12" | b"I420" => Some(PixelFormat::Yuv420),
b"YV12" => Some(PixelFormat::Yvu420),
b"RGBP" => Some(PixelFormat::Rgb565),
b"RGB3" => Some(PixelFormat::Rgb24),
b"BGR3" => Some(PixelFormat::Bgr24),
b"GREY" | b"Y800" => Some(PixelFormat::Grey),
_ => None,
}
}
/// Check if format is compressed (JPEG/MJPEG)
pub fn is_compressed(&self) -> bool {
matches!(self, PixelFormat::Mjpeg | PixelFormat::Jpeg)
}
/// Get bytes per pixel for uncompressed formats
/// Returns None for compressed formats
pub fn bytes_per_pixel(&self) -> Option<usize> {
match self {
PixelFormat::Mjpeg | PixelFormat::Jpeg => None,
PixelFormat::Yuyv | PixelFormat::Yvyu | PixelFormat::Uyvy => Some(2),
PixelFormat::Nv12 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => None, // Variable
PixelFormat::Nv16 => None,
PixelFormat::Nv24 => None,
PixelFormat::Rgb565 => Some(2),
PixelFormat::Rgb24 | PixelFormat::Bgr24 => Some(3),
PixelFormat::Grey => Some(1),
}
}
/// Calculate expected frame size for a given resolution
/// Returns None for compressed formats (variable size)
pub fn frame_size(&self, resolution: Resolution) -> Option<usize> {
let pixels = (resolution.width * resolution.height) as usize;
match self {
PixelFormat::Mjpeg | PixelFormat::Jpeg => None,
PixelFormat::Yuyv | PixelFormat::Yvyu | PixelFormat::Uyvy => Some(pixels * 2),
PixelFormat::Nv12 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => Some(pixels * 3 / 2),
PixelFormat::Nv16 => Some(pixels * 2),
PixelFormat::Nv24 => Some(pixels * 3),
PixelFormat::Rgb565 => Some(pixels * 2),
PixelFormat::Rgb24 | PixelFormat::Bgr24 => Some(pixels * 3),
PixelFormat::Grey => Some(pixels),
}
}
/// Get priority for format selection (higher is better)
/// MJPEG is preferred for HDMI capture cards
pub fn priority(&self) -> u8 {
match self {
PixelFormat::Mjpeg => 100,
PixelFormat::Jpeg => 99,
PixelFormat::Yuyv => 80,
PixelFormat::Nv12 => 75,
PixelFormat::Yuv420 => 70,
PixelFormat::Uyvy => 65,
PixelFormat::Yvyu => 64,
PixelFormat::Yvu420 => 63,
PixelFormat::Nv16 => 60,
PixelFormat::Nv24 => 55,
PixelFormat::Rgb24 => 50,
PixelFormat::Bgr24 => 49,
PixelFormat::Rgb565 => 40,
PixelFormat::Grey => 10,
}
}
/// Get all supported formats
pub fn all() -> &'static [PixelFormat] {
&[
PixelFormat::Mjpeg,
PixelFormat::Jpeg,
PixelFormat::Yuyv,
PixelFormat::Yvyu,
PixelFormat::Uyvy,
PixelFormat::Nv12,
PixelFormat::Nv16,
PixelFormat::Nv24,
PixelFormat::Yuv420,
PixelFormat::Yvu420,
PixelFormat::Rgb565,
PixelFormat::Rgb24,
PixelFormat::Bgr24,
PixelFormat::Grey,
]
}
}
impl fmt::Display for PixelFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
PixelFormat::Mjpeg => "MJPEG",
PixelFormat::Jpeg => "JPEG",
PixelFormat::Yuyv => "YUYV",
PixelFormat::Yvyu => "YVYU",
PixelFormat::Uyvy => "UYVY",
PixelFormat::Nv12 => "NV12",
PixelFormat::Nv16 => "NV16",
PixelFormat::Nv24 => "NV24",
PixelFormat::Yuv420 => "YUV420",
PixelFormat::Yvu420 => "YVU420",
PixelFormat::Rgb565 => "RGB565",
PixelFormat::Rgb24 => "RGB24",
PixelFormat::Bgr24 => "BGR24",
PixelFormat::Grey => "GREY",
};
write!(f, "{}", name)
}
}
impl std::str::FromStr for PixelFormat {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"MJPEG" | "MJPG" => Ok(PixelFormat::Mjpeg),
"JPEG" => Ok(PixelFormat::Jpeg),
"YUYV" => Ok(PixelFormat::Yuyv),
"YVYU" => Ok(PixelFormat::Yvyu),
"UYVY" => Ok(PixelFormat::Uyvy),
"NV12" => Ok(PixelFormat::Nv12),
"NV16" => Ok(PixelFormat::Nv16),
"NV24" => Ok(PixelFormat::Nv24),
"YUV420" | "I420" => Ok(PixelFormat::Yuv420),
"YVU420" | "YV12" => Ok(PixelFormat::Yvu420),
"RGB565" => Ok(PixelFormat::Rgb565),
"RGB24" => Ok(PixelFormat::Rgb24),
"BGR24" => Ok(PixelFormat::Bgr24),
"GREY" | "GRAY" => Ok(PixelFormat::Grey),
_ => Err(format!("Unknown pixel format: {}", s)),
}
}
}
/// Resolution (width x height)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Resolution {
pub width: u32,
pub height: u32,
}
impl Resolution {
pub fn new(width: u32, height: u32) -> Self {
Self { width, height }
}
/// Check if resolution is valid
pub fn is_valid(&self) -> bool {
self.width >= 160 && self.width <= 15360 && self.height >= 120 && self.height <= 8640
}
/// Get total pixels
pub fn pixels(&self) -> u64 {
self.width as u64 * self.height as u64
}
/// Common resolutions
pub const VGA: Resolution = Resolution {
width: 640,
height: 480,
};
pub const HD720: Resolution = Resolution {
width: 1280,
height: 720,
};
pub const HD1080: Resolution = Resolution {
width: 1920,
height: 1080,
};
pub const UHD4K: Resolution = Resolution {
width: 3840,
height: 2160,
};
}
impl fmt::Display for Resolution {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}x{}", self.width, self.height)
}
}
impl From<(u32, u32)> for Resolution {
fn from((width, height): (u32, u32)) -> Self {
Self { width, height }
}
}

239
src/video/frame.rs Normal file
View File

@@ -0,0 +1,239 @@
//! Video frame data structures
use bytes::Bytes;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Instant;
use super::format::{PixelFormat, Resolution};
/// A video frame with metadata
#[derive(Debug, Clone)]
pub struct VideoFrame {
/// Raw frame data
data: Arc<Bytes>,
/// Cached xxHash64 of frame data (lazy computed for deduplication)
hash: Arc<OnceLock<u64>>,
/// Frame resolution
pub resolution: Resolution,
/// Pixel format
pub format: PixelFormat,
/// Stride (bytes per line)
pub stride: u32,
/// Whether this is a key frame (for compressed formats)
pub key_frame: bool,
/// Frame sequence number
pub sequence: u64,
/// Timestamp when frame was captured
pub capture_ts: Instant,
/// Whether capture is online (signal present)
pub online: bool,
}
impl VideoFrame {
/// Create a new video frame
pub fn new(
data: Bytes,
resolution: Resolution,
format: PixelFormat,
stride: u32,
sequence: u64,
) -> Self {
Self {
data: Arc::new(data),
hash: Arc::new(OnceLock::new()),
resolution,
format,
stride,
key_frame: true,
sequence,
capture_ts: Instant::now(),
online: true,
}
}
/// Create a frame from a Vec<u8>
pub fn from_vec(
data: Vec<u8>,
resolution: Resolution,
format: PixelFormat,
stride: u32,
sequence: u64,
) -> Self {
Self::new(Bytes::from(data), resolution, format, stride, sequence)
}
/// Get frame data as bytes slice
pub fn data(&self) -> &[u8] {
&self.data
}
/// Get frame data as Bytes (cheap clone)
pub fn data_bytes(&self) -> Bytes {
(*self.data).clone()
}
/// Get data length
pub fn len(&self) -> usize {
self.data.len()
}
/// Check if frame is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Get width
pub fn width(&self) -> u32 {
self.resolution.width
}
/// Get height
pub fn height(&self) -> u32 {
self.resolution.height
}
/// Get age of this frame (time since capture)
pub fn age(&self) -> std::time::Duration {
self.capture_ts.elapsed()
}
/// Check if this frame is still fresh (within threshold)
pub fn is_fresh(&self, max_age_ms: u64) -> bool {
self.age().as_millis() < max_age_ms as u128
}
/// Get hash of frame data (computed once, cached)
/// Used for fast frame deduplication comparison
pub fn get_hash(&self) -> u64 {
*self.hash.get_or_init(|| {
xxhash_rust::xxh64::xxh64(self.data.as_ref(), 0)
})
}
/// Check if format is JPEG/MJPEG
pub fn is_jpeg(&self) -> bool {
self.format.is_compressed()
}
/// Validate JPEG frame data
pub fn is_valid_jpeg(&self) -> bool {
if !self.is_jpeg() {
return false;
}
if self.data.len() < 125 {
return false;
}
// Check JPEG header
let start_marker = ((self.data[0] as u16) << 8) | self.data[1] as u16;
if start_marker != 0xFFD8 {
return false;
}
// Check JPEG end marker
let end = self.data.len();
let end_marker = ((self.data[end - 2] as u16) << 8) | self.data[end - 1] as u16;
// Valid end markers: 0xFFD9, 0xD900, 0x0000 (padded)
matches!(end_marker, 0xFFD9 | 0xD900 | 0x0000)
}
/// Create an offline placeholder frame
pub fn offline(resolution: Resolution, format: PixelFormat) -> Self {
Self {
data: Arc::new(Bytes::new()),
hash: Arc::new(OnceLock::new()),
resolution,
format,
stride: 0,
key_frame: true,
sequence: 0,
capture_ts: Instant::now(),
online: false,
}
}
}
/// Frame metadata without actual data (for logging/stats)
#[derive(Debug, Clone)]
pub struct FrameMeta {
pub resolution: Resolution,
pub format: PixelFormat,
pub size: usize,
pub sequence: u64,
pub key_frame: bool,
pub online: bool,
}
impl From<&VideoFrame> for FrameMeta {
fn from(frame: &VideoFrame) -> Self {
Self {
resolution: frame.resolution,
format: frame.format,
size: frame.len(),
sequence: frame.sequence,
key_frame: frame.key_frame,
online: frame.online,
}
}
}
/// Ring buffer for storing recent frames
pub struct FrameRing {
frames: Vec<Option<VideoFrame>>,
capacity: usize,
write_pos: usize,
count: usize,
}
impl FrameRing {
/// Create a new frame ring with specified capacity
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "Ring capacity must be > 0");
Self {
frames: (0..capacity).map(|_| None).collect(),
capacity,
write_pos: 0,
count: 0,
}
}
/// Push a frame into the ring
pub fn push(&mut self, frame: VideoFrame) {
self.frames[self.write_pos] = Some(frame);
self.write_pos = (self.write_pos + 1) % self.capacity;
if self.count < self.capacity {
self.count += 1;
}
}
/// Get the latest frame
pub fn latest(&self) -> Option<&VideoFrame> {
if self.count == 0 {
return None;
}
let pos = if self.write_pos == 0 {
self.capacity - 1
} else {
self.write_pos - 1
};
self.frames[pos].as_ref()
}
/// Get number of frames in ring
pub fn len(&self) -> usize {
self.count
}
/// Check if ring is empty
pub fn is_empty(&self) -> bool {
self.count == 0
}
/// Clear all frames
pub fn clear(&mut self) {
for frame in &mut self.frames {
*frame = None;
}
self.write_pos = 0;
self.count = 0;
}
}

539
src/video/h264_pipeline.rs Normal file
View File

@@ -0,0 +1,539 @@
//! H264 video encoding pipeline for WebRTC streaming
//!
//! This module provides a complete H264 encoding pipeline that connects:
//! 1. Video capture (YUYV/MJPEG from V4L2)
//! 2. Pixel conversion (YUYV → YUV420P) or JPEG decode
//! 3. H264 encoding (via hwcodec)
//! 4. RTP packetization and WebRTC track output
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, error, info, warn};
use crate::error::{AppError, Result};
use crate::video::convert::Nv12Converter;
use crate::video::decoder::mjpeg::{MjpegVaapiDecoder, MjpegVaapiDecoderConfig};
use crate::video::encoder::h264::{H264Config, H264Encoder};
use crate::video::format::{PixelFormat, Resolution};
use crate::webrtc::rtp::{H264VideoTrack, H264VideoTrackConfig};
/// H264 pipeline configuration
#[derive(Debug, Clone)]
pub struct H264PipelineConfig {
/// Input resolution
pub resolution: Resolution,
/// Input pixel format (YUYV, NV12, etc.)
pub input_format: PixelFormat,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// Target FPS
pub fps: u32,
/// GOP size (keyframe interval in frames)
pub gop_size: u32,
/// Track ID for WebRTC
pub track_id: String,
/// Stream ID for WebRTC
pub stream_id: String,
}
impl Default for H264PipelineConfig {
fn default() -> Self {
Self {
resolution: Resolution::HD720,
input_format: PixelFormat::Yuyv,
bitrate_kbps: 8000,
fps: 30,
gop_size: 30,
track_id: "video0".to_string(),
stream_id: "one-kvm-stream".to_string(),
}
}
}
/// H264 pipeline statistics
#[derive(Debug, Clone, Default)]
pub struct H264PipelineStats {
/// Total frames captured
pub frames_captured: u64,
/// Total frames encoded
pub frames_encoded: u64,
/// Frames dropped (encoding too slow)
pub frames_dropped: u64,
/// Total bytes encoded
pub bytes_encoded: u64,
/// Keyframes encoded
pub keyframes_encoded: u64,
/// Average encoding time per frame (ms)
pub avg_encode_time_ms: f32,
/// Current encoding FPS
pub current_fps: f32,
/// Errors encountered
pub errors: u64,
}
/// H264 video encoding pipeline
pub struct H264Pipeline {
config: H264PipelineConfig,
/// H264 encoder instance
encoder: Arc<Mutex<Option<H264Encoder>>>,
/// NV12 converter (for BGR24/RGB24/YUYV → NV12)
nv12_converter: Arc<Mutex<Option<Nv12Converter>>>,
/// MJPEG VAAPI decoder (for MJPEG input, outputs NV12)
mjpeg_decoder: Arc<Mutex<Option<MjpegVaapiDecoder>>>,
/// WebRTC video track
video_track: Arc<H264VideoTrack>,
/// Pipeline statistics
stats: Arc<Mutex<H264PipelineStats>>,
/// Running state
running: watch::Sender<bool>,
/// Encode time accumulator for averaging
encode_times: Arc<Mutex<Vec<f32>>>,
}
impl H264Pipeline {
/// Create a new H264 pipeline
pub fn new(config: H264PipelineConfig) -> Result<Self> {
info!(
"Creating H264 pipeline: {}x{} @ {} kbps, {} fps",
config.resolution.width,
config.resolution.height,
config.bitrate_kbps,
config.fps
);
// Determine encoder input format based on pipeline input
// NV12 is optimal for VAAPI, use it for all formats
// VAAPI encoders typically only support NV12 input
let encoder_input_format = crate::video::encoder::h264::H264InputFormat::Nv12;
// Create H264 encoder with appropriate input format
let encoder_config = H264Config {
base: crate::video::encoder::traits::EncoderConfig::h264(
config.resolution,
config.bitrate_kbps,
),
bitrate_kbps: config.bitrate_kbps,
gop_size: config.gop_size,
fps: config.fps,
input_format: encoder_input_format,
};
let encoder = H264Encoder::new(encoder_config)?;
info!(
"H264 encoder created: {} ({}) with {:?} input",
encoder.codec_name(),
encoder.encoder_type(),
encoder_input_format
);
// Create NV12 converter or MJPEG decoder based on input format
// All formats are converted to NV12 for VAAPI encoder
let (nv12_converter, mjpeg_decoder) = match config.input_format {
// NV12 input - direct passthrough
PixelFormat::Nv12 => {
info!("NV12 input: direct passthrough to encoder");
(None, None)
}
// YUYV (4:2:2 packed) → NV12
PixelFormat::Yuyv => {
info!("YUYV input: converting to NV12");
(Some(Nv12Converter::yuyv_to_nv12(config.resolution)), None)
}
// RGB24 → NV12
PixelFormat::Rgb24 => {
info!("RGB24 input: converting to NV12");
(Some(Nv12Converter::rgb24_to_nv12(config.resolution)), None)
}
// BGR24 → NV12
PixelFormat::Bgr24 => {
info!("BGR24 input: converting to NV12");
(Some(Nv12Converter::bgr24_to_nv12(config.resolution)), None)
}
// MJPEG/JPEG → NV12 (via hwcodec decoder)
PixelFormat::Mjpeg | PixelFormat::Jpeg => {
let decoder_config = MjpegVaapiDecoderConfig {
resolution: config.resolution,
use_hwaccel: true,
};
let decoder = MjpegVaapiDecoder::new(decoder_config)?;
info!(
"MJPEG decoder created for H264 pipeline (outputs NV12)"
);
(None, Some(decoder))
}
_ => {
return Err(AppError::VideoError(format!(
"Unsupported input format for H264 pipeline: {}",
config.input_format
)));
}
};
// Create WebRTC video track
let track_config = H264VideoTrackConfig {
track_id: config.track_id.clone(),
stream_id: config.stream_id.clone(),
resolution: config.resolution,
bitrate_kbps: config.bitrate_kbps,
fps: config.fps,
profile_level_id: None, // Let browser negotiate the best profile
};
let video_track = Arc::new(H264VideoTrack::new(track_config));
let (running_tx, _) = watch::channel(false);
Ok(Self {
config,
encoder: Arc::new(Mutex::new(Some(encoder))),
nv12_converter: Arc::new(Mutex::new(nv12_converter)),
mjpeg_decoder: Arc::new(Mutex::new(mjpeg_decoder)),
video_track,
stats: Arc::new(Mutex::new(H264PipelineStats::default())),
running: running_tx,
encode_times: Arc::new(Mutex::new(Vec::with_capacity(100))),
})
}
/// Get the WebRTC video track
pub fn video_track(&self) -> Arc<H264VideoTrack> {
self.video_track.clone()
}
/// Get current statistics
pub async fn stats(&self) -> H264PipelineStats {
self.stats.lock().await.clone()
}
/// Check if pipeline is running
pub fn is_running(&self) -> bool {
*self.running.borrow()
}
/// Start the encoding pipeline
///
/// This starts a background task that receives raw frames from the receiver,
/// encodes them to H264, and sends them to the WebRTC track.
pub async fn start(&self, mut frame_rx: broadcast::Receiver<Vec<u8>>) {
if *self.running.borrow() {
warn!("H264 pipeline already running");
return;
}
let _ = self.running.send(true);
info!("Starting H264 pipeline (input format: {})", self.config.input_format);
let encoder = self.encoder.lock().await.take();
let nv12_converter = self.nv12_converter.lock().await.take();
let mjpeg_decoder = self.mjpeg_decoder.lock().await.take();
let video_track = self.video_track.clone();
let stats = self.stats.clone();
let encode_times = self.encode_times.clone();
let config = self.config.clone();
let mut running_rx = self.running.subscribe();
// Spawn encoding task
tokio::spawn(async move {
let mut encoder = match encoder {
Some(e) => e,
None => {
error!("No encoder available");
return;
}
};
let mut nv12_converter = nv12_converter;
let mut mjpeg_decoder = mjpeg_decoder;
let mut frame_count: u64 = 0;
let mut last_fps_time = Instant::now();
let mut fps_frame_count: u64 = 0;
// Pre-allocated NV12 buffer for MJPEG decoder output (avoids per-frame allocation)
let nv12_size = (config.resolution.width * config.resolution.height * 3 / 2) as usize;
let mut nv12_buffer = vec![0u8; nv12_size];
// Flag for one-time warnings
let mut size_mismatch_warned = false;
loop {
tokio::select! {
biased;
_ = running_rx.changed() => {
if !*running_rx.borrow() {
info!("H264 pipeline stopping");
break;
}
}
result = frame_rx.recv() => {
match result {
Ok(raw_frame) => {
let start = Instant::now();
// Validate frame size for uncompressed formats
if let Some(expected_size) = config.input_format.frame_size(config.resolution) {
if raw_frame.len() != expected_size && !size_mismatch_warned {
warn!(
"Frame size mismatch: got {} bytes, expected {} for {} {}x{}",
raw_frame.len(),
expected_size,
config.input_format,
config.resolution.width,
config.resolution.height
);
size_mismatch_warned = true;
}
}
// Update captured count
{
let mut s = stats.lock().await;
s.frames_captured += 1;
}
// Convert to NV12 for VAAPI encoder
// MJPEG -> NV12 (via VAAPI decoder)
// BGR24/RGB24/YUYV -> NV12 (via NV12 converter)
// NV12 -> pass through
//
// Optimized: avoid unnecessary allocations and copies
frame_count += 1;
fps_frame_count += 1;
let pts_ms = (frame_count * 1000 / config.fps as u64) as i64;
let encode_result = if let Some(ref mut decoder) = mjpeg_decoder {
// MJPEG input - decode to NV12 via VAAPI
match decoder.decode(&raw_frame) {
Ok(nv12_frame) => {
// Calculate required size for this frame
let required_size = (nv12_frame.width * nv12_frame.height * 3 / 2) as usize;
// Resize buffer if needed (handles resolution changes)
if nv12_buffer.len() < required_size {
debug!(
"Resizing NV12 buffer: {} -> {} bytes (resolution: {}x{})",
nv12_buffer.len(), required_size,
nv12_frame.width, nv12_frame.height
);
nv12_buffer.resize(required_size, 0);
}
// Copy to pre-allocated buffer (guaranteed to fit after resize)
let written = nv12_frame.copy_to_packed_nv12(&mut nv12_buffer)
.expect("BUG: buffer too small after resize");
encoder.encode_raw(&nv12_buffer[..written], pts_ms)
}
Err(e) => {
error!("MJPEG VAAPI decode failed: {}", e);
let mut s = stats.lock().await;
s.errors += 1;
continue;
}
}
} else if let Some(ref mut conv) = nv12_converter {
// BGR24/RGB24/YUYV input - convert to NV12
// Optimized: pass reference directly without copy
match conv.convert(&raw_frame) {
Ok(nv12_data) => encoder.encode_raw(nv12_data, pts_ms),
Err(e) => {
error!("NV12 conversion failed: {}", e);
let mut s = stats.lock().await;
s.errors += 1;
continue;
}
}
} else {
// NV12 input - pass reference directly
encoder.encode_raw(&raw_frame, pts_ms)
};
match encode_result {
Ok(frames) => {
if !frames.is_empty() {
let frame = &frames[0];
let is_keyframe = frame.key == 1;
// Send to WebRTC track
let duration = Duration::from_millis(
1000 / config.fps as u64
);
if let Err(e) = video_track
.write_frame(&frame.data, duration, is_keyframe)
.await
{
error!("Failed to write frame to track: {}", e);
let mut s = stats.lock().await;
s.errors += 1;
} else {
// Update stats
let encode_time = start.elapsed().as_secs_f32() * 1000.0;
let mut s = stats.lock().await;
s.frames_encoded += 1;
s.bytes_encoded += frame.data.len() as u64;
if is_keyframe {
s.keyframes_encoded += 1;
}
// Update encode time average
let mut times = encode_times.lock().await;
times.push(encode_time);
if times.len() > 100 {
times.remove(0);
}
if !times.is_empty() {
s.avg_encode_time_ms =
times.iter().sum::<f32>() / times.len() as f32;
}
}
}
}
Err(e) => {
error!("Encoding failed: {}", e);
let mut s = stats.lock().await;
s.errors += 1;
}
}
// Update FPS every second
if last_fps_time.elapsed() >= Duration::from_secs(1) {
let mut s = stats.lock().await;
s.current_fps = fps_frame_count as f32
/ last_fps_time.elapsed().as_secs_f32();
fps_frame_count = 0;
last_fps_time = Instant::now();
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
let mut s = stats.lock().await;
s.frames_dropped += n;
}
Err(broadcast::error::RecvError::Closed) => {
info!("Frame channel closed, stopping H264 pipeline");
break;
}
}
}
}
}
info!("H264 pipeline task exited");
});
}
/// Stop the encoding pipeline
pub fn stop(&self) {
if *self.running.borrow() {
let _ = self.running.send(false);
info!("Stopping H264 pipeline");
}
}
/// Request a keyframe (force IDR)
pub async fn request_keyframe(&self) {
// Note: hwcodec doesn't support on-demand keyframe requests
// The encoder will produce keyframes based on GOP size
debug!("Keyframe requested (will occur at next GOP boundary)");
}
/// Update bitrate dynamically
pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> {
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate_kbps)?;
info!("H264 pipeline bitrate updated to {} kbps", bitrate_kbps);
}
Ok(())
}
}
/// Builder for H264 pipeline configuration
pub struct H264PipelineBuilder {
config: H264PipelineConfig,
}
impl H264PipelineBuilder {
pub fn new() -> Self {
Self {
config: H264PipelineConfig::default(),
}
}
pub fn resolution(mut self, resolution: Resolution) -> Self {
self.config.resolution = resolution;
self
}
pub fn input_format(mut self, format: PixelFormat) -> Self {
self.config.input_format = format;
self
}
pub fn bitrate_kbps(mut self, bitrate: u32) -> Self {
self.config.bitrate_kbps = bitrate;
self
}
pub fn fps(mut self, fps: u32) -> Self {
self.config.fps = fps;
self
}
pub fn gop_size(mut self, gop: u32) -> Self {
self.config.gop_size = gop;
self
}
pub fn track_id(mut self, id: &str) -> Self {
self.config.track_id = id.to_string();
self
}
pub fn stream_id(mut self, id: &str) -> Self {
self.config.stream_id = id.to_string();
self
}
pub fn build(self) -> Result<H264Pipeline> {
H264Pipeline::new(self.config)
}
}
impl Default for H264PipelineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config_default() {
let config = H264PipelineConfig::default();
assert_eq!(config.resolution, Resolution::HD720);
assert_eq!(config.bitrate_kbps, 2000);
assert_eq!(config.fps, 30);
assert_eq!(config.gop_size, 30);
}
#[test]
fn test_pipeline_builder() {
let builder = H264PipelineBuilder::new()
.resolution(Resolution::HD1080)
.bitrate_kbps(4000)
.fps(60)
.input_format(PixelFormat::Yuyv);
assert_eq!(builder.config.resolution, Resolution::HD1080);
assert_eq!(builder.config.bitrate_kbps, 4000);
assert_eq!(builder.config.fps, 60);
assert_eq!(builder.config.input_format, PixelFormat::Yuyv);
}
}

29
src/video/mod.rs Normal file
View File

@@ -0,0 +1,29 @@
//! Video capture and streaming module
//!
//! This module provides V4L2 video capture, encoding, and streaming functionality.
pub mod capture;
pub mod convert;
pub mod decoder;
pub mod device;
pub mod encoder;
pub mod format;
pub mod frame;
pub mod h264_pipeline;
pub mod shared_video_pipeline;
pub mod stream_manager;
pub mod streamer;
pub mod video_session;
pub use capture::VideoCapturer;
pub use convert::{MjpegDecoder, MjpegToYuv420Converter, PixelConverter, Yuv420pBuffer};
pub use decoder::{MjpegVaapiDecoder, MjpegVaapiDecoderConfig};
pub use device::{VideoDevice, VideoDeviceInfo};
pub use encoder::{JpegEncoder, H264Encoder, H264EncoderType};
pub use format::PixelFormat;
pub use frame::VideoFrame;
pub use h264_pipeline::{H264Pipeline, H264PipelineBuilder, H264PipelineConfig};
pub use shared_video_pipeline::{EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats};
pub use stream_manager::VideoStreamManager;
pub use streamer::{Streamer, StreamerState};
pub use video_session::{VideoSessionManager, VideoSessionManagerConfig, VideoSessionInfo, VideoSessionState, CodecInfo};

View File

@@ -0,0 +1,940 @@
//! Universal shared video encoding pipeline
//!
//! Supports multiple codecs: H264, H265, VP8, VP9
//! A single encoder broadcasts to multiple WebRTC sessions.
//!
//! Architecture:
//! ```text
//! VideoCapturer (MJPEG/YUYV/NV12)
//! |
//! v (broadcast::Receiver<VideoFrame>)
//! SharedVideoPipeline (single encoder)
//! |
//! v (broadcast::Sender<EncodedVideoFrame>)
//! ┌────┴────┬────────┬────────┐
//! v v v v
//! Session1 Session2 Session3 ...
//! ```
use bytes::Bytes;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, watch, Mutex, RwLock};
use tracing::{debug, error, info, warn};
use crate::error::{AppError, Result};
use crate::video::convert::{Nv12Converter, PixelConverter};
use crate::video::decoder::mjpeg::{MjpegTurboDecoder, MjpegVaapiDecoder, MjpegVaapiDecoderConfig};
use crate::video::encoder::h264::{H264Config, H264Encoder};
use crate::video::encoder::h265::{H265Config, H265Encoder};
use crate::video::encoder::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
use crate::video::encoder::traits::EncoderConfig;
use crate::video::encoder::vp8::{VP8Config, VP8Encoder};
use crate::video::encoder::vp9::{VP9Config, VP9Encoder};
use crate::video::format::{PixelFormat, Resolution};
use crate::video::frame::VideoFrame;
/// Encoded video frame for distribution
#[derive(Debug, Clone)]
pub struct EncodedVideoFrame {
/// Encoded data (Annex B for H264/H265, raw for VP8/VP9)
pub data: Bytes,
/// Presentation timestamp in milliseconds
pub pts_ms: i64,
/// Whether this is a keyframe
pub is_keyframe: bool,
/// Frame sequence number
pub sequence: u64,
/// Frame duration
pub duration: Duration,
/// Codec type
pub codec: VideoEncoderType,
}
/// Shared video pipeline configuration
#[derive(Debug, Clone)]
pub struct SharedVideoPipelineConfig {
/// Input resolution
pub resolution: Resolution,
/// Input pixel format
pub input_format: PixelFormat,
/// Output codec type
pub output_codec: VideoEncoderType,
/// Target bitrate in kbps
pub bitrate_kbps: u32,
/// Target FPS
pub fps: u32,
/// GOP size
pub gop_size: u32,
/// Encoder backend (None = auto select best available)
pub encoder_backend: Option<EncoderBackend>,
}
impl Default for SharedVideoPipelineConfig {
fn default() -> Self {
Self {
resolution: Resolution::HD720,
input_format: PixelFormat::Yuyv,
output_codec: VideoEncoderType::H264,
bitrate_kbps: 8000,
fps: 30,
gop_size: 30,
encoder_backend: None,
}
}
}
impl SharedVideoPipelineConfig {
/// Create H264 config
pub fn h264(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
resolution,
output_codec: VideoEncoderType::H264,
bitrate_kbps,
..Default::default()
}
}
/// Create H265 config
pub fn h265(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
resolution,
output_codec: VideoEncoderType::H265,
bitrate_kbps,
..Default::default()
}
}
/// Create VP8 config
pub fn vp8(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
resolution,
output_codec: VideoEncoderType::VP8,
bitrate_kbps,
..Default::default()
}
}
/// Create VP9 config
pub fn vp9(resolution: Resolution, bitrate_kbps: u32) -> Self {
Self {
resolution,
output_codec: VideoEncoderType::VP9,
bitrate_kbps,
..Default::default()
}
}
}
/// Pipeline statistics
#[derive(Debug, Clone, Default)]
pub struct SharedVideoPipelineStats {
pub frames_captured: u64,
pub frames_encoded: u64,
pub frames_dropped: u64,
pub bytes_encoded: u64,
pub keyframes_encoded: u64,
pub avg_encode_time_ms: f32,
pub current_fps: f32,
pub errors: u64,
pub subscribers: u64,
}
/// Universal video encoder trait object
#[allow(dead_code)]
trait VideoEncoderTrait: Send {
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>>;
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()>;
fn codec_name(&self) -> &str;
fn request_keyframe(&mut self);
}
/// Encoded frame from encoder
#[allow(dead_code)]
struct EncodedFrame {
data: Vec<u8>,
pts: i64,
key: i32,
}
/// H264 encoder wrapper
struct H264EncoderWrapper(H264Encoder);
impl VideoEncoderTrait for H264EncoderWrapper {
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
let frames = self.0.encode_raw(data, pts_ms)?;
Ok(frames
.into_iter()
.map(|f| EncodedFrame {
data: f.data,
pts: f.pts,
key: f.key,
})
.collect())
}
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.0.set_bitrate(bitrate_kbps)
}
fn codec_name(&self) -> &str {
self.0.codec_name()
}
fn request_keyframe(&mut self) {
self.0.request_keyframe()
}
}
/// H265 encoder wrapper
struct H265EncoderWrapper(H265Encoder);
impl VideoEncoderTrait for H265EncoderWrapper {
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
let frames = self.0.encode_raw(data, pts_ms)?;
Ok(frames
.into_iter()
.map(|f| EncodedFrame {
data: f.data,
pts: f.pts,
key: f.key,
})
.collect())
}
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.0.set_bitrate(bitrate_kbps)
}
fn codec_name(&self) -> &str {
self.0.codec_name()
}
fn request_keyframe(&mut self) {
self.0.request_keyframe()
}
}
/// VP8 encoder wrapper
struct VP8EncoderWrapper(VP8Encoder);
impl VideoEncoderTrait for VP8EncoderWrapper {
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
let frames = self.0.encode_raw(data, pts_ms)?;
Ok(frames
.into_iter()
.map(|f| EncodedFrame {
data: f.data,
pts: f.pts,
key: f.key,
})
.collect())
}
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.0.set_bitrate(bitrate_kbps)
}
fn codec_name(&self) -> &str {
self.0.codec_name()
}
fn request_keyframe(&mut self) {
// VP8 encoder doesn't support request_keyframe yet
}
}
/// VP9 encoder wrapper
struct VP9EncoderWrapper(VP9Encoder);
impl VideoEncoderTrait for VP9EncoderWrapper {
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
let frames = self.0.encode_raw(data, pts_ms)?;
Ok(frames
.into_iter()
.map(|f| EncodedFrame {
data: f.data,
pts: f.pts,
key: f.key,
})
.collect())
}
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
self.0.set_bitrate(bitrate_kbps)
}
fn codec_name(&self) -> &str {
self.0.codec_name()
}
fn request_keyframe(&mut self) {
// VP9 encoder doesn't support request_keyframe yet
}
}
/// Universal shared video pipeline
pub struct SharedVideoPipeline {
config: RwLock<SharedVideoPipelineConfig>,
encoder: Mutex<Option<Box<dyn VideoEncoderTrait + Send>>>,
nv12_converter: Mutex<Option<Nv12Converter>>,
yuv420p_converter: Mutex<Option<PixelConverter>>,
mjpeg_decoder: Mutex<Option<MjpegVaapiDecoder>>,
/// Turbojpeg decoder for direct MJPEG->YUV420P (optimized for software encoders)
mjpeg_turbo_decoder: Mutex<Option<MjpegTurboDecoder>>,
nv12_buffer: Mutex<Vec<u8>>,
/// YUV420P buffer for turbojpeg decoder output
yuv420p_buffer: Mutex<Vec<u8>>,
/// Whether the encoder needs YUV420P (true) or NV12 (false)
encoder_needs_yuv420p: Mutex<bool>,
frame_tx: broadcast::Sender<EncodedVideoFrame>,
stats: Mutex<SharedVideoPipelineStats>,
running: watch::Sender<bool>,
running_rx: watch::Receiver<bool>,
encode_times: Mutex<Vec<f32>>,
sequence: Mutex<u64>,
/// Atomic flag for keyframe request (avoids lock contention)
keyframe_requested: AtomicBool,
}
impl SharedVideoPipeline {
/// Create a new shared video pipeline
pub fn new(config: SharedVideoPipelineConfig) -> Result<Arc<Self>> {
info!(
"Creating shared video pipeline: {} {}x{} @ {} kbps (input: {})",
config.output_codec,
config.resolution.width,
config.resolution.height,
config.bitrate_kbps,
config.input_format
);
let (frame_tx, _) = broadcast::channel(16);
let (running_tx, running_rx) = watch::channel(false);
let nv12_size = (config.resolution.width * config.resolution.height * 3 / 2) as usize;
let yuv420p_size = nv12_size; // Same size as NV12
let pipeline = Arc::new(Self {
config: RwLock::new(config),
encoder: Mutex::new(None),
nv12_converter: Mutex::new(None),
yuv420p_converter: Mutex::new(None),
mjpeg_decoder: Mutex::new(None),
mjpeg_turbo_decoder: Mutex::new(None),
nv12_buffer: Mutex::new(vec![0u8; nv12_size]),
yuv420p_buffer: Mutex::new(vec![0u8; yuv420p_size]),
encoder_needs_yuv420p: Mutex::new(false),
frame_tx,
stats: Mutex::new(SharedVideoPipelineStats::default()),
running: running_tx,
running_rx,
encode_times: Mutex::new(Vec::with_capacity(100)),
sequence: Mutex::new(0),
keyframe_requested: AtomicBool::new(false),
});
Ok(pipeline)
}
/// Initialize encoder based on config
async fn init_encoder(&self) -> Result<()> {
let config = self.config.read().await.clone();
let registry = EncoderRegistry::global();
// Helper to get codec name for specific backend
let get_codec_name = |format: VideoEncoderType, backend: Option<EncoderBackend>| -> Option<String> {
match backend {
Some(b) => registry.encoder_with_backend(format, b).map(|e| e.codec_name.clone()),
None => registry.best_encoder(format, false).map(|e| e.codec_name.clone()),
}
};
// Create encoder based on codec type
let encoder: Box<dyn VideoEncoderTrait + Send> = match config.output_codec {
VideoEncoderType::H264 => {
let encoder_config = H264Config {
base: EncoderConfig::h264(config.resolution, config.bitrate_kbps),
bitrate_kbps: config.bitrate_kbps,
gop_size: config.gop_size,
fps: config.fps,
input_format: crate::video::encoder::h264::H264InputFormat::Nv12,
};
let encoder = if let Some(ref backend) = config.encoder_backend {
// Specific backend requested
let codec_name = get_codec_name(VideoEncoderType::H264, Some(*backend))
.ok_or_else(|| AppError::VideoError(format!(
"Backend {:?} does not support H.264", backend
)))?;
info!("Creating H264 encoder with backend {:?} (codec: {})", backend, codec_name);
H264Encoder::with_codec(encoder_config, &codec_name)?
} else {
// Auto select
H264Encoder::new(encoder_config)?
};
info!("Created H264 encoder: {}", encoder.codec_name());
Box::new(H264EncoderWrapper(encoder))
}
VideoEncoderType::H265 => {
let encoder_config = H265Config::low_latency(config.resolution, config.bitrate_kbps);
let encoder = if let Some(ref backend) = config.encoder_backend {
let codec_name = get_codec_name(VideoEncoderType::H265, Some(*backend))
.ok_or_else(|| AppError::VideoError(format!(
"Backend {:?} does not support H.265", backend
)))?;
info!("Creating H265 encoder with backend {:?} (codec: {})", backend, codec_name);
H265Encoder::with_codec(encoder_config, &codec_name)?
} else {
H265Encoder::new(encoder_config)?
};
info!("Created H265 encoder: {}", encoder.codec_name());
Box::new(H265EncoderWrapper(encoder))
}
VideoEncoderType::VP8 => {
let encoder_config = VP8Config::low_latency(config.resolution, config.bitrate_kbps);
let encoder = if let Some(ref backend) = config.encoder_backend {
let codec_name = get_codec_name(VideoEncoderType::VP8, Some(*backend))
.ok_or_else(|| AppError::VideoError(format!(
"Backend {:?} does not support VP8", backend
)))?;
info!("Creating VP8 encoder with backend {:?} (codec: {})", backend, codec_name);
VP8Encoder::with_codec(encoder_config, &codec_name)?
} else {
VP8Encoder::new(encoder_config)?
};
info!("Created VP8 encoder: {}", encoder.codec_name());
Box::new(VP8EncoderWrapper(encoder))
}
VideoEncoderType::VP9 => {
let encoder_config = VP9Config::low_latency(config.resolution, config.bitrate_kbps);
let encoder = if let Some(ref backend) = config.encoder_backend {
let codec_name = get_codec_name(VideoEncoderType::VP9, Some(*backend))
.ok_or_else(|| AppError::VideoError(format!(
"Backend {:?} does not support VP9", backend
)))?;
info!("Creating VP9 encoder with backend {:?} (codec: {})", backend, codec_name);
VP9Encoder::with_codec(encoder_config, &codec_name)?
} else {
VP9Encoder::new(encoder_config)?
};
info!("Created VP9 encoder: {}", encoder.codec_name());
Box::new(VP9EncoderWrapper(encoder))
}
};
// Determine if encoder needs YUV420P (software encoders) or NV12 (hardware encoders)
let codec_name = encoder.codec_name();
let needs_yuv420p = codec_name.contains("libvpx") || codec_name.contains("libx265");
info!(
"Encoder {} needs {} format",
codec_name,
if needs_yuv420p { "YUV420P" } else { "NV12" }
);
// Create converter or decoder based on input format and encoder needs
info!("Initializing input format handler for: {} -> {}",
config.input_format,
if needs_yuv420p { "YUV420P" } else { "NV12" });
let (nv12_converter, yuv420p_converter, mjpeg_decoder, mjpeg_turbo_decoder) = if needs_yuv420p {
// Software encoder needs YUV420P
match config.input_format {
PixelFormat::Yuv420 => {
info!("Using direct YUV420P input (no conversion)");
(None, None, None, None)
}
PixelFormat::Yuyv => {
info!("Using YUYV->YUV420P converter");
(None, Some(PixelConverter::yuyv_to_yuv420p(config.resolution)), None, None)
}
PixelFormat::Nv12 => {
info!("Using NV12->YUV420P converter");
(None, Some(PixelConverter::nv12_to_yuv420p(config.resolution)), None, None)
}
PixelFormat::Rgb24 => {
info!("Using RGB24->YUV420P converter");
(None, Some(PixelConverter::rgb24_to_yuv420p(config.resolution)), None, None)
}
PixelFormat::Bgr24 => {
info!("Using BGR24->YUV420P converter");
(None, Some(PixelConverter::bgr24_to_yuv420p(config.resolution)), None, None)
}
PixelFormat::Mjpeg | PixelFormat::Jpeg => {
// Use turbojpeg for direct MJPEG->YUV420P (no intermediate NV12)
info!("Using turbojpeg MJPEG decoder (direct YUV420P output)");
let turbo_decoder = MjpegTurboDecoder::new(config.resolution)?;
(None, None, None, Some(turbo_decoder))
}
_ => {
return Err(AppError::VideoError(format!(
"Unsupported input format: {}",
config.input_format
)));
}
}
} else {
// Hardware encoder needs NV12
match config.input_format {
PixelFormat::Nv12 => {
info!("Using direct NV12 input (no conversion)");
(None, None, None, None)
}
PixelFormat::Yuyv => {
info!("Using YUYV->NV12 converter");
(Some(Nv12Converter::yuyv_to_nv12(config.resolution)), None, None, None)
}
PixelFormat::Rgb24 => {
info!("Using RGB24->NV12 converter");
(Some(Nv12Converter::rgb24_to_nv12(config.resolution)), None, None, None)
}
PixelFormat::Bgr24 => {
info!("Using BGR24->NV12 converter");
(Some(Nv12Converter::bgr24_to_nv12(config.resolution)), None, None, None)
}
PixelFormat::Mjpeg | PixelFormat::Jpeg => {
info!("Using MJPEG decoder (NV12 output)");
let decoder_config = MjpegVaapiDecoderConfig {
resolution: config.resolution,
use_hwaccel: true,
};
let decoder = MjpegVaapiDecoder::new(decoder_config)?;
(None, None, Some(decoder), None)
}
_ => {
return Err(AppError::VideoError(format!(
"Unsupported input format: {}",
config.input_format
)));
}
}
};
*self.encoder.lock().await = Some(encoder);
*self.nv12_converter.lock().await = nv12_converter;
*self.yuv420p_converter.lock().await = yuv420p_converter;
*self.mjpeg_decoder.lock().await = mjpeg_decoder;
*self.mjpeg_turbo_decoder.lock().await = mjpeg_turbo_decoder;
*self.encoder_needs_yuv420p.lock().await = needs_yuv420p;
Ok(())
}
/// Subscribe to encoded frames
pub fn subscribe(&self) -> broadcast::Receiver<EncodedVideoFrame> {
self.frame_tx.subscribe()
}
/// Get subscriber count
pub fn subscriber_count(&self) -> usize {
self.frame_tx.receiver_count()
}
/// Request encoder to produce a keyframe on next encode
///
/// This is useful when a new client connects and needs an immediate
/// keyframe to start decoding the video stream.
///
/// Uses an atomic flag to avoid lock contention with the encoding loop.
pub async fn request_keyframe(&self) {
self.keyframe_requested.store(true, Ordering::Release);
info!("[Pipeline] Keyframe requested for new client");
}
/// Get current stats
pub async fn stats(&self) -> SharedVideoPipelineStats {
let mut stats = self.stats.lock().await.clone();
stats.subscribers = self.frame_tx.receiver_count() as u64;
stats
}
/// Check if running
pub fn is_running(&self) -> bool {
*self.running_rx.borrow()
}
/// Get current codec
pub async fn current_codec(&self) -> VideoEncoderType {
self.config.read().await.output_codec
}
/// Switch codec (requires restart)
pub async fn switch_codec(&self, codec: VideoEncoderType) -> Result<()> {
let was_running = self.is_running();
if was_running {
self.stop();
tokio::time::sleep(Duration::from_millis(100)).await;
}
{
let mut config = self.config.write().await;
config.output_codec = codec;
}
// Clear encoder state
*self.encoder.lock().await = None;
*self.nv12_converter.lock().await = None;
*self.yuv420p_converter.lock().await = None;
*self.mjpeg_decoder.lock().await = None;
*self.mjpeg_turbo_decoder.lock().await = None;
*self.encoder_needs_yuv420p.lock().await = false;
info!("Switched to {} codec", codec);
Ok(())
}
/// Start the pipeline
pub async fn start(self: &Arc<Self>, mut frame_rx: broadcast::Receiver<VideoFrame>) -> Result<()> {
if *self.running_rx.borrow() {
warn!("Pipeline already running");
return Ok(());
}
self.init_encoder().await?;
let _ = self.running.send(true);
let config = self.config.read().await.clone();
info!("Starting {} pipeline", config.output_codec);
let pipeline = self.clone();
tokio::spawn(async move {
let mut frame_count: u64 = 0;
let mut last_fps_time = Instant::now();
let mut fps_frame_count: u64 = 0;
let mut running_rx = pipeline.running_rx.clone();
loop {
tokio::select! {
biased;
_ = running_rx.changed() => {
if !*running_rx.borrow() {
break;
}
}
result = frame_rx.recv() => {
match result {
Ok(video_frame) => {
pipeline.stats.lock().await.frames_captured += 1;
if pipeline.frame_tx.receiver_count() == 0 {
continue;
}
let start = Instant::now();
match pipeline.encode_frame(&video_frame, frame_count).await {
Ok(Some(encoded_frame)) => {
let encode_time = start.elapsed().as_secs_f32() * 1000.0;
let _ = pipeline.frame_tx.send(encoded_frame.clone());
let is_keyframe = encoded_frame.is_keyframe;
// Update stats
{
let mut s = pipeline.stats.lock().await;
s.frames_encoded += 1;
s.bytes_encoded += encoded_frame.data.len() as u64;
if is_keyframe {
s.keyframes_encoded += 1;
}
let mut times = pipeline.encode_times.lock().await;
times.push(encode_time);
if times.len() > 100 {
times.remove(0);
}
if !times.is_empty() {
s.avg_encode_time_ms = times.iter().sum::<f32>() / times.len() as f32;
}
}
frame_count += 1;
fps_frame_count += 1;
}
Ok(None) => {}
Err(e) => {
error!("Encoding failed: {}", e);
pipeline.stats.lock().await.errors += 1;
}
}
if last_fps_time.elapsed() >= Duration::from_secs(1) {
let mut s = pipeline.stats.lock().await;
s.current_fps = fps_frame_count as f32 / last_fps_time.elapsed().as_secs_f32();
fps_frame_count = 0;
last_fps_time = Instant::now();
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
pipeline.stats.lock().await.frames_dropped += n;
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
info!("Video pipeline stopped");
});
Ok(())
}
/// Encode a single frame
async fn encode_frame(&self, frame: &VideoFrame, frame_count: u64) -> Result<Option<EncodedVideoFrame>> {
let config = self.config.read().await;
let raw_frame = frame.data();
let fps = config.fps;
let codec = config.output_codec;
drop(config);
let pts_ms = (frame_count * 1000 / fps as u64) as i64;
// Debug log for H265
if codec == VideoEncoderType::H265 && frame_count % 30 == 1 {
debug!(
"[Pipeline-H265] Processing frame #{}: input_size={}, pts_ms={}",
frame_count,
raw_frame.len(),
pts_ms
);
}
let mut mjpeg_decoder = self.mjpeg_decoder.lock().await;
let mut mjpeg_turbo_decoder = self.mjpeg_turbo_decoder.lock().await;
let mut nv12_converter = self.nv12_converter.lock().await;
let mut yuv420p_converter = self.yuv420p_converter.lock().await;
let needs_yuv420p = *self.encoder_needs_yuv420p.lock().await;
let mut encoder_guard = self.encoder.lock().await;
let encoder = encoder_guard.as_mut().ok_or_else(|| {
AppError::VideoError("Encoder not initialized".to_string())
})?;
// Check and consume keyframe request (atomic, no lock contention)
if self.keyframe_requested.swap(false, Ordering::AcqRel) {
encoder.request_keyframe();
debug!("[Pipeline] Keyframe will be generated for this frame");
}
let encode_result = if mjpeg_turbo_decoder.is_some() {
// Optimized path: MJPEG -> YUV420P directly via turbojpeg (for software encoders)
let turbo = mjpeg_turbo_decoder.as_mut().unwrap();
let mut yuv420p_buffer = self.yuv420p_buffer.lock().await;
let written = turbo.decode_to_yuv420p_buffer(raw_frame, &mut yuv420p_buffer)
.map_err(|e| AppError::VideoError(format!("turbojpeg decode failed: {}", e)))?;
encoder.encode_raw(&yuv420p_buffer[..written], pts_ms)
} else if mjpeg_decoder.is_some() {
// MJPEG input: decode to NV12 (for hardware encoders)
let decoder = mjpeg_decoder.as_mut().unwrap();
let nv12_frame = decoder.decode(raw_frame)
.map_err(|e| AppError::VideoError(format!("MJPEG decode failed: {}", e)))?;
let required_size = (nv12_frame.width * nv12_frame.height * 3 / 2) as usize;
let mut nv12_buffer = self.nv12_buffer.lock().await;
if nv12_buffer.len() < required_size {
nv12_buffer.resize(required_size, 0);
}
let written = nv12_frame.copy_to_packed_nv12(&mut nv12_buffer)
.expect("Buffer too small");
// Debug log for H265 after MJPEG decode
if codec == VideoEncoderType::H265 && frame_count % 30 == 1 {
debug!(
"[Pipeline-H265] MJPEG decoded: nv12_size={}, frame_width={}, frame_height={}",
written, nv12_frame.width, nv12_frame.height
);
}
encoder.encode_raw(&nv12_buffer[..written], pts_ms)
} else if needs_yuv420p && yuv420p_converter.is_some() {
// Software encoder with direct input conversion to YUV420P
let conv = yuv420p_converter.as_mut().unwrap();
let yuv420p_data = conv.convert(raw_frame)
.map_err(|e| AppError::VideoError(format!("YUV420P conversion failed: {}", e)))?;
encoder.encode_raw(yuv420p_data, pts_ms)
} else if nv12_converter.is_some() {
// Hardware encoder with input conversion to NV12
let conv = nv12_converter.as_mut().unwrap();
let nv12_data = conv.convert(raw_frame)
.map_err(|e| AppError::VideoError(format!("NV12 conversion failed: {}", e)))?;
encoder.encode_raw(nv12_data, pts_ms)
} else {
// Direct input (already in correct format)
encoder.encode_raw(raw_frame, pts_ms)
};
drop(encoder_guard);
drop(nv12_converter);
drop(yuv420p_converter);
drop(mjpeg_decoder);
drop(mjpeg_turbo_decoder);
match encode_result {
Ok(frames) => {
if !frames.is_empty() {
let encoded = frames.into_iter().next().unwrap();
let is_keyframe = encoded.key == 1;
let sequence = {
let mut seq = self.sequence.lock().await;
*seq += 1;
*seq
};
// Debug log for H265 encoded frame
if codec == VideoEncoderType::H265 && (is_keyframe || frame_count % 30 == 1) {
debug!(
"[Pipeline-H265] Encoded frame #{}: output_size={}, keyframe={}, sequence={}",
frame_count,
encoded.data.len(),
is_keyframe,
sequence
);
// Log H265 NAL unit types in the encoded data
if is_keyframe {
let nal_types = parse_h265_nal_types(&encoded.data);
debug!("[Pipeline-H265] Keyframe NAL types: {:?}", nal_types);
}
}
let config = self.config.read().await;
Ok(Some(EncodedVideoFrame {
data: Bytes::from(encoded.data),
pts_ms,
is_keyframe,
sequence,
duration: Duration::from_millis(1000 / config.fps as u64),
codec,
}))
} else {
if codec == VideoEncoderType::H265 {
warn!("[Pipeline-H265] Encoder returned no frames for frame #{}", frame_count);
}
Ok(None)
}
}
Err(e) => {
if codec == VideoEncoderType::H265 {
error!("[Pipeline-H265] Encode error at frame #{}: {}", frame_count, e);
}
Err(e)
},
}
}
/// Stop the pipeline
pub fn stop(&self) {
if *self.running_rx.borrow() {
let _ = self.running.send(false);
info!("Stopping video pipeline");
}
}
/// Set bitrate
pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> {
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate_kbps)?;
self.config.write().await.bitrate_kbps = bitrate_kbps;
}
Ok(())
}
/// Get current config
pub async fn config(&self) -> SharedVideoPipelineConfig {
self.config.read().await.clone()
}
}
impl Drop for SharedVideoPipeline {
fn drop(&mut self) {
let _ = self.running.send(false);
}
}
/// Parse H265 NAL unit types from Annex B data
fn parse_h265_nal_types(data: &[u8]) -> Vec<(u8, usize)> {
let mut nal_types = Vec::new();
let mut i = 0;
while i < data.len() {
// Find start code
let nal_start = if i + 4 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 0
&& data[i + 3] == 1
{
i + 4
} else if i + 3 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 1
{
i + 3
} else {
i += 1;
continue;
};
if nal_start >= data.len() {
break;
}
// Find next start code to get NAL size
let mut nal_end = data.len();
let mut j = nal_start + 1;
while j + 3 <= data.len() {
if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1)
|| (j + 4 <= data.len()
&& data[j] == 0
&& data[j + 1] == 0
&& data[j + 2] == 0
&& data[j + 3] == 1)
{
nal_end = j;
break;
}
j += 1;
}
// H265 NAL type is in bits 1-6 of first byte
let nal_type = (data[nal_start] >> 1) & 0x3F;
let nal_size = nal_end - nal_start;
nal_types.push((nal_type, nal_size));
i = nal_end;
}
nal_types
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_config() {
let h264 = SharedVideoPipelineConfig::h264(Resolution::HD1080, 4000);
assert_eq!(h264.output_codec, VideoEncoderType::H264);
let h265 = SharedVideoPipelineConfig::h265(Resolution::HD720, 2000);
assert_eq!(h265.output_codec, VideoEncoderType::H265);
}
}

574
src/video/stream_manager.rs Normal file
View File

@@ -0,0 +1,574 @@
//! Video Stream Manager
//!
//! Unified manager for video streaming that supports single-mode operation.
//! At any given time, only one streaming mode (MJPEG or WebRTC) is active.
//!
//! # Architecture
//!
//! ```text
//! VideoStreamManager (Public API - Single Entry Point)
//! │
//! ├── mode: StreamMode (current active mode)
//! │
//! ├── MJPEG Mode
//! │ └── Streamer ──► MjpegStreamHandler
//! │ (Future: MjpegStreamer with WsAudio/WsHid)
//! │
//! └── WebRTC Mode
//! └── WebRtcStreamer ──► H264SessionManager
//! (Extensible: H264, VP8, VP9, H265)
//! ```
//!
//! # Design Goals
//!
//! 1. **Single Entry Point**: All video operations go through VideoStreamManager
//! 2. **Mode Isolation**: MJPEG and WebRTC modes are cleanly separated
//! 3. **Extensible Codecs**: WebRTC supports multiple video codecs (H264 now, others reserved)
//! 4. **Simplified API**: Complex configuration flows are encapsulated
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use crate::config::{ConfigStore, StreamMode};
use crate::error::Result;
use crate::events::{EventBus, SystemEvent, VideoDeviceInfo};
use crate::hid::HidController;
use crate::stream::MjpegStreamHandler;
use crate::video::format::{PixelFormat, Resolution};
use crate::video::streamer::{Streamer, StreamerState};
use crate::webrtc::WebRtcStreamer;
/// Video stream manager configuration
#[derive(Debug, Clone)]
pub struct StreamManagerConfig {
/// Initial streaming mode
pub mode: StreamMode,
/// Video device path
pub device: Option<String>,
/// Video format
pub format: PixelFormat,
/// Resolution
pub resolution: Resolution,
/// FPS
pub fps: u32,
}
impl Default for StreamManagerConfig {
fn default() -> Self {
Self {
mode: StreamMode::Mjpeg,
device: None,
format: PixelFormat::Mjpeg,
resolution: Resolution::HD1080,
fps: 30,
}
}
}
/// Unified video stream manager
///
/// Manages both MJPEG and WebRTC streaming modes, ensuring only one is active
/// at any given time. This reduces resource usage and simplifies the architecture.
///
/// # Components
///
/// - **Streamer**: Handles video capture and MJPEG distribution (current implementation)
/// - **WebRtcStreamer**: High-level WebRTC manager with multi-codec support (new)
/// - **H264SessionManager**: Legacy WebRTC manager (for backward compatibility)
pub struct VideoStreamManager {
/// Current streaming mode
mode: RwLock<StreamMode>,
/// MJPEG streamer (handles video capture and MJPEG distribution)
streamer: Arc<Streamer>,
/// WebRTC streamer (unified WebRTC manager with multi-codec support)
webrtc_streamer: Arc<WebRtcStreamer>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Configuration store
config_store: RwLock<Option<ConfigStore>>,
/// Mode switching lock to prevent concurrent switch requests
switching: AtomicBool,
}
impl VideoStreamManager {
/// Create a new video stream manager with WebRtcStreamer
pub fn with_webrtc_streamer(
streamer: Arc<Streamer>,
webrtc_streamer: Arc<WebRtcStreamer>,
) -> Arc<Self> {
Arc::new(Self {
mode: RwLock::new(StreamMode::Mjpeg),
streamer,
webrtc_streamer,
events: RwLock::new(None),
config_store: RwLock::new(None),
switching: AtomicBool::new(false),
})
}
/// Check if mode switching is in progress
pub fn is_switching(&self) -> bool {
self.switching.load(Ordering::SeqCst)
}
/// Set event bus for notifications
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Set configuration store
pub async fn set_config_store(&self, config: ConfigStore) {
*self.config_store.write().await = Some(config);
}
/// Get current streaming mode
pub async fn current_mode(&self) -> StreamMode {
self.mode.read().await.clone()
}
/// Check if MJPEG mode is active
pub async fn is_mjpeg_enabled(&self) -> bool {
*self.mode.read().await == StreamMode::Mjpeg
}
/// Check if WebRTC mode is active
pub async fn is_webrtc_enabled(&self) -> bool {
*self.mode.read().await == StreamMode::WebRTC
}
/// Get the underlying streamer (for MJPEG mode)
pub fn streamer(&self) -> Arc<Streamer> {
self.streamer.clone()
}
/// Get the WebRTC streamer (unified interface with multi-codec support)
pub fn webrtc_streamer(&self) -> Arc<WebRtcStreamer> {
self.webrtc_streamer.clone()
}
/// Get the MJPEG stream handler
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
self.streamer.mjpeg_handler()
}
/// Initialize with a specific mode
pub async fn init_with_mode(self: &Arc<Self>, mode: StreamMode) -> Result<()> {
info!("Initializing video stream manager with mode: {:?}", mode);
*self.mode.write().await = mode.clone();
// Check if streamer is already initialized (capturer exists)
let needs_init = self.streamer.state().await == StreamerState::Uninitialized;
if needs_init {
match mode {
StreamMode::Mjpeg => {
// Initialize MJPEG streamer
if let Err(e) = self.streamer.init_auto().await {
warn!("Failed to auto-initialize MJPEG streamer: {}", e);
}
}
StreamMode::WebRTC => {
// WebRTC is initialized on-demand when clients connect
// But we still need to initialize the video capture
if let Err(e) = self.streamer.init_auto().await {
warn!("Failed to auto-initialize video capture for WebRTC: {}", e);
}
}
}
}
// Always reconnect frame source after initialization
// This ensures WebRTC has the correct frame_tx from the current capturer
if let Some(frame_tx) = self.streamer.frame_sender().await {
// Synchronize WebRTC config with actual capture format
let (format, resolution, fps) = self.streamer.current_video_config().await;
info!(
"Reconnecting frame source to WebRTC after init: {}x{} {:?} @ {}fps (receiver_count={})",
resolution.width, resolution.height, format, fps, frame_tx.receiver_count()
);
self.webrtc_streamer.update_video_config(resolution, format, fps).await;
self.webrtc_streamer.set_video_source(frame_tx).await;
}
Ok(())
}
/// Switch streaming mode
///
/// This will:
/// 1. Acquire switching lock (prevent concurrent switches)
/// 2. Notify clients of the mode change
/// 3. Stop the current mode
/// 4. Start the new mode (ensuring video capture runs for WebRTC)
/// 5. Update configuration
pub async fn switch_mode(self: &Arc<Self>, new_mode: StreamMode) -> Result<()> {
let current_mode = self.mode.read().await.clone();
if current_mode == new_mode {
debug!("Already in {:?} mode, no switch needed", new_mode);
return Ok(());
}
// Acquire switching lock - prevent concurrent switch requests
if self.switching.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
debug!("Mode switch already in progress, ignoring duplicate request");
return Ok(());
}
// Use a helper to ensure we release the lock when done
let result = self.do_switch_mode(current_mode, new_mode.clone()).await;
self.switching.store(false, Ordering::SeqCst);
result
}
/// Internal implementation of mode switching (called with lock held)
async fn do_switch_mode(self: &Arc<Self>, current_mode: StreamMode, new_mode: StreamMode) -> Result<()> {
info!("Switching video mode: {:?} -> {:?}", current_mode, new_mode);
// Get the actual mode strings (with codec info for WebRTC)
let new_mode_str = match &new_mode {
StreamMode::Mjpeg => "mjpeg".to_string(),
StreamMode::WebRTC => {
let codec = self.webrtc_streamer.current_video_codec().await;
codec_to_string(codec)
}
};
let previous_mode_str = match &current_mode {
StreamMode::Mjpeg => "mjpeg".to_string(),
StreamMode::WebRTC => {
let codec = self.webrtc_streamer.current_video_codec().await;
codec_to_string(codec)
}
};
// 1. Publish mode change event (clients should prepare to reconnect)
self.publish_event(SystemEvent::StreamModeChanged {
mode: new_mode_str,
previous_mode: previous_mode_str,
})
.await;
// 2. Stop current mode
match current_mode {
StreamMode::Mjpeg => {
info!("Stopping MJPEG streaming");
// Only stop MJPEG distribution, keep video capture running for WebRTC
self.streamer.mjpeg_handler().set_offline();
if let Err(e) = self.streamer.stop().await {
warn!("Error stopping MJPEG streamer: {}", e);
}
}
StreamMode::WebRTC => {
info!("Closing all WebRTC sessions");
let closed = self.webrtc_streamer.close_all_sessions().await;
if closed > 0 {
info!("Closed {} WebRTC sessions", closed);
}
}
}
// 3. Update mode
*self.mode.write().await = new_mode.clone();
// 4. Start new mode
match new_mode {
StreamMode::Mjpeg => {
info!("Starting MJPEG streaming");
if let Err(e) = self.streamer.start().await {
error!("Failed to start MJPEG streamer: {}", e);
return Err(e);
}
}
StreamMode::WebRTC => {
// WebRTC mode: ensure video capture is running for H264 encoding
info!("Activating WebRTC mode");
// Initialize streamer if not already initialized
if self.streamer.state().await == StreamerState::Uninitialized {
info!("Initializing video capture for WebRTC");
if let Err(e) = self.streamer.init_auto().await {
error!("Failed to initialize video capture for WebRTC: {}", e);
return Err(e);
}
}
// Start video capture if not streaming
if self.streamer.state().await != StreamerState::Streaming {
info!("Starting video capture for WebRTC");
if let Err(e) = self.streamer.start().await {
error!("Failed to start video capture for WebRTC: {}", e);
return Err(e);
}
}
// Wait a bit for capture to stabilize
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Connect frame source to WebRTC with correct format
if let Some(frame_tx) = self.streamer.frame_sender().await {
// Synchronize WebRTC config with actual capture format
let (format, resolution, fps) = self.streamer.current_video_config().await;
info!(
"Connecting frame source to WebRTC pipeline: {}x{} {:?} @ {}fps",
resolution.width, resolution.height, format, fps
);
self.webrtc_streamer.update_video_config(resolution, format, fps).await;
self.webrtc_streamer.set_video_source(frame_tx).await;
} else {
warn!("No frame source available for WebRTC - sessions may fail to receive video");
}
info!("WebRTC mode activated (sessions created on-demand)");
}
}
// 5. Update configuration store if available
if let Some(ref config_store) = *self.config_store.read().await {
let mut config = (*config_store.get()).clone();
config.stream.mode = new_mode.clone();
if let Err(e) = config_store.set(config).await {
warn!("Failed to persist stream mode to config: {}", e);
}
}
info!("Video mode switched to {:?}", new_mode);
Ok(())
}
/// Apply video configuration (device, format, resolution, fps)
///
/// This is called when video settings change. It will restart the
/// appropriate streaming pipeline based on current mode.
pub async fn apply_video_config(
self: &Arc<Self>,
device_path: &str,
format: PixelFormat,
resolution: Resolution,
fps: u32,
) -> Result<()> {
let mode = self.mode.read().await.clone();
info!(
"Applying video config: {} {:?} {}x{} @ {} fps (mode: {:?})",
device_path, format, resolution.width, resolution.height, fps, mode
);
// Apply to streamer (handles video capture)
self.streamer
.apply_video_config(device_path, format, resolution, fps)
.await?;
// Update WebRTC config if in WebRTC mode
if mode == StreamMode::WebRTC {
self.webrtc_streamer
.update_video_config(resolution, format, fps)
.await;
// Restart video capture for WebRTC (it was stopped during config change)
info!("Restarting video capture for WebRTC after config change");
if let Err(e) = self.streamer.start().await {
error!("Failed to restart video capture for WebRTC: {}", e);
return Err(e);
}
// Wait a bit for capture to stabilize
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Reconnect frame source with the new capturer
if let Some(frame_tx) = self.streamer.frame_sender().await {
// Note: update_video_config was already called above with the requested config,
// but verify that actual capture matches
let (actual_format, actual_resolution, actual_fps) = self.streamer.current_video_config().await;
if actual_format != format || actual_resolution != resolution || actual_fps != fps {
info!(
"Actual capture config differs from requested, updating WebRTC: {}x{} {:?} @ {}fps",
actual_resolution.width, actual_resolution.height, actual_format, actual_fps
);
self.webrtc_streamer.update_video_config(actual_resolution, actual_format, actual_fps).await;
}
info!("Reconnecting frame source to WebRTC after config change");
self.webrtc_streamer.set_video_source(frame_tx).await;
} else {
warn!("No frame source available after config change");
}
}
Ok(())
}
/// Start streaming (based on current mode)
pub async fn start(self: &Arc<Self>) -> Result<()> {
let mode = self.mode.read().await.clone();
match mode {
StreamMode::Mjpeg => {
self.streamer.start().await?;
}
StreamMode::WebRTC => {
// Ensure video capture is running
if self.streamer.state().await == StreamerState::Uninitialized {
self.streamer.init_auto().await?;
}
if self.streamer.state().await != StreamerState::Streaming {
self.streamer.start().await?;
}
// Connect frame source with correct format
if let Some(frame_tx) = self.streamer.frame_sender().await {
// Synchronize WebRTC config with actual capture format
let (format, resolution, fps) = self.streamer.current_video_config().await;
self.webrtc_streamer.update_video_config(resolution, format, fps).await;
self.webrtc_streamer.set_video_source(frame_tx).await;
}
}
}
Ok(())
}
/// Stop streaming
pub async fn stop(&self) -> Result<()> {
let mode = self.mode.read().await.clone();
match mode {
StreamMode::Mjpeg => {
self.streamer.stop().await?;
}
StreamMode::WebRTC => {
self.webrtc_streamer.close_all_sessions().await;
self.streamer.stop().await?;
}
}
Ok(())
}
/// Get video device info for device_info event
pub async fn get_video_info(&self) -> VideoDeviceInfo {
let stats = self.streamer.stats().await;
let state = self.streamer.state().await;
let device = self.streamer.current_device().await;
let mode = self.mode.read().await.clone();
// For WebRTC mode, return specific codec type (h264, h265, vp8, vp9)
// instead of generic "webrtc" to prevent frontend from defaulting to h264
let stream_mode = match &mode {
StreamMode::Mjpeg => "mjpeg".to_string(),
StreamMode::WebRTC => {
let codec = self.webrtc_streamer.current_video_codec().await;
codec_to_string(codec)
}
};
VideoDeviceInfo {
available: state != StreamerState::Uninitialized,
device: device.map(|d| d.path.display().to_string()),
format: stats.format,
resolution: stats.resolution,
fps: stats.target_fps,
online: state == StreamerState::Streaming,
stream_mode,
config_changing: self.streamer.is_config_changing(),
error: if state == StreamerState::Error {
Some("Video stream error".to_string())
} else if state == StreamerState::NoSignal {
Some("No video signal".to_string())
} else {
None
},
}
}
/// Get MJPEG client count
pub fn mjpeg_client_count(&self) -> u64 {
self.streamer.mjpeg_handler().client_count()
}
/// Get WebRTC session count
pub async fn webrtc_session_count(&self) -> usize {
self.webrtc_streamer.session_count().await
}
/// Set HID controller for WebRTC DataChannel
pub async fn set_hid_controller(&self, hid: Arc<HidController>) {
self.webrtc_streamer.set_hid_controller(hid).await;
}
/// Set audio enabled state for WebRTC
pub async fn set_webrtc_audio_enabled(&self, enabled: bool) -> Result<()> {
self.webrtc_streamer.set_audio_enabled(enabled).await
}
/// Check if WebRTC audio is enabled
pub async fn is_webrtc_audio_enabled(&self) -> bool {
self.webrtc_streamer.is_audio_enabled().await
}
/// Reconnect audio sources for all WebRTC sessions
/// Call this after audio controller restarts (e.g., quality change)
pub async fn reconnect_webrtc_audio_sources(&self) {
self.webrtc_streamer.reconnect_audio_sources().await;
}
// =========================================================================
// Delegated methods from Streamer (for backward compatibility)
// =========================================================================
/// List available video devices
pub async fn list_devices(&self) -> crate::error::Result<Vec<crate::video::device::VideoDeviceInfo>> {
self.streamer.list_devices().await
}
/// Get streamer statistics
pub async fn stats(&self) -> crate::video::streamer::StreamerStats {
self.streamer.stats().await
}
/// Check if config is being changed
pub fn is_config_changing(&self) -> bool {
self.streamer.is_config_changing()
}
/// Check if streaming is active
pub async fn is_streaming(&self) -> bool {
self.streamer.is_streaming().await
}
/// Get frame sender for video frames
pub async fn frame_sender(&self) -> Option<tokio::sync::broadcast::Sender<crate::video::frame::VideoFrame>> {
self.streamer.frame_sender().await
}
/// Publish event to event bus
async fn publish_event(&self, event: SystemEvent) {
if let Some(ref events) = *self.events.read().await {
events.publish(event);
}
}
}
/// Convert VideoCodecType to lowercase string for frontend
fn codec_to_string(codec: crate::video::encoder::VideoCodecType) -> String {
match codec {
crate::video::encoder::VideoCodecType::H264 => "h264".to_string(),
crate::video::encoder::VideoCodecType::H265 => "h265".to_string(),
crate::video::encoder::VideoCodecType::VP8 => "vp8".to_string(),
crate::video::encoder::VideoCodecType::VP9 => "vp9".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::video::encoder::VideoCodecType;
#[test]
fn test_codec_to_string() {
assert_eq!(codec_to_string(VideoCodecType::H264), "h264");
assert_eq!(codec_to_string(VideoCodecType::H265), "h265");
assert_eq!(codec_to_string(VideoCodecType::VP8), "vp8");
assert_eq!(codec_to_string(VideoCodecType::VP9), "vp9");
}
}

892
src/video/streamer.rs Normal file
View File

@@ -0,0 +1,892 @@
//! Video streamer that integrates capture and streaming
//!
//! This module provides a high-level interface for video capture and streaming,
//! managing the lifecycle of the capture thread and MJPEG/WebRTC distribution.
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, error, info, trace, warn};
use super::capture::{CaptureConfig, CaptureState, VideoCapturer};
use super::device::{enumerate_devices, find_best_device, VideoDeviceInfo};
use super::format::{PixelFormat, Resolution};
use super::frame::VideoFrame;
use crate::error::{AppError, Result};
use crate::events::{EventBus, SystemEvent};
use crate::stream::MjpegStreamHandler;
/// Streamer configuration
#[derive(Debug, Clone)]
pub struct StreamerConfig {
/// Device path (None = auto-detect)
pub device_path: Option<PathBuf>,
/// Desired resolution
pub resolution: Resolution,
/// Desired format
pub format: PixelFormat,
/// Desired FPS
pub fps: u32,
/// JPEG quality (1-100)
pub jpeg_quality: u8,
}
impl Default for StreamerConfig {
fn default() -> Self {
Self {
device_path: None,
resolution: Resolution::HD1080,
format: PixelFormat::Mjpeg,
fps: 30,
jpeg_quality: 80,
}
}
}
/// Streamer state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamerState {
/// Not initialized
Uninitialized,
/// Ready but not streaming
Ready,
/// Actively streaming
Streaming,
/// No video signal
NoSignal,
/// Error occurred
Error,
/// Device was lost (unplugged)
DeviceLost,
/// Device is being recovered (reconnecting)
Recovering,
}
/// Video streamer service
pub struct Streamer {
config: RwLock<StreamerConfig>,
capturer: RwLock<Option<Arc<VideoCapturer>>>,
mjpeg_handler: Arc<MjpegStreamHandler>,
current_device: RwLock<Option<VideoDeviceInfo>>,
state: RwLock<StreamerState>,
start_lock: tokio::sync::Mutex<()>,
/// Event bus for broadcasting state changes (optional)
events: RwLock<Option<Arc<EventBus>>>,
/// Last published state (for change detection)
last_published_state: RwLock<Option<StreamerState>>,
/// Flag to indicate config is being changed (prevents auto-start during config change)
config_changing: std::sync::atomic::AtomicBool,
/// Flag to indicate background tasks (stats, cleanup, monitor) have been started
/// These tasks should only be started once per Streamer instance
background_tasks_started: std::sync::atomic::AtomicBool,
/// Device recovery retry count
recovery_retry_count: std::sync::atomic::AtomicU32,
/// Device recovery in progress flag
recovery_in_progress: std::sync::atomic::AtomicBool,
/// Last lost device path (for recovery)
last_lost_device: RwLock<Option<String>>,
/// Last lost device reason (for logging)
last_lost_reason: RwLock<Option<String>>,
}
impl Streamer {
/// Create a new streamer
pub fn new() -> Arc<Self> {
Arc::new(Self {
config: RwLock::new(StreamerConfig::default()),
capturer: RwLock::new(None),
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
current_device: RwLock::new(None),
state: RwLock::new(StreamerState::Uninitialized),
start_lock: tokio::sync::Mutex::new(()),
events: RwLock::new(None),
last_published_state: RwLock::new(None),
config_changing: std::sync::atomic::AtomicBool::new(false),
background_tasks_started: std::sync::atomic::AtomicBool::new(false),
recovery_retry_count: std::sync::atomic::AtomicU32::new(0),
recovery_in_progress: std::sync::atomic::AtomicBool::new(false),
last_lost_device: RwLock::new(None),
last_lost_reason: RwLock::new(None),
})
}
/// Create with specific config
pub fn with_config(config: StreamerConfig) -> Arc<Self> {
Arc::new(Self {
config: RwLock::new(config),
capturer: RwLock::new(None),
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
current_device: RwLock::new(None),
state: RwLock::new(StreamerState::Uninitialized),
start_lock: tokio::sync::Mutex::new(()),
events: RwLock::new(None),
last_published_state: RwLock::new(None),
config_changing: std::sync::atomic::AtomicBool::new(false),
background_tasks_started: std::sync::atomic::AtomicBool::new(false),
recovery_retry_count: std::sync::atomic::AtomicU32::new(0),
recovery_in_progress: std::sync::atomic::AtomicBool::new(false),
last_lost_device: RwLock::new(None),
last_lost_reason: RwLock::new(None),
})
}
/// Get current state as SystemEvent
pub async fn current_state_event(&self) -> SystemEvent {
let state = *self.state.read().await;
let device = self.current_device.read().await.as_ref().map(|d| d.path.display().to_string());
SystemEvent::StreamStateChanged {
state: match state {
StreamerState::Uninitialized => "uninitialized".to_string(),
StreamerState::Ready => "ready".to_string(),
StreamerState::Streaming => "streaming".to_string(),
StreamerState::NoSignal => "no_signal".to_string(),
StreamerState::Error => "error".to_string(),
StreamerState::DeviceLost => "device_lost".to_string(),
StreamerState::Recovering => "recovering".to_string(),
},
device,
}
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Get current state
pub async fn state(&self) -> StreamerState {
*self.state.read().await
}
/// Check if config is currently being changed
/// When true, auto-start should be blocked to prevent device busy errors
pub fn is_config_changing(&self) -> bool {
self.config_changing.load(std::sync::atomic::Ordering::SeqCst)
}
/// Get MJPEG handler for stream endpoints
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
self.mjpeg_handler.clone()
}
/// Get frame sender for WebRTC integration
/// Returns None if no capturer is initialized
pub async fn frame_sender(&self) -> Option<broadcast::Sender<VideoFrame>> {
let capturer = self.capturer.read().await;
capturer.as_ref().map(|c| c.frame_sender())
}
/// Subscribe to video frames
/// Returns None if no capturer is initialized
pub async fn subscribe_frames(&self) -> Option<broadcast::Receiver<VideoFrame>> {
let capturer = self.capturer.read().await;
capturer.as_ref().map(|c| c.subscribe())
}
/// Get current device info
pub async fn current_device(&self) -> Option<VideoDeviceInfo> {
self.current_device.read().await.clone()
}
/// Get current video configuration (format, resolution, fps)
pub async fn current_video_config(&self) -> (PixelFormat, Resolution, u32) {
let config = self.config.read().await;
(config.format, config.resolution, config.fps)
}
/// List available video devices
pub async fn list_devices(&self) -> Result<Vec<VideoDeviceInfo>> {
enumerate_devices()
}
/// Validate and apply requested video parameters without auto-selection
pub async fn apply_video_config(
self: &Arc<Self>,
device_path: &str,
format: PixelFormat,
resolution: Resolution,
fps: u32,
) -> Result<()> {
// Set config_changing flag to prevent frontend mode sync during config change
self.config_changing.store(true, std::sync::atomic::Ordering::SeqCst);
let result = self.apply_video_config_inner(device_path, format, resolution, fps).await;
// Clear the flag after config change is complete
// The stream will be started by MJPEG client connection, not here
self.config_changing.store(false, std::sync::atomic::Ordering::SeqCst);
result
}
/// Internal implementation of apply_video_config
async fn apply_video_config_inner(
self: &Arc<Self>,
device_path: &str,
format: PixelFormat,
resolution: Resolution,
fps: u32,
) -> Result<()> {
// Publish "config changing" event
self.publish_event(SystemEvent::StreamConfigChanging {
reason: "device_switch".to_string(),
})
.await;
let devices = enumerate_devices()?;
let device = devices
.into_iter()
.find(|d| d.path.to_string_lossy() == device_path)
.ok_or_else(|| AppError::VideoError("Video device not found".to_string()))?;
// Validate format
let fmt_info = device
.formats
.iter()
.find(|f| f.format == format)
.ok_or_else(|| AppError::VideoError("Requested format not supported".to_string()))?;
// Validate resolution
if !fmt_info.resolutions.is_empty()
&& !fmt_info
.resolutions
.iter()
.any(|r| r.width == resolution.width && r.height == resolution.height)
{
return Err(AppError::VideoError("Requested resolution not supported".to_string()));
}
// IMPORTANT: Disconnect all MJPEG clients FIRST before stopping capture
// This prevents race conditions where clients try to reconnect and reopen the device
info!("Disconnecting all MJPEG clients before config change...");
self.mjpeg_handler.disconnect_all_clients();
// Give clients time to receive the disconnect signal and close their connections
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Stop existing capturer and wait for device release
{
// Take ownership of the old capturer to ensure it's dropped
let old_capturer = self.capturer.write().await.take();
if let Some(capturer) = old_capturer {
info!("Stopping existing capture before applying new config...");
if let Err(e) = capturer.stop().await {
warn!("Error stopping old capturer: {}", e);
}
// Explicitly drop the capturer to release V4L2 resources
drop(capturer);
}
}
// Update config
{
let mut cfg = self.config.write().await;
cfg.device_path = Some(device.path.clone());
cfg.format = format;
cfg.resolution = resolution;
cfg.fps = fps;
}
// Recreate capturer
let capture_config = CaptureConfig {
device_path: device.path.clone(),
resolution,
format,
fps,
jpeg_quality: self.config.read().await.jpeg_quality,
..Default::default()
};
let capturer = Arc::new(VideoCapturer::new(capture_config));
*self.capturer.write().await = Some(capturer.clone());
*self.current_device.write().await = Some(device.clone());
*self.state.write().await = StreamerState::Ready;
// Publish "config applied" event
info!("Publishing StreamConfigApplied event: {}x{} {:?} @ {}fps",
resolution.width, resolution.height, format, fps);
self.publish_event(SystemEvent::StreamConfigApplied {
device: device_path.to_string(),
resolution: (resolution.width, resolution.height),
format: format!("{:?}", format),
fps,
})
.await;
// Note: We don't auto-start here anymore.
// The stream will be started when MJPEG client connects (handlers.rs:790)
// This avoids race conditions between config change and client reconnection.
info!("Config applied, stream will start when client connects");
Ok(())
}
/// Initialize with auto-detected device
pub async fn init_auto(self: &Arc<Self>) -> Result<()> {
info!("Auto-detecting video device...");
let device = find_best_device()?;
info!("Found device: {} ({})", device.name, device.path.display());
self.init_with_device(device).await
}
/// Initialize with specific device
pub async fn init_with_device(self: &Arc<Self>, device: VideoDeviceInfo) -> Result<()> {
info!(
"Initializing streamer with device: {} ({})",
device.name,
device.path.display()
);
// Determine best format for this device
let config = self.config.read().await;
let format = self.select_format(&device, config.format)?;
let resolution = self.select_resolution(&device, &format, config.resolution)?;
drop(config);
// Update config with actual values
{
let mut config = self.config.write().await;
config.device_path = Some(device.path.clone());
config.format = format;
config.resolution = resolution;
}
// Store device info
*self.current_device.write().await = Some(device.clone());
// Create capturer
let config = self.config.read().await;
let capture_config = CaptureConfig {
device_path: device.path.clone(),
resolution: config.resolution,
format: config.format,
fps: config.fps,
jpeg_quality: config.jpeg_quality,
..Default::default()
};
drop(config);
let capturer = Arc::new(VideoCapturer::new(capture_config));
*self.capturer.write().await = Some(capturer);
*self.state.write().await = StreamerState::Ready;
info!("Streamer initialized: {} @ {}", format, resolution);
Ok(())
}
/// Select best format for device
fn select_format(&self, device: &VideoDeviceInfo, preferred: PixelFormat) -> Result<PixelFormat> {
// Check if preferred format is available
if device.formats.iter().any(|f| f.format == preferred) {
return Ok(preferred);
}
// Select best available format
device
.formats
.first()
.map(|f| f.format)
.ok_or_else(|| AppError::VideoError("No supported formats found".to_string()))
}
/// Select best resolution for format
fn select_resolution(
&self,
device: &VideoDeviceInfo,
format: &PixelFormat,
preferred: Resolution,
) -> Result<Resolution> {
let format_info = device
.formats
.iter()
.find(|f| &f.format == format)
.ok_or_else(|| AppError::VideoError("Format not found".to_string()))?;
// Check if preferred resolution is available
if format_info.resolutions.is_empty()
|| format_info.resolutions.iter().any(|r| {
r.width == preferred.width && r.height == preferred.height
})
{
return Ok(preferred);
}
// Select largest available resolution
format_info
.resolutions
.first()
.map(|r| r.resolution())
.ok_or_else(|| AppError::VideoError("No resolutions available".to_string()))
}
/// Restart the capturer only (for recovery - doesn't spawn new monitor)
///
/// This is a simpler version of start() used during device recovery.
/// It doesn't spawn a new state monitor since the existing one is still active.
async fn restart_capturer(&self) -> Result<()> {
let capturer = self.capturer.read().await;
let capturer = capturer
.as_ref()
.ok_or_else(|| AppError::VideoError("Capturer not initialized".to_string()))?;
// Start capture
capturer.start().await?;
// Set MJPEG handler online
self.mjpeg_handler.set_online();
// Start frame distribution task
let mjpeg_handler = self.mjpeg_handler.clone();
let mut frame_rx = capturer.subscribe();
tokio::spawn(async move {
debug!("Recovery frame distribution task started");
loop {
match frame_rx.recv().await {
Ok(frame) => {
mjpeg_handler.update_frame(frame);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
trace!("Frame distribution lagged by {} frames", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
debug!("Frame channel closed");
break;
}
}
}
});
Ok(())
}
/// Start streaming
pub async fn start(self: &Arc<Self>) -> Result<()> {
let _lock = self.start_lock.lock().await;
let state = self.state().await;
if state == StreamerState::Streaming {
return Ok(());
}
if state == StreamerState::Uninitialized {
// Auto-initialize if not done
self.init_auto().await?;
}
let capturer = self.capturer.read().await;
let capturer = capturer
.as_ref()
.ok_or_else(|| AppError::VideoError("Capturer not initialized".to_string()))?;
// Start capture
capturer.start().await?;
// Set MJPEG handler online before starting frame distribution
// This is important after config changes where disconnect_all_clients() set it offline
self.mjpeg_handler.set_online();
// Start frame distribution task
let mjpeg_handler = self.mjpeg_handler.clone();
let mut frame_rx = capturer.subscribe();
let state_ref = Arc::downgrade(self);
tokio::spawn(async move {
info!("Frame distribution task started");
loop {
match frame_rx.recv().await {
Ok(frame) => {
mjpeg_handler.update_frame(frame);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
trace!("Frame distribution lagged by {} frames", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
debug!("Frame channel closed");
break;
}
}
// Check if streamer still exists
if state_ref.upgrade().is_none() {
break;
}
}
info!("Frame distribution task ended");
});
// Monitor capture state
let mut state_rx = capturer.state_watch();
let state_ref = Arc::downgrade(self);
let mjpeg_handler = self.mjpeg_handler.clone();
tokio::spawn(async move {
while state_rx.changed().await.is_ok() {
let capture_state = *state_rx.borrow();
match capture_state {
CaptureState::Running => {
if let Some(streamer) = state_ref.upgrade() {
*streamer.state.write().await = StreamerState::Streaming;
}
}
CaptureState::NoSignal => {
mjpeg_handler.set_offline();
if let Some(streamer) = state_ref.upgrade() {
*streamer.state.write().await = StreamerState::NoSignal;
}
}
CaptureState::Stopped => {
mjpeg_handler.set_offline();
if let Some(streamer) = state_ref.upgrade() {
*streamer.state.write().await = StreamerState::Ready;
}
}
CaptureState::Error => {
mjpeg_handler.set_offline();
if let Some(streamer) = state_ref.upgrade() {
*streamer.state.write().await = StreamerState::Error;
}
}
CaptureState::DeviceLost => {
mjpeg_handler.set_offline();
if let Some(streamer) = state_ref.upgrade() {
*streamer.state.write().await = StreamerState::DeviceLost;
// Start device recovery task (fire and forget)
let streamer_clone = Arc::clone(&streamer);
tokio::spawn(async move {
streamer_clone.start_device_recovery_internal().await;
});
}
}
CaptureState::Starting => {
// Starting state - device is initializing, no action needed
}
}
}
});
// Start background tasks only once per Streamer instance
// Use compare_exchange to atomically check and set the flag
if self.background_tasks_started
.compare_exchange(false, true, std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst)
.is_ok()
{
info!("Starting background tasks (stats, cleanup, monitor)");
// Start stats broadcast task (sends stats updates every 1 second)
let stats_ref = Arc::downgrade(self);
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
interval.tick().await;
if let Some(streamer) = stats_ref.upgrade() {
let clients_stat = streamer.mjpeg_handler().get_clients_stat();
let clients = clients_stat.len() as u64;
streamer.publish_event(SystemEvent::StreamStatsUpdate {
clients,
clients_stat,
}).await;
} else {
break;
}
}
});
// Start client cleanup task (removes stale clients every 5s)
self.mjpeg_handler.clone().start_cleanup_task();
// Start auto-pause monitor task (stops stream if no clients)
let monitor_ref = Arc::downgrade(self);
let monitor_handler = self.mjpeg_handler.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(2));
let mut zero_since: Option<std::time::Instant> = None;
loop {
interval.tick().await;
let Some(streamer) = monitor_ref.upgrade() else { break; };
// Check auto-pause configuration
let config = monitor_handler.auto_pause_config();
if !config.enabled {
zero_since = None;
continue;
}
let count = monitor_handler.client_count();
if count == 0 {
if zero_since.is_none() {
zero_since = Some(std::time::Instant::now());
info!("No clients connected, starting shutdown timer ({}s)", config.shutdown_delay_secs);
} else if let Some(since) = zero_since {
if since.elapsed().as_secs() >= config.shutdown_delay_secs {
info!("Auto-pausing stream (no clients for {}s)", config.shutdown_delay_secs);
if let Err(e) = streamer.stop().await {
error!("Auto-pause failed: {}", e);
}
break;
}
}
} else {
if zero_since.is_some() {
info!("Clients reconnected, canceling auto-pause");
zero_since = None;
}
}
}
});
} else {
debug!("Background tasks already started, skipping");
}
*self.state.write().await = StreamerState::Streaming;
// Publish state change event so DeviceInfo broadcaster can update frontend
self.publish_event(self.current_state_event().await).await;
info!("Streaming started");
Ok(())
}
/// Stop streaming
pub async fn stop(&self) -> Result<()> {
if let Some(capturer) = self.capturer.read().await.as_ref() {
capturer.stop().await?;
}
self.mjpeg_handler.set_offline();
*self.state.write().await = StreamerState::Ready;
// Publish state change event so DeviceInfo broadcaster can update frontend
self.publish_event(self.current_state_event().await).await;
info!("Streaming stopped");
Ok(())
}
/// Check if streaming
pub async fn is_streaming(&self) -> bool {
self.state().await == StreamerState::Streaming
}
/// Get stream statistics
pub async fn stats(&self) -> StreamerStats {
let capturer = self.capturer.read().await;
let capture_stats = if let Some(c) = capturer.as_ref() {
Some(c.stats().await)
} else {
None
};
let config = self.config.read().await;
StreamerStats {
state: self.state().await,
device: self.current_device().await.map(|d| d.name),
format: Some(config.format.to_string()),
resolution: Some((config.resolution.width, config.resolution.height)),
clients: self.mjpeg_handler.client_count(),
target_fps: config.fps,
fps: capture_stats.as_ref().map(|s| s.current_fps).unwrap_or(0.0),
frames_captured: capture_stats.as_ref().map(|s| s.frames_captured).unwrap_or(0),
frames_dropped: capture_stats.as_ref().map(|s| s.frames_dropped).unwrap_or(0),
}
}
/// Publish event to event bus (if configured)
/// For StreamStateChanged events, only publishes if state actually changed (de-duplication)
async fn publish_event(&self, event: SystemEvent) {
if let Some(events) = self.events.read().await.as_ref() {
// For state change events, check if state actually changed
if let SystemEvent::StreamStateChanged { ref state, .. } = event {
let current_state = match state.as_str() {
"uninitialized" => StreamerState::Uninitialized,
"ready" => StreamerState::Ready,
"streaming" => StreamerState::Streaming,
"no_signal" => StreamerState::NoSignal,
"error" => StreamerState::Error,
"device_lost" => StreamerState::DeviceLost,
"recovering" => StreamerState::Recovering,
_ => StreamerState::Error,
};
let mut last_state = self.last_published_state.write().await;
if *last_state == Some(current_state) {
// State hasn't changed, skip publishing
trace!("Skipping duplicate stream state event: {}", state);
return;
}
*last_state = Some(current_state);
}
events.publish(event);
}
}
/// Start device recovery task (internal implementation)
///
/// This method starts a background task that attempts to reconnect
/// to the video device after it was lost. It retries every 1 second
/// until the device is recovered.
async fn start_device_recovery_internal(self: &Arc<Self>) {
// Check if recovery is already in progress
if self.recovery_in_progress.swap(true, std::sync::atomic::Ordering::SeqCst) {
debug!("Device recovery already in progress, skipping");
return;
}
// Get last lost device info from capturer
let (device, reason) = {
let capturer = self.capturer.read().await;
if let Some(cap) = capturer.as_ref() {
cap.last_error().unwrap_or_else(|| {
let device_path = self.current_device.blocking_read()
.as_ref()
.map(|d| d.path.display().to_string())
.unwrap_or_else(|| "unknown".to_string());
(device_path, "Device lost".to_string())
})
} else {
("unknown".to_string(), "Device lost".to_string())
}
};
// Store error info
*self.last_lost_device.write().await = Some(device.clone());
*self.last_lost_reason.write().await = Some(reason.clone());
self.recovery_retry_count.store(0, std::sync::atomic::Ordering::Relaxed);
// Publish device lost event
self.publish_event(SystemEvent::StreamDeviceLost {
device: device.clone(),
reason: reason.clone(),
}).await;
// Start recovery task
let streamer = Arc::clone(self);
tokio::spawn(async move {
let device_path = device.clone();
loop {
let attempt = streamer.recovery_retry_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
// Check if still in device lost state
let current_state = *streamer.state.read().await;
if current_state != StreamerState::DeviceLost && current_state != StreamerState::Recovering {
info!("Stream state changed during recovery, stopping recovery task");
break;
}
// Update state to Recovering
*streamer.state.write().await = StreamerState::Recovering;
// Publish reconnecting event (every 5 attempts to avoid spam)
if attempt == 1 || attempt % 5 == 0 {
streamer.publish_event(SystemEvent::StreamReconnecting {
device: device_path.clone(),
attempt,
}).await;
info!("Attempting to recover video device {} (attempt {})", device_path, attempt);
}
// Wait before retry (1 second)
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// Check if device file exists
let device_exists = std::path::Path::new(&device_path).exists();
if !device_exists {
debug!("Device {} not present yet", device_path);
continue;
}
// Try to restart capture
match streamer.restart_capturer().await {
Ok(_) => {
info!("Video device {} recovered after {} attempts", device_path, attempt);
streamer.recovery_in_progress.store(false, std::sync::atomic::Ordering::SeqCst);
// Publish recovered event
streamer.publish_event(SystemEvent::StreamRecovered {
device: device_path.clone(),
}).await;
// Clear error info
*streamer.last_lost_device.write().await = None;
*streamer.last_lost_reason.write().await = None;
return;
}
Err(e) => {
debug!("Failed to restart capture (attempt {}): {}", attempt, e);
}
}
}
streamer.recovery_in_progress.store(false, std::sync::atomic::Ordering::SeqCst);
});
}
}
impl Default for Streamer {
fn default() -> Self {
Self {
config: RwLock::new(StreamerConfig::default()),
capturer: RwLock::new(None),
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
current_device: RwLock::new(None),
state: RwLock::new(StreamerState::Uninitialized),
start_lock: tokio::sync::Mutex::new(()),
events: RwLock::new(None),
last_published_state: RwLock::new(None),
config_changing: std::sync::atomic::AtomicBool::new(false),
background_tasks_started: std::sync::atomic::AtomicBool::new(false),
recovery_retry_count: std::sync::atomic::AtomicU32::new(0),
recovery_in_progress: std::sync::atomic::AtomicBool::new(false),
last_lost_device: RwLock::new(None),
last_lost_reason: RwLock::new(None),
}
}
}
/// Streamer statistics
#[derive(Debug, Clone, serde::Serialize)]
pub struct StreamerStats {
pub state: StreamerState,
pub device: Option<String>,
pub format: Option<String>,
pub resolution: Option<(u32, u32)>,
pub clients: u64,
/// Target FPS from configuration
pub target_fps: u32,
/// Current actual FPS
pub fps: f32,
pub frames_captured: u64,
pub frames_dropped: u64,
}
impl serde::Serialize for StreamerState {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let s = match self {
StreamerState::Uninitialized => "uninitialized",
StreamerState::Ready => "ready",
StreamerState::Streaming => "streaming",
StreamerState::NoSignal => "no_signal",
StreamerState::Error => "error",
StreamerState::DeviceLost => "device_lost",
StreamerState::Recovering => "recovering",
};
serializer.serialize_str(s)
}
}

595
src/video/video_session.rs Normal file
View File

@@ -0,0 +1,595 @@
//! Video session management with multi-codec support
//!
//! This module provides session management for video streaming with:
//! - Multi-codec support (H264, H265, VP8, VP9)
//! - Session lifecycle management
//! - Dynamic codec switching
//! - Statistics and monitoring
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, info, warn};
use super::encoder::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
use super::format::Resolution;
use super::frame::VideoFrame;
use super::shared_video_pipeline::{
EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats,
};
use crate::error::{AppError, Result};
/// Maximum concurrent video sessions
const MAX_VIDEO_SESSIONS: usize = 8;
/// Video session state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VideoSessionState {
/// Session created but not started
Created,
/// Session is active and streaming
Active,
/// Session is paused
Paused,
/// Session is closing
Closing,
/// Session is closed
Closed,
}
impl std::fmt::Display for VideoSessionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VideoSessionState::Created => write!(f, "Created"),
VideoSessionState::Active => write!(f, "Active"),
VideoSessionState::Paused => write!(f, "Paused"),
VideoSessionState::Closing => write!(f, "Closing"),
VideoSessionState::Closed => write!(f, "Closed"),
}
}
}
/// Video session information
#[derive(Debug, Clone)]
pub struct VideoSessionInfo {
/// Session ID
pub session_id: String,
/// Current codec
pub codec: VideoEncoderType,
/// Session state
pub state: VideoSessionState,
/// Creation time
pub created_at: Instant,
/// Last activity time
pub last_activity: Instant,
/// Frames received
pub frames_received: u64,
/// Bytes received
pub bytes_received: u64,
}
/// Individual video session
struct VideoSession {
/// Session ID
session_id: String,
/// Codec for this session
codec: VideoEncoderType,
/// Session state
state: VideoSessionState,
/// Creation time
created_at: Instant,
/// Last activity time
last_activity: Instant,
/// Frame receiver
frame_rx: Option<broadcast::Receiver<EncodedVideoFrame>>,
/// Stats
frames_received: u64,
bytes_received: u64,
}
impl VideoSession {
fn new(session_id: String, codec: VideoEncoderType) -> Self {
let now = Instant::now();
Self {
session_id,
codec,
state: VideoSessionState::Created,
created_at: now,
last_activity: now,
frame_rx: None,
frames_received: 0,
bytes_received: 0,
}
}
fn info(&self) -> VideoSessionInfo {
VideoSessionInfo {
session_id: self.session_id.clone(),
codec: self.codec,
state: self.state,
created_at: self.created_at,
last_activity: self.last_activity,
frames_received: self.frames_received,
bytes_received: self.bytes_received,
}
}
}
/// Video session manager configuration
#[derive(Debug, Clone)]
pub struct VideoSessionManagerConfig {
/// Default codec
pub default_codec: VideoEncoderType,
/// Default resolution
pub resolution: Resolution,
/// Default bitrate (kbps)
pub bitrate_kbps: u32,
/// Default FPS
pub fps: u32,
/// Session timeout (seconds)
pub session_timeout_secs: u64,
/// Encoder backend (None = auto select best available)
pub encoder_backend: Option<EncoderBackend>,
}
impl Default for VideoSessionManagerConfig {
fn default() -> Self {
Self {
default_codec: VideoEncoderType::H264,
resolution: Resolution::HD720,
bitrate_kbps: 8000,
fps: 30,
session_timeout_secs: 300,
encoder_backend: None,
}
}
}
/// Video session manager
///
/// Manages video encoding sessions with multi-codec support.
/// A single encoder is shared across all sessions with the same codec.
pub struct VideoSessionManager {
/// Configuration
config: VideoSessionManagerConfig,
/// Active sessions
sessions: RwLock<HashMap<String, VideoSession>>,
/// Current pipeline (shared across sessions with same codec)
pipeline: RwLock<Option<Arc<SharedVideoPipeline>>>,
/// Current codec (active pipeline codec)
current_codec: RwLock<Option<VideoEncoderType>>,
/// Video frame source
frame_source: RwLock<Option<broadcast::Receiver<VideoFrame>>>,
}
impl VideoSessionManager {
/// Create a new video session manager
pub fn new(config: VideoSessionManagerConfig) -> Self {
info!(
"Creating video session manager with default codec: {}",
config.default_codec
);
Self {
config,
sessions: RwLock::new(HashMap::new()),
pipeline: RwLock::new(None),
current_codec: RwLock::new(None),
frame_source: RwLock::new(None),
}
}
/// Create with default configuration
pub fn with_defaults() -> Self {
Self::new(VideoSessionManagerConfig::default())
}
/// Set the video frame source
pub async fn set_frame_source(&self, rx: broadcast::Receiver<VideoFrame>) {
*self.frame_source.write().await = Some(rx);
}
/// Get available codecs based on hardware capabilities
pub fn available_codecs(&self) -> Vec<VideoEncoderType> {
EncoderRegistry::global().selectable_formats()
}
/// Check if a codec is available
pub fn is_codec_available(&self, codec: VideoEncoderType) -> bool {
let hardware_only = codec.hardware_only();
EncoderRegistry::global().is_format_available(codec, hardware_only)
}
/// Create a new video session
pub async fn create_session(&self, codec: Option<VideoEncoderType>) -> Result<String> {
let sessions = self.sessions.read().await;
if sessions.len() >= MAX_VIDEO_SESSIONS {
return Err(AppError::VideoError(format!(
"Maximum video sessions ({}) reached",
MAX_VIDEO_SESSIONS
)));
}
drop(sessions);
// Use specified codec or default
let codec = codec.unwrap_or(self.config.default_codec);
// Verify codec is available
if !self.is_codec_available(codec) {
return Err(AppError::VideoError(format!(
"Codec {} is not available on this system",
codec
)));
}
// Generate session ID
let session_id = uuid::Uuid::new_v4().to_string();
// Create session
let session = VideoSession::new(session_id.clone(), codec);
// Store session
let mut sessions = self.sessions.write().await;
sessions.insert(session_id.clone(), session);
info!(
"Video session created: {} (codec: {})",
session_id, codec
);
Ok(session_id)
}
/// Start a video session (subscribe to encoded frames)
pub async fn start_session(
&self,
session_id: &str,
) -> Result<broadcast::Receiver<EncodedVideoFrame>> {
// Ensure pipeline is running with correct codec
self.ensure_pipeline_for_session(session_id).await?;
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(session_id)
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
// Get pipeline and subscribe
let pipeline = self.pipeline.read().await;
let pipeline = pipeline
.as_ref()
.ok_or_else(|| AppError::VideoError("Pipeline not initialized".to_string()))?;
let rx = pipeline.subscribe();
session.frame_rx = Some(pipeline.subscribe());
session.state = VideoSessionState::Active;
session.last_activity = Instant::now();
info!("Video session started: {}", session_id);
Ok(rx)
}
/// Ensure pipeline is running with correct codec for session
async fn ensure_pipeline_for_session(&self, session_id: &str) -> Result<()> {
let sessions = self.sessions.read().await;
let session = sessions
.get(session_id)
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
let required_codec = session.codec;
drop(sessions);
let current_codec = *self.current_codec.read().await;
// Check if we need to create or switch pipeline
if current_codec != Some(required_codec) {
self.switch_pipeline_codec(required_codec).await?;
}
// Ensure pipeline is started
let pipeline = self.pipeline.read().await;
if let Some(ref pipe) = *pipeline {
if !pipe.is_running() {
// Need frame source to start
let frame_rx = {
let source = self.frame_source.read().await;
source.as_ref().map(|rx| rx.resubscribe())
};
if let Some(rx) = frame_rx {
drop(pipeline);
let pipeline = self.pipeline.read().await;
if let Some(ref pipe) = *pipeline {
pipe.start(rx).await?;
}
}
}
}
Ok(())
}
/// Switch pipeline to different codec
async fn switch_pipeline_codec(&self, codec: VideoEncoderType) -> Result<()> {
info!("Switching pipeline to codec: {}", codec);
// Stop existing pipeline
{
let pipeline = self.pipeline.read().await;
if let Some(ref pipe) = *pipeline {
pipe.stop();
}
}
// Create new pipeline config
let pipeline_config = SharedVideoPipelineConfig {
resolution: self.config.resolution,
input_format: crate::video::format::PixelFormat::Mjpeg, // Common input
output_codec: codec,
bitrate_kbps: self.config.bitrate_kbps,
fps: self.config.fps,
gop_size: 30,
encoder_backend: self.config.encoder_backend,
};
// Create new pipeline
let new_pipeline = SharedVideoPipeline::new(pipeline_config)?;
// Update state
*self.pipeline.write().await = Some(new_pipeline);
*self.current_codec.write().await = Some(codec);
info!("Pipeline switched to codec: {}", codec);
Ok(())
}
/// Get session info
pub async fn get_session(&self, session_id: &str) -> Option<VideoSessionInfo> {
let sessions = self.sessions.read().await;
sessions.get(session_id).map(|s| s.info())
}
/// List all sessions
pub async fn list_sessions(&self) -> Vec<VideoSessionInfo> {
let sessions = self.sessions.read().await;
sessions.values().map(|s| s.info()).collect()
}
/// Pause a session
pub async fn pause_session(&self, session_id: &str) -> Result<()> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(session_id)
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
session.state = VideoSessionState::Paused;
session.last_activity = Instant::now();
debug!("Video session paused: {}", session_id);
Ok(())
}
/// Resume a session
pub async fn resume_session(&self, session_id: &str) -> Result<()> {
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(session_id)
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
session.state = VideoSessionState::Active;
session.last_activity = Instant::now();
debug!("Video session resumed: {}", session_id);
Ok(())
}
/// Close a session
pub async fn close_session(&self, session_id: &str) -> Result<()> {
let mut sessions = self.sessions.write().await;
if let Some(mut session) = sessions.remove(session_id) {
session.state = VideoSessionState::Closed;
session.frame_rx = None;
info!("Video session closed: {}", session_id);
}
// If no more sessions, consider stopping pipeline
if sessions.is_empty() {
drop(sessions);
self.maybe_stop_pipeline().await;
}
Ok(())
}
/// Stop pipeline if no active sessions
async fn maybe_stop_pipeline(&self) {
let sessions = self.sessions.read().await;
let has_active = sessions
.values()
.any(|s| s.state == VideoSessionState::Active);
drop(sessions);
if !has_active {
let pipeline = self.pipeline.read().await;
if let Some(ref pipe) = *pipeline {
pipe.stop();
debug!("Pipeline stopped - no active sessions");
}
}
}
/// Cleanup stale/timed out sessions
pub async fn cleanup_stale_sessions(&self) {
let timeout = std::time::Duration::from_secs(self.config.session_timeout_secs);
let now = Instant::now();
let stale_ids: Vec<String> = {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, s)| {
(s.state == VideoSessionState::Paused
|| s.state == VideoSessionState::Created)
&& now.duration_since(s.last_activity) > timeout
})
.map(|(id, _)| id.clone())
.collect()
};
if !stale_ids.is_empty() {
let mut sessions = self.sessions.write().await;
for id in stale_ids {
info!("Removing stale video session: {}", id);
sessions.remove(&id);
}
}
}
/// Get session count
pub async fn session_count(&self) -> usize {
self.sessions.read().await.len()
}
/// Get active session count
pub async fn active_session_count(&self) -> usize {
self.sessions
.read()
.await
.values()
.filter(|s| s.state == VideoSessionState::Active)
.count()
}
/// Get pipeline statistics
pub async fn pipeline_stats(&self) -> Option<SharedVideoPipelineStats> {
let pipeline = self.pipeline.read().await;
if let Some(ref pipe) = *pipeline {
Some(pipe.stats().await)
} else {
None
}
}
/// Get current active codec
pub async fn current_codec(&self) -> Option<VideoEncoderType> {
*self.current_codec.read().await
}
/// Set bitrate for current pipeline
pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> {
let pipeline = self.pipeline.read().await;
if let Some(ref pipe) = *pipeline {
pipe.set_bitrate(bitrate_kbps).await?;
}
Ok(())
}
/// Request keyframe for all sessions
pub async fn request_keyframe(&self) {
// This would be implemented if encoders support forced keyframes
warn!("Keyframe request not yet implemented");
}
/// Change codec for a session (requires restart)
pub async fn change_session_codec(
&self,
session_id: &str,
new_codec: VideoEncoderType,
) -> Result<()> {
if !self.is_codec_available(new_codec) {
return Err(AppError::VideoError(format!(
"Codec {} is not available",
new_codec
)));
}
let mut sessions = self.sessions.write().await;
let session = sessions
.get_mut(session_id)
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
let old_codec = session.codec;
session.codec = new_codec;
session.state = VideoSessionState::Created; // Require restart
session.frame_rx = None;
session.last_activity = Instant::now();
info!(
"Session {} codec changed: {} -> {}",
session_id, old_codec, new_codec
);
Ok(())
}
/// Get codec info
pub fn get_codec_info(&self, codec: VideoEncoderType) -> Option<CodecInfo> {
let registry = EncoderRegistry::global();
let encoder = registry.best_encoder(codec, codec.hardware_only())?;
Some(CodecInfo {
codec_type: codec,
codec_name: encoder.codec_name.clone(),
backend: encoder.backend.to_string(),
is_hardware: encoder.is_hardware,
})
}
/// List all available codecs with their info
pub fn list_codec_info(&self) -> Vec<CodecInfo> {
self.available_codecs()
.iter()
.filter_map(|c| self.get_codec_info(*c))
.collect()
}
}
/// Codec information
#[derive(Debug, Clone)]
pub struct CodecInfo {
/// Codec type
pub codec_type: VideoEncoderType,
/// FFmpeg codec name
pub codec_name: String,
/// Backend (VAAPI, NVENC, etc.)
pub backend: String,
/// Whether this is hardware accelerated
pub is_hardware: bool,
}
impl Default for VideoSessionManager {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_state_display() {
assert_eq!(VideoSessionState::Active.to_string(), "Active");
assert_eq!(VideoSessionState::Closed.to_string(), "Closed");
}
#[test]
fn test_available_codecs() {
let manager = VideoSessionManager::with_defaults();
let codecs = manager.available_codecs();
println!("Available codecs: {:?}", codecs);
// H264 should always be available (software fallback)
assert!(codecs.contains(&VideoEncoderType::H264));
}
#[test]
fn test_codec_info() {
let manager = VideoSessionManager::with_defaults();
let info = manager.get_codec_info(VideoEncoderType::H264);
if let Some(info) = info {
println!(
"H264: {} ({}, hardware={})",
info.codec_name, info.backend, info.is_hardware
);
}
}
}

257
src/web/audio_ws.rs Normal file
View File

@@ -0,0 +1,257 @@
//! Audio WebSocket handler for MJPEG mode
//!
//! Provides a dedicated WebSocket endpoint (`/api/ws/audio`) for streaming
//! Opus-encoded audio data in binary format.
//!
//! ## Binary Protocol
//!
//! Each audio packet is sent as a binary WebSocket message with the following format:
//!
//! ```text
//! Byte 0: Type (0x02 = audio)
//! Bytes 1-4: Timestamp (u32 LE, milliseconds since stream start)
//! Bytes 5-6: Duration (u16 LE, milliseconds)
//! Bytes 7-10: Sequence (u32 LE)
//! Bytes 11-14: Data length (u32 LE)
//! Bytes 15+: Opus encoded data
//! ```
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::Response,
};
use futures::{SinkExt, StreamExt};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::broadcast;
use tracing::{debug, info, warn};
use crate::audio::OpusFrame;
use crate::state::AppState;
/// Audio packet type identifier
const AUDIO_PACKET_TYPE: u8 = 0x02;
/// Audio WebSocket upgrade handler
///
/// Upgrades HTTP connection to WebSocket for audio streaming.
pub async fn audio_ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> Response {
ws.on_upgrade(move |socket| handle_audio_socket(socket, state))
}
/// Handle audio WebSocket connection
async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// Try to get Opus frame subscription
let opus_rx = match state.audio.subscribe_opus_async().await {
Some(rx) => rx,
None => {
warn!("Audio not streaming, rejecting WebSocket connection");
// Send error message before closing
let _ = sender
.send(Message::Text(
r#"{"error": "Audio not streaming"}"#.to_string(),
))
.await;
return;
}
};
let mut opus_rx = opus_rx;
let stream_start = Instant::now();
info!("Audio WebSocket client connected");
// Track connection for cleanup
let mut closed = false;
// Use interval instead of sleep for more efficient keepalive
let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(30));
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
loop {
tokio::select! {
// Receive Opus frames and send to client
opus_result = opus_rx.recv() => {
match opus_result {
Ok(frame) => {
let binary = encode_audio_packet(&frame, stream_start);
if sender.send(Message::Binary(binary)).await.is_err() {
debug!("Failed to send audio frame, client disconnected");
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("Audio WebSocket client lagged by {} frames", n);
// Continue - just skip the missed frames
}
Err(broadcast::error::RecvError::Closed) => {
info!("Audio stream closed");
break;
}
}
}
// Handle client messages (ping/close)
msg = receiver.next() => {
match msg {
Some(Ok(Message::Close(_))) => {
debug!("Audio WebSocket client requested close");
closed = true;
break;
}
Some(Ok(Message::Ping(data))) => {
if sender.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Pong(_))) => {
// Pong received, connection is alive
}
Some(Ok(Message::Text(text))) => {
// Handle potential control messages
debug!("Received text message on audio WS: {}", text);
}
Some(Err(e)) => {
warn!("Audio WebSocket receive error: {}", e);
break;
}
None => {
// Connection closed
break;
}
_ => {}
}
}
// Periodic ping to keep connection alive (using interval)
_ = ping_interval.tick() => {
if sender.send(Message::Ping(vec![])).await.is_err() {
warn!("Failed to send ping, disconnecting");
break;
}
}
}
}
if !closed {
// Try to send close message
let _ = sender.send(Message::Close(None)).await;
}
info!("Audio WebSocket client disconnected");
}
/// Encode Opus frame to binary packet format
///
/// ## Format
///
/// | Offset | Size | Description |
/// |--------|------|-------------|
/// | 0 | 1 | Packet type (0x02 for audio) |
/// | 1 | 4 | Timestamp (u32 LE, ms since start) |
/// | 5 | 2 | Duration (u16 LE, ms) |
/// | 7 | 4 | Sequence number (u32 LE) |
/// | 11 | 4 | Data length (u32 LE) |
/// | 15 | N | Opus encoded data |
fn encode_audio_packet(frame: &OpusFrame, stream_start: Instant) -> Vec<u8> {
let timestamp_ms = stream_start.elapsed().as_millis() as u32;
let data_len = frame.data.len() as u32;
let mut buf = Vec::with_capacity(15 + frame.data.len());
// Header
buf.push(AUDIO_PACKET_TYPE);
buf.extend_from_slice(&timestamp_ms.to_le_bytes());
buf.extend_from_slice(&(frame.duration_ms as u16).to_le_bytes());
buf.extend_from_slice(&(frame.sequence as u32).to_le_bytes());
buf.extend_from_slice(&data_len.to_le_bytes());
// Opus data
buf.extend_from_slice(&frame.data);
buf
}
/// Decode audio packet from binary format (for testing/debugging)
#[allow(dead_code)]
pub fn decode_audio_packet(data: &[u8]) -> Option<AudioPacketHeader> {
if data.len() < 15 {
return None;
}
if data[0] != AUDIO_PACKET_TYPE {
return None;
}
let timestamp = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
let duration_ms = u16::from_le_bytes([data[5], data[6]]);
let sequence = u32::from_le_bytes([data[7], data[8], data[9], data[10]]);
let data_length = u32::from_le_bytes([data[11], data[12], data[13], data[14]]);
Some(AudioPacketHeader {
packet_type: data[0],
timestamp,
duration_ms,
sequence,
data_length,
})
}
/// Audio packet header (for decoding/testing)
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct AudioPacketHeader {
pub packet_type: u8,
pub timestamp: u32,
pub duration_ms: u16,
pub sequence: u32,
pub data_length: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_encode_decode_packet() {
let frame = OpusFrame {
data: Bytes::from(vec![1, 2, 3, 4, 5]),
duration_ms: 20,
sequence: 42,
timestamp: Instant::now(),
rtp_timestamp: 0,
};
let stream_start = Instant::now();
let encoded = encode_audio_packet(&frame, stream_start);
assert!(encoded.len() >= 15);
assert_eq!(encoded[0], AUDIO_PACKET_TYPE);
let header = decode_audio_packet(&encoded).unwrap();
assert_eq!(header.packet_type, AUDIO_PACKET_TYPE);
assert_eq!(header.duration_ms, 20);
assert_eq!(header.sequence, 42);
assert_eq!(header.data_length, 5);
}
#[test]
fn test_decode_invalid_packet() {
// Too short
assert!(decode_audio_packet(&[]).is_none());
assert!(decode_audio_packet(&[0x02; 10]).is_none());
// Wrong type
let mut bad = vec![0x01; 20];
assert!(decode_audio_packet(&bad).is_none());
}
}

View File

@@ -0,0 +1,362 @@
//! 配置热重载逻辑
//!
//! 从 handlers.rs 中抽取的配置应用函数,负责将配置变更应用到各个子系统。
use std::sync::Arc;
use crate::config::*;
use crate::error::{AppError, Result};
use crate::state::AppState;
/// 应用 Video 配置变更
pub async fn apply_video_config(
state: &Arc<AppState>,
old_config: &VideoConfig,
new_config: &VideoConfig,
) -> Result<()> {
// 检查配置是否实际变更
if old_config == new_config {
tracing::info!("Video config unchanged, skipping reload");
return Ok(());
}
tracing::info!("Applying video config changes...");
let device = new_config
.device
.clone()
.ok_or_else(|| AppError::BadRequest("video_device is required".to_string()))?;
let format = new_config
.format
.as_ref()
.and_then(|f| {
serde_json::from_value::<crate::video::format::PixelFormat>(
serde_json::Value::String(f.clone()),
)
.ok()
})
.unwrap_or(crate::video::format::PixelFormat::Mjpeg);
let resolution =
crate::video::format::Resolution::new(new_config.width, new_config.height);
// Step 1: 更新 WebRTC streamer 配置(停止现有 pipeline 和 sessions
state
.stream_manager
.webrtc_streamer()
.update_video_config(resolution, format, new_config.fps)
.await;
tracing::info!("WebRTC streamer config updated");
// Step 2: 应用视频配置到 streamer重新创建 capturer
state
.stream_manager
.streamer()
.apply_video_config(&device, format, resolution, new_config.fps)
.await
.map_err(|e| AppError::VideoError(format!("Failed to apply video config: {}", e)))?;
tracing::info!("Video config applied to streamer");
// Step 3: 重启 streamer
if let Err(e) = state.stream_manager.start().await {
tracing::error!("Failed to start streamer after config change: {}", e);
} else {
tracing::info!("Streamer started after config change");
}
// Step 4: 更新 WebRTC frame source
if let Some(frame_tx) = state.stream_manager.frame_sender().await {
let receiver_count = frame_tx.receiver_count();
state
.stream_manager
.webrtc_streamer()
.set_video_source(frame_tx)
.await;
tracing::info!(
"WebRTC streamer frame source updated (receiver_count={})",
receiver_count
);
} else {
tracing::warn!("No frame source available after config change");
}
tracing::info!("Video config applied successfully");
Ok(())
}
/// 应用 Stream 配置变更
pub async fn apply_stream_config(
state: &Arc<AppState>,
old_config: &StreamConfig,
new_config: &StreamConfig,
) -> Result<()> {
tracing::info!("Applying stream config changes...");
// 更新编码器后端
if old_config.encoder != new_config.encoder {
let encoder_backend = new_config.encoder.to_backend();
tracing::info!(
"Updating encoder backend to: {:?} (from config: {:?})",
encoder_backend,
new_config.encoder
);
state
.stream_manager
.webrtc_streamer()
.update_encoder_backend(encoder_backend)
.await;
}
// 更新码率
if old_config.bitrate_kbps != new_config.bitrate_kbps {
state
.stream_manager
.webrtc_streamer()
.set_bitrate(new_config.bitrate_kbps)
.await
.ok(); // Ignore error if no active stream
}
// 更新 ICE 配置 (STUN/TURN)
let ice_changed = old_config.stun_server != new_config.stun_server
|| old_config.turn_server != new_config.turn_server
|| old_config.turn_username != new_config.turn_username
|| old_config.turn_password != new_config.turn_password;
if ice_changed {
tracing::info!(
"Updating ICE config: STUN={:?}, TURN={:?}",
new_config.stun_server,
new_config.turn_server
);
state
.stream_manager
.webrtc_streamer()
.update_ice_config(
new_config.stun_server.clone(),
new_config.turn_server.clone(),
new_config.turn_username.clone(),
new_config.turn_password.clone(),
)
.await;
}
tracing::info!(
"Stream config applied: encoder={:?}, bitrate={} kbps",
new_config.encoder,
new_config.bitrate_kbps
);
Ok(())
}
/// 应用 HID 配置变更
pub async fn apply_hid_config(
state: &Arc<AppState>,
old_config: &HidConfig,
new_config: &HidConfig,
) -> Result<()> {
// 检查是否需要重载
if old_config.backend == new_config.backend
&& old_config.ch9329_port == new_config.ch9329_port
&& old_config.ch9329_baudrate == new_config.ch9329_baudrate
&& old_config.otg_udc == new_config.otg_udc
{
tracing::info!("HID config unchanged, skipping reload");
return Ok(());
}
tracing::info!("Applying HID config changes...");
let new_hid_backend = match new_config.backend {
HidBackend::Otg => crate::hid::HidBackendType::Otg,
HidBackend::Ch9329 => crate::hid::HidBackendType::Ch9329 {
port: new_config.ch9329_port.clone(),
baud_rate: new_config.ch9329_baudrate,
},
HidBackend::None => crate::hid::HidBackendType::None,
};
state
.hid
.reload(new_hid_backend)
.await
.map_err(|e| AppError::Config(format!("HID reload failed: {}", e)))?;
tracing::info!("HID backend reloaded successfully: {:?}", new_config.backend);
// When switching to OTG backend, automatically enable MSD if not already enabled
// OTG HID and MSD share the same USB gadget, so it makes sense to enable both
if new_config.backend == HidBackend::Otg && old_config.backend != HidBackend::Otg {
let msd_guard = state.msd.read().await;
if msd_guard.is_none() {
drop(msd_guard); // Release read lock before acquiring write lock
tracing::info!("OTG HID enabled, automatically initializing MSD...");
// Get MSD config from store
let config = state.config.get();
let msd = crate::msd::MsdController::new(
state.otg_service.clone(),
&config.msd.images_path,
&config.msd.drive_path,
);
if let Err(e) = msd.init().await {
tracing::warn!("Failed to auto-initialize MSD for OTG: {}", e);
} else {
let events = state.events.clone();
msd.set_event_bus(events).await;
*state.msd.write().await = Some(msd);
tracing::info!("MSD automatically initialized for OTG mode");
}
}
}
Ok(())
}
/// 应用 MSD 配置变更
pub async fn apply_msd_config(
state: &Arc<AppState>,
old_config: &MsdConfig,
new_config: &MsdConfig,
) -> Result<()> {
tracing::info!("MSD config sent, checking if reload needed...");
tracing::debug!("Old MSD config: {:?}", old_config);
tracing::debug!("New MSD config: {:?}", new_config);
// Check if MSD enabled state changed
let old_msd_enabled = old_config.enabled;
let new_msd_enabled = new_config.enabled;
tracing::info!("MSD enabled: old={}, new={}", old_msd_enabled, new_msd_enabled);
if old_msd_enabled != new_msd_enabled {
if new_msd_enabled {
// MSD was disabled, now enabled - need to initialize
tracing::info!("MSD enabled in config, initializing...");
let msd = crate::msd::MsdController::new(
state.otg_service.clone(),
&new_config.images_path,
&new_config.drive_path,
);
msd.init().await.map_err(|e| {
AppError::Config(format!("MSD initialization failed: {}", e))
})?;
// Set event bus
let events = state.events.clone();
msd.set_event_bus(events).await;
// Store the initialized controller
*state.msd.write().await = Some(msd);
tracing::info!("MSD initialized successfully");
} else {
// MSD was enabled, now disabled - shutdown
tracing::info!("MSD disabled in config, shutting down...");
if let Some(msd) = state.msd.write().await.as_mut() {
if let Err(e) = msd.shutdown().await {
tracing::warn!("MSD shutdown failed: {}", e);
}
}
*state.msd.write().await = None;
tracing::info!("MSD shutdown complete");
}
} else {
tracing::info!(
"MSD enabled state unchanged ({}), no reload needed",
new_msd_enabled
);
}
Ok(())
}
/// 应用 ATX 配置变更
pub async fn apply_atx_config(
state: &Arc<AppState>,
_old_config: &AtxConfig,
new_config: &AtxConfig,
) -> Result<()> {
tracing::info!("Applying ATX config changes...");
// Convert AtxConfig to AtxControllerConfig
let controller_config = new_config.to_controller_config();
// Reload the ATX controller with new configuration
let atx_guard = state.atx.read().await;
if let Some(atx) = atx_guard.as_ref() {
if let Err(e) = atx.reload(controller_config).await {
tracing::error!("ATX reload failed: {}", e);
return Err(AppError::Config(format!("ATX reload failed: {}", e)));
}
tracing::info!("ATX controller reloaded successfully");
} else {
// ATX controller not initialized, create a new one if enabled
drop(atx_guard);
if new_config.enabled {
tracing::info!("ATX enabled in config, initializing...");
let atx = crate::atx::AtxController::new(controller_config);
if let Err(e) = atx.init().await {
tracing::warn!("ATX initialization failed: {}", e);
} else {
*state.atx.write().await = Some(atx);
tracing::info!("ATX controller initialized successfully");
}
}
}
Ok(())
}
/// 应用 Audio 配置变更
pub async fn apply_audio_config(
state: &Arc<AppState>,
_old_config: &AudioConfig,
new_config: &AudioConfig,
) -> Result<()> {
tracing::info!("Applying audio config changes...");
// Create audio controller config from new config
let audio_config = crate::audio::AudioControllerConfig {
enabled: new_config.enabled,
device: new_config.device.clone(),
quality: crate::audio::AudioQuality::from_str(&new_config.quality),
};
// Update audio controller
if let Err(e) = state.audio.update_config(audio_config).await {
tracing::error!("Audio config update failed: {}", e);
// Don't fail - audio errors are not critical
} else {
tracing::info!(
"Audio config applied: enabled={}, device={}",
new_config.enabled,
new_config.device
);
}
// Also update WebRTC audio enabled state
if let Err(e) = state
.stream_manager
.set_webrtc_audio_enabled(new_config.enabled)
.await
{
tracing::warn!("Failed to update WebRTC audio state: {}", e);
} else {
tracing::info!("WebRTC audio enabled: {}", new_config.enabled);
}
// Reconnect audio sources for existing WebRTC sessions
if new_config.enabled {
state.stream_manager.reconnect_webrtc_audio_sources().await;
}
Ok(())
}

View File

@@ -0,0 +1,46 @@
//! ATX 配置 Handler
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::config::AtxConfig;
use crate::error::Result;
use crate::state::AppState;
use super::apply::apply_atx_config;
use super::types::AtxConfigUpdate;
/// 获取 ATX 配置
pub async fn get_atx_config(State(state): State<Arc<AppState>>) -> Json<AtxConfig> {
Json(state.config.get().atx.clone())
}
/// 更新 ATX 配置
pub async fn update_atx_config(
State(state): State<Arc<AppState>>,
Json(req): Json<AtxConfigUpdate>,
) -> Result<Json<AtxConfig>> {
// 1. 验证请求
req.validate()?;
// 2. 获取旧配置
let old_atx_config = state.config.get().atx.clone();
// 3. 应用更新到配置存储
state
.config
.update(|config| {
req.apply_to(&mut config.atx);
})
.await?;
// 4. 获取新配置
let new_atx_config = state.config.get().atx.clone();
// 5. 应用到子系统(热重载)
if let Err(e) = apply_atx_config(&state, &old_atx_config, &new_atx_config).await {
tracing::error!("Failed to apply ATX config: {}", e);
}
Ok(Json(new_atx_config))
}

View File

@@ -0,0 +1,46 @@
//! Audio 配置 Handler
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::config::AudioConfig;
use crate::error::Result;
use crate::state::AppState;
use super::apply::apply_audio_config;
use super::types::AudioConfigUpdate;
/// 获取 Audio 配置
pub async fn get_audio_config(State(state): State<Arc<AppState>>) -> Json<AudioConfig> {
Json(state.config.get().audio.clone())
}
/// 更新 Audio 配置
pub async fn update_audio_config(
State(state): State<Arc<AppState>>,
Json(req): Json<AudioConfigUpdate>,
) -> Result<Json<AudioConfig>> {
// 1. 验证请求
req.validate()?;
// 2. 获取旧配置
let old_audio_config = state.config.get().audio.clone();
// 3. 应用更新到配置存储
state
.config
.update(|config| {
req.apply_to(&mut config.audio);
})
.await?;
// 4. 获取新配置
let new_audio_config = state.config.get().audio.clone();
// 5. 应用到子系统(热重载)
if let Err(e) = apply_audio_config(&state, &old_audio_config, &new_audio_config).await {
tracing::error!("Failed to apply audio config: {}", e);
}
Ok(Json(new_audio_config))
}

View File

@@ -0,0 +1,46 @@
//! HID 配置 Handler
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::config::HidConfig;
use crate::error::Result;
use crate::state::AppState;
use super::apply::apply_hid_config;
use super::types::HidConfigUpdate;
/// 获取 HID 配置
pub async fn get_hid_config(State(state): State<Arc<AppState>>) -> Json<HidConfig> {
Json(state.config.get().hid.clone())
}
/// 更新 HID 配置
pub async fn update_hid_config(
State(state): State<Arc<AppState>>,
Json(req): Json<HidConfigUpdate>,
) -> Result<Json<HidConfig>> {
// 1. 验证请求
req.validate()?;
// 2. 获取旧配置
let old_hid_config = state.config.get().hid.clone();
// 3. 应用更新到配置存储
state
.config
.update(|config| {
req.apply_to(&mut config.hid);
})
.await?;
// 4. 获取新配置
let new_hid_config = state.config.get().hid.clone();
// 5. 应用到子系统(热重载)
if let Err(e) = apply_hid_config(&state, &old_hid_config, &new_hid_config).await {
tracing::error!("Failed to apply HID config: {}", e);
}
Ok(Json(new_hid_config))
}

View File

@@ -0,0 +1,48 @@
//! 配置管理 Handler 模块
//!
//! 提供 RESTful 域分离的配置 API
//! - GET /api/config/video - 获取视频配置
//! - PATCH /api/config/video - 更新视频配置
//! - GET /api/config/stream - 获取流配置
//! - PATCH /api/config/stream - 更新流配置
//! - GET /api/config/hid - 获取 HID 配置
//! - PATCH /api/config/hid - 更新 HID 配置
//! - GET /api/config/msd - 获取 MSD 配置
//! - PATCH /api/config/msd - 更新 MSD 配置
//! - GET /api/config/atx - 获取 ATX 配置
//! - PATCH /api/config/atx - 更新 ATX 配置
//! - GET /api/config/audio - 获取音频配置
//! - PATCH /api/config/audio - 更新音频配置
mod apply;
mod types;
mod video;
mod stream;
mod hid;
mod msd;
mod atx;
mod audio;
// 导出 handler 函数
pub use video::{get_video_config, update_video_config};
pub use stream::{get_stream_config, update_stream_config};
pub use hid::{get_hid_config, update_hid_config};
pub use msd::{get_msd_config, update_msd_config};
pub use atx::{get_atx_config, update_atx_config};
pub use audio::{get_audio_config, update_audio_config};
// 保留全局配置查询(向后兼容)
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::config::AppConfig;
use crate::state::AppState;
/// 获取完整配置
pub async fn get_all_config(State(state): State<Arc<AppState>>) -> Json<AppConfig> {
let mut config = (*state.config.get()).clone();
// 不暴露敏感信息
config.auth.totp_secret = None;
Json(config)
}

View File

@@ -0,0 +1,46 @@
//! MSD 配置 Handler
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::config::MsdConfig;
use crate::error::Result;
use crate::state::AppState;
use super::apply::apply_msd_config;
use super::types::MsdConfigUpdate;
/// 获取 MSD 配置
pub async fn get_msd_config(State(state): State<Arc<AppState>>) -> Json<MsdConfig> {
Json(state.config.get().msd.clone())
}
/// 更新 MSD 配置
pub async fn update_msd_config(
State(state): State<Arc<AppState>>,
Json(req): Json<MsdConfigUpdate>,
) -> Result<Json<MsdConfig>> {
// 1. 验证请求
req.validate()?;
// 2. 获取旧配置
let old_msd_config = state.config.get().msd.clone();
// 3. 应用更新到配置存储
state
.config
.update(|config| {
req.apply_to(&mut config.msd);
})
.await?;
// 4. 获取新配置
let new_msd_config = state.config.get().msd.clone();
// 5. 应用到子系统(热重载)
if let Err(e) = apply_msd_config(&state, &old_msd_config, &new_msd_config).await {
tracing::error!("Failed to apply MSD config: {}", e);
}
Ok(Json(new_msd_config))
}

View File

@@ -0,0 +1,46 @@
//! Stream 配置 Handler
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::error::Result;
use crate::state::AppState;
use super::apply::apply_stream_config;
use super::types::{StreamConfigResponse, StreamConfigUpdate};
/// 获取 Stream 配置
pub async fn get_stream_config(State(state): State<Arc<AppState>>) -> Json<StreamConfigResponse> {
let config = state.config.get();
Json(StreamConfigResponse::from(&config.stream))
}
/// 更新 Stream 配置
pub async fn update_stream_config(
State(state): State<Arc<AppState>>,
Json(req): Json<StreamConfigUpdate>,
) -> Result<Json<StreamConfigResponse>> {
// 1. 验证请求
req.validate()?;
// 2. 获取旧配置
let old_stream_config = state.config.get().stream.clone();
// 3. 应用更新到配置存储
state
.config
.update(|config| {
req.apply_to(&mut config.stream);
})
.await?;
// 4. 获取新配置
let new_stream_config = state.config.get().stream.clone();
// 5. 应用到子系统(热重载)
if let Err(e) = apply_stream_config(&state, &old_stream_config, &new_stream_config).await {
tracing::error!("Failed to apply stream config: {}", e);
}
Ok(Json(StreamConfigResponse::from(&new_stream_config)))
}

View File

@@ -0,0 +1,396 @@
use serde::Deserialize;
use typeshare::typeshare;
use crate::config::*;
use crate::error::AppError;
// ===== Video Config =====
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct VideoConfigUpdate {
pub device: Option<String>,
pub format: Option<String>,
pub width: Option<u32>,
pub height: Option<u32>,
pub fps: Option<u32>,
pub quality: Option<u32>,
}
impl VideoConfigUpdate {
pub fn validate(&self) -> crate::error::Result<()> {
if let Some(width) = self.width {
if !(320..=7680).contains(&width) {
return Err(AppError::BadRequest("Invalid width: must be 320-7680".into()));
}
}
if let Some(height) = self.height {
if !(240..=4320).contains(&height) {
return Err(AppError::BadRequest("Invalid height: must be 240-4320".into()));
}
}
if let Some(fps) = self.fps {
if !(1..=120).contains(&fps) {
return Err(AppError::BadRequest("Invalid fps: must be 1-120".into()));
}
}
if let Some(quality) = self.quality {
if !(1..=100).contains(&quality) {
return Err(AppError::BadRequest("Invalid quality: must be 1-100".into()));
}
}
Ok(())
}
pub fn apply_to(&self, config: &mut VideoConfig) {
if let Some(ref device) = self.device {
config.device = Some(device.clone());
}
if let Some(ref format) = self.format {
config.format = Some(format.clone());
}
if let Some(width) = self.width {
config.width = width;
}
if let Some(height) = self.height {
config.height = height;
}
if let Some(fps) = self.fps {
config.fps = fps;
}
if let Some(quality) = self.quality {
config.quality = quality;
}
}
}
// ===== Stream Config =====
/// Stream 配置响应(包含 has_turn_password 字段)
#[typeshare]
#[derive(Debug, serde::Serialize)]
pub struct StreamConfigResponse {
pub mode: StreamMode,
pub encoder: EncoderType,
pub bitrate_kbps: u32,
pub gop_size: u32,
pub stun_server: Option<String>,
pub turn_server: Option<String>,
pub turn_username: Option<String>,
/// 指示是否已设置 TURN 密码(实际密码不返回)
pub has_turn_password: bool,
}
impl From<&StreamConfig> for StreamConfigResponse {
fn from(config: &StreamConfig) -> Self {
Self {
mode: config.mode.clone(),
encoder: config.encoder.clone(),
bitrate_kbps: config.bitrate_kbps,
gop_size: config.gop_size,
stun_server: config.stun_server.clone(),
turn_server: config.turn_server.clone(),
turn_username: config.turn_username.clone(),
has_turn_password: config.turn_password.is_some(),
}
}
}
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct StreamConfigUpdate {
pub mode: Option<StreamMode>,
pub encoder: Option<EncoderType>,
pub bitrate_kbps: Option<u32>,
pub gop_size: Option<u32>,
/// STUN server URL (e.g., "stun:stun.l.google.com:19302")
pub stun_server: Option<String>,
/// TURN server URL (e.g., "turn:turn.example.com:3478")
pub turn_server: Option<String>,
/// TURN username
pub turn_username: Option<String>,
/// TURN password
pub turn_password: Option<String>,
}
impl StreamConfigUpdate {
pub fn validate(&self) -> crate::error::Result<()> {
if let Some(bitrate) = self.bitrate_kbps {
if !(1000..=15000).contains(&bitrate) {
return Err(AppError::BadRequest("Bitrate must be 1000-15000 kbps".into()));
}
}
if let Some(gop) = self.gop_size {
if !(10..=300).contains(&gop) {
return Err(AppError::BadRequest("GOP size must be 10-300".into()));
}
}
// Validate STUN server format
if let Some(ref stun) = self.stun_server {
if !stun.is_empty() && !stun.starts_with("stun:") {
return Err(AppError::BadRequest(
"STUN server must start with 'stun:' (e.g., stun:stun.l.google.com:19302)".into(),
));
}
}
// Validate TURN server format
if let Some(ref turn) = self.turn_server {
if !turn.is_empty() && !turn.starts_with("turn:") && !turn.starts_with("turns:") {
return Err(AppError::BadRequest(
"TURN server must start with 'turn:' or 'turns:' (e.g., turn:turn.example.com:3478)".into(),
));
}
}
Ok(())
}
pub fn apply_to(&self, config: &mut StreamConfig) {
if let Some(mode) = self.mode.clone() {
config.mode = mode;
}
if let Some(encoder) = self.encoder.clone() {
config.encoder = encoder;
}
if let Some(bitrate) = self.bitrate_kbps {
config.bitrate_kbps = bitrate;
}
if let Some(gop) = self.gop_size {
config.gop_size = gop;
}
// STUN/TURN settings - empty string means clear, Some("value") means set
if let Some(ref stun) = self.stun_server {
config.stun_server = if stun.is_empty() { None } else { Some(stun.clone()) };
}
if let Some(ref turn) = self.turn_server {
config.turn_server = if turn.is_empty() { None } else { Some(turn.clone()) };
}
if let Some(ref username) = self.turn_username {
config.turn_username = if username.is_empty() { None } else { Some(username.clone()) };
}
if let Some(ref password) = self.turn_password {
config.turn_password = if password.is_empty() { None } else { Some(password.clone()) };
}
}
}
// ===== HID Config =====
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct HidConfigUpdate {
pub backend: Option<HidBackend>,
pub ch9329_port: Option<String>,
pub ch9329_baudrate: Option<u32>,
pub otg_udc: Option<String>,
pub mouse_absolute: Option<bool>,
}
impl HidConfigUpdate {
pub fn validate(&self) -> crate::error::Result<()> {
if let Some(baudrate) = self.ch9329_baudrate {
let valid_rates = [9600, 19200, 38400, 57600, 115200];
if !valid_rates.contains(&baudrate) {
return Err(AppError::BadRequest(
"Invalid baudrate: must be 9600, 19200, 38400, 57600, or 115200".into(),
));
}
}
Ok(())
}
pub fn apply_to(&self, config: &mut HidConfig) {
if let Some(backend) = self.backend.clone() {
config.backend = backend;
}
if let Some(ref port) = self.ch9329_port {
config.ch9329_port = port.clone();
}
if let Some(baudrate) = self.ch9329_baudrate {
config.ch9329_baudrate = baudrate;
}
if let Some(ref udc) = self.otg_udc {
config.otg_udc = Some(udc.clone());
}
if let Some(absolute) = self.mouse_absolute {
config.mouse_absolute = absolute;
}
}
}
// ===== MSD Config =====
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct MsdConfigUpdate {
pub enabled: Option<bool>,
pub images_path: Option<String>,
pub drive_path: Option<String>,
pub virtual_drive_size_mb: Option<u32>,
}
impl MsdConfigUpdate {
pub fn validate(&self) -> crate::error::Result<()> {
if let Some(size) = self.virtual_drive_size_mb {
if !(1..=10240).contains(&size) {
return Err(AppError::BadRequest("Drive size must be 1-10240 MB".into()));
}
}
Ok(())
}
pub fn apply_to(&self, config: &mut MsdConfig) {
if let Some(enabled) = self.enabled {
config.enabled = enabled;
}
if let Some(ref path) = self.images_path {
config.images_path = path.clone();
}
if let Some(ref path) = self.drive_path {
config.drive_path = path.clone();
}
if let Some(size) = self.virtual_drive_size_mb {
config.virtual_drive_size_mb = size;
}
}
}
// ===== ATX Config =====
/// Update for a single ATX key configuration
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct AtxKeyConfigUpdate {
pub driver: Option<crate::atx::AtxDriverType>,
pub device: Option<String>,
pub pin: Option<u32>,
pub active_level: Option<crate::atx::ActiveLevel>,
}
/// Update for LED sensing configuration
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct AtxLedConfigUpdate {
pub enabled: Option<bool>,
pub gpio_chip: Option<String>,
pub gpio_pin: Option<u32>,
pub inverted: Option<bool>,
}
/// ATX configuration update request
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct AtxConfigUpdate {
pub enabled: Option<bool>,
/// Power button configuration
pub power: Option<AtxKeyConfigUpdate>,
/// Reset button configuration
pub reset: Option<AtxKeyConfigUpdate>,
/// LED sensing configuration
pub led: Option<AtxLedConfigUpdate>,
/// Network interface for WOL packets (empty = auto)
pub wol_interface: Option<String>,
}
impl AtxConfigUpdate {
pub fn validate(&self) -> crate::error::Result<()> {
// Validate power key config if present
if let Some(ref power) = self.power {
Self::validate_key_config(power, "power")?;
}
// Validate reset key config if present
if let Some(ref reset) = self.reset {
Self::validate_key_config(reset, "reset")?;
}
Ok(())
}
fn validate_key_config(key: &AtxKeyConfigUpdate, name: &str) -> crate::error::Result<()> {
if let Some(ref device) = key.device {
if !device.is_empty() && !std::path::Path::new(device).exists() {
return Err(AppError::BadRequest(format!(
"{} device '{}' does not exist",
name, device
)));
}
}
Ok(())
}
pub fn apply_to(&self, config: &mut AtxConfig) {
if let Some(enabled) = self.enabled {
config.enabled = enabled;
}
if let Some(ref power) = self.power {
Self::apply_key_update(power, &mut config.power);
}
if let Some(ref reset) = self.reset {
Self::apply_key_update(reset, &mut config.reset);
}
if let Some(ref led) = self.led {
Self::apply_led_update(led, &mut config.led);
}
if let Some(ref wol_interface) = self.wol_interface {
config.wol_interface = wol_interface.clone();
}
}
fn apply_key_update(update: &AtxKeyConfigUpdate, config: &mut crate::atx::AtxKeyConfig) {
if let Some(driver) = update.driver {
config.driver = driver;
}
if let Some(ref device) = update.device {
config.device = device.clone();
}
if let Some(pin) = update.pin {
config.pin = pin;
}
if let Some(level) = update.active_level {
config.active_level = level;
}
}
fn apply_led_update(update: &AtxLedConfigUpdate, config: &mut crate::atx::AtxLedConfig) {
if let Some(enabled) = update.enabled {
config.enabled = enabled;
}
if let Some(ref chip) = update.gpio_chip {
config.gpio_chip = chip.clone();
}
if let Some(pin) = update.gpio_pin {
config.gpio_pin = pin;
}
if let Some(inverted) = update.inverted {
config.inverted = inverted;
}
}
}
// ===== Audio Config =====
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct AudioConfigUpdate {
pub enabled: Option<bool>,
pub device: Option<String>,
pub quality: Option<String>,
}
impl AudioConfigUpdate {
pub fn validate(&self) -> crate::error::Result<()> {
if let Some(ref quality) = self.quality {
if !["voice", "balanced", "high"].contains(&quality.as_str()) {
return Err(AppError::BadRequest(
"Invalid quality: must be 'voice', 'balanced', or 'high'".into(),
));
}
}
Ok(())
}
pub fn apply_to(&self, config: &mut AudioConfig) {
if let Some(enabled) = self.enabled {
config.enabled = enabled;
}
if let Some(ref device) = self.device {
config.device = device.clone();
}
if let Some(ref quality) = self.quality {
config.quality = quality.clone();
}
}
}

View File

@@ -0,0 +1,47 @@
//! Video 配置 Handler
use axum::{extract::State, Json};
use std::sync::Arc;
use crate::config::VideoConfig;
use crate::error::Result;
use crate::state::AppState;
use super::apply::apply_video_config;
use super::types::VideoConfigUpdate;
/// 获取 Video 配置
pub async fn get_video_config(State(state): State<Arc<AppState>>) -> Json<VideoConfig> {
Json(state.config.get().video.clone())
}
/// 更新 Video 配置
pub async fn update_video_config(
State(state): State<Arc<AppState>>,
Json(req): Json<VideoConfigUpdate>,
) -> Result<Json<VideoConfig>> {
// 1. 验证请求
req.validate()?;
// 2. 获取旧配置
let old_video_config = state.config.get().video.clone();
// 3. 应用更新到配置存储
state
.config
.update(|config| {
req.apply_to(&mut config.video);
})
.await?;
// 4. 获取新配置
let new_video_config = state.config.get().video.clone();
// 5. 应用到子系统(热重载)
if let Err(e) = apply_video_config(&state, &old_video_config, &new_video_config).await {
tracing::error!("Failed to apply video config: {}", e);
// 根据用户选择,仅记录错误,不回滚
}
Ok(Json(new_video_config))
}

View File

@@ -0,0 +1,14 @@
//! Device discovery handlers
//!
//! Provides API endpoints for discovering available hardware devices.
use axum::Json;
use crate::atx::{discover_devices, AtxDevices};
/// GET /api/devices/atx - List available ATX devices
///
/// Returns lists of available GPIO chips and USB HID relay devices.
pub async fn list_atx_devices() -> Json<AtxDevices> {
Json(discover_devices())
}

View File

@@ -0,0 +1,352 @@
//! Extension management API handlers
use axum::{
extract::{Path, Query, State},
Json,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use typeshare::typeshare;
use crate::error::{AppError, Result};
use crate::extensions::{
EasytierConfig, EasytierInfo, ExtensionId, ExtensionInfo, ExtensionLogs,
ExtensionsStatus, GostcConfig, GostcInfo, TtydConfig, TtydInfo,
};
use crate::state::AppState;
// ============================================================================
// Get all extensions status
// ============================================================================
/// Get status of all extensions
/// GET /api/extensions
pub async fn list_extensions(State(state): State<Arc<AppState>>) -> Json<ExtensionsStatus> {
let config = state.config.get();
let mgr = &state.extensions;
Json(ExtensionsStatus {
ttyd: TtydInfo {
available: mgr.check_available(ExtensionId::Ttyd),
status: mgr.status(ExtensionId::Ttyd).await,
config: config.extensions.ttyd.clone(),
},
gostc: GostcInfo {
available: mgr.check_available(ExtensionId::Gostc),
status: mgr.status(ExtensionId::Gostc).await,
config: config.extensions.gostc.clone(),
},
easytier: EasytierInfo {
available: mgr.check_available(ExtensionId::Easytier),
status: mgr.status(ExtensionId::Easytier).await,
config: config.extensions.easytier.clone(),
},
})
}
// ============================================================================
// Individual extension status
// ============================================================================
/// Get status of a single extension
/// GET /api/extensions/:id
pub async fn get_extension(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> Result<Json<ExtensionInfo>> {
let ext_id: ExtensionId = id
.parse()
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
let mgr = &state.extensions;
Ok(Json(ExtensionInfo {
available: mgr.check_available(ext_id),
status: mgr.status(ext_id).await,
}))
}
// ============================================================================
// Start/Stop extensions
// ============================================================================
/// Start an extension
/// POST /api/extensions/:id/start
pub async fn start_extension(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> Result<Json<ExtensionInfo>> {
let ext_id: ExtensionId = id
.parse()
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
let config = state.config.get();
let mgr = &state.extensions;
// Start the extension
mgr.start(ext_id, &config.extensions)
.await
.map_err(|e| AppError::Internal(e))?;
// Return updated status
Ok(Json(ExtensionInfo {
available: mgr.check_available(ext_id),
status: mgr.status(ext_id).await,
}))
}
/// Stop an extension
/// POST /api/extensions/:id/stop
pub async fn stop_extension(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
) -> Result<Json<ExtensionInfo>> {
let ext_id: ExtensionId = id
.parse()
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
let mgr = &state.extensions;
// Stop the extension
mgr.stop(ext_id)
.await
.map_err(|e| AppError::Internal(e))?;
// Return updated status
Ok(Json(ExtensionInfo {
available: mgr.check_available(ext_id),
status: mgr.status(ext_id).await,
}))
}
// ============================================================================
// Extension logs
// ============================================================================
/// Query parameters for logs
#[derive(Deserialize, Default)]
pub struct LogsQuery {
/// Number of lines to return (default: 100, max: 500)
pub lines: Option<usize>,
}
/// Get extension logs
/// GET /api/extensions/:id/logs
pub async fn get_extension_logs(
State(state): State<Arc<AppState>>,
Path(id): Path<String>,
Query(params): Query<LogsQuery>,
) -> Result<Json<ExtensionLogs>> {
let ext_id: ExtensionId = id
.parse()
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
let lines = params.lines.unwrap_or(100).min(500);
let logs = state.extensions.logs(ext_id, lines).await;
Ok(Json(ExtensionLogs { id: ext_id, logs }))
}
// ============================================================================
// Update extension config
// ============================================================================
/// Update ttyd config
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct TtydConfigUpdate {
pub enabled: Option<bool>,
pub port: Option<u16>,
pub shell: Option<String>,
pub credential: Option<String>,
}
/// Update gostc config
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct GostcConfigUpdate {
pub enabled: Option<bool>,
pub addr: Option<String>,
pub key: Option<String>,
pub tls: Option<bool>,
}
/// Update easytier config
#[typeshare]
#[derive(Debug, Deserialize)]
pub struct EasytierConfigUpdate {
pub enabled: Option<bool>,
pub network_name: Option<String>,
pub network_secret: Option<String>,
pub peer_urls: Option<Vec<String>>,
pub virtual_ip: Option<String>,
}
/// Update ttyd configuration
/// PATCH /api/extensions/ttyd/config
pub async fn update_ttyd_config(
State(state): State<Arc<AppState>>,
Json(req): Json<TtydConfigUpdate>,
) -> Result<Json<TtydConfig>> {
// Get current config
let was_enabled = state.config.get().extensions.ttyd.enabled;
// Update config
state
.config
.update(|config| {
let ttyd = &mut config.extensions.ttyd;
if let Some(enabled) = req.enabled {
ttyd.enabled = enabled;
}
if let Some(port) = req.port {
ttyd.port = port;
}
if let Some(ref shell) = req.shell {
ttyd.shell = shell.clone();
}
if req.credential.is_some() {
ttyd.credential = req.credential.clone();
}
})
.await?;
let new_config = state.config.get();
let is_enabled = new_config.extensions.ttyd.enabled;
// Handle enable/disable state change
if was_enabled && !is_enabled {
// Was running, now disabled - stop it
state.extensions.stop(ExtensionId::Ttyd).await.ok();
} else if !was_enabled && is_enabled {
// Was disabled, now enabled - start it
if state.extensions.check_available(ExtensionId::Ttyd) {
state
.extensions
.start(ExtensionId::Ttyd, &new_config.extensions)
.await
.ok();
}
}
Ok(Json(new_config.extensions.ttyd.clone()))
}
/// Update gostc configuration
/// PATCH /api/extensions/gostc/config
pub async fn update_gostc_config(
State(state): State<Arc<AppState>>,
Json(req): Json<GostcConfigUpdate>,
) -> Result<Json<GostcConfig>> {
let was_enabled = state.config.get().extensions.gostc.enabled;
state
.config
.update(|config| {
let gostc = &mut config.extensions.gostc;
if let Some(enabled) = req.enabled {
gostc.enabled = enabled;
}
if let Some(ref addr) = req.addr {
gostc.addr = addr.clone();
}
if let Some(ref key) = req.key {
gostc.key = key.clone();
}
if let Some(tls) = req.tls {
gostc.tls = tls;
}
})
.await?;
let new_config = state.config.get();
let is_enabled = new_config.extensions.gostc.enabled;
let has_key = !new_config.extensions.gostc.key.is_empty();
if was_enabled && !is_enabled {
state.extensions.stop(ExtensionId::Gostc).await.ok();
} else if !was_enabled && is_enabled && has_key {
if state.extensions.check_available(ExtensionId::Gostc) {
state
.extensions
.start(ExtensionId::Gostc, &new_config.extensions)
.await
.ok();
}
}
Ok(Json(new_config.extensions.gostc.clone()))
}
/// Update easytier configuration
/// PATCH /api/extensions/easytier/config
pub async fn update_easytier_config(
State(state): State<Arc<AppState>>,
Json(req): Json<EasytierConfigUpdate>,
) -> Result<Json<EasytierConfig>> {
let was_enabled = state.config.get().extensions.easytier.enabled;
state
.config
.update(|config| {
let et = &mut config.extensions.easytier;
if let Some(enabled) = req.enabled {
et.enabled = enabled;
}
if let Some(ref name) = req.network_name {
et.network_name = name.clone();
}
if let Some(ref secret) = req.network_secret {
et.network_secret = secret.clone();
}
if let Some(ref peers) = req.peer_urls {
et.peer_urls = peers.clone();
}
if req.virtual_ip.is_some() {
et.virtual_ip = req.virtual_ip.clone();
}
})
.await?;
let new_config = state.config.get();
let is_enabled = new_config.extensions.easytier.enabled;
let has_name = !new_config.extensions.easytier.network_name.is_empty();
if was_enabled && !is_enabled {
state.extensions.stop(ExtensionId::Easytier).await.ok();
} else if !was_enabled && is_enabled && has_name {
if state.extensions.check_available(ExtensionId::Easytier) {
state
.extensions
.start(ExtensionId::Easytier, &new_config.extensions)
.await
.ok();
}
}
Ok(Json(new_config.extensions.easytier.clone()))
}
// ============================================================================
// Ttyd status for console (simplified)
// ============================================================================
/// Simple ttyd status for console view
#[typeshare]
#[derive(Debug, Serialize)]
pub struct TtydStatus {
pub available: bool,
pub running: bool,
}
/// Get ttyd status for console view
/// GET /api/extensions/ttyd/status
pub async fn get_ttyd_status(State(state): State<Arc<AppState>>) -> Json<TtydStatus> {
let mgr = &state.extensions;
let status = mgr.status(ExtensionId::Ttyd).await;
Json(TtydStatus {
available: mgr.check_available(ExtensionId::Ttyd),
running: status.is_running(),
})
}

2583
src/web/handlers/mod.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,239 @@
//! Terminal proxy handler - reverse proxy to ttyd via Unix socket
use axum::{
body::Body,
extract::{
ws::{Message as AxumMessage, WebSocket, WebSocketUpgrade},
OriginalUri, Path, State,
},
http::{Request, StatusCode},
response::Response,
};
use futures::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use tokio_tungstenite::tungstenite::{
client::IntoClientRequest,
http::HeaderValue,
Message as TungsteniteMessage,
};
use crate::error::AppError;
use crate::extensions::TTYD_SOCKET_PATH;
use crate::state::AppState;
/// Handle WebSocket upgrade for terminal
pub async fn terminal_ws(
State(_state): State<Arc<AppState>>,
OriginalUri(original_uri): OriginalUri,
ws: WebSocketUpgrade,
) -> Response {
let query_string = original_uri
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default();
// Use the tty subprotocol that ttyd expects
ws.protocols(["tty"])
.on_upgrade(move |socket| handle_terminal_websocket(socket, query_string))
}
/// Handle terminal WebSocket connection - bridge browser and ttyd
async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) {
// Connect to ttyd Unix socket
let unix_stream = match UnixStream::connect(TTYD_SOCKET_PATH).await {
Ok(s) => s,
Err(e) => {
tracing::error!("Failed to connect to ttyd socket: {}", e);
return;
}
};
// Build WebSocket request for ttyd with tty subprotocol
let uri_str = format!("ws://localhost/api/terminal/ws{}", query_string);
let mut request = match uri_str.into_client_request() {
Ok(r) => r,
Err(e) => {
tracing::error!("Failed to create WebSocket request: {}", e);
return;
}
};
request.headers_mut().insert(
"Sec-WebSocket-Protocol",
HeaderValue::from_static("tty"),
);
// Create WebSocket connection to ttyd
let ws_stream = match tokio_tungstenite::client_async(request, unix_stream).await {
Ok((ws, _)) => ws,
Err(e) => {
tracing::error!("Failed to establish WebSocket with ttyd: {}", e);
return;
}
};
// Split both WebSocket connections
let (mut client_tx, mut client_rx) = client_ws.split();
let (mut ttyd_tx, mut ttyd_rx) = ws_stream.split();
// Forward messages from browser to ttyd
let client_to_ttyd = tokio::spawn(async move {
while let Some(msg) = client_rx.next().await {
let ttyd_msg = match msg {
Ok(AxumMessage::Text(text)) => TungsteniteMessage::Text(text),
Ok(AxumMessage::Binary(data)) => TungsteniteMessage::Binary(data),
Ok(AxumMessage::Ping(data)) => TungsteniteMessage::Ping(data),
Ok(AxumMessage::Pong(data)) => TungsteniteMessage::Pong(data),
Ok(AxumMessage::Close(_)) => {
let _ = ttyd_tx.send(TungsteniteMessage::Close(None)).await;
break;
}
Err(_) => break,
};
if ttyd_tx.send(ttyd_msg).await.is_err() {
break;
}
}
});
// Forward messages from ttyd to browser
let ttyd_to_client = tokio::spawn(async move {
while let Some(msg) = ttyd_rx.next().await {
let client_msg = match msg {
Ok(TungsteniteMessage::Text(text)) => AxumMessage::Text(text),
Ok(TungsteniteMessage::Binary(data)) => AxumMessage::Binary(data),
Ok(TungsteniteMessage::Ping(data)) => AxumMessage::Ping(data),
Ok(TungsteniteMessage::Pong(data)) => AxumMessage::Pong(data),
Ok(TungsteniteMessage::Close(_)) => {
let _ = client_tx.send(AxumMessage::Close(None)).await;
break;
}
Ok(TungsteniteMessage::Frame(_)) => continue,
Err(_) => break,
};
if client_tx.send(client_msg).await.is_err() {
break;
}
}
});
// Wait for either direction to complete
tokio::select! {
_ = client_to_ttyd => {}
_ = ttyd_to_client => {}
}
}
/// Proxy HTTP requests to ttyd
pub async fn terminal_proxy(
State(_state): State<Arc<AppState>>,
path: Option<Path<String>>,
req: Request<Body>,
) -> Result<Response, AppError> {
let path_str = path.map(|p| p.0).unwrap_or_default();
// Connect to ttyd Unix socket
let mut unix_stream = UnixStream::connect(TTYD_SOCKET_PATH)
.await
.map_err(|e| AppError::ServiceUnavailable(format!("ttyd not running: {}", e)))?;
// Build HTTP request to forward
let method = req.method().as_str();
let query = req.uri().query().map(|q| format!("?{}", q)).unwrap_or_default();
let uri_path = if path_str.is_empty() {
format!("/api/terminal/{}", query)
} else {
format!("/api/terminal/{}{}", path_str, query)
};
// Forward relevant headers
let mut headers_str = String::new();
for (name, value) in req.headers() {
if let Ok(v) = value.to_str() {
let name_lower = name.as_str().to_lowercase();
if !matches!(
name_lower.as_str(),
"connection" | "keep-alive" | "transfer-encoding" | "upgrade"
) {
headers_str.push_str(&format!("{}: {}\r\n", name, v));
}
}
}
let http_request = format!(
"{} {} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n{}\r\n",
method, uri_path, headers_str
);
// Send request
unix_stream
.write_all(http_request.as_bytes())
.await
.map_err(|e| AppError::Internal(format!("Failed to send request: {}", e)))?;
// Read response
let mut response_buf = Vec::new();
unix_stream
.read_to_end(&mut response_buf)
.await
.map_err(|e| AppError::Internal(format!("Failed to read response: {}", e)))?;
// Parse HTTP response
let response_str = String::from_utf8_lossy(&response_buf);
let header_end = response_str
.find("\r\n\r\n")
.ok_or_else(|| AppError::Internal("Invalid HTTP response".to_string()))?;
let headers_part = &response_str[..header_end];
let body_start = header_end + 4;
// Parse status line
let status_line = headers_part
.lines()
.next()
.ok_or_else(|| AppError::Internal("Missing status line".to_string()))?;
let status_code: u16 = status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse().ok())
.unwrap_or(200);
// Build response
let mut builder = Response::builder().status(StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK));
// Forward response headers
for line in headers_part.lines().skip(1) {
if let Some((name, value)) = line.split_once(':') {
let name = name.trim();
let value = value.trim();
if !matches!(
name.to_lowercase().as_str(),
"connection" | "keep-alive" | "transfer-encoding"
) {
builder = builder.header(name, value);
}
}
}
let body = if body_start < response_buf.len() {
Body::from(response_buf[body_start..].to_vec())
} else {
Body::empty()
};
builder
.body(body)
.map_err(|e| AppError::Internal(format!("Failed to build response: {}", e)))
}
/// Terminal index page
pub async fn terminal_index(
State(state): State<Arc<AppState>>,
req: Request<Body>,
) -> Result<Response, AppError> {
terminal_proxy(State(state), None, req).await
}

12
src/web/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
mod audio_ws;
mod routes;
mod handlers;
mod static_files;
mod ws;
pub use audio_ws::audio_ws_handler;
pub use routes::create_router;
// StaticAssets is only available in release mode (embedded assets)
#[cfg(not(debug_assertions))]
pub use static_files::StaticAssets;
pub use ws::ws_handler;

178
src/web/routes.rs Normal file
View File

@@ -0,0 +1,178 @@
use axum::{
extract::DefaultBodyLimit,
middleware,
routing::{any, delete, get, patch, post, put},
Router,
};
use std::sync::Arc;
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use super::audio_ws::audio_ws_handler;
use super::handlers;
use super::ws::ws_handler;
use crate::auth::{auth_middleware, require_admin};
use crate::hid::websocket::ws_hid_handler;
use crate::state::AppState;
/// Create the main application router
pub fn create_router(state: Arc<AppState>) -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// Public routes (no auth required)
// Note: /info moved to user_routes for security (contains hostname, IPs, etc.)
let public_routes = Router::new()
.route("/health", get(handlers::health_check))
.route("/auth/login", post(handlers::login))
.route("/setup", get(handlers::setup_status))
.route("/setup/init", post(handlers::setup_init));
// User routes (authenticated users - both regular and admin)
let user_routes = Router::new()
.route("/info", get(handlers::system_info))
.route("/auth/logout", post(handlers::logout))
.route("/auth/check", get(handlers::auth_check))
.route("/devices", get(handlers::list_devices))
// WebSocket endpoint for real-time events
.route("/ws", any(ws_handler))
// Stream control endpoints
.route("/stream/status", get(handlers::stream_state))
.route("/stream/start", post(handlers::stream_start))
.route("/stream/stop", post(handlers::stream_stop))
.route("/stream/mode", get(handlers::stream_mode_get))
.route("/stream/mode", post(handlers::stream_mode_set))
.route("/stream/bitrate", post(handlers::stream_set_bitrate))
.route("/stream/codecs", get(handlers::stream_codecs_list))
// WebRTC endpoints
.route("/webrtc/session", post(handlers::webrtc_create_session))
.route("/webrtc/offer", post(handlers::webrtc_offer))
.route("/webrtc/ice", post(handlers::webrtc_ice_candidate))
.route("/webrtc/ice-servers", get(handlers::webrtc_ice_servers))
.route("/webrtc/status", get(handlers::webrtc_status))
.route("/webrtc/close", post(handlers::webrtc_close_session))
// HID endpoints
.route("/hid/status", get(handlers::hid_status))
.route("/hid/reset", post(handlers::hid_reset))
// WebSocket HID endpoint (for MJPEG mode)
.route("/ws/hid", any(ws_hid_handler))
// Audio endpoints
.route("/audio/status", get(handlers::audio_status))
.route("/audio/start", post(handlers::start_audio_streaming))
.route("/audio/stop", post(handlers::stop_audio_streaming))
.route("/audio/quality", post(handlers::set_audio_quality))
.route("/audio/device", post(handlers::select_audio_device))
.route("/audio/devices", get(handlers::list_audio_devices))
// Audio WebSocket endpoint
.route("/ws/audio", any(audio_ws_handler))
// User can change their own password (handler will check ownership)
.route("/users/:id/password", post(handlers::change_user_password));
// Admin-only routes (require admin privileges)
let admin_routes = Router::new()
// Configuration management (domain-separated endpoints)
.route("/config", get(handlers::config::get_all_config))
.route("/config", post(handlers::update_config))
.route("/config/video", get(handlers::config::get_video_config))
.route("/config/video", patch(handlers::config::update_video_config))
.route("/config/stream", get(handlers::config::get_stream_config))
.route("/config/stream", patch(handlers::config::update_stream_config))
.route("/config/hid", get(handlers::config::get_hid_config))
.route("/config/hid", patch(handlers::config::update_hid_config))
.route("/config/msd", get(handlers::config::get_msd_config))
.route("/config/msd", patch(handlers::config::update_msd_config))
.route("/config/atx", get(handlers::config::get_atx_config))
.route("/config/atx", patch(handlers::config::update_atx_config))
.route("/config/audio", get(handlers::config::get_audio_config))
.route("/config/audio", patch(handlers::config::update_audio_config))
// MSD (Mass Storage Device) endpoints
.route("/msd/status", get(handlers::msd_status))
.route("/msd/images", get(handlers::msd_images_list))
.route("/msd/images/download", post(handlers::msd_image_download))
.route("/msd/images/download/cancel", post(handlers::msd_image_download_cancel))
.route("/msd/images/:id", get(handlers::msd_image_get))
.route("/msd/images/:id", delete(handlers::msd_image_delete))
.route("/msd/connect", post(handlers::msd_connect))
.route("/msd/disconnect", post(handlers::msd_disconnect))
// MSD Virtual Drive endpoints
.route("/msd/drive", get(handlers::msd_drive_info))
.route("/msd/drive", delete(handlers::msd_drive_delete))
.route("/msd/drive/init", post(handlers::msd_drive_init))
.route("/msd/drive/files", get(handlers::msd_drive_files))
.route("/msd/drive/files/*path", get(handlers::msd_drive_download))
.route("/msd/drive/files/*path", delete(handlers::msd_drive_file_delete))
.route("/msd/drive/mkdir/*path", post(handlers::msd_drive_mkdir))
// ATX (Power Control) endpoints
.route("/atx/status", get(handlers::atx_status))
.route("/atx/power", post(handlers::atx_power))
.route("/atx/wol", post(handlers::atx_wol))
// Device discovery endpoints
.route("/devices/atx", get(handlers::devices::list_atx_devices))
// User management endpoints
.route("/users", get(handlers::list_users))
.route("/users", post(handlers::create_user))
.route("/users/:id", put(handlers::update_user))
.route("/users/:id", delete(handlers::delete_user))
// Extension management endpoints
.route("/extensions", get(handlers::extensions::list_extensions))
.route("/extensions/:id", get(handlers::extensions::get_extension))
.route("/extensions/:id/start", post(handlers::extensions::start_extension))
.route("/extensions/:id/stop", post(handlers::extensions::stop_extension))
.route("/extensions/:id/logs", get(handlers::extensions::get_extension_logs))
.route("/extensions/ttyd/config", patch(handlers::extensions::update_ttyd_config))
.route("/extensions/ttyd/status", get(handlers::extensions::get_ttyd_status))
.route("/extensions/gostc/config", patch(handlers::extensions::update_gostc_config))
.route("/extensions/easytier/config", patch(handlers::extensions::update_easytier_config))
// Terminal (ttyd) reverse proxy - WebSocket and HTTP
.route("/terminal", get(handlers::terminal::terminal_index))
.route("/terminal/", get(handlers::terminal::terminal_index))
.route("/terminal/ws", get(handlers::terminal::terminal_ws))
.route("/terminal/*path", get(handlers::terminal::terminal_proxy))
// Apply admin middleware to all admin routes
.layer(middleware::from_fn_with_state(state.clone(), require_admin));
// Combine protected routes (user + admin)
let protected_routes = Router::new()
.merge(user_routes)
.merge(admin_routes);
// Stream endpoints (accessible with auth, but typically embedded in pages)
let stream_routes = Router::new()
.route("/stream", get(handlers::mjpeg_stream))
.route("/stream/mjpeg", get(handlers::mjpeg_stream))
.route("/snapshot", get(handlers::snapshot));
// Large file upload routes (MSD images and drive files)
// Use streaming upload to support files larger than available RAM
// Disable body limit for streaming uploads - files are written directly to disk
let upload_routes = Router::new()
.route("/msd/images", post(handlers::msd_image_upload))
.route("/msd/drive/files", post(handlers::msd_drive_upload))
.layer(DefaultBodyLimit::disable());
// Combine API routes
let api_routes = Router::new()
.merge(public_routes)
.merge(protected_routes)
.merge(stream_routes)
.merge(upload_routes)
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
));
// Static file serving
let static_routes = super::static_files::static_file_router();
// Main router
Router::new()
.nest("/api", api_routes)
.merge(static_routes)
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)
}

Some files were not shown because too many files have changed in this diff Show More