mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2026-01-28 16:41:52 +08:00
init
This commit is contained in:
356
src/atx/controller.rs
Normal file
356
src/atx/controller.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
//! ATX Controller
|
||||
//!
|
||||
//! High-level controller for ATX power management with flexible hardware binding.
|
||||
//! Each action (power short, power long, reset) can be configured independently.
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::executor::{timing, AtxKeyExecutor};
|
||||
use super::led::LedSensor;
|
||||
use super::types::{AtxKeyConfig, AtxLedConfig, AtxState, PowerStatus};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// ATX power control configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AtxControllerConfig {
|
||||
/// Whether ATX is enabled
|
||||
pub enabled: bool,
|
||||
/// Power button configuration (used for both short and long press)
|
||||
pub power: AtxKeyConfig,
|
||||
/// Reset button configuration
|
||||
pub reset: AtxKeyConfig,
|
||||
/// LED sensing configuration
|
||||
pub led: AtxLedConfig,
|
||||
}
|
||||
|
||||
impl Default for AtxControllerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
power: AtxKeyConfig::default(),
|
||||
reset: AtxKeyConfig::default(),
|
||||
led: AtxLedConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal state holding all ATX components
|
||||
/// Grouped together to reduce lock acquisitions
|
||||
struct AtxInner {
|
||||
config: AtxControllerConfig,
|
||||
power_executor: Option<AtxKeyExecutor>,
|
||||
reset_executor: Option<AtxKeyExecutor>,
|
||||
led_sensor: Option<LedSensor>,
|
||||
}
|
||||
|
||||
/// ATX Controller
|
||||
///
|
||||
/// Manages ATX power control through independent executors for each action.
|
||||
/// Supports hot-reload of configuration.
|
||||
pub struct AtxController {
|
||||
/// Single lock for all internal state to reduce lock contention
|
||||
inner: RwLock<AtxInner>,
|
||||
}
|
||||
|
||||
impl AtxController {
|
||||
/// Create a new ATX controller with the specified configuration
|
||||
pub fn new(config: AtxControllerConfig) -> Self {
|
||||
Self {
|
||||
inner: RwLock::new(AtxInner {
|
||||
config,
|
||||
power_executor: None,
|
||||
reset_executor: None,
|
||||
led_sensor: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a disabled ATX controller
|
||||
pub fn disabled() -> Self {
|
||||
Self::new(AtxControllerConfig::default())
|
||||
}
|
||||
|
||||
/// Initialize the ATX controller and its executors
|
||||
pub async fn init(&self) -> Result<()> {
|
||||
let mut inner = self.inner.write().await;
|
||||
|
||||
if !inner.config.enabled {
|
||||
info!("ATX disabled in configuration");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Initializing ATX controller");
|
||||
|
||||
// Initialize power executor
|
||||
if inner.config.power.is_configured() {
|
||||
let mut executor = AtxKeyExecutor::new(inner.config.power.clone());
|
||||
if let Err(e) = executor.init().await {
|
||||
warn!("Failed to initialize power executor: {}", e);
|
||||
} else {
|
||||
info!(
|
||||
"Power executor initialized: {:?} on {} pin {}",
|
||||
inner.config.power.driver, inner.config.power.device, inner.config.power.pin
|
||||
);
|
||||
inner.power_executor = Some(executor);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize reset executor
|
||||
if inner.config.reset.is_configured() {
|
||||
let mut executor = AtxKeyExecutor::new(inner.config.reset.clone());
|
||||
if let Err(e) = executor.init().await {
|
||||
warn!("Failed to initialize reset executor: {}", e);
|
||||
} else {
|
||||
info!(
|
||||
"Reset executor initialized: {:?} on {} pin {}",
|
||||
inner.config.reset.driver, inner.config.reset.device, inner.config.reset.pin
|
||||
);
|
||||
inner.reset_executor = Some(executor);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize LED sensor
|
||||
if inner.config.led.is_configured() {
|
||||
let mut sensor = LedSensor::new(inner.config.led.clone());
|
||||
if let Err(e) = sensor.init().await {
|
||||
warn!("Failed to initialize LED sensor: {}", e);
|
||||
} else {
|
||||
info!(
|
||||
"LED sensor initialized on {} pin {}",
|
||||
inner.config.led.gpio_chip, inner.config.led.gpio_pin
|
||||
);
|
||||
inner.led_sensor = Some(sensor);
|
||||
}
|
||||
}
|
||||
|
||||
info!("ATX controller initialized successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reload the ATX controller with new configuration
|
||||
///
|
||||
/// This is called when configuration changes and supports hot-reload.
|
||||
pub async fn reload(&self, new_config: AtxControllerConfig) -> Result<()> {
|
||||
info!("Reloading ATX controller with new configuration");
|
||||
|
||||
// Shutdown existing executors
|
||||
self.shutdown_internal().await?;
|
||||
|
||||
// Update configuration and re-initialize
|
||||
{
|
||||
let mut inner = self.inner.write().await;
|
||||
inner.config = new_config;
|
||||
}
|
||||
|
||||
// Re-initialize
|
||||
self.init().await?;
|
||||
|
||||
info!("ATX controller reloaded successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current ATX state (single lock acquisition)
|
||||
pub async fn state(&self) -> AtxState {
|
||||
let inner = self.inner.read().await;
|
||||
|
||||
let power_status = if let Some(sensor) = inner.led_sensor.as_ref() {
|
||||
sensor.read().await.unwrap_or(PowerStatus::Unknown)
|
||||
} else {
|
||||
PowerStatus::Unknown
|
||||
};
|
||||
|
||||
AtxState {
|
||||
available: inner.config.enabled,
|
||||
power_configured: inner
|
||||
.power_executor
|
||||
.as_ref()
|
||||
.map(|e| e.is_initialized())
|
||||
.unwrap_or(false),
|
||||
reset_configured: inner
|
||||
.reset_executor
|
||||
.as_ref()
|
||||
.map(|e| e.is_initialized())
|
||||
.unwrap_or(false),
|
||||
power_status,
|
||||
led_supported: inner
|
||||
.led_sensor
|
||||
.as_ref()
|
||||
.map(|s| s.is_initialized())
|
||||
.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current state as SystemEvent
|
||||
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
|
||||
let state = self.state().await;
|
||||
crate::events::SystemEvent::AtxStateChanged {
|
||||
power_status: state.power_status,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if ATX is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
let inner = self.inner.read().await;
|
||||
inner.config.enabled
|
||||
}
|
||||
|
||||
/// Check if power button is configured and initialized
|
||||
pub async fn is_power_ready(&self) -> bool {
|
||||
let inner = self.inner.read().await;
|
||||
inner
|
||||
.power_executor
|
||||
.as_ref()
|
||||
.map(|e| e.is_initialized())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Check if reset button is configured and initialized
|
||||
pub async fn is_reset_ready(&self) -> bool {
|
||||
let inner = self.inner.read().await;
|
||||
inner
|
||||
.reset_executor
|
||||
.as_ref()
|
||||
.map(|e| e.is_initialized())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Short press power button (turn on or graceful shutdown)
|
||||
pub async fn power_short(&self) -> Result<()> {
|
||||
let inner = self.inner.read().await;
|
||||
let executor = inner
|
||||
.power_executor
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::Internal("Power button not configured".to_string()))?;
|
||||
|
||||
info!(
|
||||
"ATX: Short press power button ({}ms)",
|
||||
timing::SHORT_PRESS.as_millis()
|
||||
);
|
||||
executor.pulse(timing::SHORT_PRESS).await
|
||||
}
|
||||
|
||||
/// Long press power button (force power off)
|
||||
pub async fn power_long(&self) -> Result<()> {
|
||||
let inner = self.inner.read().await;
|
||||
let executor = inner
|
||||
.power_executor
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::Internal("Power button not configured".to_string()))?;
|
||||
|
||||
info!(
|
||||
"ATX: Long press power button ({}ms)",
|
||||
timing::LONG_PRESS.as_millis()
|
||||
);
|
||||
executor.pulse(timing::LONG_PRESS).await
|
||||
}
|
||||
|
||||
/// Press reset button
|
||||
pub async fn reset(&self) -> Result<()> {
|
||||
let inner = self.inner.read().await;
|
||||
let executor = inner
|
||||
.reset_executor
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::Internal("Reset button not configured".to_string()))?;
|
||||
|
||||
info!(
|
||||
"ATX: Press reset button ({}ms)",
|
||||
timing::RESET_PRESS.as_millis()
|
||||
);
|
||||
executor.pulse(timing::RESET_PRESS).await
|
||||
}
|
||||
|
||||
/// Get current power status from LED sensor
|
||||
pub async fn power_status(&self) -> Result<PowerStatus> {
|
||||
let inner = self.inner.read().await;
|
||||
match inner.led_sensor.as_ref() {
|
||||
Some(sensor) => sensor.read().await,
|
||||
None => Ok(PowerStatus::Unknown),
|
||||
}
|
||||
}
|
||||
|
||||
/// Shutdown the ATX controller
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down ATX controller");
|
||||
self.shutdown_internal().await?;
|
||||
info!("ATX controller shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Internal shutdown helper
|
||||
async fn shutdown_internal(&self) -> Result<()> {
|
||||
let mut inner = self.inner.write().await;
|
||||
|
||||
// Shutdown power executor
|
||||
if let Some(mut executor) = inner.power_executor.take() {
|
||||
executor.shutdown().await.ok();
|
||||
}
|
||||
|
||||
// Shutdown reset executor
|
||||
if let Some(mut executor) = inner.reset_executor.take() {
|
||||
executor.shutdown().await.ok();
|
||||
}
|
||||
|
||||
// Shutdown LED sensor
|
||||
if let Some(mut sensor) = inner.led_sensor.take() {
|
||||
sensor.shutdown().await.ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AtxController {
|
||||
fn drop(&mut self) {
|
||||
debug!("ATX controller dropped");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_controller_config_default() {
|
||||
let config = AtxControllerConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert!(!config.power.is_configured());
|
||||
assert!(!config.reset.is_configured());
|
||||
assert!(!config.led.is_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_controller_creation() {
|
||||
let controller = AtxController::disabled();
|
||||
assert!(controller.inner.try_read().is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_controller_disabled_state() {
|
||||
let controller = AtxController::disabled();
|
||||
let state = controller.state().await;
|
||||
assert!(!state.available);
|
||||
assert!(!state.power_configured);
|
||||
assert!(!state.reset_configured);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_controller_init_disabled() {
|
||||
let controller = AtxController::disabled();
|
||||
let result = controller.init().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_controller_is_available() {
|
||||
let controller = AtxController::disabled();
|
||||
assert!(!controller.is_available().await);
|
||||
|
||||
let config = AtxControllerConfig {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
};
|
||||
let controller = AtxController::new(config);
|
||||
assert!(controller.is_available().await);
|
||||
}
|
||||
}
|
||||
305
src/atx/executor.rs
Normal file
305
src/atx/executor.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
//! ATX Key Executor
|
||||
//!
|
||||
//! Lightweight executor for a single ATX key operation.
|
||||
//! Each executor handles one button (power or reset) with its own hardware binding.
|
||||
|
||||
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::types::{ActiveLevel, AtxDriverType, AtxKeyConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Timing constants for ATX operations
|
||||
pub mod timing {
|
||||
use std::time::Duration;
|
||||
|
||||
/// Short press duration (power on/graceful shutdown)
|
||||
pub const SHORT_PRESS: Duration = Duration::from_millis(500);
|
||||
|
||||
/// Long press duration (force power off)
|
||||
pub const LONG_PRESS: Duration = Duration::from_millis(5000);
|
||||
|
||||
/// Reset press duration
|
||||
pub const RESET_PRESS: Duration = Duration::from_millis(500);
|
||||
}
|
||||
|
||||
/// Executor for a single ATX key operation
|
||||
///
|
||||
/// Each executor manages one hardware button (power or reset).
|
||||
/// It handles both GPIO and USB relay backends.
|
||||
pub struct AtxKeyExecutor {
|
||||
config: AtxKeyConfig,
|
||||
gpio_handle: Mutex<Option<LineHandle>>,
|
||||
/// Cached USB relay file handle to avoid repeated open/close syscalls
|
||||
usb_relay_handle: Mutex<Option<File>>,
|
||||
initialized: AtomicBool,
|
||||
}
|
||||
|
||||
impl AtxKeyExecutor {
|
||||
/// Create a new executor with the given configuration
|
||||
pub fn new(config: AtxKeyConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
gpio_handle: Mutex::new(None),
|
||||
usb_relay_handle: Mutex::new(None),
|
||||
initialized: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this executor is configured
|
||||
pub fn is_configured(&self) -> bool {
|
||||
self.config.is_configured()
|
||||
}
|
||||
|
||||
/// Check if this executor is initialized
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.initialized.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Initialize the executor
|
||||
pub async fn init(&mut self) -> Result<()> {
|
||||
if !self.config.is_configured() {
|
||||
debug!("ATX key executor not configured, skipping init");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.config.driver {
|
||||
AtxDriverType::Gpio => self.init_gpio().await?,
|
||||
AtxDriverType::UsbRelay => self.init_usb_relay().await?,
|
||||
AtxDriverType::None => {}
|
||||
}
|
||||
|
||||
self.initialized.store(true, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize GPIO backend
|
||||
async fn init_gpio(&mut self) -> Result<()> {
|
||||
info!(
|
||||
"Initializing GPIO ATX executor on {} pin {}",
|
||||
self.config.device, self.config.pin
|
||||
);
|
||||
|
||||
let mut chip = Chip::new(&self.config.device)
|
||||
.map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?;
|
||||
|
||||
let line = chip.get_line(self.config.pin).map_err(|e| {
|
||||
AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e))
|
||||
})?;
|
||||
|
||||
// Initial value depends on active level (start in inactive state)
|
||||
let initial_value = match self.config.active_level {
|
||||
ActiveLevel::High => 0, // Inactive = low
|
||||
ActiveLevel::Low => 1, // Inactive = high
|
||||
};
|
||||
|
||||
let handle = line
|
||||
.request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx")
|
||||
.map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?;
|
||||
|
||||
*self.gpio_handle.lock().unwrap() = Some(handle);
|
||||
debug!("GPIO pin {} configured successfully", self.config.pin);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize USB relay backend
|
||||
async fn init_usb_relay(&self) -> Result<()> {
|
||||
info!(
|
||||
"Initializing USB relay ATX executor on {} channel {}",
|
||||
self.config.device, self.config.pin
|
||||
);
|
||||
|
||||
// Open and cache the device handle
|
||||
let device = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(&self.config.device)
|
||||
.map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?;
|
||||
|
||||
*self.usb_relay_handle.lock().unwrap() = Some(device);
|
||||
|
||||
// Ensure relay is off initially
|
||||
self.send_usb_relay_command(false)?;
|
||||
|
||||
debug!(
|
||||
"USB relay channel {} configured successfully",
|
||||
self.config.pin
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Pulse the button for the specified duration
|
||||
pub async fn pulse(&self, duration: Duration) -> Result<()> {
|
||||
if !self.is_configured() {
|
||||
return Err(AppError::Internal("ATX key not configured".to_string()));
|
||||
}
|
||||
|
||||
if !self.is_initialized() {
|
||||
return Err(AppError::Internal("ATX key not initialized".to_string()));
|
||||
}
|
||||
|
||||
match self.config.driver {
|
||||
AtxDriverType::Gpio => self.pulse_gpio(duration).await,
|
||||
AtxDriverType::UsbRelay => self.pulse_usb_relay(duration).await,
|
||||
AtxDriverType::None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Pulse GPIO pin
|
||||
async fn pulse_gpio(&self, duration: Duration) -> Result<()> {
|
||||
let (active, inactive) = match self.config.active_level {
|
||||
ActiveLevel::High => (1u8, 0u8),
|
||||
ActiveLevel::Low => (0u8, 1u8),
|
||||
};
|
||||
|
||||
// Set to active state
|
||||
{
|
||||
let guard = self.gpio_handle.lock().unwrap();
|
||||
let handle = guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?;
|
||||
handle
|
||||
.set_value(active)
|
||||
.map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?;
|
||||
}
|
||||
|
||||
// Wait for duration (no lock held)
|
||||
sleep(duration).await;
|
||||
|
||||
// Set to inactive state
|
||||
{
|
||||
let guard = self.gpio_handle.lock().unwrap();
|
||||
if let Some(handle) = guard.as_ref() {
|
||||
handle.set_value(inactive).ok();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Pulse USB relay
|
||||
async fn pulse_usb_relay(&self, duration: Duration) -> Result<()> {
|
||||
// Turn relay on
|
||||
self.send_usb_relay_command(true)?;
|
||||
|
||||
// Wait for duration
|
||||
sleep(duration).await;
|
||||
|
||||
// Turn relay off
|
||||
self.send_usb_relay_command(false)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send USB relay command using cached handle
|
||||
fn send_usb_relay_command(&self, on: bool) -> Result<()> {
|
||||
let channel = self.config.pin as u8;
|
||||
|
||||
// Standard HID relay command format
|
||||
let cmd = if on {
|
||||
[0x00, channel + 1, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
} else {
|
||||
[0x00, channel + 1, 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
};
|
||||
|
||||
let mut guard = self.usb_relay_handle.lock().unwrap();
|
||||
let device = guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
|
||||
|
||||
device
|
||||
.write_all(&cmd)
|
||||
.map_err(|e| AppError::Internal(format!("USB relay write failed: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the executor
|
||||
pub async fn shutdown(&mut self) -> Result<()> {
|
||||
if !self.is_initialized() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.config.driver {
|
||||
AtxDriverType::Gpio => {
|
||||
// Release GPIO handle
|
||||
*self.gpio_handle.lock().unwrap() = None;
|
||||
}
|
||||
AtxDriverType::UsbRelay => {
|
||||
// Ensure relay is off before closing handle
|
||||
let _ = self.send_usb_relay_command(false);
|
||||
// Release USB relay handle
|
||||
*self.usb_relay_handle.lock().unwrap() = None;
|
||||
}
|
||||
AtxDriverType::None => {}
|
||||
}
|
||||
|
||||
self.initialized.store(false, Ordering::Relaxed);
|
||||
debug!("ATX key executor shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AtxKeyExecutor {
|
||||
fn drop(&mut self) {
|
||||
// Ensure GPIO lines are released
|
||||
*self.gpio_handle.lock().unwrap() = None;
|
||||
|
||||
// Ensure USB relay is off and handle released
|
||||
if self.config.driver == AtxDriverType::UsbRelay && self.is_initialized() {
|
||||
let _ = self.send_usb_relay_command(false);
|
||||
}
|
||||
*self.usb_relay_handle.lock().unwrap() = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_executor_creation() {
|
||||
let config = AtxKeyConfig::default();
|
||||
let executor = AtxKeyExecutor::new(config);
|
||||
assert!(!executor.is_configured());
|
||||
assert!(!executor.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_with_gpio_config() {
|
||||
let config = AtxKeyConfig {
|
||||
driver: AtxDriverType::Gpio,
|
||||
device: "/dev/gpiochip0".to_string(),
|
||||
pin: 5,
|
||||
active_level: ActiveLevel::High,
|
||||
};
|
||||
let executor = AtxKeyExecutor::new(config);
|
||||
assert!(executor.is_configured());
|
||||
assert!(!executor.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_executor_with_usb_relay_config() {
|
||||
let config = AtxKeyConfig {
|
||||
driver: AtxDriverType::UsbRelay,
|
||||
device: "/dev/hidraw0".to_string(),
|
||||
pin: 0,
|
||||
active_level: ActiveLevel::High, // Ignored for USB relay
|
||||
};
|
||||
let executor = AtxKeyExecutor::new(config);
|
||||
assert!(executor.is_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timing_constants() {
|
||||
assert_eq!(timing::SHORT_PRESS.as_millis(), 500);
|
||||
assert_eq!(timing::LONG_PRESS.as_millis(), 5000);
|
||||
assert_eq!(timing::RESET_PRESS.as_millis(), 500);
|
||||
}
|
||||
}
|
||||
154
src/atx/led.rs
Normal file
154
src/atx/led.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
//! ATX LED Sensor
|
||||
//!
|
||||
//! Reads power LED status from GPIO to determine if the target system is powered on.
|
||||
|
||||
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::types::{AtxLedConfig, PowerStatus};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// LED sensor for reading power status
|
||||
///
|
||||
/// Uses GPIO to read the power LED state and determine if the system is on or off.
|
||||
pub struct LedSensor {
|
||||
config: AtxLedConfig,
|
||||
handle: Mutex<Option<LineHandle>>,
|
||||
initialized: AtomicBool,
|
||||
}
|
||||
|
||||
impl LedSensor {
|
||||
/// Create a new LED sensor with the given configuration
|
||||
pub fn new(config: AtxLedConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
handle: Mutex::new(None),
|
||||
initialized: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the sensor is configured
|
||||
pub fn is_configured(&self) -> bool {
|
||||
self.config.is_configured()
|
||||
}
|
||||
|
||||
/// Check if the sensor is initialized
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.initialized.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Initialize the LED sensor
|
||||
pub async fn init(&mut self) -> Result<()> {
|
||||
if !self.config.is_configured() {
|
||||
debug!("LED sensor not configured, skipping init");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Initializing LED sensor on {} pin {}",
|
||||
self.config.gpio_chip, self.config.gpio_pin
|
||||
);
|
||||
|
||||
let mut chip = Chip::new(&self.config.gpio_chip)
|
||||
.map_err(|e| AppError::Internal(format!("LED GPIO chip failed: {}", e)))?;
|
||||
|
||||
let line = chip.get_line(self.config.gpio_pin).map_err(|e| {
|
||||
AppError::Internal(format!("LED GPIO line {} failed: {}", self.config.gpio_pin, e))
|
||||
})?;
|
||||
|
||||
let handle = line
|
||||
.request(LineRequestFlags::INPUT, 0, "one-kvm-led")
|
||||
.map_err(|e| AppError::Internal(format!("LED GPIO request failed: {}", e)))?;
|
||||
|
||||
*self.handle.lock().unwrap() = Some(handle);
|
||||
self.initialized.store(true, Ordering::Relaxed);
|
||||
|
||||
debug!("LED sensor initialized successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read the current power status
|
||||
pub async fn read(&self) -> Result<PowerStatus> {
|
||||
if !self.is_configured() || !self.is_initialized() {
|
||||
return Ok(PowerStatus::Unknown);
|
||||
}
|
||||
|
||||
let guard = self.handle.lock().unwrap();
|
||||
match guard.as_ref() {
|
||||
Some(handle) => {
|
||||
let value = handle
|
||||
.get_value()
|
||||
.map_err(|e| AppError::Internal(format!("LED read failed: {}", e)))?;
|
||||
|
||||
// Apply inversion if configured
|
||||
let is_on = if self.config.inverted {
|
||||
value == 0 // Active low: 0 means on
|
||||
} else {
|
||||
value == 1 // Active high: 1 means on
|
||||
};
|
||||
|
||||
Ok(if is_on {
|
||||
PowerStatus::On
|
||||
} else {
|
||||
PowerStatus::Off
|
||||
})
|
||||
}
|
||||
None => Ok(PowerStatus::Unknown),
|
||||
}
|
||||
}
|
||||
|
||||
/// Shutdown the LED sensor
|
||||
pub async fn shutdown(&mut self) -> Result<()> {
|
||||
*self.handle.lock().unwrap() = None;
|
||||
self.initialized.store(false, Ordering::Relaxed);
|
||||
debug!("LED sensor shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LedSensor {
|
||||
fn drop(&mut self) {
|
||||
*self.handle.lock().unwrap() = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_led_sensor_creation() {
|
||||
let config = AtxLedConfig::default();
|
||||
let sensor = LedSensor::new(config);
|
||||
assert!(!sensor.is_configured());
|
||||
assert!(!sensor.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_led_sensor_with_config() {
|
||||
let config = AtxLedConfig {
|
||||
enabled: true,
|
||||
gpio_chip: "/dev/gpiochip0".to_string(),
|
||||
gpio_pin: 7,
|
||||
inverted: false,
|
||||
};
|
||||
let sensor = LedSensor::new(config);
|
||||
assert!(sensor.is_configured());
|
||||
assert!(!sensor.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_led_sensor_inverted_config() {
|
||||
let config = AtxLedConfig {
|
||||
enabled: true,
|
||||
gpio_chip: "/dev/gpiochip0".to_string(),
|
||||
gpio_pin: 7,
|
||||
inverted: true,
|
||||
};
|
||||
let sensor = LedSensor::new(config);
|
||||
assert!(sensor.is_configured());
|
||||
assert!(sensor.config.inverted);
|
||||
}
|
||||
}
|
||||
107
src/atx/mod.rs
Normal file
107
src/atx/mod.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! ATX Power Control Module
|
||||
//!
|
||||
//! Provides ATX power management functionality for IP-KVM.
|
||||
//! Supports flexible hardware binding with independent configuration for each action.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - Power button control (short press for on/graceful shutdown, long press for force off)
|
||||
//! - Reset button control
|
||||
//! - Power status monitoring via LED sensing (GPIO only)
|
||||
//! - Independent hardware binding for each action (GPIO or USB relay)
|
||||
//! - Hot-reload configuration support
|
||||
//!
|
||||
//! # Hardware Support
|
||||
//!
|
||||
//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control
|
||||
//! - **USB Relay**: Uses HID USB relay modules for isolated switching
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use one_kvm::atx::{AtxController, AtxControllerConfig, AtxKeyConfig, AtxDriverType, ActiveLevel};
|
||||
//!
|
||||
//! let config = AtxControllerConfig {
|
||||
//! enabled: true,
|
||||
//! power: AtxKeyConfig {
|
||||
//! driver: AtxDriverType::Gpio,
|
||||
//! device: "/dev/gpiochip0".to_string(),
|
||||
//! pin: 5,
|
||||
//! active_level: ActiveLevel::High,
|
||||
//! },
|
||||
//! reset: AtxKeyConfig {
|
||||
//! driver: AtxDriverType::UsbRelay,
|
||||
//! device: "/dev/hidraw0".to_string(),
|
||||
//! pin: 0,
|
||||
//! active_level: ActiveLevel::High,
|
||||
//! },
|
||||
//! led: Default::default(),
|
||||
//! };
|
||||
//!
|
||||
//! let controller = AtxController::new(config);
|
||||
//! controller.init().await?;
|
||||
//! controller.power_short().await?; // Turn on or graceful shutdown
|
||||
//! ```
|
||||
|
||||
mod controller;
|
||||
mod executor;
|
||||
mod led;
|
||||
mod types;
|
||||
mod wol;
|
||||
|
||||
pub use controller::{AtxController, AtxControllerConfig};
|
||||
pub use executor::timing;
|
||||
pub use types::{
|
||||
ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig,
|
||||
AtxPowerRequest, AtxState, PowerStatus,
|
||||
};
|
||||
pub use wol::send_wol;
|
||||
|
||||
/// Discover available ATX devices on the system
|
||||
///
|
||||
/// Scans for GPIO chips and USB HID relay devices in a single pass.
|
||||
pub fn discover_devices() -> AtxDevices {
|
||||
let mut devices = AtxDevices::default();
|
||||
|
||||
// Single pass through /dev directory
|
||||
if let Ok(entries) = std::fs::read_dir("/dev") {
|
||||
for entry in entries.flatten() {
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
if name_str.starts_with("gpiochip") {
|
||||
devices.gpio_chips.push(format!("/dev/{}", name_str));
|
||||
} else if name_str.starts_with("hidraw") {
|
||||
devices.usb_relays.push(format!("/dev/{}", name_str));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
devices.gpio_chips.sort();
|
||||
devices.usb_relays.sort();
|
||||
|
||||
devices
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_discover_devices() {
|
||||
let devices = discover_devices();
|
||||
// Just verify the function runs without error
|
||||
assert!(devices.gpio_chips.len() >= 0);
|
||||
assert!(devices.usb_relays.len() >= 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_exports() {
|
||||
// Verify all public exports are accessible
|
||||
let _: AtxDriverType = AtxDriverType::None;
|
||||
let _: ActiveLevel = ActiveLevel::High;
|
||||
let _: AtxKeyConfig = AtxKeyConfig::default();
|
||||
let _: AtxLedConfig = AtxLedConfig::default();
|
||||
let _: AtxState = AtxState::default();
|
||||
let _: AtxDevices = AtxDevices::default();
|
||||
}
|
||||
}
|
||||
270
src/atx/types.rs
Normal file
270
src/atx/types.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
//! ATX data types and structures
|
||||
//!
|
||||
//! Defines the configuration and state types for the flexible ATX power control system.
|
||||
//! Each ATX action (power, reset) can be independently configured with different hardware.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
/// Power status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PowerStatus {
|
||||
/// Power is on
|
||||
On,
|
||||
/// Power is off
|
||||
Off,
|
||||
/// Power status unknown (no LED connected)
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Default for PowerStatus {
|
||||
fn default() -> Self {
|
||||
Self::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
/// Driver type for ATX key operations
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AtxDriverType {
|
||||
/// GPIO control via Linux character device
|
||||
Gpio,
|
||||
/// USB HID relay module
|
||||
UsbRelay,
|
||||
/// Disabled / Not configured
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for AtxDriverType {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Active level for GPIO pins
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ActiveLevel {
|
||||
/// Active high (default for most cases)
|
||||
High,
|
||||
/// Active low (inverted)
|
||||
Low,
|
||||
}
|
||||
|
||||
impl Default for ActiveLevel {
|
||||
fn default() -> Self {
|
||||
Self::High
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a single ATX key (power or reset)
|
||||
/// This is the "four-tuple" configuration: (driver, device, pin/channel, level)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(default)]
|
||||
pub struct AtxKeyConfig {
|
||||
/// Driver type (GPIO or USB Relay)
|
||||
pub driver: AtxDriverType,
|
||||
/// Device path:
|
||||
/// - For GPIO: /dev/gpiochipX
|
||||
/// - For USB Relay: /dev/hidrawX
|
||||
pub device: String,
|
||||
/// Pin or channel number:
|
||||
/// - For GPIO: GPIO pin number
|
||||
/// - For USB Relay: relay channel (0-based)
|
||||
pub pin: u32,
|
||||
/// Active level (only applicable to GPIO, ignored for USB Relay)
|
||||
pub active_level: ActiveLevel,
|
||||
}
|
||||
|
||||
impl Default for AtxKeyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
driver: AtxDriverType::None,
|
||||
device: String::new(),
|
||||
pin: 0,
|
||||
active_level: ActiveLevel::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AtxKeyConfig {
|
||||
/// Check if this key is configured
|
||||
pub fn is_configured(&self) -> bool {
|
||||
self.driver != AtxDriverType::None && !self.device.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// LED sensing configuration (optional)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(default)]
|
||||
pub struct AtxLedConfig {
|
||||
/// Whether LED sensing is enabled
|
||||
pub enabled: bool,
|
||||
/// GPIO chip for LED sensing
|
||||
pub gpio_chip: String,
|
||||
/// GPIO pin for LED input
|
||||
pub gpio_pin: u32,
|
||||
/// Whether LED is active low (inverted logic)
|
||||
pub inverted: bool,
|
||||
}
|
||||
|
||||
impl Default for AtxLedConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
gpio_chip: String::new(),
|
||||
gpio_pin: 0,
|
||||
inverted: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AtxLedConfig {
|
||||
/// Check if LED sensing is configured
|
||||
pub fn is_configured(&self) -> bool {
|
||||
self.enabled && !self.gpio_chip.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// ATX state information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AtxState {
|
||||
/// Whether ATX feature is available/enabled
|
||||
pub available: bool,
|
||||
/// Whether power button is configured
|
||||
pub power_configured: bool,
|
||||
/// Whether reset button is configured
|
||||
pub reset_configured: bool,
|
||||
/// Current power status
|
||||
pub power_status: PowerStatus,
|
||||
/// Whether power LED sensing is supported
|
||||
pub led_supported: bool,
|
||||
}
|
||||
|
||||
impl Default for AtxState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
available: false,
|
||||
power_configured: false,
|
||||
reset_configured: false,
|
||||
power_status: PowerStatus::Unknown,
|
||||
led_supported: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ATX power action request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AtxPowerRequest {
|
||||
/// Action to perform: "short", "long", "reset"
|
||||
pub action: AtxAction,
|
||||
}
|
||||
|
||||
/// ATX power action
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AtxAction {
|
||||
/// Short press power button (turn on or graceful shutdown)
|
||||
Short,
|
||||
/// Long press power button (force power off)
|
||||
Long,
|
||||
/// Press reset button
|
||||
Reset,
|
||||
}
|
||||
|
||||
/// Available ATX devices for discovery
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AtxDevices {
|
||||
/// Available GPIO chips (/dev/gpiochip*)
|
||||
pub gpio_chips: Vec<String>,
|
||||
/// Available USB HID relay devices (/dev/hidraw*)
|
||||
pub usb_relays: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for AtxDevices {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gpio_chips: Vec::new(),
|
||||
usb_relays: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_power_status_default() {
|
||||
assert_eq!(PowerStatus::default(), PowerStatus::Unknown);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atx_driver_type_default() {
|
||||
assert_eq!(AtxDriverType::default(), AtxDriverType::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_active_level_default() {
|
||||
assert_eq!(ActiveLevel::default(), ActiveLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atx_key_config_default() {
|
||||
let config = AtxKeyConfig::default();
|
||||
assert_eq!(config.driver, AtxDriverType::None);
|
||||
assert!(config.device.is_empty());
|
||||
assert_eq!(config.pin, 0);
|
||||
assert!(!config.is_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atx_key_config_is_configured() {
|
||||
let mut config = AtxKeyConfig::default();
|
||||
assert!(!config.is_configured());
|
||||
|
||||
config.driver = AtxDriverType::Gpio;
|
||||
assert!(!config.is_configured()); // device still empty
|
||||
|
||||
config.device = "/dev/gpiochip0".to_string();
|
||||
assert!(config.is_configured());
|
||||
|
||||
config.driver = AtxDriverType::None;
|
||||
assert!(!config.is_configured()); // driver is None
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atx_led_config_default() {
|
||||
let config = AtxLedConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert!(config.gpio_chip.is_empty());
|
||||
assert!(!config.is_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atx_led_config_is_configured() {
|
||||
let mut config = AtxLedConfig::default();
|
||||
assert!(!config.is_configured());
|
||||
|
||||
config.enabled = true;
|
||||
assert!(!config.is_configured()); // gpio_chip still empty
|
||||
|
||||
config.gpio_chip = "/dev/gpiochip0".to_string();
|
||||
assert!(config.is_configured());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atx_state_default() {
|
||||
let state = AtxState::default();
|
||||
assert!(!state.available);
|
||||
assert!(!state.power_configured);
|
||||
assert!(!state.reset_configured);
|
||||
assert_eq!(state.power_status, PowerStatus::Unknown);
|
||||
}
|
||||
}
|
||||
171
src/atx/wol.rs
Normal file
171
src/atx/wol.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
//! Wake-on-LAN (WOL) implementation
|
||||
//!
|
||||
//! Sends magic packets to wake up remote machines.
|
||||
|
||||
use std::net::{SocketAddr, UdpSocket};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// WOL magic packet structure:
|
||||
/// - 6 bytes of 0xFF
|
||||
/// - 16 repetitions of the target MAC address (6 bytes each)
|
||||
/// Total: 6 + 16 * 6 = 102 bytes
|
||||
const MAGIC_PACKET_SIZE: usize = 102;
|
||||
|
||||
/// Parse MAC address string into bytes
|
||||
/// Supports formats: "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF"
|
||||
fn parse_mac_address(mac: &str) -> Result<[u8; 6]> {
|
||||
let mac = mac.trim().to_uppercase();
|
||||
let parts: Vec<&str> = if mac.contains(':') {
|
||||
mac.split(':').collect()
|
||||
} else if mac.contains('-') {
|
||||
mac.split('-').collect()
|
||||
} else {
|
||||
return Err(AppError::Config(format!("Invalid MAC address format: {}", mac)));
|
||||
};
|
||||
|
||||
if parts.len() != 6 {
|
||||
return Err(AppError::Config(format!(
|
||||
"Invalid MAC address: expected 6 parts, got {}",
|
||||
parts.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut bytes = [0u8; 6];
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
bytes[i] = u8::from_str_radix(part, 16).map_err(|_| {
|
||||
AppError::Config(format!("Invalid MAC address byte: {}", part))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
/// Build WOL magic packet
|
||||
fn build_magic_packet(mac: &[u8; 6]) -> [u8; MAGIC_PACKET_SIZE] {
|
||||
let mut packet = [0u8; MAGIC_PACKET_SIZE];
|
||||
|
||||
// First 6 bytes are 0xFF
|
||||
for byte in packet.iter_mut().take(6) {
|
||||
*byte = 0xFF;
|
||||
}
|
||||
|
||||
// Next 96 bytes are 16 repetitions of the MAC address
|
||||
for i in 0..16 {
|
||||
let offset = 6 + i * 6;
|
||||
packet[offset..offset + 6].copy_from_slice(mac);
|
||||
}
|
||||
|
||||
packet
|
||||
}
|
||||
|
||||
/// Send WOL magic packet
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `mac_address` - Target MAC address (e.g., "AA:BB:CC:DD:EE:FF")
|
||||
/// * `interface` - Optional network interface name (e.g., "eth0"). If None, uses default routing.
|
||||
pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
|
||||
let mac = parse_mac_address(mac_address)?;
|
||||
let packet = build_magic_packet(&mac);
|
||||
|
||||
info!("Sending WOL packet to {} via {:?}", mac_address, interface);
|
||||
|
||||
// Create UDP socket
|
||||
let socket = UdpSocket::bind("0.0.0.0:0")
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create UDP socket: {}", e)))?;
|
||||
|
||||
// Enable broadcast
|
||||
socket
|
||||
.set_broadcast(true)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to enable broadcast: {}", e)))?;
|
||||
|
||||
// Bind to specific interface if specified
|
||||
#[cfg(target_os = "linux")]
|
||||
if let Some(iface) = interface {
|
||||
if !iface.is_empty() {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
let fd = socket.as_raw_fd();
|
||||
let iface_bytes = iface.as_bytes();
|
||||
|
||||
// SO_BINDTODEVICE requires interface name as null-terminated string
|
||||
let mut iface_buf = [0u8; 16]; // IFNAMSIZ is typically 16
|
||||
let len = iface_bytes.len().min(15);
|
||||
iface_buf[..len].copy_from_slice(&iface_bytes[..len]);
|
||||
|
||||
let ret = unsafe {
|
||||
libc::setsockopt(
|
||||
fd,
|
||||
libc::SOL_SOCKET,
|
||||
libc::SO_BINDTODEVICE,
|
||||
iface_buf.as_ptr() as *const libc::c_void,
|
||||
(len + 1) as libc::socklen_t,
|
||||
)
|
||||
};
|
||||
|
||||
if ret < 0 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
return Err(AppError::Internal(format!(
|
||||
"Failed to bind to interface {}: {}",
|
||||
iface, err
|
||||
)));
|
||||
}
|
||||
debug!("Bound to interface: {}", iface);
|
||||
}
|
||||
}
|
||||
|
||||
// Send to broadcast address on port 9 (discard protocol, commonly used for WOL)
|
||||
let broadcast_addr: SocketAddr = "255.255.255.255:9".parse().unwrap();
|
||||
|
||||
socket
|
||||
.send_to(&packet, broadcast_addr)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to send WOL packet: {}", e)))?;
|
||||
|
||||
// Also try sending to port 7 (echo protocol, alternative WOL port)
|
||||
let broadcast_addr_7: SocketAddr = "255.255.255.255:7".parse().unwrap();
|
||||
let _ = socket.send_to(&packet, broadcast_addr_7);
|
||||
|
||||
info!("WOL packet sent successfully to {}", mac_address);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_mac_address_colon() {
|
||||
let mac = parse_mac_address("AA:BB:CC:DD:EE:FF").unwrap();
|
||||
assert_eq!(mac, [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mac_address_dash() {
|
||||
let mac = parse_mac_address("aa-bb-cc-dd-ee-ff").unwrap();
|
||||
assert_eq!(mac, [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mac_address_invalid() {
|
||||
assert!(parse_mac_address("invalid").is_err());
|
||||
assert!(parse_mac_address("AA:BB:CC:DD:EE").is_err());
|
||||
assert!(parse_mac_address("AA:BB:CC:DD:EE:GG").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_magic_packet() {
|
||||
let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF];
|
||||
let packet = build_magic_packet(&mac);
|
||||
|
||||
// Check header (6 bytes of 0xFF)
|
||||
for i in 0..6 {
|
||||
assert_eq!(packet[i], 0xFF);
|
||||
}
|
||||
|
||||
// Check MAC repetitions
|
||||
for i in 0..16 {
|
||||
let offset = 6 + i * 6;
|
||||
assert_eq!(&packet[offset..offset + 6], &mac);
|
||||
}
|
||||
}
|
||||
}
|
||||
390
src/audio/capture.rs
Normal file
390
src/audio/capture.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
//! ALSA audio capture implementation
|
||||
|
||||
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
|
||||
use alsa::{Direction, ValueOr, PCM};
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{broadcast, watch, Mutex};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::device::AudioDeviceInfo;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Audio capture configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioConfig {
|
||||
/// ALSA device name (e.g., "hw:0,0" or "default")
|
||||
pub device_name: String,
|
||||
/// Sample rate in Hz
|
||||
pub sample_rate: u32,
|
||||
/// Number of channels (1 = mono, 2 = stereo)
|
||||
pub channels: u32,
|
||||
/// Samples per frame (for Opus, typically 480 for 10ms at 48kHz)
|
||||
pub frame_size: u32,
|
||||
/// Buffer size in frames
|
||||
pub buffer_frames: u32,
|
||||
/// Period size in frames
|
||||
pub period_frames: u32,
|
||||
}
|
||||
|
||||
impl Default for AudioConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_name: "default".to_string(),
|
||||
sample_rate: 48000,
|
||||
channels: 2,
|
||||
frame_size: 960, // 20ms at 48kHz (good for Opus)
|
||||
buffer_frames: 4096,
|
||||
period_frames: 960,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AudioConfig {
|
||||
/// Create config for a specific device
|
||||
pub fn for_device(device: &AudioDeviceInfo) -> Self {
|
||||
let sample_rate = if device.sample_rates.contains(&48000) {
|
||||
48000
|
||||
} else {
|
||||
*device.sample_rates.first().unwrap_or(&48000)
|
||||
};
|
||||
|
||||
let channels = if device.channels.contains(&2) {
|
||||
2
|
||||
} else {
|
||||
*device.channels.first().unwrap_or(&2)
|
||||
};
|
||||
|
||||
Self {
|
||||
device_name: device.name.clone(),
|
||||
sample_rate,
|
||||
channels,
|
||||
frame_size: sample_rate / 50, // 20ms
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Bytes per sample (16-bit signed)
|
||||
pub fn bytes_per_sample(&self) -> u32 {
|
||||
2 * self.channels
|
||||
}
|
||||
|
||||
/// Bytes per frame
|
||||
pub fn bytes_per_frame(&self) -> usize {
|
||||
(self.frame_size * self.bytes_per_sample()) as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio frame data
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioFrame {
|
||||
/// Raw PCM data (S16LE interleaved)
|
||||
pub data: Bytes,
|
||||
/// Sample rate
|
||||
pub sample_rate: u32,
|
||||
/// Number of channels
|
||||
pub channels: u32,
|
||||
/// Number of samples per channel
|
||||
pub samples: u32,
|
||||
/// Frame sequence number
|
||||
pub sequence: u64,
|
||||
/// Capture timestamp
|
||||
pub timestamp: Instant,
|
||||
}
|
||||
|
||||
impl AudioFrame {
|
||||
pub fn new(data: Bytes, config: &AudioConfig, sequence: u64) -> Self {
|
||||
Self {
|
||||
samples: data.len() as u32 / config.bytes_per_sample(),
|
||||
data,
|
||||
sample_rate: config.sample_rate,
|
||||
channels: config.channels,
|
||||
sequence,
|
||||
timestamp: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio capture state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CaptureState {
|
||||
Stopped,
|
||||
Running,
|
||||
Error,
|
||||
}
|
||||
|
||||
/// Audio capture statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AudioStats {
|
||||
pub frames_captured: u64,
|
||||
pub frames_dropped: u64,
|
||||
pub buffer_overruns: u64,
|
||||
pub current_latency_ms: f32,
|
||||
}
|
||||
|
||||
/// ALSA audio capturer
|
||||
pub struct AudioCapturer {
|
||||
config: AudioConfig,
|
||||
state: Arc<watch::Sender<CaptureState>>,
|
||||
state_rx: watch::Receiver<CaptureState>,
|
||||
stats: Arc<Mutex<AudioStats>>,
|
||||
frame_tx: broadcast::Sender<AudioFrame>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
sequence: Arc<AtomicU64>,
|
||||
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
impl AudioCapturer {
|
||||
/// Create a new audio capturer
|
||||
pub fn new(config: AudioConfig) -> Self {
|
||||
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
|
||||
let (frame_tx, _) = broadcast::channel(32);
|
||||
|
||||
Self {
|
||||
config,
|
||||
state: Arc::new(state_tx),
|
||||
state_rx,
|
||||
stats: Arc::new(Mutex::new(AudioStats::default())),
|
||||
frame_tx,
|
||||
stop_flag: Arc::new(AtomicBool::new(false)),
|
||||
sequence: Arc::new(AtomicU64::new(0)),
|
||||
capture_handle: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub fn state(&self) -> CaptureState {
|
||||
*self.state_rx.borrow()
|
||||
}
|
||||
|
||||
/// Subscribe to state changes
|
||||
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
|
||||
self.state_rx.clone()
|
||||
}
|
||||
|
||||
/// Subscribe to audio frames
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
|
||||
self.frame_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get statistics
|
||||
pub async fn stats(&self) -> AudioStats {
|
||||
self.stats.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Start capturing
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
if self.state() == CaptureState::Running {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Starting audio capture on {} at {}Hz {}ch",
|
||||
self.config.device_name, self.config.sample_rate, self.config.channels
|
||||
);
|
||||
|
||||
self.stop_flag.store(false, Ordering::SeqCst);
|
||||
|
||||
let config = self.config.clone();
|
||||
let state = self.state.clone();
|
||||
let stats = self.stats.clone();
|
||||
let frame_tx = self.frame_tx.clone();
|
||||
let stop_flag = self.stop_flag.clone();
|
||||
let sequence = self.sequence.clone();
|
||||
|
||||
let handle = tokio::task::spawn_blocking(move || {
|
||||
capture_loop(config, state, stats, frame_tx, stop_flag, sequence);
|
||||
});
|
||||
|
||||
*self.capture_handle.lock().await = Some(handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop capturing
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
info!("Stopping audio capture");
|
||||
self.stop_flag.store(true, Ordering::SeqCst);
|
||||
|
||||
if let Some(handle) = self.capture_handle.lock().await.take() {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
let _ = self.state.send(CaptureState::Stopped);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if running
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.state() == CaptureState::Running
|
||||
}
|
||||
}
|
||||
|
||||
/// Main capture loop
|
||||
fn capture_loop(
|
||||
config: AudioConfig,
|
||||
state: Arc<watch::Sender<CaptureState>>,
|
||||
stats: Arc<Mutex<AudioStats>>,
|
||||
frame_tx: broadcast::Sender<AudioFrame>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
sequence: Arc<AtomicU64>,
|
||||
) {
|
||||
let result = run_capture(&config, &state, &stats, &frame_tx, &stop_flag, &sequence);
|
||||
|
||||
if let Err(e) = result {
|
||||
error!("Audio capture error: {}", e);
|
||||
let _ = state.send(CaptureState::Error);
|
||||
} else {
|
||||
let _ = state.send(CaptureState::Stopped);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_capture(
|
||||
config: &AudioConfig,
|
||||
state: &watch::Sender<CaptureState>,
|
||||
stats: &Arc<Mutex<AudioStats>>,
|
||||
frame_tx: &broadcast::Sender<AudioFrame>,
|
||||
stop_flag: &AtomicBool,
|
||||
sequence: &AtomicU64,
|
||||
) -> Result<()> {
|
||||
// Open ALSA device
|
||||
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
|
||||
AppError::AudioError(format!(
|
||||
"Failed to open audio device {}: {}",
|
||||
config.device_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
// Configure hardware parameters
|
||||
{
|
||||
let hwp = HwParams::any(&pcm).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to get HwParams: {}", e))
|
||||
})?;
|
||||
|
||||
hwp.set_channels(config.channels).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to set channels: {}", e))
|
||||
})?;
|
||||
|
||||
hwp.set_rate(config.sample_rate, ValueOr::Nearest).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to set sample rate: {}", e))
|
||||
})?;
|
||||
|
||||
hwp.set_format(Format::s16()).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to set format: {}", e))
|
||||
})?;
|
||||
|
||||
hwp.set_access(Access::RWInterleaved).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to set access: {}", e))
|
||||
})?;
|
||||
|
||||
hwp.set_buffer_size_near(config.buffer_frames as Frames).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to set buffer size: {}", e))
|
||||
})?;
|
||||
|
||||
hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest)
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?;
|
||||
|
||||
pcm.hw_params(&hwp).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to apply hw params: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Get actual configuration
|
||||
let actual_rate = pcm.hw_params_current()
|
||||
.map(|h| h.get_rate().unwrap_or(config.sample_rate))
|
||||
.unwrap_or(config.sample_rate);
|
||||
|
||||
info!(
|
||||
"Audio capture configured: {}Hz {}ch (requested {}Hz)",
|
||||
actual_rate, config.channels, config.sample_rate
|
||||
);
|
||||
|
||||
// Prepare for capture
|
||||
pcm.prepare().map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to prepare PCM: {}", e))
|
||||
})?;
|
||||
|
||||
let _ = state.send(CaptureState::Running);
|
||||
|
||||
// Allocate buffer - use u8 directly for zero-copy
|
||||
let frame_bytes = config.bytes_per_frame();
|
||||
let mut buffer = vec![0u8; frame_bytes];
|
||||
|
||||
// Capture loop
|
||||
while !stop_flag.load(Ordering::Relaxed) {
|
||||
// Check PCM state
|
||||
match pcm.state() {
|
||||
State::XRun => {
|
||||
warn!("Audio buffer overrun, recovering");
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.buffer_overruns += 1;
|
||||
}
|
||||
let _ = pcm.prepare();
|
||||
continue;
|
||||
}
|
||||
State::Suspended => {
|
||||
warn!("Audio device suspended, recovering");
|
||||
let _ = pcm.resume();
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Get IO handle and read audio data directly as bytes
|
||||
// Note: Use io() instead of io_checked() because USB audio devices
|
||||
// typically don't support mmap, which io_checked() requires
|
||||
let io: IO<u8> = pcm.io_bytes();
|
||||
|
||||
match io.readi(&mut buffer) {
|
||||
Ok(frames_read) => {
|
||||
if frames_read == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate actual byte count
|
||||
let byte_count = frames_read * config.channels as usize * 2;
|
||||
|
||||
// Directly use the buffer slice (already in correct byte format)
|
||||
let seq = sequence.fetch_add(1, Ordering::Relaxed);
|
||||
let frame = AudioFrame::new(
|
||||
Bytes::copy_from_slice(&buffer[..byte_count]),
|
||||
config,
|
||||
seq,
|
||||
);
|
||||
|
||||
// Send to subscribers
|
||||
if frame_tx.receiver_count() > 0 {
|
||||
if let Err(e) = frame_tx.send(frame) {
|
||||
debug!("No audio receivers: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Update stats
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.frames_captured += 1;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Check for buffer overrun (EPIPE = 32 on Linux)
|
||||
let desc = e.to_string();
|
||||
if desc.contains("EPIPE") || desc.contains("Broken pipe") {
|
||||
// Buffer overrun
|
||||
warn!("Audio buffer overrun");
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.buffer_overruns += 1;
|
||||
}
|
||||
let _ = pcm.prepare();
|
||||
} else {
|
||||
error!("Audio read error: {}", e);
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.frames_dropped += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Audio capture stopped");
|
||||
Ok(())
|
||||
}
|
||||
495
src/audio/controller.rs
Normal file
495
src/audio/controller.rs
Normal file
@@ -0,0 +1,495 @@
|
||||
//! Audio controller for high-level audio management
|
||||
//!
|
||||
//! Provides device enumeration, selection, quality control, and streaming management.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::info;
|
||||
|
||||
use super::capture::AudioConfig;
|
||||
use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo};
|
||||
use super::encoder::{OpusConfig, OpusFrame};
|
||||
use super::monitor::{AudioHealthMonitor, AudioHealthStatus};
|
||||
use super::streamer::{AudioStreamer, AudioStreamerConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
|
||||
/// Audio quality presets
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AudioQuality {
|
||||
/// Low bandwidth voice (32kbps)
|
||||
Voice,
|
||||
/// Balanced quality (64kbps) - default
|
||||
#[default]
|
||||
Balanced,
|
||||
/// High quality audio (128kbps)
|
||||
High,
|
||||
}
|
||||
|
||||
impl AudioQuality {
|
||||
/// Get the bitrate for this quality level
|
||||
pub fn bitrate(&self) -> u32 {
|
||||
match self {
|
||||
AudioQuality::Voice => 32000,
|
||||
AudioQuality::Balanced => 64000,
|
||||
AudioQuality::High => 128000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from string
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"voice" | "low" => AudioQuality::Voice,
|
||||
"high" | "music" => AudioQuality::High,
|
||||
_ => AudioQuality::Balanced,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to OpusConfig
|
||||
pub fn to_opus_config(&self) -> OpusConfig {
|
||||
match self {
|
||||
AudioQuality::Voice => OpusConfig::voice(),
|
||||
AudioQuality::Balanced => OpusConfig::default(),
|
||||
AudioQuality::High => OpusConfig::music(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AudioQuality {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AudioQuality::Voice => write!(f, "voice"),
|
||||
AudioQuality::Balanced => write!(f, "balanced"),
|
||||
AudioQuality::High => write!(f, "high"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio controller configuration
|
||||
///
|
||||
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
|
||||
/// These are optimal for Opus encoding and match WebRTC requirements.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioControllerConfig {
|
||||
/// Whether audio is enabled
|
||||
pub enabled: bool,
|
||||
/// Selected device name
|
||||
pub device: String,
|
||||
/// Audio quality preset
|
||||
pub quality: AudioQuality,
|
||||
}
|
||||
|
||||
impl Default for AudioControllerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
device: "default".to_string(),
|
||||
quality: AudioQuality::Balanced,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Current audio status
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AudioStatus {
|
||||
/// Whether audio feature is enabled
|
||||
pub enabled: bool,
|
||||
/// Whether audio is currently streaming
|
||||
pub streaming: bool,
|
||||
/// Currently selected device
|
||||
pub device: Option<String>,
|
||||
/// Current quality preset
|
||||
pub quality: AudioQuality,
|
||||
/// Number of connected subscribers
|
||||
pub subscriber_count: usize,
|
||||
/// Frames encoded
|
||||
pub frames_encoded: u64,
|
||||
/// Bytes output
|
||||
pub bytes_output: u64,
|
||||
/// Error message if any
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Audio controller
|
||||
///
|
||||
/// High-level interface for audio management, providing:
|
||||
/// - Device enumeration and selection
|
||||
/// - Quality control
|
||||
/// - Stream start/stop
|
||||
/// - Status reporting
|
||||
pub struct AudioController {
|
||||
config: RwLock<AudioControllerConfig>,
|
||||
streamer: RwLock<Option<Arc<AudioStreamer>>>,
|
||||
devices: RwLock<Vec<AudioDeviceInfo>>,
|
||||
event_bus: RwLock<Option<Arc<EventBus>>>,
|
||||
last_error: RwLock<Option<String>>,
|
||||
/// Health monitor for error tracking and recovery
|
||||
monitor: Arc<AudioHealthMonitor>,
|
||||
}
|
||||
|
||||
impl AudioController {
|
||||
/// Create a new audio controller with configuration
|
||||
pub fn new(config: AudioControllerConfig) -> Self {
|
||||
Self {
|
||||
config: RwLock::new(config),
|
||||
streamer: RwLock::new(None),
|
||||
devices: RwLock::new(Vec::new()),
|
||||
event_bus: RwLock::new(None),
|
||||
last_error: RwLock::new(None),
|
||||
monitor: Arc::new(AudioHealthMonitor::with_defaults()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set event bus for publishing audio events
|
||||
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
|
||||
*self.event_bus.write().await = Some(event_bus.clone());
|
||||
// Also set event bus on the monitor for health notifications
|
||||
self.monitor.set_event_bus(event_bus).await;
|
||||
}
|
||||
|
||||
/// Publish an event to the event bus
|
||||
async fn publish_event(&self, event: SystemEvent) {
|
||||
if let Some(ref bus) = *self.event_bus.read().await {
|
||||
bus.publish(event);
|
||||
}
|
||||
}
|
||||
|
||||
/// List available audio capture devices
|
||||
pub async fn list_devices(&self) -> Result<Vec<AudioDeviceInfo>> {
|
||||
// Get current device if streaming (it may be busy and unable to be opened)
|
||||
let current_device = if self.is_streaming().await {
|
||||
Some(self.config.read().await.device.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
|
||||
*self.devices.write().await = devices.clone();
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
/// Refresh device list and cache it
|
||||
pub async fn refresh_devices(&self) -> Result<()> {
|
||||
// Get current device if streaming (it may be busy and unable to be opened)
|
||||
let current_device = if self.is_streaming().await {
|
||||
Some(self.config.read().await.device.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
|
||||
*self.devices.write().await = devices;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get cached device list
|
||||
pub async fn get_cached_devices(&self) -> Vec<AudioDeviceInfo> {
|
||||
self.devices.read().await.clone()
|
||||
}
|
||||
|
||||
/// Select audio device
|
||||
pub async fn select_device(&self, device: &str) -> Result<()> {
|
||||
// Validate device exists
|
||||
let devices = self.list_devices().await?;
|
||||
let found = devices.iter().any(|d| d.name == device || d.description.contains(device));
|
||||
|
||||
if !found && device != "default" {
|
||||
return Err(AppError::AudioError(format!(
|
||||
"Audio device not found: {}",
|
||||
device
|
||||
)));
|
||||
}
|
||||
|
||||
// Update config
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.device = device.to_string();
|
||||
}
|
||||
|
||||
// Publish event
|
||||
self.publish_event(SystemEvent::AudioDeviceSelected {
|
||||
device: device.to_string(),
|
||||
})
|
||||
.await;
|
||||
|
||||
info!("Audio device selected: {}", device);
|
||||
|
||||
// If streaming, restart with new device
|
||||
if self.is_streaming().await {
|
||||
self.stop_streaming().await?;
|
||||
self.start_streaming().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set audio quality
|
||||
pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> {
|
||||
// Update config
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.quality = quality;
|
||||
}
|
||||
|
||||
// Update streamer if running
|
||||
if let Some(ref streamer) = *self.streamer.read().await {
|
||||
streamer.set_bitrate(quality.bitrate()).await?;
|
||||
}
|
||||
|
||||
// Publish event
|
||||
self.publish_event(SystemEvent::AudioQualityChanged {
|
||||
quality: quality.to_string(),
|
||||
})
|
||||
.await;
|
||||
|
||||
info!("Audio quality set to: {:?} ({}bps)", quality, quality.bitrate());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start audio streaming
|
||||
pub async fn start_streaming(&self) -> Result<()> {
|
||||
let config = self.config.read().await.clone();
|
||||
|
||||
if !config.enabled {
|
||||
return Err(AppError::AudioError("Audio is disabled".to_string()));
|
||||
}
|
||||
|
||||
// Check if already streaming
|
||||
if self.is_streaming().await {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Starting audio streaming with device: {}", config.device);
|
||||
|
||||
// Clear any previous error
|
||||
*self.last_error.write().await = None;
|
||||
|
||||
// Create streamer config (fixed 48kHz stereo)
|
||||
let streamer_config = AudioStreamerConfig {
|
||||
capture: AudioConfig {
|
||||
device_name: config.device.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
opus: config.quality.to_opus_config(),
|
||||
};
|
||||
|
||||
// Create and start streamer
|
||||
let streamer = Arc::new(AudioStreamer::with_config(streamer_config));
|
||||
|
||||
if let Err(e) = streamer.start().await {
|
||||
let error_msg = format!("Failed to start audio: {}", e);
|
||||
*self.last_error.write().await = Some(error_msg.clone());
|
||||
|
||||
// Report error to health monitor
|
||||
self.monitor
|
||||
.report_error(Some(&config.device), &error_msg, "start_failed")
|
||||
.await;
|
||||
|
||||
self.publish_event(SystemEvent::AudioStateChanged {
|
||||
streaming: false,
|
||||
device: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
return Err(AppError::AudioError(error_msg));
|
||||
}
|
||||
|
||||
*self.streamer.write().await = Some(streamer);
|
||||
|
||||
// Report recovery if we were in an error state
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered(Some(&config.device)).await;
|
||||
}
|
||||
|
||||
// Publish event
|
||||
self.publish_event(SystemEvent::AudioStateChanged {
|
||||
streaming: true,
|
||||
device: Some(config.device),
|
||||
})
|
||||
.await;
|
||||
|
||||
info!("Audio streaming started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop audio streaming
|
||||
pub async fn stop_streaming(&self) -> Result<()> {
|
||||
if let Some(streamer) = self.streamer.write().await.take() {
|
||||
streamer.stop().await?;
|
||||
}
|
||||
|
||||
// Publish event
|
||||
self.publish_event(SystemEvent::AudioStateChanged {
|
||||
streaming: false,
|
||||
device: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
info!("Audio streaming stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if currently streaming
|
||||
pub async fn is_streaming(&self) -> bool {
|
||||
if let Some(ref streamer) = *self.streamer.read().await {
|
||||
streamer.is_running()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current status
|
||||
pub async fn status(&self) -> AudioStatus {
|
||||
let config = self.config.read().await;
|
||||
let streaming = self.is_streaming().await;
|
||||
let error = self.last_error.read().await.clone();
|
||||
|
||||
let (subscriber_count, frames_encoded, bytes_output) = if let Some(ref streamer) =
|
||||
*self.streamer.read().await
|
||||
{
|
||||
let stats = streamer.stats().await;
|
||||
(stats.subscriber_count, stats.frames_encoded, stats.bytes_output)
|
||||
} else {
|
||||
(0, 0, 0)
|
||||
};
|
||||
|
||||
AudioStatus {
|
||||
enabled: config.enabled,
|
||||
streaming,
|
||||
device: if streaming || config.enabled {
|
||||
Some(config.device.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
quality: config.quality,
|
||||
subscriber_count,
|
||||
frames_encoded,
|
||||
bytes_output,
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to Opus frames (for WebSocket clients)
|
||||
pub fn subscribe_opus(&self) -> Option<broadcast::Receiver<OpusFrame>> {
|
||||
// Use try_read to avoid blocking - this is called from sync context sometimes
|
||||
if let Ok(guard) = self.streamer.try_read() {
|
||||
guard.as_ref().map(|s| s.subscribe_opus())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to Opus frames (async version)
|
||||
pub async fn subscribe_opus_async(&self) -> Option<broadcast::Receiver<OpusFrame>> {
|
||||
self.streamer.read().await.as_ref().map(|s| s.subscribe_opus())
|
||||
}
|
||||
|
||||
/// Enable or disable audio
|
||||
pub async fn set_enabled(&self, enabled: bool) -> Result<()> {
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.enabled = enabled;
|
||||
}
|
||||
|
||||
if !enabled && self.is_streaming().await {
|
||||
self.stop_streaming().await?;
|
||||
}
|
||||
|
||||
info!("Audio enabled: {}", enabled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update full configuration
|
||||
pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> {
|
||||
let was_streaming = self.is_streaming().await;
|
||||
let old_config = self.config.read().await.clone();
|
||||
|
||||
// Stop streaming if running
|
||||
if was_streaming {
|
||||
self.stop_streaming().await?;
|
||||
}
|
||||
|
||||
// Update config
|
||||
*self.config.write().await = new_config.clone();
|
||||
|
||||
// Restart streaming if it was running and still enabled
|
||||
if was_streaming && new_config.enabled {
|
||||
self.start_streaming().await?;
|
||||
}
|
||||
|
||||
// Publish events for changes
|
||||
if old_config.device != new_config.device {
|
||||
self.publish_event(SystemEvent::AudioDeviceSelected {
|
||||
device: new_config.device.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
if old_config.quality != new_config.quality {
|
||||
self.publish_event(SystemEvent::AudioQualityChanged {
|
||||
quality: new_config.quality.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the controller
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
self.stop_streaming().await
|
||||
}
|
||||
|
||||
/// Get the health monitor reference
|
||||
pub fn monitor(&self) -> &Arc<AudioHealthMonitor> {
|
||||
&self.monitor
|
||||
}
|
||||
|
||||
/// Get current health status
|
||||
pub async fn health_status(&self) -> AudioHealthStatus {
|
||||
self.monitor.status().await
|
||||
}
|
||||
|
||||
/// Check if the audio is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.monitor.is_healthy().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AudioController {
|
||||
fn default() -> Self {
|
||||
Self::new(AudioControllerConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_audio_quality_bitrate() {
|
||||
assert_eq!(AudioQuality::Voice.bitrate(), 32000);
|
||||
assert_eq!(AudioQuality::Balanced.bitrate(), 64000);
|
||||
assert_eq!(AudioQuality::High.bitrate(), 128000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_audio_quality_from_str() {
|
||||
assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice);
|
||||
assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice);
|
||||
assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced);
|
||||
assert_eq!(AudioQuality::from_str("high"), AudioQuality::High);
|
||||
assert_eq!(AudioQuality::from_str("music"), AudioQuality::High);
|
||||
assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_controller_default() {
|
||||
let controller = AudioController::default();
|
||||
let status = controller.status().await;
|
||||
assert!(!status.enabled);
|
||||
assert!(!status.streaming);
|
||||
}
|
||||
}
|
||||
234
src/audio/device.rs
Normal file
234
src/audio/device.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
//! Audio device enumeration using ALSA
|
||||
|
||||
use alsa::pcm::HwParams;
|
||||
use alsa::{Direction, PCM};
|
||||
use serde::Serialize;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Audio device information
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AudioDeviceInfo {
|
||||
/// Device name (e.g., "hw:0,0" or "default")
|
||||
pub name: String,
|
||||
/// Human-readable description
|
||||
pub description: String,
|
||||
/// Card index
|
||||
pub card_index: i32,
|
||||
/// Device index
|
||||
pub device_index: i32,
|
||||
/// Supported sample rates
|
||||
pub sample_rates: Vec<u32>,
|
||||
/// Supported channel counts
|
||||
pub channels: Vec<u32>,
|
||||
/// Is this a capture device
|
||||
pub is_capture: bool,
|
||||
/// Is this an HDMI audio device (likely from capture card)
|
||||
pub is_hdmi: bool,
|
||||
}
|
||||
|
||||
impl AudioDeviceInfo {
|
||||
/// Get ALSA device name
|
||||
pub fn alsa_name(&self) -> String {
|
||||
format!("hw:{},{}", self.card_index, self.device_index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerate available audio capture devices
|
||||
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
|
||||
enumerate_audio_devices_with_current(None)
|
||||
}
|
||||
|
||||
/// Enumerate available audio capture devices, with option to include a currently-in-use device
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `current_device` - Optional device name that is currently in use. This device will be
|
||||
/// included in the list even if it cannot be opened (because it's already open by us).
|
||||
pub fn enumerate_audio_devices_with_current(
|
||||
current_device: Option<&str>,
|
||||
) -> Result<Vec<AudioDeviceInfo>> {
|
||||
let mut devices = Vec::new();
|
||||
|
||||
// Try to enumerate cards
|
||||
let cards = match alsa::card::Iter::new() {
|
||||
i => i,
|
||||
};
|
||||
|
||||
for card_result in cards {
|
||||
let card = match card_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
debug!("Error iterating card: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let card_index = card.get_index();
|
||||
let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string());
|
||||
let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone());
|
||||
|
||||
debug!("Found audio card {}: {}", card_index, card_longname);
|
||||
|
||||
// Check if this looks like an HDMI capture device
|
||||
let is_hdmi = card_longname.to_lowercase().contains("hdmi")
|
||||
|| card_longname.to_lowercase().contains("capture")
|
||||
|| card_longname.to_lowercase().contains("usb");
|
||||
|
||||
// Try to open each device on this card for capture
|
||||
for device_index in 0..8 {
|
||||
let device_name = format!("hw:{},{}", card_index, device_index);
|
||||
|
||||
// Check if this is the currently-in-use device
|
||||
let is_current_device = current_device == Some(device_name.as_str());
|
||||
|
||||
// Try to open for capture
|
||||
match PCM::new(&device_name, Direction::Capture, false) {
|
||||
Ok(pcm) => {
|
||||
// Query capabilities
|
||||
let (sample_rates, channels) = query_device_caps(&pcm);
|
||||
|
||||
if !sample_rates.is_empty() && !channels.is_empty() {
|
||||
devices.push(AudioDeviceInfo {
|
||||
name: device_name,
|
||||
description: format!("{} - Device {}", card_longname, device_index),
|
||||
card_index,
|
||||
device_index,
|
||||
sample_rates,
|
||||
channels,
|
||||
is_capture: true,
|
||||
is_hdmi,
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Device doesn't exist or can't be opened for capture
|
||||
// But if it's the current device, include it anyway (it's busy because we're using it)
|
||||
if is_current_device {
|
||||
debug!(
|
||||
"Device {} is busy (in use by us), adding with default caps",
|
||||
device_name
|
||||
);
|
||||
devices.push(AudioDeviceInfo {
|
||||
name: device_name,
|
||||
description: format!(
|
||||
"{} - Device {} (in use)",
|
||||
card_longname, device_index
|
||||
),
|
||||
card_index,
|
||||
device_index,
|
||||
// Use common default capabilities for HDMI capture devices
|
||||
sample_rates: vec![44100, 48000],
|
||||
channels: vec![2],
|
||||
is_capture: true,
|
||||
is_hdmi,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for "default" device
|
||||
if let Ok(pcm) = PCM::new("default", Direction::Capture, false) {
|
||||
let (sample_rates, channels) = query_device_caps(&pcm);
|
||||
if !sample_rates.is_empty() {
|
||||
devices.insert(
|
||||
0,
|
||||
AudioDeviceInfo {
|
||||
name: "default".to_string(),
|
||||
description: "Default Audio Device".to_string(),
|
||||
card_index: -1,
|
||||
device_index: -1,
|
||||
sample_rates,
|
||||
channels,
|
||||
is_capture: true,
|
||||
is_hdmi: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
info!("Found {} audio capture devices", devices.len());
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
/// Query device capabilities
|
||||
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
|
||||
let hwp = match HwParams::any(pcm) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return (vec![], vec![]),
|
||||
};
|
||||
|
||||
// Common sample rates to check
|
||||
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
|
||||
let mut supported_rates = Vec::new();
|
||||
|
||||
for rate in &common_rates {
|
||||
if hwp.test_rate(*rate).is_ok() {
|
||||
supported_rates.push(*rate);
|
||||
}
|
||||
}
|
||||
|
||||
// Check channel counts
|
||||
let mut supported_channels = Vec::new();
|
||||
for ch in 1..=8 {
|
||||
if hwp.test_channels(ch).is_ok() {
|
||||
supported_channels.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
(supported_rates, supported_channels)
|
||||
}
|
||||
|
||||
/// Find the best audio device for capture
|
||||
/// Prefers HDMI/capture devices over built-in microphones
|
||||
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
|
||||
let devices = enumerate_audio_devices()?;
|
||||
|
||||
if devices.is_empty() {
|
||||
return Err(AppError::AudioError("No audio capture devices found".to_string()));
|
||||
}
|
||||
|
||||
// First, look for HDMI/capture card devices that support 48kHz stereo
|
||||
for device in &devices {
|
||||
if device.is_hdmi
|
||||
&& device.sample_rates.contains(&48000)
|
||||
&& device.channels.contains(&2)
|
||||
{
|
||||
info!("Selected HDMI audio device: {}", device.description);
|
||||
return Ok(device.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Then look for any device supporting 48kHz stereo
|
||||
for device in &devices {
|
||||
if device.sample_rates.contains(&48000) && device.channels.contains(&2) {
|
||||
info!("Selected audio device: {}", device.description);
|
||||
return Ok(device.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to first device
|
||||
let device = devices.into_iter().next().unwrap();
|
||||
warn!(
|
||||
"Using fallback audio device: {} (may not support optimal settings)",
|
||||
device.description
|
||||
);
|
||||
Ok(device)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_enumerate_devices() {
|
||||
// This test may not find devices in CI environment
|
||||
let result = enumerate_audio_devices();
|
||||
println!("Audio devices: {:?}", result);
|
||||
// Just verify it doesn't panic
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
280
src/audio/encoder.rs
Normal file
280
src/audio/encoder.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
//! Opus audio encoder for WebRTC
|
||||
|
||||
use audiopus::coder::GenericCtl;
|
||||
use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate};
|
||||
use bytes::Bytes;
|
||||
use std::time::Instant;
|
||||
use tracing::{info, trace};
|
||||
|
||||
use super::capture::AudioFrame;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Opus encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpusConfig {
|
||||
/// Sample rate (must be 8000, 12000, 16000, 24000, or 48000)
|
||||
pub sample_rate: u32,
|
||||
/// Channels (1 or 2)
|
||||
pub channels: u32,
|
||||
/// Target bitrate in bps
|
||||
pub bitrate: u32,
|
||||
/// Application mode
|
||||
pub application: OpusApplication,
|
||||
/// Enable forward error correction
|
||||
pub fec: bool,
|
||||
}
|
||||
|
||||
impl Default for OpusConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sample_rate: 48000,
|
||||
channels: 2,
|
||||
bitrate: 64000, // 64 kbps
|
||||
application: OpusApplication::Audio,
|
||||
fec: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpusConfig {
|
||||
/// Create config for voice (lower latency)
|
||||
pub fn voice() -> Self {
|
||||
Self {
|
||||
application: OpusApplication::Voip,
|
||||
bitrate: 32000,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for music (higher quality)
|
||||
pub fn music() -> Self {
|
||||
Self {
|
||||
application: OpusApplication::Audio,
|
||||
bitrate: 128000,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn to_audiopus_sample_rate(&self) -> SampleRate {
|
||||
match self.sample_rate {
|
||||
8000 => SampleRate::Hz8000,
|
||||
12000 => SampleRate::Hz12000,
|
||||
16000 => SampleRate::Hz16000,
|
||||
24000 => SampleRate::Hz24000,
|
||||
_ => SampleRate::Hz48000,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_audiopus_channels(&self) -> Channels {
|
||||
if self.channels == 1 {
|
||||
Channels::Mono
|
||||
} else {
|
||||
Channels::Stereo
|
||||
}
|
||||
}
|
||||
|
||||
fn to_audiopus_application(&self) -> Application {
|
||||
match self.application {
|
||||
OpusApplication::Voip => Application::Voip,
|
||||
OpusApplication::Audio => Application::Audio,
|
||||
OpusApplication::LowDelay => Application::LowDelay,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Opus application mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OpusApplication {
|
||||
/// Voice over IP
|
||||
Voip,
|
||||
/// General audio
|
||||
Audio,
|
||||
/// Low delay mode
|
||||
LowDelay,
|
||||
}
|
||||
|
||||
/// Encoded Opus frame
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpusFrame {
|
||||
/// Encoded Opus data
|
||||
pub data: Bytes,
|
||||
/// Duration in milliseconds
|
||||
pub duration_ms: u32,
|
||||
/// Sequence number
|
||||
pub sequence: u64,
|
||||
/// Timestamp
|
||||
pub timestamp: Instant,
|
||||
/// RTP timestamp (samples)
|
||||
pub rtp_timestamp: u32,
|
||||
}
|
||||
|
||||
impl OpusFrame {
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Opus encoder
|
||||
pub struct OpusEncoder {
|
||||
config: OpusConfig,
|
||||
encoder: Encoder,
|
||||
/// Output buffer
|
||||
output_buffer: Vec<u8>,
|
||||
/// Frame counter for RTP timestamp
|
||||
frame_count: u64,
|
||||
/// Samples per frame
|
||||
samples_per_frame: u32,
|
||||
}
|
||||
|
||||
impl OpusEncoder {
|
||||
/// Create a new Opus encoder
|
||||
pub fn new(config: OpusConfig) -> Result<Self> {
|
||||
let sample_rate = config.to_audiopus_sample_rate();
|
||||
let channels = config.to_audiopus_channels();
|
||||
let application = config.to_audiopus_application();
|
||||
|
||||
let mut encoder = Encoder::new(sample_rate, channels, application).map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e))
|
||||
})?;
|
||||
|
||||
// Configure encoder
|
||||
encoder
|
||||
.set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32))
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
|
||||
|
||||
if config.fec {
|
||||
encoder
|
||||
.set_inband_fec(true)
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?;
|
||||
}
|
||||
|
||||
// Calculate samples per frame (20ms at sample_rate)
|
||||
let samples_per_frame = config.sample_rate / 50;
|
||||
|
||||
info!(
|
||||
"Opus encoder created: {}Hz {}ch {}bps",
|
||||
config.sample_rate, config.channels, config.bitrate
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
encoder,
|
||||
output_buffer: vec![0u8; 4000], // Max Opus frame size
|
||||
frame_count: 0,
|
||||
samples_per_frame,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Result<Self> {
|
||||
Self::new(OpusConfig::default())
|
||||
}
|
||||
|
||||
/// Encode PCM audio data (S16LE interleaved)
|
||||
pub fn encode(&mut self, pcm_data: &[i16]) -> Result<OpusFrame> {
|
||||
let encoded_len = self
|
||||
.encoder
|
||||
.encode(pcm_data, &mut self.output_buffer)
|
||||
.map_err(|e| AppError::AudioError(format!("Opus encode failed: {:?}", e)))?;
|
||||
|
||||
let samples = pcm_data.len() as u32 / self.config.channels;
|
||||
let duration_ms = (samples * 1000) / self.config.sample_rate;
|
||||
let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32;
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
trace!(
|
||||
"Encoded {} samples to {} bytes Opus",
|
||||
pcm_data.len(),
|
||||
encoded_len
|
||||
);
|
||||
|
||||
Ok(OpusFrame {
|
||||
data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]),
|
||||
duration_ms,
|
||||
sequence: self.frame_count - 1,
|
||||
timestamp: Instant::now(),
|
||||
rtp_timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode from AudioFrame
|
||||
///
|
||||
/// Uses zero-copy conversion from bytes to i16 samples via bytemuck.
|
||||
pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result<OpusFrame> {
|
||||
// Zero-copy: directly cast bytes to i16 slice
|
||||
// AudioFrame.data is S16LE format, which matches native little-endian i16
|
||||
let samples: &[i16] = bytemuck::cast_slice(&frame.data);
|
||||
self.encode(samples)
|
||||
}
|
||||
|
||||
/// Get encoder configuration
|
||||
pub fn config(&self) -> &OpusConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Reset encoder state
|
||||
pub fn reset(&mut self) -> Result<()> {
|
||||
self.encoder
|
||||
.reset_state()
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to reset encoder: {:?}", e)))?;
|
||||
self.frame_count = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set bitrate dynamically
|
||||
pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
|
||||
self.encoder
|
||||
.set_bitrate(Bitrate::BitsPerSecond(bitrate as i32))
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio encoder statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EncoderStats {
|
||||
pub frames_encoded: u64,
|
||||
pub bytes_output: u64,
|
||||
pub avg_frame_size: usize,
|
||||
pub current_bitrate: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_opus_config_default() {
|
||||
let config = OpusConfig::default();
|
||||
assert_eq!(config.sample_rate, 48000);
|
||||
assert_eq!(config.channels, 2);
|
||||
assert_eq!(config.bitrate, 64000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_encoder() {
|
||||
let config = OpusConfig::default();
|
||||
let encoder = OpusEncoder::new(config);
|
||||
assert!(encoder.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_silence() {
|
||||
let config = OpusConfig::default();
|
||||
let mut encoder = OpusEncoder::new(config).unwrap();
|
||||
|
||||
// 20ms of stereo silence at 48kHz
|
||||
let silence = vec![0i16; 960 * 2];
|
||||
let result = encoder.encode(&silence);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let frame = result.unwrap();
|
||||
assert!(!frame.is_empty());
|
||||
assert!(frame.len() < silence.len() * 2); // Should be compressed
|
||||
}
|
||||
}
|
||||
26
src/audio/mod.rs
Normal file
26
src/audio/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
//! Audio capture and encoding module
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - ALSA audio capture
|
||||
//! - Opus encoding for WebRTC
|
||||
//! - Audio device enumeration
|
||||
//! - Audio streaming pipeline
|
||||
//! - High-level audio controller
|
||||
//! - Shared audio pipeline for WebRTC multi-session support
|
||||
//! - Device health monitoring
|
||||
|
||||
pub mod capture;
|
||||
pub mod controller;
|
||||
pub mod device;
|
||||
pub mod encoder;
|
||||
pub mod monitor;
|
||||
pub mod shared_pipeline;
|
||||
pub mod streamer;
|
||||
|
||||
pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
|
||||
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
|
||||
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
|
||||
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
|
||||
pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig};
|
||||
pub use shared_pipeline::{SharedAudioPipeline, SharedAudioPipelineConfig, SharedAudioPipelineStats};
|
||||
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
|
||||
352
src/audio/monitor.rs
Normal file
352
src/audio/monitor.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
//! Audio device health monitoring
|
||||
//!
|
||||
//! This module provides health monitoring for audio capture devices, including:
|
||||
//! - Device connectivity checks
|
||||
//! - Automatic reconnection on failure
|
||||
//! - Error tracking and notification
|
||||
//! - Log throttling to prevent log flooding
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// Audio health status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AudioHealthStatus {
|
||||
/// Device is healthy and operational
|
||||
Healthy,
|
||||
/// Device has an error, attempting recovery
|
||||
Error {
|
||||
/// Human-readable error reason
|
||||
reason: String,
|
||||
/// Error code for programmatic handling
|
||||
error_code: String,
|
||||
/// Number of recovery attempts made
|
||||
retry_count: u32,
|
||||
},
|
||||
/// Device is disconnected or not available
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
impl Default for AudioHealthStatus {
|
||||
fn default() -> Self {
|
||||
Self::Healthy
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio health monitor configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioMonitorConfig {
|
||||
/// Retry interval when device is lost (milliseconds)
|
||||
pub retry_interval_ms: u64,
|
||||
/// Maximum retry attempts before giving up (0 = infinite)
|
||||
pub max_retries: u32,
|
||||
/// Log throttle interval in seconds
|
||||
pub log_throttle_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for AudioMonitorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
retry_interval_ms: 1000,
|
||||
max_retries: 0, // infinite retry
|
||||
log_throttle_secs: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio health monitor
|
||||
///
|
||||
/// Monitors audio device health and manages error recovery.
|
||||
/// Publishes WebSocket events when device status changes.
|
||||
pub struct AudioHealthMonitor {
|
||||
/// Current health status
|
||||
status: RwLock<AudioHealthStatus>,
|
||||
/// Event bus for notifications
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
/// Log throttler to prevent log flooding
|
||||
throttler: LogThrottler,
|
||||
/// Configuration
|
||||
config: AudioMonitorConfig,
|
||||
/// Whether monitoring is active (reserved for future use)
|
||||
#[allow(dead_code)]
|
||||
running: AtomicBool,
|
||||
/// Current retry count
|
||||
retry_count: AtomicU32,
|
||||
/// Last error code (for change detection)
|
||||
last_error_code: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl AudioHealthMonitor {
|
||||
/// Create a new audio health monitor with the specified configuration
|
||||
pub fn new(config: AudioMonitorConfig) -> Self {
|
||||
let throttle_secs = config.log_throttle_secs;
|
||||
Self {
|
||||
status: RwLock::new(AudioHealthStatus::Healthy),
|
||||
events: RwLock::new(None),
|
||||
throttler: LogThrottler::with_secs(throttle_secs),
|
||||
config,
|
||||
running: AtomicBool::new(false),
|
||||
retry_count: AtomicU32::new(0),
|
||||
last_error_code: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new audio health monitor with default configuration
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(AudioMonitorConfig::default())
|
||||
}
|
||||
|
||||
/// Set the event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Report an error from audio operations
|
||||
///
|
||||
/// This method is called when an audio operation fails. It:
|
||||
/// 1. Updates the health status
|
||||
/// 2. Logs the error (with throttling)
|
||||
/// 3. Publishes a WebSocket event if the error is new or changed
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The audio device name (if known)
|
||||
/// * `reason` - Human-readable error description
|
||||
/// * `error_code` - Error code for programmatic handling
|
||||
pub async fn report_error(&self, device: Option<&str>, reason: &str, error_code: &str) {
|
||||
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if error code changed
|
||||
let error_changed = {
|
||||
let last = self.last_error_code.read().await;
|
||||
last.as_ref().map(|s| s.as_str()) != Some(error_code)
|
||||
};
|
||||
|
||||
// Log with throttling (always log if error type changed)
|
||||
let throttle_key = format!("audio_{}", error_code);
|
||||
if error_changed || self.throttler.should_log(&throttle_key) {
|
||||
warn!(
|
||||
"Audio error: {} (code: {}, attempt: {})",
|
||||
reason, error_code, count
|
||||
);
|
||||
}
|
||||
|
||||
// Update last error code
|
||||
*self.last_error_code.write().await = Some(error_code.to_string());
|
||||
|
||||
// Update status
|
||||
*self.status.write().await = AudioHealthStatus::Error {
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
retry_count: count,
|
||||
};
|
||||
|
||||
// Publish event (only if error changed or first occurrence)
|
||||
if error_changed || count == 1 {
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::AudioDeviceLost {
|
||||
device: device.map(|s| s.to_string()),
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Report that a reconnection attempt is starting
|
||||
///
|
||||
/// Publishes a reconnecting event to notify clients.
|
||||
pub async fn report_reconnecting(&self) {
|
||||
let attempt = self.retry_count.load(Ordering::Relaxed);
|
||||
|
||||
// Only publish every 5 attempts to avoid event spam
|
||||
if attempt == 1 || attempt % 5 == 0 {
|
||||
debug!("Audio reconnecting, attempt {}", attempt);
|
||||
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::AudioReconnecting { attempt });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Report that the device has recovered
|
||||
///
|
||||
/// This method is called when the audio device successfully reconnects.
|
||||
/// It resets the error state and publishes a recovery event.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The audio device name
|
||||
pub async fn report_recovered(&self, device: Option<&str>) {
|
||||
let prev_status = self.status.read().await.clone();
|
||||
|
||||
// Only report recovery if we were in an error state
|
||||
if prev_status != AudioHealthStatus::Healthy {
|
||||
let retry_count = self.retry_count.load(Ordering::Relaxed);
|
||||
info!("Audio recovered after {} retries", retry_count);
|
||||
|
||||
// Reset state
|
||||
self.retry_count.store(0, Ordering::Relaxed);
|
||||
self.throttler.clear("audio_");
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = AudioHealthStatus::Healthy;
|
||||
|
||||
// Publish recovery event
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::AudioRecovered {
|
||||
device: device.map(|s| s.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current health status
|
||||
pub async fn status(&self) -> AudioHealthStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get the current retry count
|
||||
pub fn retry_count(&self) -> u32 {
|
||||
self.retry_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check if the monitor is in an error state
|
||||
pub async fn is_error(&self) -> bool {
|
||||
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
|
||||
}
|
||||
|
||||
/// Check if the monitor is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
matches!(*self.status.read().await, AudioHealthStatus::Healthy)
|
||||
}
|
||||
|
||||
/// Reset the monitor to healthy state without publishing events
|
||||
///
|
||||
/// This is useful during initialization.
|
||||
pub async fn reset(&self) {
|
||||
self.retry_count.store(0, Ordering::Relaxed);
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = AudioHealthStatus::Healthy;
|
||||
self.throttler.clear_all();
|
||||
}
|
||||
|
||||
/// Get the configuration
|
||||
pub fn config(&self) -> &AudioMonitorConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Check if we should continue retrying
|
||||
///
|
||||
/// Returns `false` if max_retries is set and we've exceeded it.
|
||||
pub fn should_retry(&self) -> bool {
|
||||
if self.config.max_retries == 0 {
|
||||
return true; // Infinite retry
|
||||
}
|
||||
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
|
||||
}
|
||||
|
||||
/// Get the retry interval
|
||||
pub fn retry_interval(&self) -> Duration {
|
||||
Duration::from_millis(self.config.retry_interval_ms)
|
||||
}
|
||||
|
||||
/// Get the current error message if in error state
|
||||
pub async fn error_message(&self) -> Option<String> {
|
||||
match &*self.status.read().await {
|
||||
AudioHealthStatus::Error { reason, .. } => Some(reason.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AudioHealthMonitor {
|
||||
fn default() -> Self {
|
||||
Self::with_defaults()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initial_status() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert!(!monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_error() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
|
||||
monitor
|
||||
.report_error(Some("hw:0,0"), "Device not found", "device_disconnected")
|
||||
.await;
|
||||
|
||||
assert!(monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 1);
|
||||
|
||||
if let AudioHealthStatus::Error {
|
||||
reason,
|
||||
error_code,
|
||||
retry_count,
|
||||
} = monitor.status().await
|
||||
{
|
||||
assert_eq!(reason, "Device not found");
|
||||
assert_eq!(error_code, "device_disconnected");
|
||||
assert_eq!(retry_count, 1);
|
||||
} else {
|
||||
panic!("Expected Error status");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_recovered() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
|
||||
// First report an error
|
||||
monitor
|
||||
.report_error(Some("default"), "Capture failed", "capture_error")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
// Then report recovery
|
||||
monitor.report_recovered(Some("default")).await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_count_increments() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
|
||||
for i in 1..=5 {
|
||||
monitor
|
||||
.report_error(None, "Error", "io_error")
|
||||
.await;
|
||||
assert_eq!(monitor.retry_count(), i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
|
||||
monitor
|
||||
.report_error(None, "Error", "io_error")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
monitor.reset().await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
}
|
||||
453
src/audio/shared_pipeline.rs
Normal file
453
src/audio/shared_pipeline.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
//! Shared Audio Pipeline for WebRTC
|
||||
//!
|
||||
//! This module provides a shared audio encoding pipeline that can serve
|
||||
//! multiple WebRTC sessions with a single encoder instance.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! AudioCapturer (ALSA)
|
||||
//! |
|
||||
//! v (broadcast::Receiver<AudioFrame>)
|
||||
//! SharedAudioPipeline (single Opus encoder)
|
||||
//! |
|
||||
//! v (broadcast::Sender<OpusFrame>)
|
||||
//! ┌────┴────┬────────┬────────┐
|
||||
//! v v v v
|
||||
//! Session1 Session2 Session3 ...
|
||||
//! (RTP) (RTP) (RTP) (RTP)
|
||||
//! ```
|
||||
//!
|
||||
//! # Key Features
|
||||
//!
|
||||
//! - **Single encoder**: All sessions share one Opus encoder
|
||||
//! - **Broadcast distribution**: Encoded frames are broadcast to all subscribers
|
||||
//! - **Dynamic bitrate**: Bitrate can be changed at runtime
|
||||
//! - **Statistics**: Tracks encoding performance metrics
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{broadcast, Mutex, RwLock};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use super::capture::AudioFrame;
|
||||
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Shared audio pipeline configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SharedAudioPipelineConfig {
|
||||
/// Sample rate (must match audio capture)
|
||||
pub sample_rate: u32,
|
||||
/// Number of channels (1 or 2)
|
||||
pub channels: u32,
|
||||
/// Target bitrate in bps
|
||||
pub bitrate: u32,
|
||||
/// Opus application mode
|
||||
pub application: OpusApplicationMode,
|
||||
/// Enable forward error correction
|
||||
pub fec: bool,
|
||||
/// Broadcast channel capacity
|
||||
pub channel_capacity: usize,
|
||||
}
|
||||
|
||||
impl Default for SharedAudioPipelineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sample_rate: 48000,
|
||||
channels: 2,
|
||||
bitrate: 64000,
|
||||
application: OpusApplicationMode::Audio,
|
||||
fec: true,
|
||||
channel_capacity: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SharedAudioPipelineConfig {
|
||||
/// Create config optimized for voice
|
||||
pub fn voice() -> Self {
|
||||
Self {
|
||||
bitrate: 32000,
|
||||
application: OpusApplicationMode::Voip,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config optimized for music/high quality
|
||||
pub fn high_quality() -> Self {
|
||||
Self {
|
||||
bitrate: 128000,
|
||||
application: OpusApplicationMode::Audio,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to OpusConfig
|
||||
pub fn to_opus_config(&self) -> OpusConfig {
|
||||
OpusConfig {
|
||||
sample_rate: self.sample_rate,
|
||||
channels: self.channels,
|
||||
bitrate: self.bitrate,
|
||||
application: match self.application {
|
||||
OpusApplicationMode::Voip => super::encoder::OpusApplication::Voip,
|
||||
OpusApplicationMode::Audio => super::encoder::OpusApplication::Audio,
|
||||
OpusApplicationMode::LowDelay => super::encoder::OpusApplication::LowDelay,
|
||||
},
|
||||
fec: self.fec,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Opus application mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OpusApplicationMode {
|
||||
/// Voice over IP - optimized for speech
|
||||
Voip,
|
||||
/// General audio - balanced quality
|
||||
Audio,
|
||||
/// Low delay mode - minimal latency
|
||||
LowDelay,
|
||||
}
|
||||
|
||||
/// Shared audio pipeline statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SharedAudioPipelineStats {
|
||||
/// Frames received from audio capture
|
||||
pub frames_received: u64,
|
||||
/// Frames successfully encoded
|
||||
pub frames_encoded: u64,
|
||||
/// Frames dropped (encode errors)
|
||||
pub frames_dropped: u64,
|
||||
/// Total bytes encoded
|
||||
pub bytes_encoded: u64,
|
||||
/// Number of active subscribers
|
||||
pub subscribers: u64,
|
||||
/// Average encode time in milliseconds
|
||||
pub avg_encode_time_ms: f32,
|
||||
/// Current bitrate in bps
|
||||
pub current_bitrate: u32,
|
||||
/// Pipeline running time in seconds
|
||||
pub running_time_secs: f64,
|
||||
}
|
||||
|
||||
/// Shared Audio Pipeline
|
||||
///
|
||||
/// Provides a single Opus encoder that serves multiple WebRTC sessions.
|
||||
/// All sessions receive the same encoded audio stream via broadcast channel.
|
||||
pub struct SharedAudioPipeline {
|
||||
/// Configuration
|
||||
config: RwLock<SharedAudioPipelineConfig>,
|
||||
/// Opus encoder (protected by mutex for encoding)
|
||||
encoder: Mutex<Option<OpusEncoder>>,
|
||||
/// Broadcast sender for encoded Opus frames
|
||||
opus_tx: broadcast::Sender<OpusFrame>,
|
||||
/// Running state
|
||||
running: AtomicBool,
|
||||
/// Statistics
|
||||
stats: Mutex<SharedAudioPipelineStats>,
|
||||
/// Start time for running time calculation
|
||||
start_time: RwLock<Option<Instant>>,
|
||||
/// Encode time accumulator for averaging
|
||||
encode_time_sum_us: AtomicU64,
|
||||
/// Encode count for averaging
|
||||
encode_count: AtomicU64,
|
||||
/// Stop signal (atomic for lock-free checking)
|
||||
stop_flag: AtomicBool,
|
||||
/// Encoding task handle
|
||||
task_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
impl SharedAudioPipeline {
|
||||
/// Create a new shared audio pipeline
|
||||
pub fn new(config: SharedAudioPipelineConfig) -> Result<Arc<Self>> {
|
||||
let (opus_tx, _) = broadcast::channel(config.channel_capacity);
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
config: RwLock::new(config),
|
||||
encoder: Mutex::new(None),
|
||||
opus_tx,
|
||||
running: AtomicBool::new(false),
|
||||
stats: Mutex::new(SharedAudioPipelineStats::default()),
|
||||
start_time: RwLock::new(None),
|
||||
encode_time_sum_us: AtomicU64::new(0),
|
||||
encode_count: AtomicU64::new(0),
|
||||
stop_flag: AtomicBool::new(false),
|
||||
task_handle: Mutex::new(None),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Result<Arc<Self>> {
|
||||
Self::new(SharedAudioPipelineConfig::default())
|
||||
}
|
||||
|
||||
/// Start the audio encoding pipeline
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `audio_rx` - Receiver for raw audio frames from AudioCapturer
|
||||
pub async fn start(self: &Arc<Self>, audio_rx: broadcast::Receiver<AudioFrame>) -> Result<()> {
|
||||
if self.running.load(Ordering::SeqCst) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let config = self.config.read().await.clone();
|
||||
|
||||
info!(
|
||||
"Starting shared audio pipeline: {}Hz {}ch {}bps",
|
||||
config.sample_rate, config.channels, config.bitrate
|
||||
);
|
||||
|
||||
// Create encoder
|
||||
let opus_config = config.to_opus_config();
|
||||
let encoder = OpusEncoder::new(opus_config)?;
|
||||
*self.encoder.lock().await = Some(encoder);
|
||||
|
||||
// Reset stats
|
||||
{
|
||||
let mut stats = self.stats.lock().await;
|
||||
*stats = SharedAudioPipelineStats::default();
|
||||
stats.current_bitrate = config.bitrate;
|
||||
}
|
||||
|
||||
// Reset counters
|
||||
self.encode_time_sum_us.store(0, Ordering::SeqCst);
|
||||
self.encode_count.store(0, Ordering::SeqCst);
|
||||
*self.start_time.write().await = Some(Instant::now());
|
||||
self.stop_flag.store(false, Ordering::SeqCst);
|
||||
|
||||
self.running.store(true, Ordering::SeqCst);
|
||||
|
||||
// Start encoding task
|
||||
let pipeline = self.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
pipeline.encoding_task(audio_rx).await;
|
||||
});
|
||||
|
||||
*self.task_handle.lock().await = Some(handle);
|
||||
|
||||
info!("Shared audio pipeline started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the audio encoding pipeline
|
||||
pub fn stop(&self) {
|
||||
if !self.running.load(Ordering::SeqCst) {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Stopping shared audio pipeline");
|
||||
|
||||
// Signal stop (atomic, no lock needed)
|
||||
self.stop_flag.store(true, Ordering::SeqCst);
|
||||
|
||||
self.running.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if pipeline is running
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.running.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Subscribe to encoded Opus frames
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<OpusFrame> {
|
||||
self.opus_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get number of active subscribers
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.opus_tx.receiver_count()
|
||||
}
|
||||
|
||||
/// Get current statistics
|
||||
pub async fn stats(&self) -> SharedAudioPipelineStats {
|
||||
let mut stats = self.stats.lock().await.clone();
|
||||
stats.subscribers = self.subscriber_count() as u64;
|
||||
|
||||
// Calculate average encode time
|
||||
let count = self.encode_count.load(Ordering::SeqCst);
|
||||
if count > 0 {
|
||||
let sum_us = self.encode_time_sum_us.load(Ordering::SeqCst);
|
||||
stats.avg_encode_time_ms = (sum_us as f64 / count as f64 / 1000.0) as f32;
|
||||
}
|
||||
|
||||
// Calculate running time
|
||||
if let Some(start) = *self.start_time.read().await {
|
||||
stats.running_time_secs = start.elapsed().as_secs_f64();
|
||||
}
|
||||
|
||||
stats
|
||||
}
|
||||
|
||||
/// Set bitrate dynamically
|
||||
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
|
||||
// Update config
|
||||
self.config.write().await.bitrate = bitrate;
|
||||
|
||||
// Update encoder if running
|
||||
if let Some(ref mut encoder) = *self.encoder.lock().await {
|
||||
encoder.set_bitrate(bitrate)?;
|
||||
}
|
||||
|
||||
// Update stats
|
||||
self.stats.lock().await.current_bitrate = bitrate;
|
||||
|
||||
info!("Shared audio pipeline bitrate changed to {}bps", bitrate);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update configuration (requires restart)
|
||||
pub async fn update_config(&self, config: SharedAudioPipelineConfig) -> Result<()> {
|
||||
if self.is_running() {
|
||||
return Err(AppError::AudioError(
|
||||
"Cannot update config while pipeline is running".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
*self.config.write().await = config;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Internal encoding task
|
||||
async fn encoding_task(self: Arc<Self>, mut audio_rx: broadcast::Receiver<AudioFrame>) {
|
||||
info!("Audio encoding task started");
|
||||
|
||||
loop {
|
||||
// Check stop flag (atomic, no async lock needed)
|
||||
if self.stop_flag.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Receive audio frame with timeout
|
||||
let recv_result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
audio_rx.recv(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match recv_result {
|
||||
Ok(Ok(audio_frame)) => {
|
||||
// Update received count
|
||||
{
|
||||
let mut stats = self.stats.lock().await;
|
||||
stats.frames_received += 1;
|
||||
}
|
||||
|
||||
// Encode frame
|
||||
let encode_start = Instant::now();
|
||||
let encode_result = {
|
||||
let mut encoder_guard = self.encoder.lock().await;
|
||||
if let Some(ref mut encoder) = *encoder_guard {
|
||||
Some(encoder.encode_frame(&audio_frame))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
let encode_time = encode_start.elapsed();
|
||||
|
||||
// Update encode time stats
|
||||
self.encode_time_sum_us
|
||||
.fetch_add(encode_time.as_micros() as u64, Ordering::SeqCst);
|
||||
self.encode_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
match encode_result {
|
||||
Some(Ok(opus_frame)) => {
|
||||
// Update stats
|
||||
{
|
||||
let mut stats = self.stats.lock().await;
|
||||
stats.frames_encoded += 1;
|
||||
stats.bytes_encoded += opus_frame.data.len() as u64;
|
||||
}
|
||||
|
||||
// Broadcast to subscribers
|
||||
if self.opus_tx.receiver_count() > 0 {
|
||||
if let Err(e) = self.opus_tx.send(opus_frame) {
|
||||
trace!("No audio subscribers: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("Opus encode error: {}", e);
|
||||
let mut stats = self.stats.lock().await;
|
||||
stats.frames_dropped += 1;
|
||||
}
|
||||
None => {
|
||||
warn!("Encoder not available");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(broadcast::error::RecvError::Closed)) => {
|
||||
info!("Audio source channel closed");
|
||||
break;
|
||||
}
|
||||
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
|
||||
warn!("Audio pipeline lagged by {} frames", n);
|
||||
let mut stats = self.stats.lock().await;
|
||||
stats.frames_dropped += n;
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout - check if still running
|
||||
if !self.running.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
debug!("Audio receive timeout, continuing...");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
self.running.store(false, Ordering::SeqCst);
|
||||
*self.encoder.lock().await = None;
|
||||
|
||||
let stats = self.stats().await;
|
||||
info!(
|
||||
"Audio encoding task ended: {} frames encoded, {} dropped, {:.1}s runtime",
|
||||
stats.frames_encoded, stats.frames_dropped, stats.running_time_secs
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SharedAudioPipeline {
|
||||
fn drop(&mut self) {
|
||||
self.stop();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = SharedAudioPipelineConfig::default();
|
||||
assert_eq!(config.sample_rate, 48000);
|
||||
assert_eq!(config.channels, 2);
|
||||
assert_eq!(config.bitrate, 64000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_voice() {
|
||||
let config = SharedAudioPipelineConfig::voice();
|
||||
assert_eq!(config.bitrate, 32000);
|
||||
assert_eq!(config.application, OpusApplicationMode::Voip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_high_quality() {
|
||||
let config = SharedAudioPipelineConfig::high_quality();
|
||||
assert_eq!(config.bitrate, 128000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pipeline_creation() {
|
||||
let config = SharedAudioPipelineConfig::default();
|
||||
let pipeline = SharedAudioPipeline::new(config);
|
||||
assert!(pipeline.is_ok());
|
||||
|
||||
let pipeline = pipeline.unwrap();
|
||||
assert!(!pipeline.is_running());
|
||||
assert_eq!(pipeline.subscriber_count(), 0);
|
||||
}
|
||||
}
|
||||
401
src/audio/streamer.rs
Normal file
401
src/audio/streamer.rs
Normal file
@@ -0,0 +1,401 @@
|
||||
//! Audio streaming pipeline
|
||||
//!
|
||||
//! Coordinates audio capture and Opus encoding, distributing encoded
|
||||
//! frames to multiple subscribers via broadcast channel.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{broadcast, watch, Mutex, RwLock};
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
use super::capture::{AudioCapturer, AudioConfig, CaptureState};
|
||||
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Audio stream state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AudioStreamState {
|
||||
/// Stream is stopped
|
||||
Stopped,
|
||||
/// Stream is starting up
|
||||
Starting,
|
||||
/// Stream is running
|
||||
Running,
|
||||
/// Stream encountered an error
|
||||
Error,
|
||||
}
|
||||
|
||||
impl Default for AudioStreamState {
|
||||
fn default() -> Self {
|
||||
Self::Stopped
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio streamer configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioStreamerConfig {
|
||||
/// Audio capture configuration
|
||||
pub capture: AudioConfig,
|
||||
/// Opus encoder configuration
|
||||
pub opus: OpusConfig,
|
||||
}
|
||||
|
||||
impl Default for AudioStreamerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
capture: AudioConfig::default(),
|
||||
opus: OpusConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AudioStreamerConfig {
|
||||
/// Create config for a specific device with default quality
|
||||
pub fn for_device(device_name: &str) -> Self {
|
||||
Self {
|
||||
capture: AudioConfig {
|
||||
device_name: device_name.to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
opus: OpusConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config with specified bitrate
|
||||
pub fn with_bitrate(mut self, bitrate: u32) -> Self {
|
||||
self.opus.bitrate = bitrate;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio stream statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AudioStreamStats {
|
||||
/// Frames captured from ALSA
|
||||
pub frames_captured: u64,
|
||||
/// Frames encoded to Opus
|
||||
pub frames_encoded: u64,
|
||||
/// Total bytes output (Opus)
|
||||
pub bytes_output: u64,
|
||||
/// Current encoding bitrate
|
||||
pub current_bitrate: u32,
|
||||
/// Number of active subscribers
|
||||
pub subscriber_count: usize,
|
||||
/// Buffer overruns
|
||||
pub buffer_overruns: u64,
|
||||
}
|
||||
|
||||
/// Audio streamer
|
||||
///
|
||||
/// Manages the audio capture -> encode -> broadcast pipeline.
|
||||
pub struct AudioStreamer {
|
||||
config: RwLock<AudioStreamerConfig>,
|
||||
state: watch::Sender<AudioStreamState>,
|
||||
state_rx: watch::Receiver<AudioStreamState>,
|
||||
capturer: RwLock<Option<Arc<AudioCapturer>>>,
|
||||
encoder: Arc<Mutex<Option<OpusEncoder>>>,
|
||||
opus_tx: broadcast::Sender<OpusFrame>,
|
||||
stats: Arc<Mutex<AudioStreamStats>>,
|
||||
sequence: AtomicU64,
|
||||
stream_start_time: RwLock<Option<Instant>>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AudioStreamer {
|
||||
/// Create a new audio streamer with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(AudioStreamerConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new audio streamer with specified configuration
|
||||
pub fn with_config(config: AudioStreamerConfig) -> Self {
|
||||
let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped);
|
||||
let (opus_tx, _) = broadcast::channel(64);
|
||||
|
||||
Self {
|
||||
config: RwLock::new(config),
|
||||
state: state_tx,
|
||||
state_rx,
|
||||
capturer: RwLock::new(None),
|
||||
encoder: Arc::new(Mutex::new(None)),
|
||||
opus_tx,
|
||||
stats: Arc::new(Mutex::new(AudioStreamStats::default())),
|
||||
sequence: AtomicU64::new(0),
|
||||
stream_start_time: RwLock::new(None),
|
||||
stop_flag: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub fn state(&self) -> AudioStreamState {
|
||||
*self.state_rx.borrow()
|
||||
}
|
||||
|
||||
/// Subscribe to state changes
|
||||
pub fn state_watch(&self) -> watch::Receiver<AudioStreamState> {
|
||||
self.state_rx.clone()
|
||||
}
|
||||
|
||||
/// Subscribe to Opus frames
|
||||
pub fn subscribe_opus(&self) -> broadcast::Receiver<OpusFrame> {
|
||||
self.opus_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get number of active subscribers
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.opus_tx.receiver_count()
|
||||
}
|
||||
|
||||
/// Get current statistics
|
||||
pub async fn stats(&self) -> AudioStreamStats {
|
||||
let mut stats = self.stats.lock().await.clone();
|
||||
stats.subscriber_count = self.subscriber_count();
|
||||
stats
|
||||
}
|
||||
|
||||
/// Update configuration (only when stopped)
|
||||
pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> {
|
||||
if self.state() != AudioStreamState::Stopped {
|
||||
return Err(AppError::AudioError(
|
||||
"Cannot change config while streaming".to_string(),
|
||||
));
|
||||
}
|
||||
*self.config.write().await = config;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically (can be done while streaming)
|
||||
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
|
||||
// Update config
|
||||
self.config.write().await.opus.bitrate = bitrate;
|
||||
|
||||
// Update encoder if running
|
||||
if let Some(ref mut encoder) = *self.encoder.lock().await {
|
||||
encoder.set_bitrate(bitrate)?;
|
||||
}
|
||||
|
||||
// Update stats
|
||||
self.stats.lock().await.current_bitrate = bitrate;
|
||||
|
||||
info!("Audio bitrate changed to {}bps", bitrate);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start the audio stream
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
if self.state() == AudioStreamState::Running {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let _ = self.state.send(AudioStreamState::Starting);
|
||||
self.stop_flag.store(false, Ordering::SeqCst);
|
||||
|
||||
let config = self.config.read().await.clone();
|
||||
|
||||
info!(
|
||||
"Starting audio stream: {} @ {}Hz {}ch, {}bps Opus",
|
||||
config.capture.device_name,
|
||||
config.capture.sample_rate,
|
||||
config.capture.channels,
|
||||
config.opus.bitrate
|
||||
);
|
||||
|
||||
// Create capturer
|
||||
let capturer = Arc::new(AudioCapturer::new(config.capture.clone()));
|
||||
*self.capturer.write().await = Some(capturer.clone());
|
||||
|
||||
// Create encoder
|
||||
let encoder = OpusEncoder::new(config.opus.clone())?;
|
||||
*self.encoder.lock().await = Some(encoder);
|
||||
|
||||
// Start capture
|
||||
capturer.start().await?;
|
||||
|
||||
// Reset stats
|
||||
{
|
||||
let mut stats = self.stats.lock().await;
|
||||
*stats = AudioStreamStats::default();
|
||||
stats.current_bitrate = config.opus.bitrate;
|
||||
}
|
||||
|
||||
// Record start time
|
||||
*self.stream_start_time.write().await = Some(Instant::now());
|
||||
self.sequence.store(0, Ordering::SeqCst);
|
||||
|
||||
// Start encoding task
|
||||
let capturer_for_task = capturer.clone();
|
||||
let encoder = self.encoder.clone();
|
||||
let opus_tx = self.opus_tx.clone();
|
||||
let stats = self.stats.clone();
|
||||
let state = self.state.clone();
|
||||
let stop_flag = self.stop_flag.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::stream_task(capturer_for_task, encoder, opus_tx, stats, state, stop_flag).await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the audio stream
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
if self.state() == AudioStreamState::Stopped {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Stopping audio stream");
|
||||
|
||||
// Signal stop
|
||||
self.stop_flag.store(true, Ordering::SeqCst);
|
||||
|
||||
// Stop capturer
|
||||
if let Some(ref capturer) = *self.capturer.read().await {
|
||||
capturer.stop().await?;
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
*self.capturer.write().await = None;
|
||||
*self.encoder.lock().await = None;
|
||||
*self.stream_start_time.write().await = None;
|
||||
|
||||
let _ = self.state.send(AudioStreamState::Stopped);
|
||||
info!("Audio stream stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if streaming
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.state() == AudioStreamState::Running
|
||||
}
|
||||
|
||||
/// Internal streaming task
|
||||
async fn stream_task(
|
||||
capturer: Arc<AudioCapturer>,
|
||||
encoder: Arc<Mutex<Option<OpusEncoder>>>,
|
||||
opus_tx: broadcast::Sender<OpusFrame>,
|
||||
stats: Arc<Mutex<AudioStreamStats>>,
|
||||
state: watch::Sender<AudioStreamState>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
) {
|
||||
let mut pcm_rx = capturer.subscribe();
|
||||
let _ = state.send(AudioStreamState::Running);
|
||||
|
||||
info!("Audio stream task started");
|
||||
|
||||
loop {
|
||||
// Check stop flag (atomic, no async lock needed)
|
||||
if stop_flag.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Check capturer state
|
||||
if capturer.state() == CaptureState::Error {
|
||||
error!("Audio capture error, stopping stream");
|
||||
let _ = state.send(AudioStreamState::Error);
|
||||
break;
|
||||
}
|
||||
|
||||
// Receive PCM frame with timeout
|
||||
let recv_result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
pcm_rx.recv(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match recv_result {
|
||||
Ok(Ok(audio_frame)) => {
|
||||
// Update capture stats
|
||||
{
|
||||
let mut s = stats.lock().await;
|
||||
s.frames_captured += 1;
|
||||
}
|
||||
|
||||
// Encode to Opus
|
||||
let opus_result = {
|
||||
let mut enc_guard = encoder.lock().await;
|
||||
if let Some(ref mut enc) = *enc_guard {
|
||||
Some(enc.encode_frame(&audio_frame))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
match opus_result {
|
||||
Some(Ok(opus_frame)) => {
|
||||
// Update stats
|
||||
{
|
||||
let mut s = stats.lock().await;
|
||||
s.frames_encoded += 1;
|
||||
s.bytes_output += opus_frame.data.len() as u64;
|
||||
}
|
||||
|
||||
// Broadcast to subscribers
|
||||
if opus_tx.receiver_count() > 0 {
|
||||
if let Err(e) = opus_tx.send(opus_frame) {
|
||||
trace!("No audio subscribers: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("Opus encode error: {}", e);
|
||||
}
|
||||
None => {
|
||||
warn!("Encoder not available");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(broadcast::error::RecvError::Closed)) => {
|
||||
info!("Audio capture channel closed");
|
||||
break;
|
||||
}
|
||||
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
|
||||
warn!("Audio receiver lagged by {} frames", n);
|
||||
let mut s = stats.lock().await;
|
||||
s.buffer_overruns += n;
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout - check if still capturing
|
||||
if capturer.state() != CaptureState::Running {
|
||||
info!("Audio capture stopped, ending stream task");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = state.send(AudioStreamState::Stopped);
|
||||
info!("Audio stream task ended");
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AudioStreamer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_streamer_config_default() {
|
||||
let config = AudioStreamerConfig::default();
|
||||
assert_eq!(config.capture.sample_rate, 48000);
|
||||
assert_eq!(config.opus.bitrate, 64000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streamer_config_for_device() {
|
||||
let config = AudioStreamerConfig::for_device("hw:0,0");
|
||||
assert_eq!(config.capture.device_name, "hw:0,0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streamer_state() {
|
||||
let streamer = AudioStreamer::new();
|
||||
assert_eq!(streamer.state(), AudioStreamState::Stopped);
|
||||
}
|
||||
}
|
||||
142
src/auth/middleware.rs
Normal file
142
src/auth/middleware.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use axum_extra::extract::CookieJar;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Session cookie name
|
||||
pub const SESSION_COOKIE: &str = "one_kvm_session";
|
||||
|
||||
/// Auth layer for extracting session from request
|
||||
#[derive(Clone)]
|
||||
pub struct AuthLayer;
|
||||
|
||||
/// Extract session ID from request
|
||||
pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option<String> {
|
||||
// First try cookie
|
||||
if let Some(cookie) = cookies.get(SESSION_COOKIE) {
|
||||
return Some(cookie.value().to_string());
|
||||
}
|
||||
|
||||
// Then try Authorization header (Bearer token)
|
||||
if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
|
||||
if let Ok(auth_str) = auth_header.to_str() {
|
||||
if let Some(token) = auth_str.strip_prefix("Bearer ") {
|
||||
return Some(token.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Authentication middleware
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<Arc<AppState>>,
|
||||
cookies: CookieJar,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
// Check if system is initialized
|
||||
if !state.config.is_initialized() {
|
||||
// Allow access to setup endpoints when not initialized
|
||||
let path = request.uri().path();
|
||||
if path.starts_with("/api/setup") || path == "/api/info" || path.starts_with("/") && !path.starts_with("/api/") {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
|
||||
// Public endpoints that don't require auth
|
||||
let path = request.uri().path();
|
||||
if is_public_endpoint(path) {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
|
||||
// Extract session ID
|
||||
let session_id = extract_session_id(&cookies, request.headers());
|
||||
|
||||
if let Some(session_id) = session_id {
|
||||
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
|
||||
// Add session to request extensions
|
||||
request.extensions_mut().insert(session);
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
|
||||
/// Check if endpoint is public (no auth required)
|
||||
fn is_public_endpoint(path: &str) -> bool {
|
||||
// Note: paths here are relative to /api since middleware is applied before nest
|
||||
matches!(
|
||||
path,
|
||||
"/"
|
||||
| "/auth/login"
|
||||
| "/info"
|
||||
| "/health"
|
||||
| "/setup"
|
||||
| "/setup/init"
|
||||
// Also check with /api prefix for direct access
|
||||
| "/api/auth/login"
|
||||
| "/api/info"
|
||||
| "/api/health"
|
||||
| "/api/setup"
|
||||
| "/api/setup/init"
|
||||
) || path.starts_with("/assets/")
|
||||
|| path.starts_with("/static/")
|
||||
|| path.ends_with(".js")
|
||||
|| path.ends_with(".css")
|
||||
|| path.ends_with(".ico")
|
||||
|| path.ends_with(".png")
|
||||
|| path.ends_with(".svg")
|
||||
}
|
||||
|
||||
/// Require authentication - returns 401 if not authenticated
|
||||
pub async fn require_auth(
|
||||
State(state): State<Arc<AppState>>,
|
||||
cookies: CookieJar,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let session_id = extract_session_id(&cookies, request.headers());
|
||||
|
||||
if let Some(session_id) = session_id {
|
||||
if let Ok(Some(_session)) = state.sessions.get(&session_id).await {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
|
||||
/// Require admin privileges - returns 403 if not admin
|
||||
pub async fn require_admin(
|
||||
State(state): State<Arc<AppState>>,
|
||||
cookies: CookieJar,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let session_id = extract_session_id(&cookies, request.headers());
|
||||
|
||||
if let Some(session_id) = session_id {
|
||||
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
|
||||
// Get user and check admin status
|
||||
if let Ok(Some(user)) = state.users.get(&session.user_id).await {
|
||||
if user.is_admin {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
// User is authenticated but not admin
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not authenticated at all
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
9
src/auth/mod.rs
Normal file
9
src/auth/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod password;
|
||||
mod session;
|
||||
mod user;
|
||||
pub mod middleware;
|
||||
|
||||
pub use password::{hash_password, verify_password};
|
||||
pub use session::{Session, SessionStore};
|
||||
pub use user::{User, UserStore};
|
||||
pub use middleware::{AuthLayer, SESSION_COOKIE, auth_middleware, require_admin};
|
||||
41
src/auth/password.rs
Normal file
41
src/auth/password.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2,
|
||||
};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Hash a password using Argon2
|
||||
pub fn hash_password(password: &str) -> Result<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map(|hash| hash.to_string())
|
||||
.map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Verify a password against a hash
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
|
||||
let parsed_hash = PasswordHash::new(hash)
|
||||
.map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?;
|
||||
|
||||
Ok(Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_password_hash_verify() {
|
||||
let password = "test_password_123";
|
||||
let hash = hash_password(password).unwrap();
|
||||
|
||||
assert!(verify_password(password, &hash).unwrap());
|
||||
assert!(!verify_password("wrong_password", &hash).unwrap());
|
||||
}
|
||||
}
|
||||
129
src/auth/session.rs
Normal file
129
src/auth/session.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Sqlite};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// Session data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub id: String,
|
||||
pub user_id: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Check if session is expired
|
||||
pub fn is_expired(&self) -> bool {
|
||||
Utc::now() > self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Session store backed by SQLite
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStore {
|
||||
pool: Pool<Sqlite>,
|
||||
default_ttl: Duration,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Create a new session store
|
||||
pub fn new(pool: Pool<Sqlite>, ttl_secs: i64) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
default_ttl: Duration::seconds(ttl_secs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session
|
||||
pub async fn create(&self, user_id: &str) -> Result<Session> {
|
||||
let session = Session {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
created_at: Utc::now(),
|
||||
expires_at: Utc::now() + self.default_ttl,
|
||||
data: None,
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO sessions (id, user_id, created_at, expires_at, data)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
"#,
|
||||
)
|
||||
.bind(&session.id)
|
||||
.bind(&session.user_id)
|
||||
.bind(session.created_at.to_rfc3339())
|
||||
.bind(session.expires_at.to_rfc3339())
|
||||
.bind(session.data.as_ref().map(|d| d.to_string()))
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Get a session by ID
|
||||
pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
|
||||
let row: Option<(String, String, String, String, Option<String>)> = sqlx::query_as(
|
||||
"SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1",
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some((id, user_id, created_at, expires_at, data)) => {
|
||||
let session = Session {
|
||||
id,
|
||||
user_id,
|
||||
created_at: DateTime::parse_from_rfc3339(&created_at)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
expires_at: DateTime::parse_from_rfc3339(&expires_at)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
data: data.and_then(|d| serde_json::from_str(&d).ok()),
|
||||
};
|
||||
|
||||
if session.is_expired() {
|
||||
self.delete(&session.id).await?;
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(session))
|
||||
}
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete a session
|
||||
pub async fn delete(&self, session_id: &str) -> Result<()> {
|
||||
sqlx::query("DELETE FROM sessions WHERE id = ?1")
|
||||
.bind(session_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all expired sessions
|
||||
pub async fn cleanup_expired(&self) -> Result<u64> {
|
||||
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < datetime('now')")
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Extend session expiration
|
||||
pub async fn extend(&self, session_id: &str) -> Result<()> {
|
||||
let new_expires = Utc::now() + self.default_ttl;
|
||||
sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2")
|
||||
.bind(new_expires.to_rfc3339())
|
||||
.bind(session_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
185
src/auth/user.rs
Normal file
185
src/auth/user.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Sqlite};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use super::password::{hash_password, verify_password};
|
||||
|
||||
/// User row type from database
|
||||
type UserRow = (String, String, String, i32, String, String);
|
||||
|
||||
/// User data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
#[serde(skip_serializing)]
|
||||
pub password_hash: String,
|
||||
pub is_admin: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl User {
|
||||
/// Convert from database row to User
|
||||
fn from_row(row: UserRow) -> Self {
|
||||
let (id, username, password_hash, is_admin, created_at, updated_at) = row;
|
||||
Self {
|
||||
id,
|
||||
username,
|
||||
password_hash,
|
||||
is_admin: is_admin != 0,
|
||||
created_at: DateTime::parse_from_rfc3339(&created_at)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
updated_at: DateTime::parse_from_rfc3339(&updated_at)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// User store backed by SQLite
|
||||
#[derive(Clone)]
|
||||
pub struct UserStore {
|
||||
pool: Pool<Sqlite>,
|
||||
}
|
||||
|
||||
impl UserStore {
|
||||
/// Create a new user store
|
||||
pub fn new(pool: Pool<Sqlite>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Create a new user
|
||||
pub async fn create(&self, username: &str, password: &str, is_admin: bool) -> Result<User> {
|
||||
// Check if username already exists
|
||||
if self.get_by_username(username).await?.is_some() {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"Username '{}' already exists",
|
||||
username
|
||||
)));
|
||||
}
|
||||
|
||||
let password_hash = hash_password(password)?;
|
||||
let now = Utc::now();
|
||||
let user = User {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
username: username.to_string(),
|
||||
password_hash,
|
||||
is_admin,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO users (id, username, password_hash, is_admin, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
"#,
|
||||
)
|
||||
.bind(&user.id)
|
||||
.bind(&user.username)
|
||||
.bind(&user.password_hash)
|
||||
.bind(user.is_admin as i32)
|
||||
.bind(user.created_at.to_rfc3339())
|
||||
.bind(user.updated_at.to_rfc3339())
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// Get user by ID
|
||||
pub async fn get(&self, user_id: &str) -> Result<Option<User>> {
|
||||
let row: Option<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users WHERE id = ?1",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(User::from_row))
|
||||
}
|
||||
|
||||
/// Get user by username
|
||||
pub async fn get_by_username(&self, username: &str) -> Result<Option<User>> {
|
||||
let row: Option<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users WHERE username = ?1",
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(User::from_row))
|
||||
}
|
||||
|
||||
/// Verify user credentials
|
||||
pub async fn verify(&self, username: &str, password: &str) -> Result<Option<User>> {
|
||||
let user = match self.get_by_username(username).await? {
|
||||
Some(user) => user,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
if verify_password(password, &user.password_hash)? {
|
||||
Ok(Some(user))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Update user password
|
||||
pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> {
|
||||
let password_hash = hash_password(new_password)?;
|
||||
let now = Utc::now();
|
||||
|
||||
let result = sqlx::query(
|
||||
"UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3",
|
||||
)
|
||||
.bind(&password_hash)
|
||||
.bind(now.to_rfc3339())
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(AppError::NotFound("User not found".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all users
|
||||
pub async fn list(&self) -> Result<Vec<User>> {
|
||||
let rows: Vec<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash, is_admin, created_at, updated_at FROM users ORDER BY created_at",
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(User::from_row).collect())
|
||||
}
|
||||
|
||||
/// Delete user by ID
|
||||
pub async fn delete(&self, user_id: &str) -> Result<()> {
|
||||
let result = sqlx::query("DELETE FROM users WHERE id = ?1")
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(AppError::NotFound("User not found".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if any users exist
|
||||
pub async fn has_users(&self) -> Result<bool> {
|
||||
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
Ok(count.0 > 0)
|
||||
}
|
||||
}
|
||||
5
src/config/mod.rs
Normal file
5
src/config/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod schema;
|
||||
mod store;
|
||||
|
||||
pub use schema::*;
|
||||
pub use store::ConfigStore;
|
||||
416
src/config/schema.rs
Normal file
416
src/config/schema.rs
Normal file
@@ -0,0 +1,416 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
// Re-export ExtensionsConfig from extensions module
|
||||
pub use crate::extensions::ExtensionsConfig;
|
||||
|
||||
/// Main application configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct AppConfig {
|
||||
/// Whether initial setup has been completed
|
||||
pub initialized: bool,
|
||||
/// Authentication settings
|
||||
pub auth: AuthConfig,
|
||||
/// Video capture settings
|
||||
pub video: VideoConfig,
|
||||
/// HID (keyboard/mouse) settings
|
||||
pub hid: HidConfig,
|
||||
/// Mass Storage Device settings
|
||||
pub msd: MsdConfig,
|
||||
/// ATX power control settings
|
||||
pub atx: AtxConfig,
|
||||
/// Audio settings
|
||||
pub audio: AudioConfig,
|
||||
/// Streaming settings
|
||||
pub stream: StreamConfig,
|
||||
/// Web server settings
|
||||
pub web: WebConfig,
|
||||
/// Extensions settings (ttyd, gostc, easytier)
|
||||
pub extensions: ExtensionsConfig,
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
initialized: false,
|
||||
auth: AuthConfig::default(),
|
||||
video: VideoConfig::default(),
|
||||
hid: HidConfig::default(),
|
||||
msd: MsdConfig::default(),
|
||||
atx: AtxConfig::default(),
|
||||
audio: AudioConfig::default(),
|
||||
stream: StreamConfig::default(),
|
||||
web: WebConfig::default(),
|
||||
extensions: ExtensionsConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Authentication configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct AuthConfig {
|
||||
/// Session timeout in seconds
|
||||
pub session_timeout_secs: u32,
|
||||
/// Enable 2FA
|
||||
pub totp_enabled: bool,
|
||||
/// TOTP secret (encrypted)
|
||||
pub totp_secret: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AuthConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
session_timeout_secs: 3600 * 24, // 24 hours
|
||||
totp_enabled: false,
|
||||
totp_secret: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Video capture configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(default)]
|
||||
pub struct VideoConfig {
|
||||
/// Video device path (e.g., /dev/video0)
|
||||
pub device: Option<String>,
|
||||
/// Video pixel format (e.g., "MJPEG", "YUYV", "NV12")
|
||||
pub format: Option<String>,
|
||||
/// Resolution width
|
||||
pub width: u32,
|
||||
/// Resolution height
|
||||
pub height: u32,
|
||||
/// Frame rate
|
||||
pub fps: u32,
|
||||
/// JPEG quality (1-100)
|
||||
pub quality: u32,
|
||||
}
|
||||
|
||||
impl Default for VideoConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device: None,
|
||||
format: None, // Auto-detect or use MJPEG as default
|
||||
width: 1920,
|
||||
height: 1080,
|
||||
fps: 30,
|
||||
quality: 80,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HID backend type
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum HidBackend {
|
||||
/// USB OTG HID gadget
|
||||
Otg,
|
||||
/// CH9329 serial HID controller
|
||||
Ch9329,
|
||||
/// Disabled
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for HidBackend {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// HID configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(default)]
|
||||
pub struct HidConfig {
|
||||
/// HID backend type
|
||||
pub backend: HidBackend,
|
||||
/// OTG keyboard device path
|
||||
pub otg_keyboard: String,
|
||||
/// OTG mouse device path
|
||||
pub otg_mouse: String,
|
||||
/// OTG UDC (USB Device Controller) name
|
||||
pub otg_udc: Option<String>,
|
||||
/// CH9329 serial port
|
||||
pub ch9329_port: String,
|
||||
/// CH9329 baud rate
|
||||
pub ch9329_baudrate: u32,
|
||||
/// Mouse mode: absolute or relative
|
||||
pub mouse_absolute: bool,
|
||||
}
|
||||
|
||||
impl Default for HidConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
backend: HidBackend::None,
|
||||
otg_keyboard: "/dev/hidg0".to_string(),
|
||||
otg_mouse: "/dev/hidg1".to_string(),
|
||||
otg_udc: None,
|
||||
ch9329_port: "/dev/ttyUSB0".to_string(),
|
||||
ch9329_baudrate: 9600,
|
||||
mouse_absolute: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct MsdConfig {
|
||||
/// Enable MSD functionality
|
||||
pub enabled: bool,
|
||||
/// Storage path for ISO/IMG images
|
||||
pub images_path: String,
|
||||
/// Path for Ventoy bootable drive file
|
||||
pub drive_path: String,
|
||||
/// Ventoy drive size in MB (minimum 1024 MB / 1 GB)
|
||||
pub virtual_drive_size_mb: u32,
|
||||
}
|
||||
|
||||
impl Default for MsdConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
images_path: "./data/msd/images".to_string(),
|
||||
drive_path: "./data/msd/ventoy.img".to_string(),
|
||||
virtual_drive_size_mb: 16 * 1024, // 16GB default
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export ATX types from atx module for configuration
|
||||
pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig};
|
||||
|
||||
/// ATX power control configuration
|
||||
///
|
||||
/// Each ATX action (power, reset) can be independently configured with its own
|
||||
/// hardware binding using the four-tuple: (driver, device, pin, active_level).
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct AtxConfig {
|
||||
/// Enable ATX functionality
|
||||
pub enabled: bool,
|
||||
/// Power button configuration (used for both short and long press)
|
||||
pub power: AtxKeyConfig,
|
||||
/// Reset button configuration
|
||||
pub reset: AtxKeyConfig,
|
||||
/// LED sensing configuration (optional)
|
||||
pub led: AtxLedConfig,
|
||||
/// Network interface for WOL packets (empty = auto)
|
||||
pub wol_interface: String,
|
||||
}
|
||||
|
||||
impl Default for AtxConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
power: AtxKeyConfig::default(),
|
||||
reset: AtxKeyConfig::default(),
|
||||
led: AtxLedConfig::default(),
|
||||
wol_interface: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AtxConfig {
|
||||
/// Convert to AtxControllerConfig for the controller
|
||||
pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig {
|
||||
crate::atx::AtxControllerConfig {
|
||||
enabled: self.enabled,
|
||||
power: self.power.clone(),
|
||||
reset: self.reset.clone(),
|
||||
led: self.led.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio configuration
|
||||
///
|
||||
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
|
||||
/// These are optimal for Opus encoding and match WebRTC requirements.
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct AudioConfig {
|
||||
/// Enable audio capture
|
||||
pub enabled: bool,
|
||||
/// ALSA device name
|
||||
pub device: String,
|
||||
/// Audio quality preset: "voice", "balanced", "high"
|
||||
pub quality: String,
|
||||
}
|
||||
|
||||
impl Default for AudioConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
device: "default".to_string(),
|
||||
quality: "balanced".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream mode
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum StreamMode {
|
||||
/// WebRTC with H264/H265
|
||||
WebRTC,
|
||||
/// MJPEG over HTTP
|
||||
Mjpeg,
|
||||
}
|
||||
|
||||
impl Default for StreamMode {
|
||||
fn default() -> Self {
|
||||
Self::Mjpeg
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoder type
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EncoderType {
|
||||
/// Auto-detect best encoder
|
||||
Auto,
|
||||
/// Software encoder (libx264)
|
||||
Software,
|
||||
/// VAAPI hardware encoder
|
||||
Vaapi,
|
||||
/// NVIDIA NVENC hardware encoder
|
||||
Nvenc,
|
||||
/// Intel Quick Sync hardware encoder
|
||||
Qsv,
|
||||
/// AMD AMF hardware encoder
|
||||
Amf,
|
||||
/// Rockchip MPP hardware encoder
|
||||
Rkmpp,
|
||||
/// V4L2 M2M hardware encoder
|
||||
V4l2m2m,
|
||||
}
|
||||
|
||||
impl Default for EncoderType {
|
||||
fn default() -> Self {
|
||||
Self::Auto
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderType {
|
||||
/// Convert to EncoderBackend for registry queries
|
||||
pub fn to_backend(&self) -> Option<crate::video::encoder::registry::EncoderBackend> {
|
||||
use crate::video::encoder::registry::EncoderBackend;
|
||||
match self {
|
||||
EncoderType::Auto => None,
|
||||
EncoderType::Software => Some(EncoderBackend::Software),
|
||||
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
|
||||
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
|
||||
EncoderType::Qsv => Some(EncoderBackend::Qsv),
|
||||
EncoderType::Amf => Some(EncoderBackend::Amf),
|
||||
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
|
||||
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get display name for UI
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
EncoderType::Auto => "Auto (Recommended)",
|
||||
EncoderType::Software => "Software (CPU)",
|
||||
EncoderType::Vaapi => "VAAPI",
|
||||
EncoderType::Nvenc => "NVIDIA NVENC",
|
||||
EncoderType::Qsv => "Intel Quick Sync",
|
||||
EncoderType::Amf => "AMD AMF",
|
||||
EncoderType::Rkmpp => "Rockchip MPP",
|
||||
EncoderType::V4l2m2m => "V4L2 M2M",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct StreamConfig {
|
||||
/// Stream mode
|
||||
pub mode: StreamMode,
|
||||
/// Encoder type for H264/H265
|
||||
pub encoder: EncoderType,
|
||||
/// Target bitrate in kbps (for H264/H265)
|
||||
pub bitrate_kbps: u32,
|
||||
/// GOP size
|
||||
pub gop_size: u32,
|
||||
/// Custom STUN server (e.g., "stun:stun.l.google.com:19302")
|
||||
pub stun_server: Option<String>,
|
||||
/// Custom TURN server (e.g., "turn:turn.example.com:3478")
|
||||
pub turn_server: Option<String>,
|
||||
/// TURN username
|
||||
pub turn_username: Option<String>,
|
||||
/// TURN password (stored encrypted in DB, not exposed via API)
|
||||
pub turn_password: Option<String>,
|
||||
/// Auto-pause when no clients connected
|
||||
#[typeshare(skip)]
|
||||
pub auto_pause_enabled: bool,
|
||||
/// Auto-pause delay (seconds)
|
||||
#[typeshare(skip)]
|
||||
pub auto_pause_delay_secs: u64,
|
||||
/// Client timeout for cleanup (seconds)
|
||||
#[typeshare(skip)]
|
||||
pub client_timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for StreamConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: StreamMode::Mjpeg,
|
||||
encoder: EncoderType::Auto,
|
||||
bitrate_kbps: 8000,
|
||||
gop_size: 30,
|
||||
stun_server: Some("stun:stun.l.google.com:19302".to_string()),
|
||||
turn_server: None,
|
||||
turn_username: None,
|
||||
turn_password: None,
|
||||
auto_pause_enabled: false,
|
||||
auto_pause_delay_secs: 10,
|
||||
client_timeout_secs: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Web server configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct WebConfig {
|
||||
/// HTTP port
|
||||
pub http_port: u16,
|
||||
/// HTTPS port
|
||||
pub https_port: u16,
|
||||
/// Bind address
|
||||
pub bind_address: String,
|
||||
/// Enable HTTPS
|
||||
pub https_enabled: bool,
|
||||
/// Custom SSL certificate path
|
||||
pub ssl_cert_path: Option<String>,
|
||||
/// Custom SSL key path
|
||||
pub ssl_key_path: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for WebConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
http_port: 8080,
|
||||
https_port: 8443,
|
||||
bind_address: "0.0.0.0".to_string(),
|
||||
https_enabled: false,
|
||||
ssl_cert_path: None,
|
||||
ssl_key_path: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
264
src/config/store.rs
Normal file
264
src/config/store.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use arc_swap::ArcSwap;
|
||||
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use super::AppConfig;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Configuration store backed by SQLite
|
||||
///
|
||||
/// Uses `ArcSwap` for lock-free reads, providing high performance
|
||||
/// for frequent configuration access in hot paths.
|
||||
#[derive(Clone)]
|
||||
pub struct ConfigStore {
|
||||
pool: Pool<Sqlite>,
|
||||
/// Lock-free cache using ArcSwap for zero-cost reads
|
||||
cache: Arc<ArcSwap<AppConfig>>,
|
||||
change_tx: broadcast::Sender<ConfigChange>,
|
||||
}
|
||||
|
||||
/// Configuration change event
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfigChange {
|
||||
pub key: String,
|
||||
}
|
||||
|
||||
impl ConfigStore {
|
||||
/// Create a new configuration store
|
||||
pub async fn new(db_path: &Path) -> Result<Self> {
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = db_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
// SQLite uses single-writer mode, 2 connections is sufficient for embedded devices
|
||||
// One for reads, one for writes to avoid blocking
|
||||
.max_connections(2)
|
||||
// Set reasonable timeouts for embedded environments
|
||||
.acquire_timeout(Duration::from_secs(5))
|
||||
.idle_timeout(Duration::from_secs(300))
|
||||
.connect(&db_url)
|
||||
.await?;
|
||||
|
||||
// Initialize database schema
|
||||
Self::init_schema(&pool).await?;
|
||||
|
||||
// Load or create default config
|
||||
let config = Self::load_config(&pool).await?;
|
||||
let cache = Arc::new(ArcSwap::from_pointee(config));
|
||||
|
||||
let (change_tx, _) = broadcast::channel(16);
|
||||
|
||||
Ok(Self {
|
||||
pool,
|
||||
cache,
|
||||
change_tx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize database schema
|
||||
async fn init_schema(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
expires_at TEXT NOT NULL,
|
||||
data TEXT
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS api_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
permissions TEXT NOT NULL,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
last_used TEXT
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load configuration from database
|
||||
async fn load_config(pool: &Pool<Sqlite>) -> Result<AppConfig> {
|
||||
let row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT value FROM config WHERE key = 'app_config'"
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some((json,)) => {
|
||||
serde_json::from_str(&json).map_err(|e| AppError::Config(e.to_string()))
|
||||
}
|
||||
None => {
|
||||
// Create default config
|
||||
let config = AppConfig::default();
|
||||
Self::save_config_to_db(pool, &config).await?;
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save configuration to database
|
||||
async fn save_config_to_db(pool: &Pool<Sqlite>, config: &AppConfig) -> Result<()> {
|
||||
let json = serde_json::to_string(config)?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO config (key, value, updated_at)
|
||||
VALUES ('app_config', ?1, datetime('now'))
|
||||
ON CONFLICT(key) DO UPDATE SET value = ?1, updated_at = datetime('now')
|
||||
"#,
|
||||
)
|
||||
.bind(&json)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current configuration (lock-free, zero-copy)
|
||||
///
|
||||
/// Returns an `Arc<AppConfig>` for efficient sharing without cloning.
|
||||
/// This is a lock-free operation with minimal overhead.
|
||||
pub fn get(&self) -> Arc<AppConfig> {
|
||||
self.cache.load_full()
|
||||
}
|
||||
|
||||
/// Set entire configuration
|
||||
pub async fn set(&self, config: AppConfig) -> Result<()> {
|
||||
Self::save_config_to_db(&self.pool, &config).await?;
|
||||
self.cache.store(Arc::new(config));
|
||||
|
||||
// Notify subscribers
|
||||
let _ = self.change_tx.send(ConfigChange {
|
||||
key: "app_config".to_string(),
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update configuration with a closure
|
||||
///
|
||||
/// Note: This uses a read-modify-write pattern. For concurrent updates,
|
||||
/// the last write wins. This is acceptable for configuration changes
|
||||
/// which are infrequent and typically user-initiated.
|
||||
pub async fn update<F>(&self, f: F) -> Result<()>
|
||||
where
|
||||
F: FnOnce(&mut AppConfig),
|
||||
{
|
||||
// Load current config, clone it for modification
|
||||
let current = self.cache.load();
|
||||
let mut config = (**current).clone();
|
||||
f(&mut config);
|
||||
|
||||
// Persist to database first
|
||||
Self::save_config_to_db(&self.pool, &config).await?;
|
||||
|
||||
// Then update cache atomically
|
||||
self.cache.store(Arc::new(config));
|
||||
|
||||
// Notify subscribers
|
||||
let _ = self.change_tx.send(ConfigChange {
|
||||
key: "app_config".to_string(),
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Subscribe to configuration changes
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<ConfigChange> {
|
||||
self.change_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Check if system is initialized (lock-free)
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.cache.load().initialized
|
||||
}
|
||||
|
||||
/// Get database pool for session management
|
||||
pub fn pool(&self) -> &Pool<Sqlite> {
|
||||
&self.pool
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config_store() {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
|
||||
let store = ConfigStore::new(&db_path).await.unwrap();
|
||||
|
||||
// Check default config (now lock-free, no await needed)
|
||||
let config = store.get();
|
||||
assert!(!config.initialized);
|
||||
|
||||
// Update config
|
||||
store.update(|c| {
|
||||
c.initialized = true;
|
||||
c.web.http_port = 9000;
|
||||
}).await.unwrap();
|
||||
|
||||
// Verify update
|
||||
let config = store.get();
|
||||
assert!(config.initialized);
|
||||
assert_eq!(config.web.http_port, 9000);
|
||||
|
||||
// Create new store instance and verify persistence
|
||||
let store2 = ConfigStore::new(&db_path).await.unwrap();
|
||||
let config = store2.get();
|
||||
assert!(config.initialized);
|
||||
assert_eq!(config.web.http_port, 9000);
|
||||
}
|
||||
}
|
||||
98
src/error.rs
Normal file
98
src/error.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Application-wide error type
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthError(String),
|
||||
|
||||
#[error("Not authenticated")]
|
||||
Unauthorized,
|
||||
|
||||
#[error("Forbidden: {0}")]
|
||||
Forbidden(String),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
|
||||
#[error("Video error: {0}")]
|
||||
VideoError(String),
|
||||
|
||||
#[error("Video device lost [{device}]: {reason}")]
|
||||
VideoDeviceLost { device: String, reason: String },
|
||||
|
||||
#[error("Audio error: {0}")]
|
||||
AudioError(String),
|
||||
|
||||
#[error("HID error [{backend}]: {reason} (code: {error_code})")]
|
||||
HidError {
|
||||
backend: String,
|
||||
reason: String,
|
||||
error_code: String,
|
||||
},
|
||||
|
||||
#[error("WebRTC error: {0}")]
|
||||
WebRtcError(String),
|
||||
|
||||
#[error("Service unavailable: {0}")]
|
||||
ServiceUnavailable(String),
|
||||
}
|
||||
|
||||
/// Error response body (unified success format)
|
||||
#[derive(Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
// Always return 200 OK - success/failure is indicated by the success field
|
||||
StatusCode::OK
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status_code();
|
||||
let body = ErrorResponse {
|
||||
success: false,
|
||||
message: self.to_string(),
|
||||
};
|
||||
|
||||
tracing::error!(
|
||||
error_type = std::any::type_name_of_val(&self),
|
||||
error_message = %body.message,
|
||||
"Request failed"
|
||||
);
|
||||
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for handlers
|
||||
pub type Result<T> = std::result::Result<T, AppError>;
|
||||
137
src/events/mod.rs
Normal file
137
src/events/mod.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
//! Event system for real-time state notifications
|
||||
//!
|
||||
//! This module provides a global event bus for broadcasting system events
|
||||
//! to WebSocket clients and other subscribers.
|
||||
|
||||
pub mod types;
|
||||
|
||||
pub use types::{
|
||||
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent, VideoDeviceInfo,
|
||||
};
|
||||
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Event channel capacity (ring buffer size)
|
||||
const EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
/// Global event bus for broadcasting system events
|
||||
///
|
||||
/// The event bus uses tokio's broadcast channel to distribute events
|
||||
/// to multiple subscribers. Events are delivered to all active subscribers.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use one_kvm::events::{EventBus, SystemEvent};
|
||||
///
|
||||
/// let bus = EventBus::new();
|
||||
///
|
||||
/// // Publish an event
|
||||
/// bus.publish(SystemEvent::StreamStateChanged {
|
||||
/// state: "streaming".to_string(),
|
||||
/// device: Some("/dev/video0".to_string()),
|
||||
/// });
|
||||
///
|
||||
/// // Subscribe to events
|
||||
/// let mut rx = bus.subscribe();
|
||||
/// tokio::spawn(async move {
|
||||
/// while let Ok(event) = rx.recv().await {
|
||||
/// println!("Received event: {:?}", event);
|
||||
/// }
|
||||
/// });
|
||||
/// ```
|
||||
pub struct EventBus {
|
||||
tx: broadcast::Sender<SystemEvent>,
|
||||
}
|
||||
|
||||
impl EventBus {
|
||||
/// Create a new event bus
|
||||
pub fn new() -> Self {
|
||||
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||
Self { tx }
|
||||
}
|
||||
|
||||
/// Publish an event to all subscribers
|
||||
///
|
||||
/// If there are no active subscribers, the event is silently dropped.
|
||||
/// This is by design - events are fire-and-forget notifications.
|
||||
pub fn publish(&self, event: SystemEvent) {
|
||||
// If no subscribers, send returns Err which is normal
|
||||
let _ = self.tx.send(event);
|
||||
}
|
||||
|
||||
/// Subscribe to events
|
||||
///
|
||||
/// Returns a receiver that will receive all future events.
|
||||
/// The receiver uses a ring buffer, so if a subscriber falls too far
|
||||
/// behind, it will receive a `Lagged` error and miss some events.
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<SystemEvent> {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get the current number of active subscribers
|
||||
///
|
||||
/// Useful for monitoring and debugging.
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.tx.receiver_count()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EventBus {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_publish_subscribe() {
|
||||
let bus = EventBus::new();
|
||||
let mut rx = bus.subscribe();
|
||||
|
||||
bus.publish(SystemEvent::StreamStateChanged {
|
||||
state: "streaming".to_string(),
|
||||
device: Some("/dev/video0".to_string()),
|
||||
});
|
||||
|
||||
let event = rx.recv().await.unwrap();
|
||||
assert!(matches!(event, SystemEvent::StreamStateChanged { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_subscribers() {
|
||||
let bus = EventBus::new();
|
||||
let mut rx1 = bus.subscribe();
|
||||
let mut rx2 = bus.subscribe();
|
||||
|
||||
assert_eq!(bus.subscriber_count(), 2);
|
||||
|
||||
bus.publish(SystemEvent::SystemError {
|
||||
module: "test".to_string(),
|
||||
severity: "info".to_string(),
|
||||
message: "test message".to_string(),
|
||||
});
|
||||
|
||||
let event1 = rx1.recv().await.unwrap();
|
||||
let event2 = rx2.recv().await.unwrap();
|
||||
|
||||
assert!(matches!(event1, SystemEvent::SystemError { .. }));
|
||||
assert!(matches!(event2, SystemEvent::SystemError { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_subscribers() {
|
||||
let bus = EventBus::new();
|
||||
assert_eq!(bus.subscriber_count(), 0);
|
||||
|
||||
// Should not panic when publishing with no subscribers
|
||||
bus.publish(SystemEvent::SystemError {
|
||||
module: "test".to_string(),
|
||||
severity: "info".to_string(),
|
||||
message: "test".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
592
src/events/types.rs
Normal file
592
src/events/types.rs
Normal file
@@ -0,0 +1,592 @@
|
||||
//! System event types
|
||||
//!
|
||||
//! Defines all event types that can be broadcast through the event bus.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::atx::PowerStatus;
|
||||
use crate::msd::MsdMode;
|
||||
|
||||
// ============================================================================
|
||||
// Device Info Structures (for system.device_info event)
|
||||
// ============================================================================
|
||||
|
||||
/// Video device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VideoDeviceInfo {
|
||||
/// Whether video device is available
|
||||
pub available: bool,
|
||||
/// Device path (e.g., /dev/video0)
|
||||
pub device: Option<String>,
|
||||
/// Pixel format (e.g., "MJPEG", "YUYV")
|
||||
pub format: Option<String>,
|
||||
/// Resolution (width, height)
|
||||
pub resolution: Option<(u32, u32)>,
|
||||
/// Frames per second
|
||||
pub fps: u32,
|
||||
/// Whether stream is currently active
|
||||
pub online: bool,
|
||||
/// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
|
||||
pub stream_mode: String,
|
||||
/// Whether video config is currently being changed (frontend should skip mode sync)
|
||||
pub config_changing: bool,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// HID device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HidDeviceInfo {
|
||||
/// Whether HID backend is available
|
||||
pub available: bool,
|
||||
/// Backend type: "otg", "ch9329", "none"
|
||||
pub backend: String,
|
||||
/// Whether backend is initialized and ready
|
||||
pub initialized: bool,
|
||||
/// Whether absolute mouse positioning is supported
|
||||
pub supports_absolute_mouse: bool,
|
||||
/// Device path (e.g., serial port for CH9329)
|
||||
pub device: Option<String>,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// MSD device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MsdDeviceInfo {
|
||||
/// Whether MSD is available
|
||||
pub available: bool,
|
||||
/// Operating mode: "none", "image", "drive"
|
||||
pub mode: String,
|
||||
/// Whether storage is connected to target
|
||||
pub connected: bool,
|
||||
/// Currently mounted image ID
|
||||
pub image_id: Option<String>,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// ATX device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AtxDeviceInfo {
|
||||
/// Whether ATX controller is available
|
||||
pub available: bool,
|
||||
/// Backend type: "gpio", "usb_relay", "none"
|
||||
pub backend: String,
|
||||
/// Whether backend is initialized
|
||||
pub initialized: bool,
|
||||
/// Whether power is currently on
|
||||
pub power_on: bool,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Audio device information
|
||||
///
|
||||
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AudioDeviceInfo {
|
||||
/// Whether audio is enabled/available
|
||||
pub available: bool,
|
||||
/// Whether audio is currently streaming
|
||||
pub streaming: bool,
|
||||
/// Current audio device name
|
||||
pub device: Option<String>,
|
||||
/// Quality preset: "voice", "balanced", "high"
|
||||
pub quality: String,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Per-client statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClientStats {
|
||||
/// Client ID
|
||||
pub id: String,
|
||||
/// Current FPS for this client (frames sent in last second)
|
||||
pub fps: u32,
|
||||
/// Connected duration (seconds)
|
||||
pub connected_secs: u64,
|
||||
}
|
||||
|
||||
/// System event enumeration
|
||||
///
|
||||
/// All events are tagged with their event name for serialization.
|
||||
/// The `serde(tag = "event", content = "data")` attribute creates a
|
||||
/// JSON structure like:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "event": "stream.state_changed",
|
||||
/// "data": { "state": "streaming", "device": "/dev/video0" }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "event", content = "data")]
|
||||
pub enum SystemEvent {
|
||||
// ============================================================================
|
||||
// Video Stream Events
|
||||
// ============================================================================
|
||||
/// Stream state changed (e.g., started, stopped, error)
|
||||
#[serde(rename = "stream.state_changed")]
|
||||
StreamStateChanged {
|
||||
/// Current state: "uninitialized", "ready", "streaming", "no_signal", "error"
|
||||
state: String,
|
||||
/// Device path if available
|
||||
device: Option<String>,
|
||||
},
|
||||
|
||||
/// Stream configuration is being changed
|
||||
///
|
||||
/// Sent before applying new configuration to notify clients that
|
||||
/// the stream will be interrupted temporarily.
|
||||
#[serde(rename = "stream.config_changing")]
|
||||
StreamConfigChanging {
|
||||
/// Reason for change: "device_switch", "resolution_change", "format_change"
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Stream configuration has been applied successfully
|
||||
///
|
||||
/// Sent after new configuration is active. Clients can reconnect now.
|
||||
#[serde(rename = "stream.config_applied")]
|
||||
StreamConfigApplied {
|
||||
/// Device path
|
||||
device: String,
|
||||
/// Resolution (width, height)
|
||||
resolution: (u32, u32),
|
||||
/// Pixel format: "mjpeg", "yuyv", etc.
|
||||
format: String,
|
||||
/// Frames per second
|
||||
fps: u32,
|
||||
},
|
||||
|
||||
/// Stream device was lost (disconnected or error)
|
||||
#[serde(rename = "stream.device_lost")]
|
||||
StreamDeviceLost {
|
||||
/// Device path that was lost
|
||||
device: String,
|
||||
/// Reason for loss
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Stream device is reconnecting
|
||||
#[serde(rename = "stream.reconnecting")]
|
||||
StreamReconnecting {
|
||||
/// Device path being reconnected
|
||||
device: String,
|
||||
/// Retry attempt number
|
||||
attempt: u32,
|
||||
},
|
||||
|
||||
/// Stream device has recovered
|
||||
#[serde(rename = "stream.recovered")]
|
||||
StreamRecovered {
|
||||
/// Device path that was recovered
|
||||
device: String,
|
||||
},
|
||||
|
||||
/// Stream statistics update (sent periodically for client stats)
|
||||
#[serde(rename = "stream.stats_update")]
|
||||
StreamStatsUpdate {
|
||||
/// Number of connected clients
|
||||
clients: u64,
|
||||
/// Per-client statistics (client_id -> client stats)
|
||||
/// Each client's FPS reflects the actual frames sent in the last second
|
||||
clients_stat: HashMap<String, ClientStats>,
|
||||
},
|
||||
|
||||
/// Stream mode changed (MJPEG <-> WebRTC)
|
||||
///
|
||||
/// Sent when the streaming mode is switched. Clients should disconnect
|
||||
/// from the current stream and reconnect using the new mode.
|
||||
#[serde(rename = "stream.mode_changed")]
|
||||
StreamModeChanged {
|
||||
/// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
|
||||
mode: String,
|
||||
/// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
|
||||
previous_mode: String,
|
||||
},
|
||||
|
||||
// ============================================================================
|
||||
// HID Events
|
||||
// ============================================================================
|
||||
/// HID backend state changed
|
||||
#[serde(rename = "hid.state_changed")]
|
||||
HidStateChanged {
|
||||
/// Backend type: "otg", "ch9329", "none"
|
||||
backend: String,
|
||||
/// Whether backend is initialized and ready
|
||||
initialized: bool,
|
||||
/// Error message if any, None if OK
|
||||
error: Option<String>,
|
||||
/// Error code for programmatic handling: "epipe", "eagain", "port_not_found", etc.
|
||||
error_code: Option<String>,
|
||||
},
|
||||
|
||||
/// HID backend is being switched
|
||||
#[serde(rename = "hid.backend_switching")]
|
||||
HidBackendSwitching {
|
||||
/// Current backend
|
||||
from: String,
|
||||
/// New backend
|
||||
to: String,
|
||||
},
|
||||
|
||||
/// HID device lost (device file missing or I/O error)
|
||||
#[serde(rename = "hid.device_lost")]
|
||||
HidDeviceLost {
|
||||
/// Backend type: "otg", "ch9329"
|
||||
backend: String,
|
||||
/// Device path that was lost (e.g., /dev/hidg0 or /dev/ttyUSB0)
|
||||
device: Option<String>,
|
||||
/// Human-readable reason for loss
|
||||
reason: String,
|
||||
/// Error code: "epipe", "eshutdown", "eagain", "enxio", "port_not_found", "io_error"
|
||||
error_code: String,
|
||||
},
|
||||
|
||||
/// HID device is reconnecting
|
||||
#[serde(rename = "hid.reconnecting")]
|
||||
HidReconnecting {
|
||||
/// Backend type: "otg", "ch9329"
|
||||
backend: String,
|
||||
/// Current retry attempt number
|
||||
attempt: u32,
|
||||
},
|
||||
|
||||
/// HID device has recovered after error
|
||||
#[serde(rename = "hid.recovered")]
|
||||
HidRecovered {
|
||||
/// Backend type: "otg", "ch9329"
|
||||
backend: String,
|
||||
},
|
||||
|
||||
// ============================================================================
|
||||
// MSD (Mass Storage Device) Events
|
||||
// ============================================================================
|
||||
/// MSD state changed
|
||||
#[serde(rename = "msd.state_changed")]
|
||||
MsdStateChanged {
|
||||
/// Operating mode
|
||||
mode: MsdMode,
|
||||
/// Whether storage is connected to target
|
||||
connected: bool,
|
||||
},
|
||||
|
||||
/// Image has been mounted
|
||||
#[serde(rename = "msd.image_mounted")]
|
||||
MsdImageMounted {
|
||||
/// Image ID
|
||||
image_id: String,
|
||||
/// Image filename
|
||||
image_name: String,
|
||||
/// Image size in bytes
|
||||
size: u64,
|
||||
/// Mount as CD-ROM (read-only)
|
||||
cdrom: bool,
|
||||
},
|
||||
|
||||
/// Image has been unmounted
|
||||
#[serde(rename = "msd.image_unmounted")]
|
||||
MsdImageUnmounted,
|
||||
|
||||
/// File upload progress (for large file uploads)
|
||||
#[serde(rename = "msd.upload_progress")]
|
||||
MsdUploadProgress {
|
||||
/// Upload operation ID
|
||||
upload_id: String,
|
||||
/// Filename being uploaded
|
||||
filename: String,
|
||||
/// Bytes uploaded so far
|
||||
bytes_uploaded: u64,
|
||||
/// Total file size
|
||||
total_bytes: u64,
|
||||
/// Progress percentage (0.0 - 100.0)
|
||||
progress_pct: f32,
|
||||
},
|
||||
|
||||
/// Image download progress (for URL downloads)
|
||||
#[serde(rename = "msd.download_progress")]
|
||||
MsdDownloadProgress {
|
||||
/// Download operation ID
|
||||
download_id: String,
|
||||
/// Source URL
|
||||
url: String,
|
||||
/// Target filename
|
||||
filename: String,
|
||||
/// Bytes downloaded so far
|
||||
bytes_downloaded: u64,
|
||||
/// Total file size (None if unknown)
|
||||
total_bytes: Option<u64>,
|
||||
/// Progress percentage (0.0 - 100.0, None if total unknown)
|
||||
progress_pct: Option<f32>,
|
||||
/// Download status: "started", "in_progress", "completed", "failed"
|
||||
status: String,
|
||||
},
|
||||
|
||||
/// USB gadget connection status changed (host connected/disconnected)
|
||||
#[serde(rename = "msd.usb_status_changed")]
|
||||
MsdUsbStatusChanged {
|
||||
/// Whether host is connected to USB device
|
||||
connected: bool,
|
||||
/// USB device state from kernel (e.g., "configured", "not attached")
|
||||
device_state: String,
|
||||
},
|
||||
|
||||
/// MSD operation error (configfs, image mount, etc.)
|
||||
#[serde(rename = "msd.error")]
|
||||
MsdError {
|
||||
/// Human-readable reason for error
|
||||
reason: String,
|
||||
/// Error code: "configfs_error", "image_not_found", "mount_failed", "io_error"
|
||||
error_code: String,
|
||||
},
|
||||
|
||||
/// MSD has recovered after error
|
||||
#[serde(rename = "msd.recovered")]
|
||||
MsdRecovered,
|
||||
|
||||
// ============================================================================
|
||||
// ATX (Power Control) Events
|
||||
// ============================================================================
|
||||
/// ATX power state changed
|
||||
#[serde(rename = "atx.state_changed")]
|
||||
AtxStateChanged {
|
||||
/// Power status
|
||||
power_status: PowerStatus,
|
||||
},
|
||||
|
||||
/// ATX action was executed
|
||||
#[serde(rename = "atx.action_executed")]
|
||||
AtxActionExecuted {
|
||||
/// Action: "short", "long", "reset"
|
||||
action: String,
|
||||
/// When the action was executed
|
||||
timestamp: DateTime<Utc>,
|
||||
},
|
||||
|
||||
// ============================================================================
|
||||
// Audio Events
|
||||
// ============================================================================
|
||||
/// Audio state changed (streaming started/stopped)
|
||||
#[serde(rename = "audio.state_changed")]
|
||||
AudioStateChanged {
|
||||
/// Whether audio is currently streaming
|
||||
streaming: bool,
|
||||
/// Current device (None if stopped)
|
||||
device: Option<String>,
|
||||
},
|
||||
|
||||
/// Audio device was selected
|
||||
#[serde(rename = "audio.device_selected")]
|
||||
AudioDeviceSelected {
|
||||
/// Selected device name
|
||||
device: String,
|
||||
},
|
||||
|
||||
/// Audio quality was changed
|
||||
#[serde(rename = "audio.quality_changed")]
|
||||
AudioQualityChanged {
|
||||
/// New quality setting: "voice", "balanced", "high"
|
||||
quality: String,
|
||||
},
|
||||
|
||||
/// Audio device lost (capture error or device disconnected)
|
||||
#[serde(rename = "audio.device_lost")]
|
||||
AudioDeviceLost {
|
||||
/// Audio device name (e.g., "hw:0,0")
|
||||
device: Option<String>,
|
||||
/// Human-readable reason for loss
|
||||
reason: String,
|
||||
/// Error code: "device_busy", "device_disconnected", "capture_error", "io_error"
|
||||
error_code: String,
|
||||
},
|
||||
|
||||
/// Audio device is reconnecting
|
||||
#[serde(rename = "audio.reconnecting")]
|
||||
AudioReconnecting {
|
||||
/// Current retry attempt number
|
||||
attempt: u32,
|
||||
},
|
||||
|
||||
/// Audio device has recovered after error
|
||||
#[serde(rename = "audio.recovered")]
|
||||
AudioRecovered {
|
||||
/// Audio device name
|
||||
device: Option<String>,
|
||||
},
|
||||
|
||||
// ============================================================================
|
||||
// System Events
|
||||
// ============================================================================
|
||||
/// A device was added (hot-plug)
|
||||
#[serde(rename = "system.device_added")]
|
||||
SystemDeviceAdded {
|
||||
/// Device type: "video", "audio", "hid", etc.
|
||||
device_type: String,
|
||||
/// Device path
|
||||
device_path: String,
|
||||
/// Device name/description
|
||||
device_name: String,
|
||||
},
|
||||
|
||||
/// A device was removed (hot-unplug)
|
||||
#[serde(rename = "system.device_removed")]
|
||||
SystemDeviceRemoved {
|
||||
/// Device type
|
||||
device_type: String,
|
||||
/// Device path that was removed
|
||||
device_path: String,
|
||||
},
|
||||
|
||||
/// System error or warning
|
||||
#[serde(rename = "system.error")]
|
||||
SystemError {
|
||||
/// Module that generated the error: "stream", "hid", "msd", "atx"
|
||||
module: String,
|
||||
/// Severity: "warning", "error", "critical"
|
||||
severity: String,
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Complete device information (sent on WebSocket connect and state changes)
|
||||
#[serde(rename = "system.device_info")]
|
||||
DeviceInfo {
|
||||
/// Video device information
|
||||
video: VideoDeviceInfo,
|
||||
/// HID device information
|
||||
hid: HidDeviceInfo,
|
||||
/// MSD device information (None if MSD not enabled)
|
||||
msd: Option<MsdDeviceInfo>,
|
||||
/// ATX device information (None if ATX not enabled)
|
||||
atx: Option<AtxDeviceInfo>,
|
||||
/// Audio device information (None if audio not enabled)
|
||||
audio: Option<AudioDeviceInfo>,
|
||||
},
|
||||
|
||||
/// WebSocket error notification (for connection-level errors like lag)
|
||||
#[serde(rename = "error")]
|
||||
Error {
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl SystemEvent {
|
||||
/// Get the event name (for filtering/routing)
|
||||
pub fn event_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::StreamStateChanged { .. } => "stream.state_changed",
|
||||
Self::StreamConfigChanging { .. } => "stream.config_changing",
|
||||
Self::StreamConfigApplied { .. } => "stream.config_applied",
|
||||
Self::StreamDeviceLost { .. } => "stream.device_lost",
|
||||
Self::StreamReconnecting { .. } => "stream.reconnecting",
|
||||
Self::StreamRecovered { .. } => "stream.recovered",
|
||||
Self::StreamStatsUpdate { .. } => "stream.stats_update",
|
||||
Self::StreamModeChanged { .. } => "stream.mode_changed",
|
||||
Self::HidStateChanged { .. } => "hid.state_changed",
|
||||
Self::HidBackendSwitching { .. } => "hid.backend_switching",
|
||||
Self::HidDeviceLost { .. } => "hid.device_lost",
|
||||
Self::HidReconnecting { .. } => "hid.reconnecting",
|
||||
Self::HidRecovered { .. } => "hid.recovered",
|
||||
Self::MsdStateChanged { .. } => "msd.state_changed",
|
||||
Self::MsdImageMounted { .. } => "msd.image_mounted",
|
||||
Self::MsdImageUnmounted => "msd.image_unmounted",
|
||||
Self::MsdUploadProgress { .. } => "msd.upload_progress",
|
||||
Self::MsdDownloadProgress { .. } => "msd.download_progress",
|
||||
Self::MsdUsbStatusChanged { .. } => "msd.usb_status_changed",
|
||||
Self::MsdError { .. } => "msd.error",
|
||||
Self::MsdRecovered => "msd.recovered",
|
||||
Self::AtxStateChanged { .. } => "atx.state_changed",
|
||||
Self::AtxActionExecuted { .. } => "atx.action_executed",
|
||||
Self::AudioStateChanged { .. } => "audio.state_changed",
|
||||
Self::AudioDeviceSelected { .. } => "audio.device_selected",
|
||||
Self::AudioQualityChanged { .. } => "audio.quality_changed",
|
||||
Self::AudioDeviceLost { .. } => "audio.device_lost",
|
||||
Self::AudioReconnecting { .. } => "audio.reconnecting",
|
||||
Self::AudioRecovered { .. } => "audio.recovered",
|
||||
Self::SystemDeviceAdded { .. } => "system.device_added",
|
||||
Self::SystemDeviceRemoved { .. } => "system.device_removed",
|
||||
Self::SystemError { .. } => "system.error",
|
||||
Self::DeviceInfo { .. } => "system.device_info",
|
||||
Self::Error { .. } => "error",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if event name matches a topic pattern
|
||||
///
|
||||
/// Supports wildcards:
|
||||
/// - `*` matches all events
|
||||
/// - `stream.*` matches all stream events
|
||||
/// - `stream.state_changed` matches exact event
|
||||
pub fn matches_topic(&self, topic: &str) -> bool {
|
||||
if topic == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
let event_name = self.event_name();
|
||||
|
||||
if topic.ends_with(".*") {
|
||||
let prefix = topic.trim_end_matches(".*");
|
||||
event_name.starts_with(prefix)
|
||||
} else {
|
||||
event_name == topic
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_event_name() {
|
||||
let event = SystemEvent::StreamStateChanged {
|
||||
state: "streaming".to_string(),
|
||||
device: Some("/dev/video0".to_string()),
|
||||
};
|
||||
assert_eq!(event.event_name(), "stream.state_changed");
|
||||
|
||||
let event = SystemEvent::MsdImageMounted {
|
||||
image_id: "123".to_string(),
|
||||
image_name: "ubuntu.iso".to_string(),
|
||||
size: 1024,
|
||||
cdrom: true,
|
||||
};
|
||||
assert_eq!(event.event_name(), "msd.image_mounted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_topic() {
|
||||
let event = SystemEvent::StreamStateChanged {
|
||||
state: "streaming".to_string(),
|
||||
device: None,
|
||||
};
|
||||
|
||||
assert!(event.matches_topic("*"));
|
||||
assert!(event.matches_topic("stream.*"));
|
||||
assert!(event.matches_topic("stream.state_changed"));
|
||||
assert!(!event.matches_topic("msd.*"));
|
||||
assert!(!event.matches_topic("stream.config_changed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let event = SystemEvent::StreamConfigApplied {
|
||||
device: "/dev/video0".to_string(),
|
||||
resolution: (1920, 1080),
|
||||
format: "mjpeg".to_string(),
|
||||
fps: 30,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("stream.config_applied"));
|
||||
assert!(json.contains("/dev/video0"));
|
||||
|
||||
let deserialized: SystemEvent = serde_json::from_str(&json).unwrap();
|
||||
assert!(matches!(deserialized, SystemEvent::StreamConfigApplied { .. }));
|
||||
}
|
||||
}
|
||||
425
src/extensions/manager.rs
Normal file
425
src/extensions/manager.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
//! Extension process manager
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use super::types::*;
|
||||
|
||||
/// Maximum number of log lines to keep per extension
|
||||
const LOG_BUFFER_SIZE: usize = 200;
|
||||
|
||||
/// Number of log lines to buffer before flushing to shared storage
|
||||
const LOG_BATCH_SIZE: usize = 16;
|
||||
|
||||
/// Unix socket path for ttyd
|
||||
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
|
||||
|
||||
/// Extension process with log buffer
|
||||
struct ExtensionProcess {
|
||||
child: Child,
|
||||
logs: Arc<RwLock<VecDeque<String>>>,
|
||||
}
|
||||
|
||||
/// Extension manager handles lifecycle of external processes
|
||||
pub struct ExtensionManager {
|
||||
processes: RwLock<HashMap<ExtensionId, ExtensionProcess>>,
|
||||
/// Cached availability status (checked once at startup)
|
||||
availability: HashMap<ExtensionId, bool>,
|
||||
}
|
||||
|
||||
impl Default for ExtensionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionManager {
|
||||
/// Create a new extension manager with cached availability
|
||||
pub fn new() -> Self {
|
||||
// Check availability once at startup
|
||||
let availability = ExtensionId::all()
|
||||
.iter()
|
||||
.map(|id| (*id, Path::new(id.binary_path()).exists()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
processes: RwLock::new(HashMap::new()),
|
||||
availability,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the binary for an extension is available (cached)
|
||||
pub fn check_available(&self, id: ExtensionId) -> bool {
|
||||
*self.availability.get(&id).unwrap_or(&false)
|
||||
}
|
||||
|
||||
/// Get the current status of an extension
|
||||
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
|
||||
if !self.check_available(id) {
|
||||
return ExtensionStatus::Unavailable;
|
||||
}
|
||||
|
||||
let processes = self.processes.read().await;
|
||||
match processes.get(&id) {
|
||||
Some(proc) => {
|
||||
if let Some(pid) = proc.child.id() {
|
||||
ExtensionStatus::Running { pid }
|
||||
} else {
|
||||
ExtensionStatus::Stopped
|
||||
}
|
||||
}
|
||||
None => ExtensionStatus::Stopped,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start an extension with the given configuration
|
||||
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
|
||||
if !self.check_available(id) {
|
||||
return Err(format!(
|
||||
"{} not found at {}",
|
||||
id.display_name(),
|
||||
id.binary_path()
|
||||
));
|
||||
}
|
||||
|
||||
// Stop existing process first
|
||||
self.stop(id).await.ok();
|
||||
|
||||
// Build command arguments
|
||||
let args = self.build_args(id, config).await?;
|
||||
|
||||
tracing::info!(
|
||||
"Starting extension {}: {} {}",
|
||||
id,
|
||||
id.binary_path(),
|
||||
args.join(" ")
|
||||
);
|
||||
|
||||
let mut child = Command::new(id.binary_path())
|
||||
.args(&args)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?;
|
||||
|
||||
let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE)));
|
||||
|
||||
// Spawn log collector for stdout
|
||||
if let Some(stdout) = child.stdout.take() {
|
||||
let logs_clone = logs.clone();
|
||||
let id_clone = id;
|
||||
tokio::spawn(async move {
|
||||
Self::collect_logs(id_clone, stdout, logs_clone).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn log collector for stderr
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
let logs_clone = logs.clone();
|
||||
let id_clone = id;
|
||||
tokio::spawn(async move {
|
||||
Self::collect_logs(id_clone, stderr, logs_clone).await;
|
||||
});
|
||||
}
|
||||
|
||||
let pid = child.id();
|
||||
tracing::info!("Extension {} started with PID {:?}", id, pid);
|
||||
|
||||
let mut processes = self.processes.write().await;
|
||||
processes.insert(id, ExtensionProcess { child, logs });
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop an extension
|
||||
pub async fn stop(&self, id: ExtensionId) -> Result<(), String> {
|
||||
let mut processes = self.processes.write().await;
|
||||
if let Some(mut proc) = processes.remove(&id) {
|
||||
tracing::info!("Stopping extension {}", id);
|
||||
if let Err(e) = proc.child.kill().await {
|
||||
tracing::warn!("Failed to kill {}: {}", id, e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get recent logs for an extension
|
||||
pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec<String> {
|
||||
let processes = self.processes.read().await;
|
||||
if let Some(proc) = processes.get(&id) {
|
||||
let logs = proc.logs.read().await;
|
||||
let start = logs.len().saturating_sub(lines);
|
||||
logs.range(start..).cloned().collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect logs from a stream with batched writes to reduce lock contention
|
||||
async fn collect_logs<R: tokio::io::AsyncRead + Unpin>(
|
||||
id: ExtensionId,
|
||||
reader: R,
|
||||
logs: Arc<RwLock<VecDeque<String>>>,
|
||||
) {
|
||||
let reader = BufReader::new(reader);
|
||||
let mut lines = reader.lines();
|
||||
let mut local_buffer = Vec::with_capacity(LOG_BATCH_SIZE);
|
||||
|
||||
loop {
|
||||
match lines.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
tracing::debug!("[{}] {}", id, line);
|
||||
local_buffer.push(line);
|
||||
|
||||
// Flush when batch is full
|
||||
if local_buffer.len() >= LOG_BATCH_SIZE {
|
||||
Self::flush_logs(&logs, &mut local_buffer).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream ended, flush remaining logs
|
||||
if !local_buffer.is_empty() {
|
||||
Self::flush_logs(&logs, &mut local_buffer).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[{}] Error reading log: {}", id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush buffered logs to shared storage
|
||||
async fn flush_logs(logs: &RwLock<VecDeque<String>>, buffer: &mut Vec<String>) {
|
||||
let mut logs = logs.write().await;
|
||||
for line in buffer.drain(..) {
|
||||
if logs.len() >= LOG_BUFFER_SIZE {
|
||||
logs.pop_front();
|
||||
}
|
||||
logs.push_back(line);
|
||||
}
|
||||
}
|
||||
|
||||
/// Build command arguments for an extension
|
||||
async fn build_args(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<Vec<String>, String> {
|
||||
match id {
|
||||
ExtensionId::Ttyd => {
|
||||
let c = &config.ttyd;
|
||||
|
||||
// Prepare socket directory and clean up old socket (async)
|
||||
Self::prepare_ttyd_socket().await?;
|
||||
|
||||
let mut args = vec![
|
||||
"-i".to_string(), TTYD_SOCKET_PATH.to_string(), // Unix socket
|
||||
"-b".to_string(), "/api/terminal".to_string(), // Base path for reverse proxy
|
||||
"-W".to_string(), // Writable (allow input)
|
||||
];
|
||||
|
||||
// Add credential if set (still useful for additional security layer)
|
||||
if let Some(ref cred) = c.credential {
|
||||
if !cred.is_empty() {
|
||||
args.extend(["-c".to_string(), cred.clone()]);
|
||||
}
|
||||
}
|
||||
|
||||
// Add shell as last argument
|
||||
args.push(c.shell.clone());
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
ExtensionId::Gostc => {
|
||||
let c = &config.gostc;
|
||||
if c.key.is_empty() {
|
||||
return Err("GOSTC client key is required".into());
|
||||
}
|
||||
|
||||
let mut args = Vec::new();
|
||||
|
||||
// Add TLS flag
|
||||
if c.tls {
|
||||
args.push("--tls=true".to_string());
|
||||
}
|
||||
|
||||
// Add server address
|
||||
if !c.addr.is_empty() {
|
||||
args.extend(["-addr".to_string(), c.addr.clone()]);
|
||||
}
|
||||
|
||||
// Add client key
|
||||
args.extend(["-key".to_string(), c.key.clone()]);
|
||||
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
ExtensionId::Easytier => {
|
||||
let c = &config.easytier;
|
||||
if c.network_name.is_empty() {
|
||||
return Err("EasyTier network name is required".into());
|
||||
}
|
||||
|
||||
let mut args = vec![
|
||||
"--network-name".to_string(),
|
||||
c.network_name.clone(),
|
||||
"--network-secret".to_string(),
|
||||
c.network_secret.clone(),
|
||||
];
|
||||
|
||||
// Add peer URLs
|
||||
for peer in &c.peer_urls {
|
||||
if !peer.is_empty() {
|
||||
args.extend(["--peers".to_string(), peer.clone()]);
|
||||
}
|
||||
}
|
||||
|
||||
// Add virtual IP: use -d for DHCP if empty, or -i for specific IP
|
||||
if let Some(ref ip) = c.virtual_ip {
|
||||
if !ip.is_empty() {
|
||||
// Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24)
|
||||
args.extend(["-i".to_string(), ip.clone()]);
|
||||
} else {
|
||||
// Empty string means use DHCP
|
||||
args.push("-d".to_string());
|
||||
}
|
||||
} else {
|
||||
// None means use DHCP
|
||||
args.push("-d".to_string());
|
||||
}
|
||||
|
||||
Ok(args)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare ttyd socket directory and clean up old socket file
|
||||
async fn prepare_ttyd_socket() -> Result<(), String> {
|
||||
let socket_path = Path::new(TTYD_SOCKET_PATH);
|
||||
|
||||
// Ensure socket directory exists
|
||||
if let Some(socket_dir) = socket_path.parent() {
|
||||
if !socket_dir.exists() {
|
||||
tokio::fs::create_dir_all(socket_dir)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to create socket directory: {}", e))?;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old socket file if exists
|
||||
if tokio::fs::try_exists(TTYD_SOCKET_PATH).await.unwrap_or(false) {
|
||||
tokio::fs::remove_file(TTYD_SOCKET_PATH)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to remove old socket: {}", e))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Health check - restart crashed processes that should be running
|
||||
pub async fn health_check(&self, config: &ExtensionsConfig) {
|
||||
// Collect extensions that need restart check
|
||||
let checks: Vec<_> = ExtensionId::all()
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
let should_run = match id {
|
||||
ExtensionId::Ttyd => config.ttyd.enabled,
|
||||
ExtensionId::Gostc => config.gostc.enabled && !config.gostc.key.is_empty(),
|
||||
ExtensionId::Easytier => {
|
||||
config.easytier.enabled && !config.easytier.network_name.is_empty()
|
||||
}
|
||||
};
|
||||
if should_run && self.check_available(*id) {
|
||||
Some(*id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Check which ones need restart (single read lock)
|
||||
let needs_restart: Vec<_> = {
|
||||
let processes = self.processes.read().await;
|
||||
checks
|
||||
.into_iter()
|
||||
.filter(|id| {
|
||||
if let Some(proc) = processes.get(id) {
|
||||
proc.child.id().is_none()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Restart all crashed extensions in parallel
|
||||
let restart_futures: Vec<_> = needs_restart
|
||||
.into_iter()
|
||||
.map(|id| async move {
|
||||
tracing::info!("Health check: restarting {}", id);
|
||||
if let Err(e) = self.start(id, config).await {
|
||||
tracing::error!("Failed to restart {}: {}", id, e);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
futures::future::join_all(restart_futures).await;
|
||||
}
|
||||
|
||||
/// Start all enabled extensions in parallel
|
||||
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
|
||||
use std::pin::Pin;
|
||||
use futures::Future;
|
||||
|
||||
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
|
||||
|
||||
// Collect enabled extensions
|
||||
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
|
||||
start_futures.push(Box::pin(async {
|
||||
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
|
||||
tracing::error!("Failed to start ttyd: {}", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
if config.gostc.enabled
|
||||
&& !config.gostc.key.is_empty()
|
||||
&& self.check_available(ExtensionId::Gostc)
|
||||
{
|
||||
start_futures.push(Box::pin(async {
|
||||
if let Err(e) = self.start(ExtensionId::Gostc, config).await {
|
||||
tracing::error!("Failed to start gostc: {}", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
if config.easytier.enabled
|
||||
&& !config.easytier.network_name.is_empty()
|
||||
&& self.check_available(ExtensionId::Easytier)
|
||||
{
|
||||
start_futures.push(Box::pin(async {
|
||||
if let Err(e) = self.start(ExtensionId::Easytier, config).await {
|
||||
tracing::error!("Failed to start easytier: {}", e);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Start all in parallel
|
||||
futures::future::join_all(start_futures).await;
|
||||
}
|
||||
|
||||
/// Stop all running extensions in parallel
|
||||
pub async fn stop_all(&self) {
|
||||
let stop_futures: Vec<_> = ExtensionId::all()
|
||||
.iter()
|
||||
.map(|id| self.stop(*id))
|
||||
.collect();
|
||||
futures::future::join_all(stop_futures).await;
|
||||
}
|
||||
}
|
||||
7
src/extensions/mod.rs
Normal file
7
src/extensions/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Extensions module - manage external processes like ttyd, gostc, easytier
|
||||
|
||||
mod manager;
|
||||
mod types;
|
||||
|
||||
pub use manager::{ExtensionManager, TTYD_SOCKET_PATH};
|
||||
pub use types::*;
|
||||
251
src/extensions/types.rs
Normal file
251
src/extensions/types.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
//! Extension types and configurations
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
/// Extension identifier (fixed set of supported extensions)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ExtensionId {
|
||||
/// Web terminal (ttyd)
|
||||
Ttyd,
|
||||
/// NAT traversal client (gostc)
|
||||
Gostc,
|
||||
/// P2P VPN (easytier)
|
||||
Easytier,
|
||||
}
|
||||
|
||||
impl ExtensionId {
|
||||
/// Get the binary path for this extension
|
||||
pub fn binary_path(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ttyd => "/usr/bin/ttyd",
|
||||
Self::Gostc => "/usr/bin/gostc",
|
||||
Self::Easytier => "/usr/bin/easytier-core",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the display name for this extension
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ttyd => "Web Terminal",
|
||||
Self::Gostc => "GOSTC Tunnel",
|
||||
Self::Easytier => "EasyTier VPN",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all extension IDs
|
||||
pub fn all() -> &'static [ExtensionId] {
|
||||
&[Self::Ttyd, Self::Gostc, Self::Easytier]
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ExtensionId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Ttyd => write!(f, "ttyd"),
|
||||
Self::Gostc => write!(f, "gostc"),
|
||||
Self::Easytier => write!(f, "easytier"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for ExtensionId {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"ttyd" => Ok(Self::Ttyd),
|
||||
"gostc" => Ok(Self::Gostc),
|
||||
"easytier" => Ok(Self::Easytier),
|
||||
_ => Err(format!("Unknown extension: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension running status
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "state", content = "data", rename_all = "lowercase")]
|
||||
pub enum ExtensionStatus {
|
||||
/// Binary not found at expected path
|
||||
Unavailable,
|
||||
/// Extension is stopped
|
||||
Stopped,
|
||||
/// Extension is running
|
||||
Running {
|
||||
/// Process ID
|
||||
pid: u32,
|
||||
},
|
||||
/// Extension failed to start
|
||||
Failed {
|
||||
/// Error message
|
||||
error: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ExtensionStatus {
|
||||
pub fn is_running(&self) -> bool {
|
||||
matches!(self, Self::Running { .. })
|
||||
}
|
||||
}
|
||||
|
||||
/// ttyd configuration (Web Terminal)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct TtydConfig {
|
||||
/// Enable auto-start
|
||||
pub enabled: bool,
|
||||
/// Port to listen on
|
||||
pub port: u16,
|
||||
/// Shell to execute
|
||||
pub shell: String,
|
||||
/// Credential in format "user:password" (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub credential: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for TtydConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
port: 7681,
|
||||
shell: "/bin/bash".to_string(),
|
||||
credential: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gostc configuration (NAT traversal based on FRP)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct GostcConfig {
|
||||
/// Enable auto-start
|
||||
pub enabled: bool,
|
||||
/// Server address (e.g., gostc.mofeng.run)
|
||||
pub addr: String,
|
||||
/// Client key from GOSTC management panel
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
pub key: String,
|
||||
/// Enable TLS
|
||||
pub tls: bool,
|
||||
}
|
||||
|
||||
impl Default for GostcConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
addr: "gostc.mofeng.run".to_string(),
|
||||
key: String::new(),
|
||||
tls: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// EasyTier configuration (P2P VPN)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct EasytierConfig {
|
||||
/// Enable auto-start
|
||||
pub enabled: bool,
|
||||
/// Network name
|
||||
pub network_name: String,
|
||||
/// Network secret/password
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
pub network_secret: String,
|
||||
/// Peer node URLs
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub peer_urls: Vec<String>,
|
||||
/// Virtual IP address (optional, auto-assigned if not set)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub virtual_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for EasytierConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
network_name: String::new(),
|
||||
network_secret: String::new(),
|
||||
peer_urls: Vec::new(),
|
||||
virtual_ip: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined extensions configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(default)]
|
||||
pub struct ExtensionsConfig {
|
||||
pub ttyd: TtydConfig,
|
||||
pub gostc: GostcConfig,
|
||||
pub easytier: EasytierConfig,
|
||||
}
|
||||
|
||||
/// Extension info with status and config
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtensionInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
}
|
||||
|
||||
/// ttyd extension info
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TtydInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
/// Configuration
|
||||
pub config: TtydConfig,
|
||||
}
|
||||
|
||||
/// gostc extension info
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GostcInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
/// Configuration
|
||||
pub config: GostcConfig,
|
||||
}
|
||||
|
||||
/// easytier extension info
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EasytierInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
/// Configuration
|
||||
pub config: EasytierConfig,
|
||||
}
|
||||
|
||||
/// All extensions status response
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtensionsStatus {
|
||||
pub ttyd: TtydInfo,
|
||||
pub gostc: GostcInfo,
|
||||
pub easytier: EasytierInfo,
|
||||
}
|
||||
|
||||
/// Extension logs response
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtensionLogs {
|
||||
pub id: ExtensionId,
|
||||
pub logs: Vec<String>,
|
||||
}
|
||||
130
src/hid/backend.rs
Normal file
130
src/hid/backend.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
//! HID backend trait definition
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::types::{KeyboardEvent, MouseEvent};
|
||||
use crate::error::Result;
|
||||
|
||||
/// Default CH9329 baud rate
|
||||
fn default_ch9329_baud_rate() -> u32 {
|
||||
9600
|
||||
}
|
||||
|
||||
/// HID backend type
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum HidBackendType {
|
||||
/// USB OTG gadget mode
|
||||
Otg,
|
||||
/// CH9329 serial HID controller
|
||||
Ch9329 {
|
||||
/// Serial port path
|
||||
port: String,
|
||||
/// Baud rate (default: 9600)
|
||||
#[serde(default = "default_ch9329_baud_rate")]
|
||||
baud_rate: u32,
|
||||
},
|
||||
/// No HID backend (disabled)
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for HidBackendType {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
impl HidBackendType {
|
||||
/// Check if OTG backend is available on this system
|
||||
pub fn otg_available() -> bool {
|
||||
// Check for USB gadget support
|
||||
std::path::Path::new("/sys/class/udc").exists()
|
||||
}
|
||||
|
||||
/// Detect the best available backend
|
||||
pub fn detect() -> Self {
|
||||
// Check for OTG gadget support
|
||||
if Self::otg_available() {
|
||||
return Self::Otg;
|
||||
}
|
||||
|
||||
// Check for common CH9329 serial ports
|
||||
let common_ports = [
|
||||
"/dev/ttyUSB0",
|
||||
"/dev/ttyUSB1",
|
||||
"/dev/ttyAMA0",
|
||||
"/dev/serial0",
|
||||
];
|
||||
|
||||
for port in &common_ports {
|
||||
if std::path::Path::new(port).exists() {
|
||||
return Self::Ch9329 {
|
||||
port: port.to_string(),
|
||||
baud_rate: 9600, // Use default baud rate for auto-detection
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Self::None
|
||||
}
|
||||
|
||||
/// Get backend name as string
|
||||
pub fn name_str(&self) -> &str {
|
||||
match self {
|
||||
Self::Otg => "otg",
|
||||
Self::Ch9329 { .. } => "ch9329",
|
||||
Self::None => "none",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HID backend trait
|
||||
#[async_trait]
|
||||
pub trait HidBackend: Send + Sync {
|
||||
/// Get backend name
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Initialize the backend
|
||||
async fn init(&self) -> Result<()>;
|
||||
|
||||
/// Send a keyboard event
|
||||
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>;
|
||||
|
||||
/// Send a mouse event
|
||||
async fn send_mouse(&self, event: MouseEvent) -> Result<()>;
|
||||
|
||||
/// Reset all inputs (release all keys/buttons)
|
||||
async fn reset(&self) -> Result<()>;
|
||||
|
||||
/// Shutdown the backend
|
||||
async fn shutdown(&self) -> Result<()>;
|
||||
|
||||
/// Check if backend supports absolute mouse positioning
|
||||
fn supports_absolute_mouse(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Get screen resolution (for absolute mouse)
|
||||
fn screen_resolution(&self) -> Option<(u32, u32)> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Set screen resolution (for absolute mouse)
|
||||
fn set_screen_resolution(&mut self, _width: u32, _height: u32) {}
|
||||
}
|
||||
|
||||
/// HID backend information
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct HidBackendInfo {
|
||||
/// Backend name
|
||||
pub name: String,
|
||||
/// Backend type
|
||||
pub backend_type: String,
|
||||
/// Is initialized
|
||||
pub initialized: bool,
|
||||
/// Supports absolute mouse
|
||||
pub absolute_mouse: bool,
|
||||
/// Screen resolution (if absolute mouse)
|
||||
pub resolution: Option<(u32, u32)>,
|
||||
}
|
||||
1324
src/hid/ch9329.rs
Normal file
1324
src/hid/ch9329.rs
Normal file
File diff suppressed because it is too large
Load Diff
281
src/hid/datachannel.rs
Normal file
281
src/hid/datachannel.rs
Normal file
@@ -0,0 +1,281 @@
|
||||
//! DataChannel HID message parsing and handling
|
||||
//!
|
||||
//! Binary message format:
|
||||
//! - Byte 0: Message type
|
||||
//! - 0x01: Keyboard event
|
||||
//! - 0x02: Mouse event
|
||||
//! - Remaining bytes: Event data
|
||||
//!
|
||||
//! Keyboard event (type 0x01):
|
||||
//! - Byte 1: Event type (0x00 = down, 0x01 = up)
|
||||
//! - Byte 2: Key code (USB HID usage code or JS keyCode)
|
||||
//! - Byte 3: Modifiers bitmask
|
||||
//! - Bit 0: Left Ctrl
|
||||
//! - Bit 1: Left Shift
|
||||
//! - Bit 2: Left Alt
|
||||
//! - Bit 3: Left Meta
|
||||
//! - Bit 4: Right Ctrl
|
||||
//! - Bit 5: Right Shift
|
||||
//! - Bit 6: Right Alt
|
||||
//! - Bit 7: Right Meta
|
||||
//!
|
||||
//! Mouse event (type 0x02):
|
||||
//! - Byte 1: Event type
|
||||
//! - 0x00: Move (relative)
|
||||
//! - 0x01: MoveAbs (absolute)
|
||||
//! - 0x02: Down
|
||||
//! - 0x03: Up
|
||||
//! - 0x04: Scroll
|
||||
//! - Bytes 2-3: X coordinate (i16 LE for relative, u16 LE for absolute)
|
||||
//! - Bytes 4-5: Y coordinate (i16 LE for relative, u16 LE for absolute)
|
||||
//! - Byte 6: Button (0=left, 1=middle, 2=right) or Scroll delta (i8)
|
||||
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::{
|
||||
KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent, MouseEventType,
|
||||
};
|
||||
|
||||
/// Message types
|
||||
pub const MSG_KEYBOARD: u8 = 0x01;
|
||||
pub const MSG_MOUSE: u8 = 0x02;
|
||||
|
||||
/// Keyboard event types
|
||||
pub const KB_EVENT_DOWN: u8 = 0x00;
|
||||
pub const KB_EVENT_UP: u8 = 0x01;
|
||||
|
||||
/// Mouse event types
|
||||
pub const MS_EVENT_MOVE: u8 = 0x00;
|
||||
pub const MS_EVENT_MOVE_ABS: u8 = 0x01;
|
||||
pub const MS_EVENT_DOWN: u8 = 0x02;
|
||||
pub const MS_EVENT_UP: u8 = 0x03;
|
||||
pub const MS_EVENT_SCROLL: u8 = 0x04;
|
||||
|
||||
/// Parsed HID event from DataChannel
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HidChannelEvent {
|
||||
Keyboard(KeyboardEvent),
|
||||
Mouse(MouseEvent),
|
||||
}
|
||||
|
||||
/// Parse a binary HID message from DataChannel
|
||||
pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.is_empty() {
|
||||
warn!("Empty HID message");
|
||||
return None;
|
||||
}
|
||||
|
||||
let msg_type = data[0];
|
||||
|
||||
match msg_type {
|
||||
MSG_KEYBOARD => parse_keyboard_message(&data[1..]),
|
||||
MSG_MOUSE => parse_mouse_message(&data[1..]),
|
||||
_ => {
|
||||
warn!("Unknown HID message type: 0x{:02X}", msg_type);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse keyboard message payload
|
||||
fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.len() < 3 {
|
||||
warn!("Keyboard message too short: {} bytes", data.len());
|
||||
return None;
|
||||
}
|
||||
|
||||
let event_type = match data[0] {
|
||||
KB_EVENT_DOWN => KeyEventType::Down,
|
||||
KB_EVENT_UP => KeyEventType::Up,
|
||||
_ => {
|
||||
warn!("Unknown keyboard event type: 0x{:02X}", data[0]);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let key = data[1];
|
||||
let modifiers_byte = data[2];
|
||||
|
||||
let modifiers = KeyboardModifiers {
|
||||
left_ctrl: modifiers_byte & 0x01 != 0,
|
||||
left_shift: modifiers_byte & 0x02 != 0,
|
||||
left_alt: modifiers_byte & 0x04 != 0,
|
||||
left_meta: modifiers_byte & 0x08 != 0,
|
||||
right_ctrl: modifiers_byte & 0x10 != 0,
|
||||
right_shift: modifiers_byte & 0x20 != 0,
|
||||
right_alt: modifiers_byte & 0x40 != 0,
|
||||
right_meta: modifiers_byte & 0x80 != 0,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Parsed keyboard: {:?} key=0x{:02X} modifiers=0x{:02X}",
|
||||
event_type, key, modifiers_byte
|
||||
);
|
||||
|
||||
Some(HidChannelEvent::Keyboard(KeyboardEvent {
|
||||
event_type,
|
||||
key,
|
||||
modifiers,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Parse mouse message payload
|
||||
fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.len() < 6 {
|
||||
warn!("Mouse message too short: {} bytes", data.len());
|
||||
return None;
|
||||
}
|
||||
|
||||
let event_type = match data[0] {
|
||||
MS_EVENT_MOVE => MouseEventType::Move,
|
||||
MS_EVENT_MOVE_ABS => MouseEventType::MoveAbs,
|
||||
MS_EVENT_DOWN => MouseEventType::Down,
|
||||
MS_EVENT_UP => MouseEventType::Up,
|
||||
MS_EVENT_SCROLL => MouseEventType::Scroll,
|
||||
_ => {
|
||||
warn!("Unknown mouse event type: 0x{:02X}", data[0]);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse coordinates as i16 LE (works for both relative and absolute)
|
||||
let x = i16::from_le_bytes([data[1], data[2]]) as i32;
|
||||
let y = i16::from_le_bytes([data[3], data[4]]) as i32;
|
||||
|
||||
// Button or scroll delta
|
||||
let (button, scroll) = match event_type {
|
||||
MouseEventType::Down | MouseEventType::Up => {
|
||||
let btn = match data[5] {
|
||||
0 => Some(MouseButton::Left),
|
||||
1 => Some(MouseButton::Middle),
|
||||
2 => Some(MouseButton::Right),
|
||||
3 => Some(MouseButton::Back),
|
||||
4 => Some(MouseButton::Forward),
|
||||
_ => Some(MouseButton::Left),
|
||||
};
|
||||
(btn, 0i8)
|
||||
}
|
||||
MouseEventType::Scroll => (None, data[5] as i8),
|
||||
_ => (None, 0i8),
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Parsed mouse: {:?} x={} y={} button={:?} scroll={}",
|
||||
event_type, x, y, button, scroll
|
||||
);
|
||||
|
||||
Some(HidChannelEvent::Mouse(MouseEvent {
|
||||
event_type,
|
||||
x,
|
||||
y,
|
||||
button,
|
||||
scroll,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Encode a keyboard event to binary format (for sending to client if needed)
|
||||
pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
|
||||
let event_type = match event.event_type {
|
||||
KeyEventType::Down => KB_EVENT_DOWN,
|
||||
KeyEventType::Up => KB_EVENT_UP,
|
||||
};
|
||||
|
||||
let modifiers = event.modifiers.to_hid_byte();
|
||||
|
||||
vec![MSG_KEYBOARD, event_type, event.key, modifiers]
|
||||
}
|
||||
|
||||
/// Encode a mouse event to binary format (for sending to client if needed)
|
||||
pub fn encode_mouse_event(event: &MouseEvent) -> Vec<u8> {
|
||||
let event_type = match event.event_type {
|
||||
MouseEventType::Move => MS_EVENT_MOVE,
|
||||
MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS,
|
||||
MouseEventType::Down => MS_EVENT_DOWN,
|
||||
MouseEventType::Up => MS_EVENT_UP,
|
||||
MouseEventType::Scroll => MS_EVENT_SCROLL,
|
||||
};
|
||||
|
||||
let x_bytes = (event.x as i16).to_le_bytes();
|
||||
let y_bytes = (event.y as i16).to_le_bytes();
|
||||
|
||||
let extra = match event.event_type {
|
||||
MouseEventType::Down | MouseEventType::Up => {
|
||||
event.button.as_ref().map(|b| match b {
|
||||
MouseButton::Left => 0u8,
|
||||
MouseButton::Middle => 1u8,
|
||||
MouseButton::Right => 2u8,
|
||||
MouseButton::Back => 3u8,
|
||||
MouseButton::Forward => 4u8,
|
||||
}).unwrap_or(0)
|
||||
}
|
||||
MouseEventType::Scroll => event.scroll as u8,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
vec![
|
||||
MSG_MOUSE,
|
||||
event_type,
|
||||
x_bytes[0],
|
||||
x_bytes[1],
|
||||
y_bytes[0],
|
||||
y_bytes[1],
|
||||
extra,
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_keyboard_down() {
|
||||
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // A key with left ctrl
|
||||
let event = parse_hid_message(&data).unwrap();
|
||||
|
||||
match event {
|
||||
HidChannelEvent::Keyboard(kb) => {
|
||||
assert!(matches!(kb.event_type, KeyEventType::Down));
|
||||
assert_eq!(kb.key, 0x04);
|
||||
assert!(kb.modifiers.left_ctrl);
|
||||
assert!(!kb.modifiers.left_shift);
|
||||
}
|
||||
_ => panic!("Expected keyboard event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mouse_move() {
|
||||
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
|
||||
let event = parse_hid_message(&data).unwrap();
|
||||
|
||||
match event {
|
||||
HidChannelEvent::Mouse(ms) => {
|
||||
assert!(matches!(ms.event_type, MouseEventType::Move));
|
||||
assert_eq!(ms.x, 10);
|
||||
assert_eq!(ms.y, -10);
|
||||
}
|
||||
_ => panic!("Expected mouse event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_keyboard() {
|
||||
let event = KeyboardEvent {
|
||||
event_type: KeyEventType::Down,
|
||||
key: 0x04,
|
||||
modifiers: KeyboardModifiers {
|
||||
left_ctrl: true,
|
||||
left_shift: false,
|
||||
left_alt: false,
|
||||
left_meta: false,
|
||||
right_ctrl: false,
|
||||
right_shift: false,
|
||||
right_alt: false,
|
||||
right_meta: false,
|
||||
},
|
||||
};
|
||||
|
||||
let encoded = encode_keyboard_event(&event);
|
||||
assert_eq!(encoded, vec![MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]);
|
||||
}
|
||||
}
|
||||
430
src/hid/keymap.rs
Normal file
430
src/hid/keymap.rs
Normal file
@@ -0,0 +1,430 @@
|
||||
//! USB HID keyboard key codes mapping
|
||||
//!
|
||||
//! This module provides mapping between JavaScript key codes and USB HID usage codes.
|
||||
//! Reference: USB HID Usage Tables 1.12, Section 10 (Keyboard/Keypad Page)
|
||||
|
||||
/// USB HID key codes (Usage Page 0x07)
|
||||
#[allow(dead_code)]
|
||||
pub mod usb {
|
||||
// Letters A-Z (0x04 - 0x1D)
|
||||
pub const KEY_A: u8 = 0x04;
|
||||
pub const KEY_B: u8 = 0x05;
|
||||
pub const KEY_C: u8 = 0x06;
|
||||
pub const KEY_D: u8 = 0x07;
|
||||
pub const KEY_E: u8 = 0x08;
|
||||
pub const KEY_F: u8 = 0x09;
|
||||
pub const KEY_G: u8 = 0x0A;
|
||||
pub const KEY_H: u8 = 0x0B;
|
||||
pub const KEY_I: u8 = 0x0C;
|
||||
pub const KEY_J: u8 = 0x0D;
|
||||
pub const KEY_K: u8 = 0x0E;
|
||||
pub const KEY_L: u8 = 0x0F;
|
||||
pub const KEY_M: u8 = 0x10;
|
||||
pub const KEY_N: u8 = 0x11;
|
||||
pub const KEY_O: u8 = 0x12;
|
||||
pub const KEY_P: u8 = 0x13;
|
||||
pub const KEY_Q: u8 = 0x14;
|
||||
pub const KEY_R: u8 = 0x15;
|
||||
pub const KEY_S: u8 = 0x16;
|
||||
pub const KEY_T: u8 = 0x17;
|
||||
pub const KEY_U: u8 = 0x18;
|
||||
pub const KEY_V: u8 = 0x19;
|
||||
pub const KEY_W: u8 = 0x1A;
|
||||
pub const KEY_X: u8 = 0x1B;
|
||||
pub const KEY_Y: u8 = 0x1C;
|
||||
pub const KEY_Z: u8 = 0x1D;
|
||||
|
||||
// Numbers 1-9, 0 (0x1E - 0x27)
|
||||
pub const KEY_1: u8 = 0x1E;
|
||||
pub const KEY_2: u8 = 0x1F;
|
||||
pub const KEY_3: u8 = 0x20;
|
||||
pub const KEY_4: u8 = 0x21;
|
||||
pub const KEY_5: u8 = 0x22;
|
||||
pub const KEY_6: u8 = 0x23;
|
||||
pub const KEY_7: u8 = 0x24;
|
||||
pub const KEY_8: u8 = 0x25;
|
||||
pub const KEY_9: u8 = 0x26;
|
||||
pub const KEY_0: u8 = 0x27;
|
||||
|
||||
// Control keys
|
||||
pub const KEY_ENTER: u8 = 0x28;
|
||||
pub const KEY_ESCAPE: u8 = 0x29;
|
||||
pub const KEY_BACKSPACE: u8 = 0x2A;
|
||||
pub const KEY_TAB: u8 = 0x2B;
|
||||
pub const KEY_SPACE: u8 = 0x2C;
|
||||
pub const KEY_MINUS: u8 = 0x2D;
|
||||
pub const KEY_EQUAL: u8 = 0x2E;
|
||||
pub const KEY_LEFT_BRACKET: u8 = 0x2F;
|
||||
pub const KEY_RIGHT_BRACKET: u8 = 0x30;
|
||||
pub const KEY_BACKSLASH: u8 = 0x31;
|
||||
pub const KEY_HASH: u8 = 0x32; // Non-US # and ~
|
||||
pub const KEY_SEMICOLON: u8 = 0x33;
|
||||
pub const KEY_APOSTROPHE: u8 = 0x34;
|
||||
pub const KEY_GRAVE: u8 = 0x35;
|
||||
pub const KEY_COMMA: u8 = 0x36;
|
||||
pub const KEY_PERIOD: u8 = 0x37;
|
||||
pub const KEY_SLASH: u8 = 0x38;
|
||||
pub const KEY_CAPS_LOCK: u8 = 0x39;
|
||||
|
||||
// Function keys F1-F12
|
||||
pub const KEY_F1: u8 = 0x3A;
|
||||
pub const KEY_F2: u8 = 0x3B;
|
||||
pub const KEY_F3: u8 = 0x3C;
|
||||
pub const KEY_F4: u8 = 0x3D;
|
||||
pub const KEY_F5: u8 = 0x3E;
|
||||
pub const KEY_F6: u8 = 0x3F;
|
||||
pub const KEY_F7: u8 = 0x40;
|
||||
pub const KEY_F8: u8 = 0x41;
|
||||
pub const KEY_F9: u8 = 0x42;
|
||||
pub const KEY_F10: u8 = 0x43;
|
||||
pub const KEY_F11: u8 = 0x44;
|
||||
pub const KEY_F12: u8 = 0x45;
|
||||
|
||||
// Special keys
|
||||
pub const KEY_PRINT_SCREEN: u8 = 0x46;
|
||||
pub const KEY_SCROLL_LOCK: u8 = 0x47;
|
||||
pub const KEY_PAUSE: u8 = 0x48;
|
||||
pub const KEY_INSERT: u8 = 0x49;
|
||||
pub const KEY_HOME: u8 = 0x4A;
|
||||
pub const KEY_PAGE_UP: u8 = 0x4B;
|
||||
pub const KEY_DELETE: u8 = 0x4C;
|
||||
pub const KEY_END: u8 = 0x4D;
|
||||
pub const KEY_PAGE_DOWN: u8 = 0x4E;
|
||||
pub const KEY_RIGHT_ARROW: u8 = 0x4F;
|
||||
pub const KEY_LEFT_ARROW: u8 = 0x50;
|
||||
pub const KEY_DOWN_ARROW: u8 = 0x51;
|
||||
pub const KEY_UP_ARROW: u8 = 0x52;
|
||||
|
||||
// Numpad
|
||||
pub const KEY_NUM_LOCK: u8 = 0x53;
|
||||
pub const KEY_NUMPAD_DIVIDE: u8 = 0x54;
|
||||
pub const KEY_NUMPAD_MULTIPLY: u8 = 0x55;
|
||||
pub const KEY_NUMPAD_MINUS: u8 = 0x56;
|
||||
pub const KEY_NUMPAD_PLUS: u8 = 0x57;
|
||||
pub const KEY_NUMPAD_ENTER: u8 = 0x58;
|
||||
pub const KEY_NUMPAD_1: u8 = 0x59;
|
||||
pub const KEY_NUMPAD_2: u8 = 0x5A;
|
||||
pub const KEY_NUMPAD_3: u8 = 0x5B;
|
||||
pub const KEY_NUMPAD_4: u8 = 0x5C;
|
||||
pub const KEY_NUMPAD_5: u8 = 0x5D;
|
||||
pub const KEY_NUMPAD_6: u8 = 0x5E;
|
||||
pub const KEY_NUMPAD_7: u8 = 0x5F;
|
||||
pub const KEY_NUMPAD_8: u8 = 0x60;
|
||||
pub const KEY_NUMPAD_9: u8 = 0x61;
|
||||
pub const KEY_NUMPAD_0: u8 = 0x62;
|
||||
pub const KEY_NUMPAD_DECIMAL: u8 = 0x63;
|
||||
|
||||
// Additional keys
|
||||
pub const KEY_NON_US_BACKSLASH: u8 = 0x64;
|
||||
pub const KEY_APPLICATION: u8 = 0x65; // Context menu
|
||||
pub const KEY_POWER: u8 = 0x66;
|
||||
pub const KEY_NUMPAD_EQUAL: u8 = 0x67;
|
||||
|
||||
// F13-F24
|
||||
pub const KEY_F13: u8 = 0x68;
|
||||
pub const KEY_F14: u8 = 0x69;
|
||||
pub const KEY_F15: u8 = 0x6A;
|
||||
pub const KEY_F16: u8 = 0x6B;
|
||||
pub const KEY_F17: u8 = 0x6C;
|
||||
pub const KEY_F18: u8 = 0x6D;
|
||||
pub const KEY_F19: u8 = 0x6E;
|
||||
pub const KEY_F20: u8 = 0x6F;
|
||||
pub const KEY_F21: u8 = 0x70;
|
||||
pub const KEY_F22: u8 = 0x71;
|
||||
pub const KEY_F23: u8 = 0x72;
|
||||
pub const KEY_F24: u8 = 0x73;
|
||||
|
||||
// Modifier keys (these are handled separately in the modifier byte)
|
||||
pub const KEY_LEFT_CTRL: u8 = 0xE0;
|
||||
pub const KEY_LEFT_SHIFT: u8 = 0xE1;
|
||||
pub const KEY_LEFT_ALT: u8 = 0xE2;
|
||||
pub const KEY_LEFT_META: u8 = 0xE3;
|
||||
pub const KEY_RIGHT_CTRL: u8 = 0xE4;
|
||||
pub const KEY_RIGHT_SHIFT: u8 = 0xE5;
|
||||
pub const KEY_RIGHT_ALT: u8 = 0xE6;
|
||||
pub const KEY_RIGHT_META: u8 = 0xE7;
|
||||
}
|
||||
|
||||
/// JavaScript key codes (event.keyCode / event.code)
|
||||
#[allow(dead_code)]
|
||||
pub mod js {
|
||||
// Letters
|
||||
pub const KEY_A: u8 = 65;
|
||||
pub const KEY_B: u8 = 66;
|
||||
pub const KEY_C: u8 = 67;
|
||||
pub const KEY_D: u8 = 68;
|
||||
pub const KEY_E: u8 = 69;
|
||||
pub const KEY_F: u8 = 70;
|
||||
pub const KEY_G: u8 = 71;
|
||||
pub const KEY_H: u8 = 72;
|
||||
pub const KEY_I: u8 = 73;
|
||||
pub const KEY_J: u8 = 74;
|
||||
pub const KEY_K: u8 = 75;
|
||||
pub const KEY_L: u8 = 76;
|
||||
pub const KEY_M: u8 = 77;
|
||||
pub const KEY_N: u8 = 78;
|
||||
pub const KEY_O: u8 = 79;
|
||||
pub const KEY_P: u8 = 80;
|
||||
pub const KEY_Q: u8 = 81;
|
||||
pub const KEY_R: u8 = 82;
|
||||
pub const KEY_S: u8 = 83;
|
||||
pub const KEY_T: u8 = 84;
|
||||
pub const KEY_U: u8 = 85;
|
||||
pub const KEY_V: u8 = 86;
|
||||
pub const KEY_W: u8 = 87;
|
||||
pub const KEY_X: u8 = 88;
|
||||
pub const KEY_Y: u8 = 89;
|
||||
pub const KEY_Z: u8 = 90;
|
||||
|
||||
// Numbers (top row)
|
||||
pub const KEY_0: u8 = 48;
|
||||
pub const KEY_1: u8 = 49;
|
||||
pub const KEY_2: u8 = 50;
|
||||
pub const KEY_3: u8 = 51;
|
||||
pub const KEY_4: u8 = 52;
|
||||
pub const KEY_5: u8 = 53;
|
||||
pub const KEY_6: u8 = 54;
|
||||
pub const KEY_7: u8 = 55;
|
||||
pub const KEY_8: u8 = 56;
|
||||
pub const KEY_9: u8 = 57;
|
||||
|
||||
// Function keys
|
||||
pub const KEY_F1: u8 = 112;
|
||||
pub const KEY_F2: u8 = 113;
|
||||
pub const KEY_F3: u8 = 114;
|
||||
pub const KEY_F4: u8 = 115;
|
||||
pub const KEY_F5: u8 = 116;
|
||||
pub const KEY_F6: u8 = 117;
|
||||
pub const KEY_F7: u8 = 118;
|
||||
pub const KEY_F8: u8 = 119;
|
||||
pub const KEY_F9: u8 = 120;
|
||||
pub const KEY_F10: u8 = 121;
|
||||
pub const KEY_F11: u8 = 122;
|
||||
pub const KEY_F12: u8 = 123;
|
||||
|
||||
// Control keys
|
||||
pub const KEY_BACKSPACE: u8 = 8;
|
||||
pub const KEY_TAB: u8 = 9;
|
||||
pub const KEY_ENTER: u8 = 13;
|
||||
pub const KEY_SHIFT: u8 = 16;
|
||||
pub const KEY_CTRL: u8 = 17;
|
||||
pub const KEY_ALT: u8 = 18;
|
||||
pub const KEY_PAUSE: u8 = 19;
|
||||
pub const KEY_CAPS_LOCK: u8 = 20;
|
||||
pub const KEY_ESCAPE: u8 = 27;
|
||||
pub const KEY_SPACE: u8 = 32;
|
||||
pub const KEY_PAGE_UP: u8 = 33;
|
||||
pub const KEY_PAGE_DOWN: u8 = 34;
|
||||
pub const KEY_END: u8 = 35;
|
||||
pub const KEY_HOME: u8 = 36;
|
||||
pub const KEY_LEFT: u8 = 37;
|
||||
pub const KEY_UP: u8 = 38;
|
||||
pub const KEY_RIGHT: u8 = 39;
|
||||
pub const KEY_DOWN: u8 = 40;
|
||||
pub const KEY_INSERT: u8 = 45;
|
||||
pub const KEY_DELETE: u8 = 46;
|
||||
|
||||
// Punctuation
|
||||
pub const KEY_SEMICOLON: u8 = 186;
|
||||
pub const KEY_EQUAL: u8 = 187;
|
||||
pub const KEY_COMMA: u8 = 188;
|
||||
pub const KEY_MINUS: u8 = 189;
|
||||
pub const KEY_PERIOD: u8 = 190;
|
||||
pub const KEY_SLASH: u8 = 191;
|
||||
pub const KEY_GRAVE: u8 = 192;
|
||||
pub const KEY_LEFT_BRACKET: u8 = 219;
|
||||
pub const KEY_BACKSLASH: u8 = 220;
|
||||
pub const KEY_RIGHT_BRACKET: u8 = 221;
|
||||
pub const KEY_APOSTROPHE: u8 = 222;
|
||||
|
||||
// Numpad
|
||||
pub const KEY_NUMPAD_0: u8 = 96;
|
||||
pub const KEY_NUMPAD_1: u8 = 97;
|
||||
pub const KEY_NUMPAD_2: u8 = 98;
|
||||
pub const KEY_NUMPAD_3: u8 = 99;
|
||||
pub const KEY_NUMPAD_4: u8 = 100;
|
||||
pub const KEY_NUMPAD_5: u8 = 101;
|
||||
pub const KEY_NUMPAD_6: u8 = 102;
|
||||
pub const KEY_NUMPAD_7: u8 = 103;
|
||||
pub const KEY_NUMPAD_8: u8 = 104;
|
||||
pub const KEY_NUMPAD_9: u8 = 105;
|
||||
pub const KEY_NUMPAD_MULTIPLY: u8 = 106;
|
||||
pub const KEY_NUMPAD_ADD: u8 = 107;
|
||||
pub const KEY_NUMPAD_SUBTRACT: u8 = 109;
|
||||
pub const KEY_NUMPAD_DECIMAL: u8 = 110;
|
||||
pub const KEY_NUMPAD_DIVIDE: u8 = 111;
|
||||
|
||||
// Lock keys
|
||||
pub const KEY_NUM_LOCK: u8 = 144;
|
||||
pub const KEY_SCROLL_LOCK: u8 = 145;
|
||||
|
||||
// Windows keys
|
||||
pub const KEY_META_LEFT: u8 = 91;
|
||||
pub const KEY_META_RIGHT: u8 = 92;
|
||||
pub const KEY_CONTEXT_MENU: u8 = 93;
|
||||
}
|
||||
|
||||
/// JavaScript keyCode to USB HID keyCode mapping table
|
||||
/// Using a fixed-size array for O(1) lookup instead of HashMap
|
||||
/// Index = JavaScript keyCode, Value = USB HID keyCode (0 means unmapped)
|
||||
static JS_TO_USB_TABLE: [u8; 256] = {
|
||||
let mut table = [0u8; 256];
|
||||
|
||||
// Letters A-Z (JS 65-90 -> USB 0x04-0x1D)
|
||||
let mut i = 0u8;
|
||||
while i < 26 {
|
||||
table[(65 + i) as usize] = usb::KEY_A + i;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Numbers 1-9, 0 (JS 49-57, 48 -> USB 0x1E-0x27)
|
||||
table[49] = usb::KEY_1; // 1
|
||||
table[50] = usb::KEY_2; // 2
|
||||
table[51] = usb::KEY_3; // 3
|
||||
table[52] = usb::KEY_4; // 4
|
||||
table[53] = usb::KEY_5; // 5
|
||||
table[54] = usb::KEY_6; // 6
|
||||
table[55] = usb::KEY_7; // 7
|
||||
table[56] = usb::KEY_8; // 8
|
||||
table[57] = usb::KEY_9; // 9
|
||||
table[48] = usb::KEY_0; // 0
|
||||
|
||||
// Function keys F1-F12 (JS 112-123 -> USB 0x3A-0x45)
|
||||
table[112] = usb::KEY_F1;
|
||||
table[113] = usb::KEY_F2;
|
||||
table[114] = usb::KEY_F3;
|
||||
table[115] = usb::KEY_F4;
|
||||
table[116] = usb::KEY_F5;
|
||||
table[117] = usb::KEY_F6;
|
||||
table[118] = usb::KEY_F7;
|
||||
table[119] = usb::KEY_F8;
|
||||
table[120] = usb::KEY_F9;
|
||||
table[121] = usb::KEY_F10;
|
||||
table[122] = usb::KEY_F11;
|
||||
table[123] = usb::KEY_F12;
|
||||
|
||||
// Control keys
|
||||
table[13] = usb::KEY_ENTER; // Enter
|
||||
table[27] = usb::KEY_ESCAPE; // Escape
|
||||
table[8] = usb::KEY_BACKSPACE; // Backspace
|
||||
table[9] = usb::KEY_TAB; // Tab
|
||||
table[32] = usb::KEY_SPACE; // Space
|
||||
table[20] = usb::KEY_CAPS_LOCK; // Caps Lock
|
||||
|
||||
// Punctuation (JS codes vary by browser/layout)
|
||||
table[189] = usb::KEY_MINUS; // -
|
||||
table[187] = usb::KEY_EQUAL; // =
|
||||
table[219] = usb::KEY_LEFT_BRACKET; // [
|
||||
table[221] = usb::KEY_RIGHT_BRACKET; // ]
|
||||
table[220] = usb::KEY_BACKSLASH; // \
|
||||
table[186] = usb::KEY_SEMICOLON; // ;
|
||||
table[222] = usb::KEY_APOSTROPHE; // '
|
||||
table[192] = usb::KEY_GRAVE; // `
|
||||
table[188] = usb::KEY_COMMA; // ,
|
||||
table[190] = usb::KEY_PERIOD; // .
|
||||
table[191] = usb::KEY_SLASH; // /
|
||||
|
||||
// Navigation keys
|
||||
table[45] = usb::KEY_INSERT;
|
||||
table[46] = usb::KEY_DELETE;
|
||||
table[36] = usb::KEY_HOME;
|
||||
table[35] = usb::KEY_END;
|
||||
table[33] = usb::KEY_PAGE_UP;
|
||||
table[34] = usb::KEY_PAGE_DOWN;
|
||||
|
||||
// Arrow keys
|
||||
table[39] = usb::KEY_RIGHT_ARROW;
|
||||
table[37] = usb::KEY_LEFT_ARROW;
|
||||
table[40] = usb::KEY_DOWN_ARROW;
|
||||
table[38] = usb::KEY_UP_ARROW;
|
||||
|
||||
// Numpad
|
||||
table[144] = usb::KEY_NUM_LOCK;
|
||||
table[111] = usb::KEY_NUMPAD_DIVIDE;
|
||||
table[106] = usb::KEY_NUMPAD_MULTIPLY;
|
||||
table[109] = usb::KEY_NUMPAD_MINUS;
|
||||
table[107] = usb::KEY_NUMPAD_PLUS;
|
||||
table[96] = usb::KEY_NUMPAD_0;
|
||||
table[97] = usb::KEY_NUMPAD_1;
|
||||
table[98] = usb::KEY_NUMPAD_2;
|
||||
table[99] = usb::KEY_NUMPAD_3;
|
||||
table[100] = usb::KEY_NUMPAD_4;
|
||||
table[101] = usb::KEY_NUMPAD_5;
|
||||
table[102] = usb::KEY_NUMPAD_6;
|
||||
table[103] = usb::KEY_NUMPAD_7;
|
||||
table[104] = usb::KEY_NUMPAD_8;
|
||||
table[105] = usb::KEY_NUMPAD_9;
|
||||
table[110] = usb::KEY_NUMPAD_DECIMAL;
|
||||
|
||||
// Special keys
|
||||
table[19] = usb::KEY_PAUSE;
|
||||
table[145] = usb::KEY_SCROLL_LOCK;
|
||||
table[93] = usb::KEY_APPLICATION; // Context menu
|
||||
|
||||
// Modifier keys
|
||||
table[17] = usb::KEY_LEFT_CTRL;
|
||||
table[16] = usb::KEY_LEFT_SHIFT;
|
||||
table[18] = usb::KEY_LEFT_ALT;
|
||||
table[91] = usb::KEY_LEFT_META; // Left Windows/Command
|
||||
table[92] = usb::KEY_RIGHT_META; // Right Windows/Command
|
||||
|
||||
table
|
||||
};
|
||||
|
||||
/// Convert JavaScript keyCode to USB HID keyCode
|
||||
///
|
||||
/// Uses a fixed-size lookup table for O(1) performance.
|
||||
/// Returns None if the key code is not mapped.
|
||||
#[inline]
|
||||
pub fn js_to_usb(js_code: u8) -> Option<u8> {
|
||||
let usb_code = JS_TO_USB_TABLE[js_code as usize];
|
||||
if usb_code != 0 {
|
||||
Some(usb_code)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a key code is a modifier key
|
||||
pub fn is_modifier_key(usb_code: u8) -> bool {
|
||||
(0xE0..=0xE7).contains(&usb_code)
|
||||
}
|
||||
|
||||
/// Get modifier bit for a modifier key
|
||||
pub fn modifier_bit(usb_code: u8) -> Option<u8> {
|
||||
match usb_code {
|
||||
usb::KEY_LEFT_CTRL => Some(0x01),
|
||||
usb::KEY_LEFT_SHIFT => Some(0x02),
|
||||
usb::KEY_LEFT_ALT => Some(0x04),
|
||||
usb::KEY_LEFT_META => Some(0x08),
|
||||
usb::KEY_RIGHT_CTRL => Some(0x10),
|
||||
usb::KEY_RIGHT_SHIFT => Some(0x20),
|
||||
usb::KEY_RIGHT_ALT => Some(0x40),
|
||||
usb::KEY_RIGHT_META => Some(0x80),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_letter_mapping() {
|
||||
assert_eq!(js_to_usb(65), Some(usb::KEY_A)); // A
|
||||
assert_eq!(js_to_usb(90), Some(usb::KEY_Z)); // Z
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_number_mapping() {
|
||||
assert_eq!(js_to_usb(48), Some(usb::KEY_0));
|
||||
assert_eq!(js_to_usb(49), Some(usb::KEY_1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modifier_key() {
|
||||
assert!(is_modifier_key(usb::KEY_LEFT_CTRL));
|
||||
assert!(is_modifier_key(usb::KEY_RIGHT_SHIFT));
|
||||
assert!(!is_modifier_key(usb::KEY_A));
|
||||
}
|
||||
}
|
||||
417
src/hid/mod.rs
Normal file
417
src/hid/mod.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
//! HID (Human Interface Device) control module
|
||||
//!
|
||||
//! This module provides keyboard and mouse control for remote KVM:
|
||||
//! - USB OTG gadget mode (native Linux USB gadget)
|
||||
//! - CH9329 serial HID controller
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC
|
||||
//! |
|
||||
//! [OTG | CH9329]
|
||||
//! ```
|
||||
|
||||
pub mod backend;
|
||||
pub mod ch9329;
|
||||
pub mod datachannel;
|
||||
pub mod keymap;
|
||||
pub mod monitor;
|
||||
pub mod otg;
|
||||
pub mod types;
|
||||
pub mod websocket;
|
||||
|
||||
pub use backend::{HidBackend, HidBackendType};
|
||||
pub use monitor::{HidHealthMonitor, HidHealthStatus, HidMonitorConfig};
|
||||
pub use otg::LedState;
|
||||
pub use types::{
|
||||
KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent, MouseEventType,
|
||||
};
|
||||
|
||||
/// HID backend information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HidInfo {
|
||||
/// Backend name
|
||||
pub name: &'static str,
|
||||
/// Whether backend is initialized
|
||||
pub initialized: bool,
|
||||
/// Whether absolute mouse positioning is supported
|
||||
pub supports_absolute_mouse: bool,
|
||||
/// Screen resolution for absolute mouse
|
||||
pub screen_resolution: Option<(u32, u32)>,
|
||||
}
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::otg::OtgService;
|
||||
|
||||
/// HID controller managing keyboard and mouse input
|
||||
pub struct HidController {
|
||||
/// OTG Service reference (only used when backend is OTG)
|
||||
otg_service: Option<Arc<OtgService>>,
|
||||
/// Active backend
|
||||
backend: Arc<RwLock<Option<Box<dyn HidBackend>>>>,
|
||||
/// Backend type (mutable for reload)
|
||||
backend_type: RwLock<HidBackendType>,
|
||||
/// Event bus for broadcasting state changes (optional)
|
||||
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
|
||||
/// Health monitor for error tracking and recovery
|
||||
monitor: Arc<HidHealthMonitor>,
|
||||
}
|
||||
|
||||
impl HidController {
|
||||
/// Create a new HID controller with specified backend
|
||||
///
|
||||
/// For OTG backend, otg_service should be provided to support hot-reload
|
||||
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
|
||||
Self {
|
||||
otg_service,
|
||||
backend: Arc::new(RwLock::new(None)),
|
||||
backend_type: RwLock::new(backend_type),
|
||||
events: tokio::sync::RwLock::new(None),
|
||||
monitor: Arc::new(HidHealthMonitor::with_defaults()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<crate::events::EventBus>) {
|
||||
*self.events.write().await = Some(events.clone());
|
||||
// Also set event bus on the monitor for health notifications
|
||||
self.monitor.set_event_bus(events).await;
|
||||
}
|
||||
|
||||
/// Initialize the HID backend
|
||||
pub async fn init(&self) -> Result<()> {
|
||||
let backend_type = self.backend_type.read().await.clone();
|
||||
let backend: Box<dyn HidBackend> = match backend_type {
|
||||
HidBackendType::Otg => {
|
||||
// Request HID functions from OtgService
|
||||
let otg_service = self
|
||||
.otg_service
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::Internal("OtgService not available".into()))?;
|
||||
|
||||
info!("Requesting HID functions from OtgService");
|
||||
let handles = otg_service.enable_hid().await?;
|
||||
|
||||
// Create OtgBackend from handles (no longer manages gadget itself)
|
||||
info!("Creating OTG HID backend from device paths");
|
||||
Box::new(otg::OtgBackend::from_handles(handles)?)
|
||||
}
|
||||
HidBackendType::Ch9329 { ref port, baud_rate } => {
|
||||
info!("Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate);
|
||||
Box::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?)
|
||||
}
|
||||
HidBackendType::None => {
|
||||
warn!("HID backend disabled");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
backend.init().await?;
|
||||
*self.backend.write().await = Some(backend);
|
||||
|
||||
info!("HID backend initialized: {:?}", backend_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the HID backend and release resources
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down HID controller");
|
||||
|
||||
// Close the backend
|
||||
*self.backend.write().await = None;
|
||||
|
||||
// If OTG backend, notify OtgService to disable HID
|
||||
let backend_type = self.backend_type.read().await.clone();
|
||||
if matches!(backend_type, HidBackendType::Otg) {
|
||||
if let Some(ref otg_service) = self.otg_service {
|
||||
info!("Disabling HID functions in OtgService");
|
||||
otg_service.disable_hid().await?;
|
||||
}
|
||||
}
|
||||
|
||||
info!("HID controller shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send keyboard event
|
||||
pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
|
||||
let backend = self.backend.read().await;
|
||||
match backend.as_ref() {
|
||||
Some(b) => {
|
||||
match b.send_keyboard(event).await {
|
||||
Ok(_) => {
|
||||
// Check if we were in an error state and now recovered
|
||||
if self.monitor.is_error().await {
|
||||
let backend_type = self.backend_type.read().await;
|
||||
self.monitor.report_recovered(backend_type.name_str()).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
// Report error to monitor, but skip temporary EAGAIN retries
|
||||
// - "eagain_retry": within threshold, just temporary busy
|
||||
// - "eagain": exceeded threshold, report as error
|
||||
if let AppError::HidError { ref backend, ref reason, ref error_code } = e {
|
||||
if error_code != "eagain_retry" {
|
||||
self.monitor.report_error(backend, None, reason, error_code).await;
|
||||
}
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => Err(AppError::BadRequest("HID backend not available".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send mouse event
|
||||
pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
|
||||
let backend = self.backend.read().await;
|
||||
match backend.as_ref() {
|
||||
Some(b) => {
|
||||
match b.send_mouse(event).await {
|
||||
Ok(_) => {
|
||||
// Check if we were in an error state and now recovered
|
||||
if self.monitor.is_error().await {
|
||||
let backend_type = self.backend_type.read().await;
|
||||
self.monitor.report_recovered(backend_type.name_str()).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
// Report error to monitor, but skip temporary EAGAIN retries
|
||||
// - "eagain_retry": within threshold, just temporary busy
|
||||
// - "eagain": exceeded threshold, report as error
|
||||
if let AppError::HidError { ref backend, ref reason, ref error_code } = e {
|
||||
if error_code != "eagain_retry" {
|
||||
self.monitor.report_error(backend, None, reason, error_code).await;
|
||||
}
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => Err(AppError::BadRequest("HID backend not available".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset all keys (release all pressed keys)
|
||||
pub async fn reset(&self) -> Result<()> {
|
||||
let backend = self.backend.read().await;
|
||||
match backend.as_ref() {
|
||||
Some(b) => b.reset().await,
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if backend is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
self.backend.read().await.is_some()
|
||||
}
|
||||
|
||||
/// Get backend type
|
||||
pub async fn backend_type(&self) -> HidBackendType {
|
||||
self.backend_type.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get backend info
|
||||
pub async fn info(&self) -> Option<HidInfo> {
|
||||
let backend = self.backend.read().await;
|
||||
backend.as_ref().map(|b| HidInfo {
|
||||
name: b.name(),
|
||||
initialized: true,
|
||||
supports_absolute_mouse: b.supports_absolute_mouse(),
|
||||
screen_resolution: b.screen_resolution(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current state as SystemEvent
|
||||
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
|
||||
let backend = self.backend.read().await;
|
||||
let backend_type = self.backend_type().await;
|
||||
let (backend_name, initialized) = match backend.as_ref() {
|
||||
Some(b) => (b.name(), true),
|
||||
None => (backend_type.name_str(), false),
|
||||
};
|
||||
|
||||
// Include error information from monitor
|
||||
let (error, error_code) = match self.monitor.status().await {
|
||||
HidHealthStatus::Error { reason, error_code, .. } => {
|
||||
(Some(reason), Some(error_code))
|
||||
}
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
crate::events::SystemEvent::HidStateChanged {
|
||||
backend: backend_name.to_string(),
|
||||
initialized,
|
||||
error,
|
||||
error_code,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the health monitor reference
|
||||
pub fn monitor(&self) -> &Arc<HidHealthMonitor> {
|
||||
&self.monitor
|
||||
}
|
||||
|
||||
/// Get current health status
|
||||
pub async fn health_status(&self) -> HidHealthStatus {
|
||||
self.monitor.status().await
|
||||
}
|
||||
|
||||
/// Check if the HID backend is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.monitor.is_healthy().await
|
||||
}
|
||||
|
||||
/// Reload the HID backend with new type
|
||||
pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> {
|
||||
info!("Reloading HID backend: {:?}", new_backend_type);
|
||||
|
||||
// Shutdown existing backend first
|
||||
if let Some(backend) = self.backend.write().await.take() {
|
||||
if let Err(e) = backend.shutdown().await {
|
||||
warn!("Error shutting down old HID backend: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create and initialize new backend
|
||||
let new_backend: Option<Box<dyn HidBackend>> = match new_backend_type {
|
||||
HidBackendType::Otg => {
|
||||
info!("Initializing OTG HID backend");
|
||||
|
||||
// Get OtgService reference
|
||||
let otg_service = match self.otg_service.as_ref() {
|
||||
Some(svc) => svc,
|
||||
None => {
|
||||
warn!("OTG backend requires OtgService, but it's not available");
|
||||
return Err(AppError::Config(
|
||||
"OTG backend not available (OtgService missing)".to_string()
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Request HID functions from OtgService
|
||||
match otg_service.enable_hid().await {
|
||||
Ok(handles) => {
|
||||
// Create OtgBackend from handles
|
||||
match otg::OtgBackend::from_handles(handles) {
|
||||
Ok(backend) => {
|
||||
let boxed: Box<dyn HidBackend> = Box::new(backend);
|
||||
match boxed.init().await {
|
||||
Ok(_) => {
|
||||
info!("OTG backend initialized successfully");
|
||||
Some(boxed)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to initialize OTG backend: {}", e);
|
||||
// Cleanup: disable HID in OtgService
|
||||
if let Err(e2) = otg_service.disable_hid().await {
|
||||
warn!("Failed to cleanup HID after init failure: {}", e2);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create OTG backend: {}", e);
|
||||
// Cleanup: disable HID in OtgService
|
||||
if let Err(e2) = otg_service.disable_hid().await {
|
||||
warn!("Failed to cleanup HID after creation failure: {}", e2);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to enable HID in OtgService: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
HidBackendType::Ch9329 { ref port, baud_rate } => {
|
||||
info!("Initializing CH9329 HID backend on {} @ {} baud", port, baud_rate);
|
||||
match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) {
|
||||
Ok(b) => {
|
||||
let boxed = Box::new(b);
|
||||
match boxed.init().await {
|
||||
Ok(_) => Some(boxed),
|
||||
Err(e) => {
|
||||
warn!("Failed to initialize CH9329 backend: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create CH9329 backend: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
HidBackendType::None => {
|
||||
warn!("HID backend disabled");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
*self.backend.write().await = new_backend;
|
||||
|
||||
if self.backend.read().await.is_some() {
|
||||
info!("HID backend reloaded successfully: {:?}", new_backend_type);
|
||||
|
||||
// Update backend_type on success
|
||||
*self.backend_type.write().await = new_backend_type.clone();
|
||||
|
||||
// Reset monitor state on successful reload
|
||||
self.monitor.reset().await;
|
||||
|
||||
// Publish HID state changed event
|
||||
let backend_name = new_backend_type.name_str().to_string();
|
||||
self.publish_event(crate::events::SystemEvent::HidStateChanged {
|
||||
backend: backend_name,
|
||||
initialized: true,
|
||||
error: None,
|
||||
error_code: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
warn!("HID backend reload resulted in no active backend");
|
||||
|
||||
// Update backend_type even on failure (to reflect the attempted change)
|
||||
*self.backend_type.write().await = new_backend_type.clone();
|
||||
|
||||
// Publish event with initialized=false
|
||||
self.publish_event(crate::events::SystemEvent::HidStateChanged {
|
||||
backend: new_backend_type.name_str().to_string(),
|
||||
initialized: false,
|
||||
error: Some("Failed to initialize HID backend".to_string()),
|
||||
error_code: Some("init_failed".to_string()),
|
||||
})
|
||||
.await;
|
||||
|
||||
Err(AppError::Internal(
|
||||
"Failed to reload HID backend".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish event to event bus if available
|
||||
async fn publish_event(&self, event: crate::events::SystemEvent) {
|
||||
if let Some(events) = self.events.read().await.as_ref() {
|
||||
events.publish(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HidController {
|
||||
fn default() -> Self {
|
||||
Self::new(HidBackendType::None, None)
|
||||
}
|
||||
}
|
||||
429
src/hid/monitor.rs
Normal file
429
src/hid/monitor.rs
Normal file
@@ -0,0 +1,429 @@
|
||||
//! HID device health monitoring
|
||||
//!
|
||||
//! This module provides health monitoring for HID devices, including:
|
||||
//! - Device connectivity checks
|
||||
//! - Automatic reconnection on failure
|
||||
//! - Error tracking and notification
|
||||
//! - Log throttling to prevent log flooding
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// HID health status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum HidHealthStatus {
|
||||
/// Device is healthy and operational
|
||||
Healthy,
|
||||
/// Device has an error, attempting recovery
|
||||
Error {
|
||||
/// Human-readable error reason
|
||||
reason: String,
|
||||
/// Error code for programmatic handling
|
||||
error_code: String,
|
||||
/// Number of recovery attempts made
|
||||
retry_count: u32,
|
||||
},
|
||||
/// Device is disconnected
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
impl Default for HidHealthStatus {
|
||||
fn default() -> Self {
|
||||
Self::Healthy
|
||||
}
|
||||
}
|
||||
|
||||
/// HID health monitor configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HidMonitorConfig {
|
||||
/// Health check interval in milliseconds
|
||||
pub check_interval_ms: u64,
|
||||
/// Retry interval when device is lost (milliseconds)
|
||||
pub retry_interval_ms: u64,
|
||||
/// Maximum retry attempts before giving up (0 = infinite)
|
||||
pub max_retries: u32,
|
||||
/// Log throttle interval in seconds
|
||||
pub log_throttle_secs: u64,
|
||||
/// Recovery cooldown in milliseconds (suppress logs after recovery)
|
||||
pub recovery_cooldown_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for HidMonitorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
check_interval_ms: 1000,
|
||||
retry_interval_ms: 1000,
|
||||
max_retries: 0, // infinite retry
|
||||
log_throttle_secs: 5,
|
||||
recovery_cooldown_ms: 1000, // 1 second cooldown after recovery
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HID health monitor
|
||||
///
|
||||
/// Monitors HID device health and manages error recovery.
|
||||
/// Publishes WebSocket events when device status changes.
|
||||
pub struct HidHealthMonitor {
|
||||
/// Current health status
|
||||
status: RwLock<HidHealthStatus>,
|
||||
/// Event bus for notifications
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
/// Log throttler to prevent log flooding
|
||||
throttler: LogThrottler,
|
||||
/// Configuration
|
||||
config: HidMonitorConfig,
|
||||
/// Whether monitoring is active (reserved for future use)
|
||||
#[allow(dead_code)]
|
||||
running: AtomicBool,
|
||||
/// Current retry count
|
||||
retry_count: AtomicU32,
|
||||
/// Last error code (for change detection)
|
||||
last_error_code: RwLock<Option<String>>,
|
||||
/// Last recovery timestamp (milliseconds since start, for cooldown)
|
||||
last_recovery_ms: AtomicU64,
|
||||
/// Start instant for timing
|
||||
start_instant: Instant,
|
||||
}
|
||||
|
||||
impl HidHealthMonitor {
|
||||
/// Create a new HID health monitor with the specified configuration
|
||||
pub fn new(config: HidMonitorConfig) -> Self {
|
||||
let throttle_secs = config.log_throttle_secs;
|
||||
Self {
|
||||
status: RwLock::new(HidHealthStatus::Healthy),
|
||||
events: RwLock::new(None),
|
||||
throttler: LogThrottler::with_secs(throttle_secs),
|
||||
config,
|
||||
running: AtomicBool::new(false),
|
||||
retry_count: AtomicU32::new(0),
|
||||
last_error_code: RwLock::new(None),
|
||||
last_recovery_ms: AtomicU64::new(0),
|
||||
start_instant: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new HID health monitor with default configuration
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(HidMonitorConfig::default())
|
||||
}
|
||||
|
||||
/// Set the event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Report an error from HID operations
|
||||
///
|
||||
/// This method is called when an HID operation fails. It:
|
||||
/// 1. Updates the health status
|
||||
/// 2. Logs the error (with throttling and cooldown respect)
|
||||
/// 3. Publishes a WebSocket event if the error is new or changed
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `backend` - The HID backend type ("otg" or "ch9329")
|
||||
/// * `device` - The device path (if known)
|
||||
/// * `reason` - Human-readable error description
|
||||
/// * `error_code` - Error code for programmatic handling
|
||||
pub async fn report_error(
|
||||
&self,
|
||||
backend: &str,
|
||||
device: Option<&str>,
|
||||
reason: &str,
|
||||
error_code: &str,
|
||||
) {
|
||||
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if we're in cooldown period after recent recovery
|
||||
let current_ms = self.start_instant.elapsed().as_millis() as u64;
|
||||
let last_recovery = self.last_recovery_ms.load(Ordering::Relaxed);
|
||||
let in_cooldown = last_recovery > 0 && current_ms < last_recovery + self.config.recovery_cooldown_ms;
|
||||
|
||||
// Check if error code changed
|
||||
let error_changed = {
|
||||
let last = self.last_error_code.read().await;
|
||||
last.as_ref().map(|s| s.as_str()) != Some(error_code)
|
||||
};
|
||||
|
||||
// Log with throttling (skip if in cooldown period unless error type changed)
|
||||
let throttle_key = format!("hid_{}_{}", backend, error_code);
|
||||
if !in_cooldown && (error_changed || self.throttler.should_log(&throttle_key)) {
|
||||
warn!(
|
||||
"HID {} error: {} (code: {}, attempt: {})",
|
||||
backend, reason, error_code, count
|
||||
);
|
||||
}
|
||||
|
||||
// Update last error code
|
||||
*self.last_error_code.write().await = Some(error_code.to_string());
|
||||
|
||||
// Update status
|
||||
*self.status.write().await = HidHealthStatus::Error {
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
retry_count: count,
|
||||
};
|
||||
|
||||
// Publish event (only if error changed or first occurrence, and not in cooldown)
|
||||
if !in_cooldown && (error_changed || count == 1) {
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::HidDeviceLost {
|
||||
backend: backend.to_string(),
|
||||
device: device.map(|s| s.to_string()),
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Report that a reconnection attempt is starting
|
||||
///
|
||||
/// Publishes a reconnecting event to notify clients.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `backend` - The HID backend type
|
||||
pub async fn report_reconnecting(&self, backend: &str) {
|
||||
let attempt = self.retry_count.load(Ordering::Relaxed);
|
||||
|
||||
// Only publish every 5 attempts to avoid event spam
|
||||
if attempt == 1 || attempt % 5 == 0 {
|
||||
debug!("HID {} reconnecting, attempt {}", backend, attempt);
|
||||
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::HidReconnecting {
|
||||
backend: backend.to_string(),
|
||||
attempt,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Report that the device has recovered
|
||||
///
|
||||
/// This method is called when the HID device successfully reconnects.
|
||||
/// It resets the error state and publishes a recovery event.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `backend` - The HID backend type
|
||||
pub async fn report_recovered(&self, backend: &str) {
|
||||
let prev_status = self.status.read().await.clone();
|
||||
|
||||
// Only report recovery if we were in an error state
|
||||
if prev_status != HidHealthStatus::Healthy {
|
||||
let retry_count = self.retry_count.load(Ordering::Relaxed);
|
||||
|
||||
// Set cooldown timestamp
|
||||
let current_ms = self.start_instant.elapsed().as_millis() as u64;
|
||||
self.last_recovery_ms.store(current_ms, Ordering::Relaxed);
|
||||
|
||||
// Only log and publish events if there were multiple retries
|
||||
// (avoid log spam for transient single-retry recoveries)
|
||||
if retry_count > 1 {
|
||||
debug!(
|
||||
"HID {} recovered after {} retries",
|
||||
backend, retry_count
|
||||
);
|
||||
|
||||
// Publish recovery event
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::HidRecovered {
|
||||
backend: backend.to_string(),
|
||||
});
|
||||
|
||||
// Also publish state changed to indicate healthy state
|
||||
events.publish(SystemEvent::HidStateChanged {
|
||||
backend: backend.to_string(),
|
||||
initialized: true,
|
||||
error: None,
|
||||
error_code: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state (always reset, even for single-retry recoveries)
|
||||
self.retry_count.store(0, Ordering::Relaxed);
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = HidHealthStatus::Healthy;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current health status
|
||||
pub async fn status(&self) -> HidHealthStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get the current retry count
|
||||
pub fn retry_count(&self) -> u32 {
|
||||
self.retry_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check if the monitor is in an error state
|
||||
pub async fn is_error(&self) -> bool {
|
||||
matches!(*self.status.read().await, HidHealthStatus::Error { .. })
|
||||
}
|
||||
|
||||
/// Check if the monitor is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
matches!(*self.status.read().await, HidHealthStatus::Healthy)
|
||||
}
|
||||
|
||||
/// Reset the monitor to healthy state without publishing events
|
||||
///
|
||||
/// This is useful during initialization.
|
||||
pub async fn reset(&self) {
|
||||
self.retry_count.store(0, Ordering::Relaxed);
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = HidHealthStatus::Healthy;
|
||||
self.throttler.clear_all();
|
||||
}
|
||||
|
||||
/// Get the configuration
|
||||
pub fn config(&self) -> &HidMonitorConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Check if we should continue retrying
|
||||
///
|
||||
/// Returns `false` if max_retries is set and we've exceeded it.
|
||||
pub fn should_retry(&self) -> bool {
|
||||
if self.config.max_retries == 0 {
|
||||
return true; // Infinite retry
|
||||
}
|
||||
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
|
||||
}
|
||||
|
||||
/// Get the retry interval
|
||||
pub fn retry_interval(&self) -> Duration {
|
||||
Duration::from_millis(self.config.retry_interval_ms)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HidHealthMonitor {
|
||||
fn default() -> Self {
|
||||
Self::with_defaults()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initial_status() {
|
||||
let monitor = HidHealthMonitor::with_defaults();
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert!(!monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_error() {
|
||||
let monitor = HidHealthMonitor::with_defaults();
|
||||
|
||||
monitor
|
||||
.report_error("otg", Some("/dev/hidg0"), "Device not found", "enoent")
|
||||
.await;
|
||||
|
||||
assert!(monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 1);
|
||||
|
||||
if let HidHealthStatus::Error {
|
||||
reason,
|
||||
error_code,
|
||||
retry_count,
|
||||
} = monitor.status().await
|
||||
{
|
||||
assert_eq!(reason, "Device not found");
|
||||
assert_eq!(error_code, "enoent");
|
||||
assert_eq!(retry_count, 1);
|
||||
} else {
|
||||
panic!("Expected Error status");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_recovered() {
|
||||
let monitor = HidHealthMonitor::with_defaults();
|
||||
|
||||
// First report an error
|
||||
monitor
|
||||
.report_error("ch9329", None, "Port not found", "port_not_found")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
// Then report recovery
|
||||
monitor.report_recovered("ch9329").await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_count_increments() {
|
||||
let monitor = HidHealthMonitor::with_defaults();
|
||||
|
||||
for i in 1..=5 {
|
||||
monitor
|
||||
.report_error("otg", None, "Error", "io_error")
|
||||
.await;
|
||||
assert_eq!(monitor.retry_count(), i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_retry_infinite() {
|
||||
let monitor = HidHealthMonitor::new(HidMonitorConfig {
|
||||
max_retries: 0, // infinite
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
for _ in 0..100 {
|
||||
monitor
|
||||
.report_error("otg", None, "Error", "io_error")
|
||||
.await;
|
||||
assert!(monitor.should_retry());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_retry_limited() {
|
||||
let monitor = HidHealthMonitor::new(HidMonitorConfig {
|
||||
max_retries: 3,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
assert!(monitor.should_retry());
|
||||
|
||||
monitor.report_error("otg", None, "Error", "io_error").await;
|
||||
assert!(monitor.should_retry()); // 1 < 3
|
||||
|
||||
monitor.report_error("otg", None, "Error", "io_error").await;
|
||||
assert!(monitor.should_retry()); // 2 < 3
|
||||
|
||||
monitor.report_error("otg", None, "Error", "io_error").await;
|
||||
assert!(!monitor.should_retry()); // 3 >= 3
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset() {
|
||||
let monitor = HidHealthMonitor::with_defaults();
|
||||
|
||||
monitor
|
||||
.report_error("otg", None, "Error", "io_error")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
monitor.reset().await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
}
|
||||
848
src/hid/otg.rs
Normal file
848
src/hid/otg.rs
Normal file
@@ -0,0 +1,848 @@
|
||||
//! OTG USB Gadget HID backend
|
||||
//!
|
||||
//! This backend uses Linux USB Gadget API to emulate USB HID devices.
|
||||
//! It creates and manages three HID devices:
|
||||
//! - hidg0: Keyboard (8-byte reports, with LED feedback)
|
||||
//! - hidg1: Relative Mouse (4-byte reports)
|
||||
//! - hidg2: Absolute Mouse (6-byte reports)
|
||||
//!
|
||||
//! Requirements:
|
||||
//! - USB OTG/Device controller (UDC)
|
||||
//! - ConfigFS with USB gadget support
|
||||
//! - Root privileges for gadget setup
|
||||
//!
|
||||
//! Error Recovery:
|
||||
//! This module implements automatic device reconnection based on PiKVM's approach.
|
||||
//! When ESHUTDOWN or EAGAIN errors occur (common during MSD operations), the device
|
||||
//! file handles are closed and reopened on the next operation.
|
||||
//! See: https://github.com/raspberrypi/linux/issues/4373
|
||||
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::Mutex;
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write};
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
use super::backend::HidBackend;
|
||||
use super::keymap;
|
||||
use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::otg::{HidDevicePaths, wait_for_hid_devices};
|
||||
|
||||
/// Device type for ensure_device operations
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum DeviceType {
|
||||
Keyboard,
|
||||
MouseRelative,
|
||||
MouseAbsolute,
|
||||
}
|
||||
|
||||
/// Keyboard LED state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct LedState {
|
||||
/// Num Lock LED
|
||||
pub num_lock: bool,
|
||||
/// Caps Lock LED
|
||||
pub caps_lock: bool,
|
||||
/// Scroll Lock LED
|
||||
pub scroll_lock: bool,
|
||||
/// Compose LED
|
||||
pub compose: bool,
|
||||
/// Kana LED
|
||||
pub kana: bool,
|
||||
}
|
||||
|
||||
impl LedState {
|
||||
/// Create from raw byte
|
||||
pub fn from_byte(b: u8) -> Self {
|
||||
Self {
|
||||
num_lock: b & 0x01 != 0,
|
||||
caps_lock: b & 0x02 != 0,
|
||||
scroll_lock: b & 0x04 != 0,
|
||||
compose: b & 0x08 != 0,
|
||||
kana: b & 0x10 != 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to raw byte
|
||||
pub fn to_byte(&self) -> u8 {
|
||||
let mut b = 0u8;
|
||||
if self.num_lock { b |= 0x01; }
|
||||
if self.caps_lock { b |= 0x02; }
|
||||
if self.scroll_lock { b |= 0x04; }
|
||||
if self.compose { b |= 0x08; }
|
||||
if self.kana { b |= 0x10; }
|
||||
b
|
||||
}
|
||||
}
|
||||
|
||||
/// OTG HID backend with 3 devices
|
||||
///
|
||||
/// This backend opens HID device files created by OtgService.
|
||||
/// It does NOT manage the USB gadget itself - that's handled by OtgService.
|
||||
///
|
||||
/// ## Error Recovery
|
||||
///
|
||||
/// Based on PiKVM's implementation, this backend automatically handles:
|
||||
/// - EAGAIN (errno 11): Resource temporarily unavailable - just retry later, don't close device
|
||||
/// - ESHUTDOWN (errno 108): Transport endpoint shutdown - close and reopen device
|
||||
///
|
||||
/// When ESHUTDOWN occurs, the device file handle is closed and will be
|
||||
/// reopened on the next operation attempt.
|
||||
pub struct OtgBackend {
|
||||
/// Keyboard device path (/dev/hidg0)
|
||||
keyboard_path: PathBuf,
|
||||
/// Relative mouse device path (/dev/hidg1)
|
||||
mouse_rel_path: PathBuf,
|
||||
/// Absolute mouse device path (/dev/hidg2)
|
||||
mouse_abs_path: PathBuf,
|
||||
/// Keyboard device file
|
||||
keyboard_dev: Mutex<Option<File>>,
|
||||
/// Relative mouse device file
|
||||
mouse_rel_dev: Mutex<Option<File>>,
|
||||
/// Absolute mouse device file
|
||||
mouse_abs_dev: Mutex<Option<File>>,
|
||||
/// Current keyboard state
|
||||
keyboard_state: Mutex<KeyboardReport>,
|
||||
/// Current mouse button state
|
||||
mouse_buttons: AtomicU8,
|
||||
/// Last known LED state (using parking_lot::RwLock for sync access)
|
||||
led_state: parking_lot::RwLock<LedState>,
|
||||
/// Screen resolution for absolute mouse (using parking_lot::RwLock for sync access)
|
||||
screen_resolution: parking_lot::RwLock<Option<(u32, u32)>>,
|
||||
/// UDC name for state checking (e.g., "fcc00000.usb")
|
||||
udc_name: parking_lot::RwLock<Option<String>>,
|
||||
/// Whether the device is currently online (UDC configured and devices accessible)
|
||||
online: AtomicBool,
|
||||
/// Last error log time for throttling (using parking_lot for sync)
|
||||
last_error_log: parking_lot::Mutex<std::time::Instant>,
|
||||
/// Error count since last successful operation (for log throttling)
|
||||
error_count: AtomicU8,
|
||||
/// Consecutive EAGAIN count (for offline threshold detection)
|
||||
eagain_count: AtomicU8,
|
||||
}
|
||||
|
||||
/// Threshold for consecutive EAGAIN errors before reporting offline
|
||||
const EAGAIN_OFFLINE_THRESHOLD: u8 = 3;
|
||||
|
||||
impl OtgBackend {
|
||||
/// Create OTG backend from device paths provided by OtgService
|
||||
///
|
||||
/// This is the ONLY way to create an OtgBackend - it no longer manages
|
||||
/// the USB gadget itself. The gadget must already be set up by OtgService.
|
||||
pub fn from_handles(paths: HidDevicePaths) -> Result<Self> {
|
||||
Ok(Self {
|
||||
keyboard_path: paths.keyboard,
|
||||
mouse_rel_path: paths.mouse_relative,
|
||||
mouse_abs_path: paths.mouse_absolute,
|
||||
keyboard_dev: Mutex::new(None),
|
||||
mouse_rel_dev: Mutex::new(None),
|
||||
mouse_abs_dev: Mutex::new(None),
|
||||
keyboard_state: Mutex::new(KeyboardReport::default()),
|
||||
mouse_buttons: AtomicU8::new(0),
|
||||
led_state: parking_lot::RwLock::new(LedState::default()),
|
||||
screen_resolution: parking_lot::RwLock::new(Some((1920, 1080))),
|
||||
udc_name: parking_lot::RwLock::new(None),
|
||||
online: AtomicBool::new(false),
|
||||
last_error_log: parking_lot::Mutex::new(std::time::Instant::now()),
|
||||
error_count: AtomicU8::new(0),
|
||||
eagain_count: AtomicU8::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Log throttled error message (max once per second)
|
||||
fn log_throttled_error(&self, msg: &str) {
|
||||
let mut last_log = self.last_error_log.lock();
|
||||
let now = std::time::Instant::now();
|
||||
if now.duration_since(*last_log).as_secs() >= 1 {
|
||||
let count = self.error_count.swap(0, Ordering::Relaxed);
|
||||
if count > 1 {
|
||||
warn!("{} (repeated {} times)", msg, count);
|
||||
} else {
|
||||
warn!("{}", msg);
|
||||
}
|
||||
*last_log = now;
|
||||
} else {
|
||||
self.error_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset error count on successful operation
|
||||
fn reset_error_count(&self) {
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
// Also reset EAGAIN count - successful operation means device is working
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Set the UDC name for state checking
|
||||
pub fn set_udc_name(&self, udc: &str) {
|
||||
*self.udc_name.write() = Some(udc.to_string());
|
||||
}
|
||||
|
||||
/// Check if the UDC is in "configured" state
|
||||
///
|
||||
/// This is based on PiKVM's `__is_udc_configured()` method.
|
||||
/// The UDC state file indicates whether the USB host has enumerated and configured the gadget.
|
||||
pub fn is_udc_configured(&self) -> bool {
|
||||
let udc_name = self.udc_name.read();
|
||||
if let Some(ref udc) = *udc_name {
|
||||
let state_path = format!("/sys/class/udc/{}/state", udc);
|
||||
match fs::read_to_string(&state_path) {
|
||||
Ok(content) => {
|
||||
let state = content.trim().to_lowercase();
|
||||
trace!("UDC {} state: {}", udc, state);
|
||||
state == "configured"
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to read UDC state from {}: {}", state_path, e);
|
||||
// If we can't read the state, assume it might be configured
|
||||
// to avoid blocking operations unnecessarily
|
||||
true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No UDC name set, try to auto-detect
|
||||
if let Some(udc) = Self::find_udc() {
|
||||
drop(udc_name);
|
||||
*self.udc_name.write() = Some(udc.clone());
|
||||
let state_path = format!("/sys/class/udc/{}/state", udc);
|
||||
fs::read_to_string(&state_path)
|
||||
.map(|s| s.trim().to_lowercase() == "configured")
|
||||
.unwrap_or(true)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the first available UDC
|
||||
fn find_udc() -> Option<String> {
|
||||
let udc_path = PathBuf::from("/sys/class/udc");
|
||||
if let Ok(entries) = fs::read_dir(&udc_path) {
|
||||
for entry in entries.flatten() {
|
||||
if let Some(name) = entry.file_name().to_str() {
|
||||
return Some(name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if device is online
|
||||
pub fn is_online(&self) -> bool {
|
||||
self.online.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Ensure a device is open and ready for I/O
|
||||
///
|
||||
/// This method is based on PiKVM's `__ensure_device()` pattern:
|
||||
/// 1. Check if device path exists, close handle if not
|
||||
/// 2. If handle is None but path exists, reopen the device
|
||||
/// 3. Return whether the device is ready for I/O
|
||||
fn ensure_device(&self, device_type: DeviceType) -> Result<()> {
|
||||
let (path, dev_mutex) = match device_type {
|
||||
DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev),
|
||||
DeviceType::MouseRelative => (&self.mouse_rel_path, &self.mouse_rel_dev),
|
||||
DeviceType::MouseAbsolute => (&self.mouse_abs_path, &self.mouse_abs_dev),
|
||||
};
|
||||
|
||||
// Check if device path exists
|
||||
if !path.exists() {
|
||||
// Close the device if open (device was removed)
|
||||
let mut dev = dev_mutex.lock();
|
||||
if dev.is_some() {
|
||||
debug!("Device path {} no longer exists, closing handle", path.display());
|
||||
*dev = None;
|
||||
}
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
return Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: format!("Device not found: {}", path.display()),
|
||||
error_code: "enoent".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// If device is not open, try to open it
|
||||
let mut dev = dev_mutex.lock();
|
||||
if dev.is_none() {
|
||||
match Self::open_device(path) {
|
||||
Ok(file) => {
|
||||
info!("Reopened HID device: {}", path.display());
|
||||
*dev = Some(file);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to reopen HID device {}: {}", path.display(), e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.online.store(true, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Close a device (used when ESHUTDOWN is received)
|
||||
#[allow(dead_code)]
|
||||
fn close_device(&self, device_type: DeviceType) {
|
||||
let dev_mutex = match device_type {
|
||||
DeviceType::Keyboard => &self.keyboard_dev,
|
||||
DeviceType::MouseRelative => &self.mouse_rel_dev,
|
||||
DeviceType::MouseAbsolute => &self.mouse_abs_dev,
|
||||
};
|
||||
|
||||
let mut dev = dev_mutex.lock();
|
||||
if dev.is_some() {
|
||||
debug!("Closing {:?} device handle for recovery", device_type);
|
||||
*dev = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Close all device handles (for recovery)
|
||||
#[allow(dead_code)]
|
||||
fn close_all_devices(&self) {
|
||||
self.close_device(DeviceType::Keyboard);
|
||||
self.close_device(DeviceType::MouseRelative);
|
||||
self.close_device(DeviceType::MouseAbsolute);
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Open a HID device file with read/write access
|
||||
fn open_device(path: &PathBuf) -> Result<File> {
|
||||
OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.custom_flags(libc::O_NONBLOCK)
|
||||
.open(path)
|
||||
.map_err(|e| {
|
||||
AppError::Internal(format!("Failed to open HID device {}: {}", path.display(), e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert I/O error to HidError with appropriate error code
|
||||
fn io_error_to_hid_error(e: std::io::Error, operation: &str) -> AppError {
|
||||
let error_code = match e.raw_os_error() {
|
||||
Some(32) => "epipe", // EPIPE - broken pipe
|
||||
Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown
|
||||
Some(11) => "eagain", // EAGAIN - resource temporarily unavailable
|
||||
Some(6) => "enxio", // ENXIO - no such device or address
|
||||
Some(19) => "enodev", // ENODEV - no such device
|
||||
Some(5) => "eio", // EIO - I/O error
|
||||
Some(2) => "enoent", // ENOENT - no such file or directory
|
||||
_ => "io_error",
|
||||
};
|
||||
|
||||
AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: format!("{}: {}", operation, e),
|
||||
error_code: error_code.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if all HID device files exist
|
||||
pub fn check_devices_exist(&self) -> bool {
|
||||
self.keyboard_path.exists()
|
||||
&& self.mouse_rel_path.exists()
|
||||
&& self.mouse_abs_path.exists()
|
||||
}
|
||||
|
||||
/// Get list of missing device paths
|
||||
pub fn get_missing_devices(&self) -> Vec<String> {
|
||||
let mut missing = Vec::new();
|
||||
if !self.keyboard_path.exists() {
|
||||
missing.push(self.keyboard_path.display().to_string());
|
||||
}
|
||||
if !self.mouse_rel_path.exists() {
|
||||
missing.push(self.mouse_rel_path.display().to_string());
|
||||
}
|
||||
if !self.mouse_abs_path.exists() {
|
||||
missing.push(self.mouse_abs_path.display().to_string());
|
||||
}
|
||||
missing
|
||||
}
|
||||
|
||||
/// Send keyboard report (8 bytes)
|
||||
///
|
||||
/// This method ensures the device is open before writing, and handles
|
||||
/// ESHUTDOWN errors by closing the device handle for later reconnection.
|
||||
/// EAGAIN errors are treated as temporary - device stays open.
|
||||
fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> {
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::Keyboard)?;
|
||||
|
||||
let mut dev = self.keyboard_dev.lock();
|
||||
if let Some(ref mut file) = *dev {
|
||||
let data = report.to_bytes();
|
||||
match file.write_all(&data) {
|
||||
Ok(_) => {
|
||||
self.online.store(true, Ordering::Relaxed);
|
||||
self.reset_error_count();
|
||||
trace!("Sent keyboard report: {:02X?}", data);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_code = e.raw_os_error();
|
||||
|
||||
match error_code {
|
||||
Some(108) => {
|
||||
// ESHUTDOWN - endpoint closed, need to reopen device
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
debug!("Keyboard ESHUTDOWN, closing for recovery");
|
||||
*dev = None;
|
||||
Err(Self::io_error_to_hid_error(e, "Failed to write keyboard report"))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN - temporary busy, track consecutive count
|
||||
self.log_throttled_error("HID keyboard busy (EAGAIN)");
|
||||
let count = self.eagain_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
if count >= EAGAIN_OFFLINE_THRESHOLD {
|
||||
// Exceeded threshold, report as offline
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: format!("Device busy ({} consecutive EAGAIN)", count),
|
||||
error_code: "eagain".to_string(),
|
||||
})
|
||||
} else {
|
||||
// Within threshold, return retry error (won't trigger offline event)
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: "Device temporarily busy".to_string(),
|
||||
error_code: "eagain_retry".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
warn!("Keyboard write error: {}", e);
|
||||
Err(Self::io_error_to_hid_error(e, "Failed to write keyboard report"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: "Keyboard device not opened".to_string(),
|
||||
error_code: "not_opened".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Send relative mouse report (4 bytes: buttons, dx, dy, wheel)
|
||||
///
|
||||
/// This method ensures the device is open before writing, and handles
|
||||
/// ESHUTDOWN errors by closing the device handle for later reconnection.
|
||||
/// EAGAIN errors are treated as temporary - device stays open.
|
||||
fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> {
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::MouseRelative)?;
|
||||
|
||||
let mut dev = self.mouse_rel_dev.lock();
|
||||
if let Some(ref mut file) = *dev {
|
||||
let data = [buttons, dx as u8, dy as u8, wheel as u8];
|
||||
match file.write_all(&data) {
|
||||
Ok(_) => {
|
||||
self.online.store(true, Ordering::Relaxed);
|
||||
self.reset_error_count();
|
||||
trace!("Sent relative mouse report: {:02X?}", data);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_code = e.raw_os_error();
|
||||
|
||||
match error_code {
|
||||
Some(108) => {
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
debug!("Relative mouse ESHUTDOWN, closing for recovery");
|
||||
*dev = None;
|
||||
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN - temporary busy, track consecutive count
|
||||
self.log_throttled_error("HID relative mouse busy (EAGAIN)");
|
||||
let count = self.eagain_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
if count >= EAGAIN_OFFLINE_THRESHOLD {
|
||||
// Exceeded threshold, report as offline
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: format!("Device busy ({} consecutive EAGAIN)", count),
|
||||
error_code: "eagain".to_string(),
|
||||
})
|
||||
} else {
|
||||
// Within threshold, return retry error (won't trigger offline event)
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: "Device temporarily busy".to_string(),
|
||||
error_code: "eagain_retry".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
warn!("Relative mouse write error: {}", e);
|
||||
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: "Relative mouse device not opened".to_string(),
|
||||
error_code: "not_opened".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Send absolute mouse report (6 bytes: buttons, x_lo, x_hi, y_lo, y_hi, wheel)
|
||||
///
|
||||
/// This method ensures the device is open before writing, and handles
|
||||
/// ESHUTDOWN errors by closing the device handle for later reconnection.
|
||||
/// EAGAIN errors are treated as temporary - device stays open.
|
||||
fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> {
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::MouseAbsolute)?;
|
||||
|
||||
let mut dev = self.mouse_abs_dev.lock();
|
||||
if let Some(ref mut file) = *dev {
|
||||
let data = [
|
||||
buttons,
|
||||
(x & 0xFF) as u8,
|
||||
(x >> 8) as u8,
|
||||
(y & 0xFF) as u8,
|
||||
(y >> 8) as u8,
|
||||
wheel as u8,
|
||||
];
|
||||
match file.write_all(&data) {
|
||||
Ok(_) => {
|
||||
self.online.store(true, Ordering::Relaxed);
|
||||
self.reset_error_count();
|
||||
trace!("Sent absolute mouse report: {:02X?}", data);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_code = e.raw_os_error();
|
||||
|
||||
match error_code {
|
||||
Some(108) => {
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
debug!("Absolute mouse ESHUTDOWN, closing for recovery");
|
||||
*dev = None;
|
||||
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN - temporary busy, track consecutive count
|
||||
self.log_throttled_error("HID absolute mouse busy (EAGAIN)");
|
||||
let count = self.eagain_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
if count >= EAGAIN_OFFLINE_THRESHOLD {
|
||||
// Exceeded threshold, report as offline
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: format!("Device busy ({} consecutive EAGAIN)", count),
|
||||
error_code: "eagain".to_string(),
|
||||
})
|
||||
} else {
|
||||
// Within threshold, return retry error (won't trigger offline event)
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: "Device temporarily busy".to_string(),
|
||||
error_code: "eagain_retry".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
warn!("Absolute mouse write error: {}", e);
|
||||
Err(Self::io_error_to_hid_error(e, "Failed to write mouse report"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(AppError::HidError {
|
||||
backend: "otg".to_string(),
|
||||
reason: "Absolute mouse device not opened".to_string(),
|
||||
error_code: "not_opened".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Read keyboard LED state (non-blocking)
|
||||
pub fn read_led_state(&self) -> Result<Option<LedState>> {
|
||||
let mut dev = self.keyboard_dev.lock();
|
||||
if let Some(ref mut file) = *dev {
|
||||
let mut buf = [0u8; 1];
|
||||
match file.read(&mut buf) {
|
||||
Ok(1) => {
|
||||
let state = LedState::from_byte(buf[0]);
|
||||
// Update LED state (using parking_lot RwLock)
|
||||
*self.led_state.write() = state;
|
||||
Ok(Some(state))
|
||||
}
|
||||
Ok(_) => Ok(None), // No data available
|
||||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
|
||||
Err(e) => Err(AppError::Internal(format!("Failed to read LED state: {}", e))),
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get last known LED state
|
||||
pub fn led_state(&self) -> LedState {
|
||||
*self.led_state.read()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HidBackend for OtgBackend {
|
||||
fn name(&self) -> &'static str {
|
||||
"OTG USB Gadget"
|
||||
}
|
||||
|
||||
async fn init(&self) -> Result<()> {
|
||||
info!("Initializing OTG HID backend");
|
||||
|
||||
// Auto-detect UDC name for state checking
|
||||
if let Some(udc) = Self::find_udc() {
|
||||
info!("Auto-detected UDC: {}", udc);
|
||||
self.set_udc_name(&udc);
|
||||
}
|
||||
|
||||
// Wait for devices to appear (they should already exist from OtgService)
|
||||
let device_paths = vec![
|
||||
self.keyboard_path.clone(),
|
||||
self.mouse_rel_path.clone(),
|
||||
self.mouse_abs_path.clone(),
|
||||
];
|
||||
|
||||
if !wait_for_hid_devices(&device_paths, 2000).await {
|
||||
return Err(AppError::Internal("HID devices did not appear".into()));
|
||||
}
|
||||
|
||||
// Open keyboard device
|
||||
if self.keyboard_path.exists() {
|
||||
let file = Self::open_device(&self.keyboard_path)?;
|
||||
*self.keyboard_dev.lock() = Some(file);
|
||||
info!("Keyboard device opened: {}", self.keyboard_path.display());
|
||||
} else {
|
||||
warn!("Keyboard device not found: {}", self.keyboard_path.display());
|
||||
}
|
||||
|
||||
// Open relative mouse device
|
||||
if self.mouse_rel_path.exists() {
|
||||
let file = Self::open_device(&self.mouse_rel_path)?;
|
||||
*self.mouse_rel_dev.lock() = Some(file);
|
||||
info!("Relative mouse device opened: {}", self.mouse_rel_path.display());
|
||||
} else {
|
||||
warn!("Relative mouse device not found: {}", self.mouse_rel_path.display());
|
||||
}
|
||||
|
||||
// Open absolute mouse device
|
||||
if self.mouse_abs_path.exists() {
|
||||
let file = Self::open_device(&self.mouse_abs_path)?;
|
||||
*self.mouse_abs_dev.lock() = Some(file);
|
||||
info!("Absolute mouse device opened: {}", self.mouse_abs_path.display());
|
||||
} else {
|
||||
warn!("Absolute mouse device not found: {}", self.mouse_abs_path.display());
|
||||
}
|
||||
|
||||
// Mark as online if all devices opened successfully
|
||||
self.online.store(true, Ordering::Relaxed);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
|
||||
// Convert JS keycode to USB HID if needed
|
||||
let usb_key = keymap::js_to_usb(event.key).unwrap_or(event.key);
|
||||
|
||||
// Handle modifier keys separately
|
||||
if keymap::is_modifier_key(usb_key) {
|
||||
let mut state = self.keyboard_state.lock();
|
||||
|
||||
if let Some(bit) = keymap::modifier_bit(usb_key) {
|
||||
match event.event_type {
|
||||
KeyEventType::Down => state.modifiers |= bit,
|
||||
KeyEventType::Up => state.modifiers &= !bit,
|
||||
}
|
||||
}
|
||||
|
||||
let report = state.clone();
|
||||
drop(state);
|
||||
|
||||
self.send_keyboard_report(&report)?;
|
||||
} else {
|
||||
let mut state = self.keyboard_state.lock();
|
||||
|
||||
// Update modifiers from event
|
||||
state.modifiers = event.modifiers.to_hid_byte();
|
||||
|
||||
match event.event_type {
|
||||
KeyEventType::Down => {
|
||||
state.add_key(usb_key);
|
||||
}
|
||||
KeyEventType::Up => {
|
||||
state.remove_key(usb_key);
|
||||
}
|
||||
}
|
||||
|
||||
let report = state.clone();
|
||||
drop(state);
|
||||
|
||||
self.send_keyboard_report(&report)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
|
||||
let buttons = self.mouse_buttons.load(Ordering::Relaxed);
|
||||
|
||||
match event.event_type {
|
||||
MouseEventType::Move => {
|
||||
// Relative movement - use hidg1
|
||||
let dx = event.x.clamp(-127, 127) as i8;
|
||||
let dy = event.y.clamp(-127, 127) as i8;
|
||||
self.send_mouse_report_relative(buttons, dx, dy, 0)?;
|
||||
}
|
||||
MouseEventType::MoveAbs => {
|
||||
// Absolute movement - use hidg2
|
||||
// Frontend sends 0-32767 range directly (standard HID absolute mouse range)
|
||||
let x = event.x.clamp(0, 32767) as u16;
|
||||
let y = event.y.clamp(0, 32767) as u16;
|
||||
self.send_mouse_report_absolute(buttons, x, y, 0)?;
|
||||
}
|
||||
MouseEventType::Down => {
|
||||
if let Some(button) = event.button {
|
||||
let bit = button.to_hid_bit();
|
||||
let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit;
|
||||
// Send on relative device for button clicks
|
||||
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
|
||||
}
|
||||
}
|
||||
MouseEventType::Up => {
|
||||
if let Some(button) = event.button {
|
||||
let bit = button.to_hid_bit();
|
||||
let new_buttons = self.mouse_buttons.fetch_and(!bit, Ordering::Relaxed) & !bit;
|
||||
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
|
||||
}
|
||||
}
|
||||
MouseEventType::Scroll => {
|
||||
self.send_mouse_report_relative(buttons, 0, 0, event.scroll)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn reset(&self) -> Result<()> {
|
||||
// Reset keyboard
|
||||
{
|
||||
let mut state = self.keyboard_state.lock();
|
||||
state.clear();
|
||||
let report = state.clone();
|
||||
drop(state);
|
||||
self.send_keyboard_report(&report)?;
|
||||
}
|
||||
|
||||
// Reset mouse
|
||||
self.mouse_buttons.store(0, Ordering::Relaxed);
|
||||
self.send_mouse_report_relative(0, 0, 0, 0)?;
|
||||
self.send_mouse_report_absolute(0, 0, 0, 0)?;
|
||||
|
||||
info!("HID state reset");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
// Reset before closing
|
||||
self.reset().await?;
|
||||
|
||||
// Close devices
|
||||
*self.keyboard_dev.lock() = None;
|
||||
*self.mouse_rel_dev.lock() = None;
|
||||
*self.mouse_abs_dev.lock() = None;
|
||||
|
||||
// Gadget cleanup is handled by OtgService, not here
|
||||
|
||||
info!("OTG backend shutdown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn supports_absolute_mouse(&self) -> bool {
|
||||
self.mouse_abs_path.exists()
|
||||
}
|
||||
|
||||
fn screen_resolution(&self) -> Option<(u32, u32)> {
|
||||
*self.screen_resolution.read()
|
||||
}
|
||||
|
||||
fn set_screen_resolution(&mut self, width: u32, height: u32) {
|
||||
*self.screen_resolution.write() = Some((width, height));
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if OTG HID gadget is available
|
||||
pub fn is_otg_available() -> bool {
|
||||
// Check for existing HID devices (they should be created by OtgService)
|
||||
let kb = PathBuf::from("/dev/hidg0");
|
||||
let mouse_rel = PathBuf::from("/dev/hidg1");
|
||||
let mouse_abs = PathBuf::from("/dev/hidg2");
|
||||
|
||||
kb.exists() && mouse_rel.exists() && mouse_abs.exists()
|
||||
}
|
||||
|
||||
/// Implement Drop for OtgBackend to close device files
|
||||
impl Drop for OtgBackend {
|
||||
fn drop(&mut self) {
|
||||
// Close device files
|
||||
// Note: Gadget cleanup is handled by OtgService, not here
|
||||
*self.keyboard_dev.lock() = None;
|
||||
*self.mouse_rel_dev.lock() = None;
|
||||
*self.mouse_abs_dev.lock() = None;
|
||||
debug!("OtgBackend dropped, device files closed");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_otg_availability_check() {
|
||||
// This just tests the function runs without panicking
|
||||
let _available = is_otg_available();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_led_state() {
|
||||
let state = LedState::from_byte(0b00000011);
|
||||
assert!(state.num_lock);
|
||||
assert!(state.caps_lock);
|
||||
assert!(!state.scroll_lock);
|
||||
|
||||
assert_eq!(state.to_byte(), 0b00000011);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_report_sizes() {
|
||||
// Keyboard report is 8 bytes
|
||||
let kb_report = KeyboardReport::default();
|
||||
assert_eq!(kb_report.to_bytes().len(), 8);
|
||||
}
|
||||
}
|
||||
382
src/hid/types.rs
Normal file
382
src/hid/types.rs
Normal file
@@ -0,0 +1,382 @@
|
||||
//! HID event types for keyboard and mouse
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Keyboard event type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum KeyEventType {
|
||||
/// Key pressed down
|
||||
Down,
|
||||
/// Key released
|
||||
Up,
|
||||
}
|
||||
|
||||
/// Keyboard modifier flags
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct KeyboardModifiers {
|
||||
/// Left Control
|
||||
#[serde(default)]
|
||||
pub left_ctrl: bool,
|
||||
/// Left Shift
|
||||
#[serde(default)]
|
||||
pub left_shift: bool,
|
||||
/// Left Alt
|
||||
#[serde(default)]
|
||||
pub left_alt: bool,
|
||||
/// Left Meta (Windows/Super key)
|
||||
#[serde(default)]
|
||||
pub left_meta: bool,
|
||||
/// Right Control
|
||||
#[serde(default)]
|
||||
pub right_ctrl: bool,
|
||||
/// Right Shift
|
||||
#[serde(default)]
|
||||
pub right_shift: bool,
|
||||
/// Right Alt (AltGr)
|
||||
#[serde(default)]
|
||||
pub right_alt: bool,
|
||||
/// Right Meta
|
||||
#[serde(default)]
|
||||
pub right_meta: bool,
|
||||
}
|
||||
|
||||
impl KeyboardModifiers {
|
||||
/// Convert to USB HID modifier byte
|
||||
pub fn to_hid_byte(&self) -> u8 {
|
||||
let mut byte = 0u8;
|
||||
if self.left_ctrl {
|
||||
byte |= 0x01;
|
||||
}
|
||||
if self.left_shift {
|
||||
byte |= 0x02;
|
||||
}
|
||||
if self.left_alt {
|
||||
byte |= 0x04;
|
||||
}
|
||||
if self.left_meta {
|
||||
byte |= 0x08;
|
||||
}
|
||||
if self.right_ctrl {
|
||||
byte |= 0x10;
|
||||
}
|
||||
if self.right_shift {
|
||||
byte |= 0x20;
|
||||
}
|
||||
if self.right_alt {
|
||||
byte |= 0x40;
|
||||
}
|
||||
if self.right_meta {
|
||||
byte |= 0x80;
|
||||
}
|
||||
byte
|
||||
}
|
||||
|
||||
/// Create from USB HID modifier byte
|
||||
pub fn from_hid_byte(byte: u8) -> Self {
|
||||
Self {
|
||||
left_ctrl: byte & 0x01 != 0,
|
||||
left_shift: byte & 0x02 != 0,
|
||||
left_alt: byte & 0x04 != 0,
|
||||
left_meta: byte & 0x08 != 0,
|
||||
right_ctrl: byte & 0x10 != 0,
|
||||
right_shift: byte & 0x20 != 0,
|
||||
right_alt: byte & 0x40 != 0,
|
||||
right_meta: byte & 0x80 != 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if any modifier is active
|
||||
pub fn any(&self) -> bool {
|
||||
self.left_ctrl
|
||||
|| self.left_shift
|
||||
|| self.left_alt
|
||||
|| self.left_meta
|
||||
|| self.right_ctrl
|
||||
|| self.right_shift
|
||||
|| self.right_alt
|
||||
|| self.right_meta
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyboard event
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyboardEvent {
|
||||
/// Event type (down/up)
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: KeyEventType,
|
||||
/// Key code (USB HID usage code or JavaScript key code)
|
||||
pub key: u8,
|
||||
/// Modifier keys state
|
||||
#[serde(default)]
|
||||
pub modifiers: KeyboardModifiers,
|
||||
}
|
||||
|
||||
impl KeyboardEvent {
|
||||
/// Create a key down event
|
||||
pub fn key_down(key: u8, modifiers: KeyboardModifiers) -> Self {
|
||||
Self {
|
||||
event_type: KeyEventType::Down,
|
||||
key,
|
||||
modifiers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a key up event
|
||||
pub fn key_up(key: u8, modifiers: KeyboardModifiers) -> Self {
|
||||
Self {
|
||||
event_type: KeyEventType::Up,
|
||||
key,
|
||||
modifiers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mouse button
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MouseButton {
|
||||
Left,
|
||||
Right,
|
||||
Middle,
|
||||
Back,
|
||||
Forward,
|
||||
}
|
||||
|
||||
impl MouseButton {
|
||||
/// Convert to USB HID button bit
|
||||
pub fn to_hid_bit(&self) -> u8 {
|
||||
match self {
|
||||
MouseButton::Left => 0x01,
|
||||
MouseButton::Right => 0x02,
|
||||
MouseButton::Middle => 0x04,
|
||||
MouseButton::Back => 0x08,
|
||||
MouseButton::Forward => 0x10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Mouse event type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MouseEventType {
|
||||
/// Mouse moved (relative movement)
|
||||
Move,
|
||||
/// Mouse moved (absolute position)
|
||||
MoveAbs,
|
||||
/// Button pressed
|
||||
Down,
|
||||
/// Button released
|
||||
Up,
|
||||
/// Mouse wheel scroll
|
||||
Scroll,
|
||||
}
|
||||
|
||||
/// Mouse event
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MouseEvent {
|
||||
/// Event type
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: MouseEventType,
|
||||
/// X coordinate or delta
|
||||
#[serde(default)]
|
||||
pub x: i32,
|
||||
/// Y coordinate or delta
|
||||
#[serde(default)]
|
||||
pub y: i32,
|
||||
/// Button (for down/up events)
|
||||
#[serde(default)]
|
||||
pub button: Option<MouseButton>,
|
||||
/// Scroll delta (for scroll events)
|
||||
#[serde(default)]
|
||||
pub scroll: i8,
|
||||
}
|
||||
|
||||
impl MouseEvent {
|
||||
/// Create a relative move event
|
||||
pub fn move_rel(dx: i32, dy: i32) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Move,
|
||||
x: dx,
|
||||
y: dy,
|
||||
button: None,
|
||||
scroll: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an absolute move event
|
||||
pub fn move_abs(x: i32, y: i32) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::MoveAbs,
|
||||
x,
|
||||
y,
|
||||
button: None,
|
||||
scroll: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a button down event
|
||||
pub fn button_down(button: MouseButton) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Down,
|
||||
x: 0,
|
||||
y: 0,
|
||||
button: Some(button),
|
||||
scroll: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a button up event
|
||||
pub fn button_up(button: MouseButton) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Up,
|
||||
x: 0,
|
||||
y: 0,
|
||||
button: Some(button),
|
||||
scroll: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a scroll event
|
||||
pub fn scroll(delta: i8) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Scroll,
|
||||
x: 0,
|
||||
y: 0,
|
||||
button: None,
|
||||
scroll: delta,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined HID event (keyboard or mouse)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "device", rename_all = "lowercase")]
|
||||
pub enum HidEvent {
|
||||
Keyboard(KeyboardEvent),
|
||||
Mouse(MouseEvent),
|
||||
}
|
||||
|
||||
/// USB HID keyboard report (8 bytes)
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct KeyboardReport {
|
||||
/// Modifier byte
|
||||
pub modifiers: u8,
|
||||
/// Reserved byte
|
||||
pub reserved: u8,
|
||||
/// Key codes (up to 6 simultaneous keys)
|
||||
pub keys: [u8; 6],
|
||||
}
|
||||
|
||||
impl KeyboardReport {
|
||||
/// Convert to bytes for USB HID
|
||||
pub fn to_bytes(&self) -> [u8; 8] {
|
||||
[
|
||||
self.modifiers,
|
||||
self.reserved,
|
||||
self.keys[0],
|
||||
self.keys[1],
|
||||
self.keys[2],
|
||||
self.keys[3],
|
||||
self.keys[4],
|
||||
self.keys[5],
|
||||
]
|
||||
}
|
||||
|
||||
/// Add a key to the report
|
||||
pub fn add_key(&mut self, key: u8) -> bool {
|
||||
for slot in &mut self.keys {
|
||||
if *slot == 0 {
|
||||
*slot = key;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false // All slots full
|
||||
}
|
||||
|
||||
/// Remove a key from the report
|
||||
pub fn remove_key(&mut self, key: u8) {
|
||||
for slot in &mut self.keys {
|
||||
if *slot == key {
|
||||
*slot = 0;
|
||||
}
|
||||
}
|
||||
// Compact the array
|
||||
self.keys.sort_by(|a, b| b.cmp(a));
|
||||
}
|
||||
|
||||
/// Clear all keys
|
||||
pub fn clear(&mut self) {
|
||||
self.modifiers = 0;
|
||||
self.keys = [0; 6];
|
||||
}
|
||||
}
|
||||
|
||||
/// USB HID mouse report
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MouseReport {
|
||||
/// Button state
|
||||
pub buttons: u8,
|
||||
/// X movement (-127 to 127)
|
||||
pub x: i8,
|
||||
/// Y movement (-127 to 127)
|
||||
pub y: i8,
|
||||
/// Wheel movement (-127 to 127)
|
||||
pub wheel: i8,
|
||||
}
|
||||
|
||||
impl MouseReport {
|
||||
/// Convert to bytes for USB HID (relative mouse)
|
||||
pub fn to_bytes_relative(&self) -> [u8; 4] {
|
||||
[
|
||||
self.buttons,
|
||||
self.x as u8,
|
||||
self.y as u8,
|
||||
self.wheel as u8,
|
||||
]
|
||||
}
|
||||
|
||||
/// Convert to bytes for USB HID (absolute mouse)
|
||||
pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] {
|
||||
[
|
||||
self.buttons,
|
||||
(x & 0xFF) as u8,
|
||||
(x >> 8) as u8,
|
||||
(y & 0xFF) as u8,
|
||||
(y >> 8) as u8,
|
||||
self.wheel as u8,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_modifier_conversion() {
|
||||
let mods = KeyboardModifiers {
|
||||
left_ctrl: true,
|
||||
left_shift: true,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(mods.to_hid_byte(), 0x03);
|
||||
|
||||
let mods2 = KeyboardModifiers::from_hid_byte(0x03);
|
||||
assert!(mods2.left_ctrl);
|
||||
assert!(mods2.left_shift);
|
||||
assert!(!mods2.left_alt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keyboard_report() {
|
||||
let mut report = KeyboardReport::default();
|
||||
assert!(report.add_key(0x04)); // 'A'
|
||||
assert!(report.add_key(0x05)); // 'B'
|
||||
assert_eq!(report.keys[0], 0x04);
|
||||
assert_eq!(report.keys[1], 0x05);
|
||||
|
||||
report.remove_key(0x04);
|
||||
assert_eq!(report.keys[0], 0x05);
|
||||
}
|
||||
}
|
||||
160
src/hid/websocket.rs
Normal file
160
src/hid/websocket.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
//! WebSocket HID channel for HTTP/MJPEG mode
|
||||
//!
|
||||
//! This provides an alternative to WebRTC DataChannel for HID input
|
||||
//! when using MJPEG streaming mode.
|
||||
//!
|
||||
//! Uses binary protocol only (same format as DataChannel):
|
||||
//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes)
|
||||
//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes)
|
||||
//!
|
||||
//! See datachannel.rs for detailed protocol specification.
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::datachannel::{parse_hid_message, HidChannelEvent};
|
||||
use crate::state::AppState;
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// Binary response codes
|
||||
const RESP_OK: u8 = 0x00;
|
||||
const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01;
|
||||
const RESP_ERR_INVALID_MESSAGE: u8 = 0x02;
|
||||
#[allow(dead_code)]
|
||||
const RESP_ERR_SEND_FAILED: u8 = 0x03;
|
||||
|
||||
/// WebSocket HID upgrade handler
|
||||
pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_hid_socket(socket, state))
|
||||
}
|
||||
|
||||
/// Handle HID WebSocket connection
|
||||
async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
// Log throttler for error messages (5 second interval)
|
||||
let log_throttler = LogThrottler::with_secs(5);
|
||||
|
||||
info!("WebSocket HID connection established (binary protocol)");
|
||||
|
||||
// Check if HID controller is available and send initial status
|
||||
let hid_available = state.hid.is_available().await;
|
||||
let initial_response = if hid_available {
|
||||
vec![RESP_OK]
|
||||
} else {
|
||||
vec![RESP_ERR_HID_UNAVAILABLE]
|
||||
};
|
||||
|
||||
if sender.send(Message::Binary(initial_response)).await.is_err() {
|
||||
error!("Failed to send initial HID status");
|
||||
return;
|
||||
}
|
||||
|
||||
// Process incoming messages (binary only)
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Binary(data)) => {
|
||||
// Check HID availability before processing each message
|
||||
let hid_available = state.hid.is_available().await;
|
||||
if !hid_available {
|
||||
if log_throttler.should_log("hid_unavailable") {
|
||||
warn!("HID controller not available, ignoring message");
|
||||
}
|
||||
// Send error response (optional, for client awareness)
|
||||
let _ = sender.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE])).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = handle_binary_message(&data, &state).await {
|
||||
// Log with throttling to avoid spam
|
||||
if log_throttler.should_log("binary_hid_error") {
|
||||
warn!("Binary HID message error: {}", e);
|
||||
}
|
||||
// Don't send error response for every failed message to reduce overhead
|
||||
}
|
||||
}
|
||||
Ok(Message::Text(text)) => {
|
||||
// Text messages are no longer supported
|
||||
if log_throttler.should_log("text_message_rejected") {
|
||||
debug!("Received text message (not supported): {} bytes", text.len());
|
||||
}
|
||||
let _ = sender.send(Message::Binary(vec![RESP_ERR_INVALID_MESSAGE])).await;
|
||||
}
|
||||
Ok(Message::Ping(data)) => {
|
||||
let _ = sender.send(Message::Pong(data)).await;
|
||||
}
|
||||
Ok(Message::Close(_)) => {
|
||||
info!("WebSocket HID connection closed by client");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("WebSocket error: {}", e);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket HID connection ended");
|
||||
}
|
||||
|
||||
/// Handle binary HID message (same format as DataChannel)
|
||||
async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> {
|
||||
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
|
||||
|
||||
match event {
|
||||
HidChannelEvent::Keyboard(kb_event) => {
|
||||
state
|
||||
.hid
|
||||
.send_keyboard(kb_event)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
}
|
||||
HidChannelEvent::Mouse(ms_event) => {
|
||||
state
|
||||
.hid
|
||||
.send_mouse(ms_event)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::hid::datachannel::{MSG_KEYBOARD, MSG_MOUSE, KB_EVENT_DOWN, MS_EVENT_MOVE};
|
||||
|
||||
#[test]
|
||||
fn test_response_codes() {
|
||||
assert_eq!(RESP_OK, 0x00);
|
||||
assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01);
|
||||
assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02);
|
||||
assert_eq!(RESP_ERR_SEND_FAILED, 0x03);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keyboard_message_format() {
|
||||
// Keyboard message: [0x01, event_type, key, modifiers]
|
||||
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl
|
||||
let event = parse_hid_message(&data);
|
||||
assert!(event.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mouse_message_format() {
|
||||
// Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra]
|
||||
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
|
||||
let event = parse_hid_message(&data);
|
||||
assert!(event.is_some());
|
||||
}
|
||||
}
|
||||
24
src/lib.rs
Normal file
24
src/lib.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
//! One-KVM - Lightweight IP-KVM solution
|
||||
//!
|
||||
//! This crate provides the core functionality for One-KVM,
|
||||
//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust.
|
||||
|
||||
pub mod atx;
|
||||
pub mod audio;
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod events;
|
||||
pub mod extensions;
|
||||
pub mod hid;
|
||||
pub mod modules;
|
||||
pub mod msd;
|
||||
pub mod otg;
|
||||
pub mod state;
|
||||
pub mod stream;
|
||||
pub mod utils;
|
||||
pub mod video;
|
||||
pub mod web;
|
||||
pub mod webrtc;
|
||||
|
||||
pub use error::{AppError, Result};
|
||||
667
src/main.rs
Normal file
667
src/main.rs
Normal file
@@ -0,0 +1,667 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use tokio::sync::broadcast;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use rustls::crypto::{ring, CryptoProvider};
|
||||
|
||||
use one_kvm::atx::AtxController;
|
||||
use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality};
|
||||
use one_kvm::auth::{SessionStore, UserStore};
|
||||
use one_kvm::config::{self, AppConfig, ConfigStore};
|
||||
use one_kvm::events::EventBus;
|
||||
use one_kvm::extensions::ExtensionManager;
|
||||
use one_kvm::hid::{HidBackendType, HidController};
|
||||
use one_kvm::msd::MsdController;
|
||||
use one_kvm::otg::OtgService;
|
||||
use one_kvm::state::AppState;
|
||||
use one_kvm::video::format::{PixelFormat, Resolution};
|
||||
use one_kvm::video::{Streamer, VideoStreamManager};
|
||||
use one_kvm::web;
|
||||
use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig};
|
||||
|
||||
/// Log level for the application
|
||||
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
|
||||
enum LogLevel {Error, Warn, #[default] Info, Verbose, Debug, Trace,}
|
||||
|
||||
/// One-KVM command line arguments
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "one-kvm")]
|
||||
#[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)]
|
||||
struct CliArgs {
|
||||
/// Listen address
|
||||
#[arg(short = 'a', long, value_name = "ADDRESS", default_value = "0.0.0.0")]
|
||||
address: String,
|
||||
|
||||
/// HTTP port (used when HTTPS is disabled)
|
||||
#[arg(short = 'p', long, value_name = "PORT", default_value = "8080")]
|
||||
http_port: u16,
|
||||
|
||||
/// HTTPS port (used when HTTPS is enabled)
|
||||
#[arg(long, value_name = "PORT", default_value = "8443")]
|
||||
https_port: u16,
|
||||
|
||||
/// Enable HTTPS
|
||||
#[arg(long, default_value = "false")]
|
||||
enable_https: bool,
|
||||
|
||||
/// Path to SSL certificate file (generates self-signed if not provided)
|
||||
#[arg(long, value_name = "FILE", requires = "ssl_key")]
|
||||
ssl_cert: Option<PathBuf>,
|
||||
|
||||
/// Path to SSL private key file
|
||||
#[arg(long, value_name = "FILE", requires = "ssl_cert")]
|
||||
ssl_key: Option<PathBuf>,
|
||||
|
||||
/// Data directory path (default: ./data)
|
||||
#[arg(short = 'd', long, value_name = "DIR")]
|
||||
data_dir: Option<PathBuf>,
|
||||
|
||||
/// Log level (error, warn, info, verbose, debug, trace)
|
||||
#[arg(short = 'l', long, value_name = "LEVEL", default_value = "info")]
|
||||
log_level: LogLevel,
|
||||
|
||||
/// Increase verbosity (-v for verbose, -vv for debug, -vvv for trace)
|
||||
#[arg(short = 'v', long, action = clap::ArgAction::Count)]
|
||||
verbose: u8,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Parse command line arguments
|
||||
let args = CliArgs::parse();
|
||||
|
||||
// Initialize logging with CLI arguments
|
||||
init_logging(args.log_level, args.verbose);
|
||||
|
||||
// Install default crypto provider (required by rustls 0.23+)
|
||||
CryptoProvider::install_default(ring::default_provider())
|
||||
.expect("Failed to install rustls crypto provider");
|
||||
|
||||
tracing::info!(
|
||||
"Starting One-KVM v{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
);
|
||||
|
||||
// Determine data directory (CLI arg takes precedence)
|
||||
let data_dir = args.data_dir.unwrap_or_else(get_data_dir);
|
||||
tracing::info!("Data directory: {}", data_dir.display());
|
||||
|
||||
// Ensure data directory exists
|
||||
tokio::fs::create_dir_all(&data_dir).await?;
|
||||
|
||||
// Initialize configuration store
|
||||
let db_path = data_dir.join("one-kvm.db");
|
||||
let config_store = ConfigStore::new(&db_path).await?;
|
||||
let mut config = (*config_store.get()).clone();
|
||||
|
||||
// Apply CLI argument overrides to config
|
||||
config.web.bind_address = args.address;
|
||||
config.web.http_port = args.http_port;
|
||||
config.web.https_port = args.https_port;
|
||||
config.web.https_enabled = args.enable_https;
|
||||
|
||||
if let Some(cert_path) = args.ssl_cert {
|
||||
config.web.ssl_cert_path = Some(cert_path.to_string_lossy().to_string());
|
||||
}
|
||||
if let Some(key_path) = args.ssl_key {
|
||||
config.web.ssl_key_path = Some(key_path.to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
// Log final configuration
|
||||
if config.web.https_enabled {
|
||||
tracing::info!(
|
||||
"Server will listen on: https://{}:{}",
|
||||
config.web.bind_address,
|
||||
config.web.https_port
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Server will listen on: http://{}:{}",
|
||||
config.web.bind_address,
|
||||
config.web.http_port
|
||||
);
|
||||
}
|
||||
|
||||
// Initialize session store
|
||||
let session_store = SessionStore::new(
|
||||
config_store.pool().clone(),
|
||||
config.auth.session_timeout_secs as i64,
|
||||
);
|
||||
|
||||
// Initialize user store
|
||||
let user_store = UserStore::new(config_store.pool().clone());
|
||||
|
||||
// Create shutdown channel
|
||||
let (shutdown_tx, _) = broadcast::channel::<()>(1);
|
||||
|
||||
// Create event bus for real-time notifications
|
||||
let events = Arc::new(EventBus::new());
|
||||
tracing::info!("Event bus initialized");
|
||||
|
||||
// Parse video configuration once (avoid duplication)
|
||||
let (video_format, video_resolution) = parse_video_config(&config);
|
||||
tracing::debug!("Parsed video config: {} @ {}x{}", video_format, video_resolution.width, video_resolution.height);
|
||||
|
||||
// Create video streamer and initialize with config if device is set
|
||||
let streamer = Streamer::new();
|
||||
streamer.set_event_bus(events.clone()).await;
|
||||
if let Some(ref device_path) = config.video.device {
|
||||
if let Err(e) = streamer
|
||||
.apply_video_config(device_path, video_format, video_resolution, config.video.fps)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to initialize video with config: {}, will auto-detect", e);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Video configured: {} @ {}x{} {}",
|
||||
device_path, video_resolution.width, video_resolution.height, video_format
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Create WebRTC streamer
|
||||
let webrtc_streamer = {
|
||||
let webrtc_config = WebRtcStreamerConfig {
|
||||
resolution: video_resolution,
|
||||
input_format: video_format,
|
||||
fps: config.video.fps,
|
||||
bitrate_kbps: config.stream.bitrate_kbps,
|
||||
gop_size: config.stream.gop_size,
|
||||
encoder_backend: config.stream.encoder.to_backend(),
|
||||
webrtc: {
|
||||
let mut stun_servers = vec![];
|
||||
let mut turn_servers = vec![];
|
||||
|
||||
// Add STUN server from config
|
||||
if let Some(ref stun) = config.stream.stun_server {
|
||||
if !stun.is_empty() {
|
||||
stun_servers.push(stun.clone());
|
||||
tracing::info!("WebRTC STUN server configured: {}", stun);
|
||||
}
|
||||
}
|
||||
|
||||
// Add TURN server from config
|
||||
if let Some(ref turn) = config.stream.turn_server {
|
||||
if !turn.is_empty() {
|
||||
let username = config.stream.turn_username.clone().unwrap_or_default();
|
||||
let credential = config.stream.turn_password.clone().unwrap_or_default();
|
||||
turn_servers.push(one_kvm::webrtc::config::TurnServer {
|
||||
url: turn.clone(),
|
||||
username: username.clone(),
|
||||
credential,
|
||||
});
|
||||
tracing::info!("WebRTC TURN server configured: {} (user: {})", turn, username);
|
||||
}
|
||||
}
|
||||
|
||||
one_kvm::webrtc::config::WebRtcConfig {
|
||||
stun_servers,
|
||||
turn_servers,
|
||||
..Default::default()
|
||||
}
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
WebRtcStreamer::with_config(webrtc_config)
|
||||
};
|
||||
tracing::info!("WebRTC streamer created (supports H264, extensible to VP8/VP9/H265)");
|
||||
|
||||
|
||||
// Create OTG Service (single instance for centralized USB gadget management)
|
||||
let otg_service = Arc::new(OtgService::new());
|
||||
tracing::info!("OTG Service created");
|
||||
|
||||
// Pre-enable OTG functions to avoid gadget recreation (prevents kernel crashes)
|
||||
let will_use_otg_hid = matches!(config.hid.backend, config::HidBackend::Otg);
|
||||
let will_use_msd = config.msd.enabled || will_use_otg_hid;
|
||||
|
||||
if will_use_otg_hid {
|
||||
if !config.msd.enabled {
|
||||
tracing::info!("OTG HID enabled, automatically enabling MSD functionality");
|
||||
}
|
||||
if let Err(e) = otg_service.enable_hid().await {
|
||||
tracing::warn!("Failed to pre-enable HID: {}", e);
|
||||
}
|
||||
}
|
||||
if will_use_msd {
|
||||
if let Err(e) = otg_service.enable_msd().await {
|
||||
tracing::warn!("Failed to pre-enable MSD: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create HID controller based on config
|
||||
let hid_backend = match config.hid.backend {
|
||||
config::HidBackend::Otg => HidBackendType::Otg,
|
||||
config::HidBackend::Ch9329 => HidBackendType::Ch9329 {
|
||||
port: config.hid.ch9329_port.clone(),
|
||||
baud_rate: config.hid.ch9329_baudrate,
|
||||
},
|
||||
config::HidBackend::None => HidBackendType::None,
|
||||
};
|
||||
let hid = Arc::new(HidController::new(
|
||||
hid_backend,
|
||||
Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG
|
||||
));
|
||||
hid.set_event_bus(events.clone()).await;
|
||||
if let Err(e) = hid.init().await {
|
||||
tracing::warn!("Failed to initialize HID backend: {}", e);
|
||||
}
|
||||
|
||||
// Create MSD controller (optional, based on config)
|
||||
let msd = if config.msd.enabled {
|
||||
// Initialize Ventoy resources from data directory
|
||||
let ventoy_resource_dir = ventoy_img::get_resource_dir(&data_dir);
|
||||
if ventoy_resource_dir.exists() {
|
||||
if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) {
|
||||
tracing::warn!("Failed to initialize Ventoy resources: {}", e);
|
||||
tracing::info!("Ventoy resource files should be placed in: {}", ventoy_resource_dir.display());
|
||||
tracing::info!("Required files: {:?}", ventoy_img::required_files());
|
||||
} else {
|
||||
tracing::info!("Ventoy resources initialized from {}", ventoy_resource_dir.display());
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Ventoy resource directory not found: {}", ventoy_resource_dir.display());
|
||||
tracing::info!("Create the directory and place the following files: {:?}", ventoy_img::required_files());
|
||||
}
|
||||
|
||||
let controller = MsdController::new(
|
||||
otg_service.clone(),
|
||||
&config.msd.images_path,
|
||||
&config.msd.drive_path,
|
||||
);
|
||||
if let Err(e) = controller.init().await {
|
||||
tracing::warn!("Failed to initialize MSD controller: {}", e);
|
||||
None
|
||||
} else {
|
||||
controller.set_event_bus(events.clone()).await;
|
||||
Some(controller)
|
||||
}
|
||||
} else {
|
||||
tracing::info!("MSD disabled in configuration");
|
||||
None
|
||||
};
|
||||
|
||||
// Create ATX controller (optional, based on config)
|
||||
let atx = if config.atx.enabled {
|
||||
let controller_config = config.atx.to_controller_config();
|
||||
let controller = AtxController::new(controller_config);
|
||||
|
||||
if let Err(e) = controller.init().await {
|
||||
tracing::warn!("Failed to initialize ATX controller: {}", e);
|
||||
None
|
||||
} else {
|
||||
Some(controller)
|
||||
}
|
||||
} else {
|
||||
tracing::info!("ATX disabled in configuration");
|
||||
None
|
||||
};
|
||||
|
||||
// Create Audio controller
|
||||
let audio = {
|
||||
let audio_config = AudioControllerConfig {
|
||||
enabled: config.audio.enabled,
|
||||
device: config.audio.device.clone(),
|
||||
quality: AudioQuality::from_str(&config.audio.quality),
|
||||
};
|
||||
|
||||
let controller = AudioController::new(audio_config);
|
||||
controller.set_event_bus(events.clone()).await;
|
||||
|
||||
if config.audio.enabled {
|
||||
tracing::info!(
|
||||
"Audio enabled: {}, quality={}",
|
||||
config.audio.device,
|
||||
config.audio.quality
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Audio disabled in configuration");
|
||||
}
|
||||
|
||||
Arc::new(controller)
|
||||
};
|
||||
|
||||
// Create Extension manager (ttyd, gostc, easytier)
|
||||
let extensions = Arc::new(ExtensionManager::new());
|
||||
tracing::info!("Extension manager initialized");
|
||||
|
||||
// Wire up WebRTC streamer with HID controller
|
||||
// This enables WebRTC DataChannel to process HID events
|
||||
webrtc_streamer.set_hid_controller(hid.clone()).await;
|
||||
|
||||
// Wire up WebRTC streamer with Audio controller
|
||||
// This enables WebRTC audio track to receive Opus frames
|
||||
webrtc_streamer.set_audio_controller(audio.clone()).await;
|
||||
if config.audio.enabled {
|
||||
if let Err(e) = webrtc_streamer.set_audio_enabled(true).await {
|
||||
tracing::warn!("Failed to enable WebRTC audio: {}", e);
|
||||
} else {
|
||||
tracing::info!("WebRTC audio enabled");
|
||||
}
|
||||
}
|
||||
|
||||
// Set up frame source from video streamer (if capturer is available)
|
||||
// The frame source allows WebRTC sessions to receive live video frames
|
||||
if let Some(frame_tx) = streamer.frame_sender().await {
|
||||
// Synchronize WebRTC config with actual capture format before connecting
|
||||
let (actual_format, actual_resolution, actual_fps) = streamer.current_video_config().await;
|
||||
tracing::info!(
|
||||
"Initial video config from capturer: {}x{} {:?} @ {}fps",
|
||||
actual_resolution.width, actual_resolution.height, actual_format, actual_fps
|
||||
);
|
||||
webrtc_streamer.update_video_config(actual_resolution, actual_format, actual_fps).await;
|
||||
webrtc_streamer.set_video_source(frame_tx).await;
|
||||
tracing::info!("WebRTC streamer connected to video frame source");
|
||||
} else {
|
||||
tracing::warn!("Video capturer not ready, WebRTC will connect to frame source when available");
|
||||
}
|
||||
|
||||
// Create video stream manager (unified MJPEG/WebRTC management)
|
||||
// Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance
|
||||
let stream_manager = VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone());
|
||||
stream_manager.set_event_bus(events.clone()).await;
|
||||
stream_manager.set_config_store(config_store.clone()).await;
|
||||
|
||||
// Initialize stream manager with configured mode
|
||||
let initial_mode = config.stream.mode.clone();
|
||||
if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await {
|
||||
tracing::warn!("Failed to initialize stream manager with mode {:?}: {}", initial_mode, e);
|
||||
} else {
|
||||
tracing::info!("Video stream manager initialized with mode: {:?}", initial_mode);
|
||||
}
|
||||
|
||||
// Create application state
|
||||
let state = AppState::new(
|
||||
config_store.clone(),
|
||||
session_store,
|
||||
user_store,
|
||||
otg_service,
|
||||
stream_manager,
|
||||
hid,
|
||||
msd,
|
||||
atx,
|
||||
audio,
|
||||
extensions.clone(),
|
||||
events.clone(),
|
||||
shutdown_tx.clone(),
|
||||
data_dir.clone(),
|
||||
);
|
||||
|
||||
// Start enabled extensions
|
||||
{
|
||||
let ext_config = config_store.get();
|
||||
extensions.start_enabled(&ext_config.extensions).await;
|
||||
}
|
||||
|
||||
// Start extension health check task (every 30 seconds)
|
||||
{
|
||||
let extensions_clone = extensions.clone();
|
||||
let config_store_clone = config_store.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let config = config_store_clone.get();
|
||||
extensions_clone.health_check(&config.extensions).await;
|
||||
}
|
||||
});
|
||||
tracing::info!("Extension health check task started");
|
||||
}
|
||||
|
||||
// Start device info broadcast task
|
||||
// This monitors state change events and broadcasts DeviceInfo to all clients
|
||||
spawn_device_info_broadcaster(state.clone(), events);
|
||||
|
||||
// Create router
|
||||
let app = web::create_router(state.clone());
|
||||
|
||||
// Determine bind address based on HTTPS setting
|
||||
let bind_addr: SocketAddr = if config.web.https_enabled {
|
||||
format!("{}:{}", config.web.bind_address, config.web.https_port).parse()?
|
||||
} else {
|
||||
format!("{}:{}", config.web.bind_address, config.web.http_port).parse()?
|
||||
};
|
||||
|
||||
// Setup graceful shutdown
|
||||
let shutdown_signal = async move {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install CTRL+C handler");
|
||||
tracing::info!("Shutdown signal received");
|
||||
let _ = shutdown_tx.send(());
|
||||
};
|
||||
|
||||
// Start server
|
||||
if config.web.https_enabled {
|
||||
// Generate self-signed certificate if no custom cert provided
|
||||
let tls_config = if let (Some(cert_path), Some(key_path)) =
|
||||
(&config.web.ssl_cert_path, &config.web.ssl_key_path)
|
||||
{
|
||||
RustlsConfig::from_pem_file(cert_path, key_path).await?
|
||||
} else {
|
||||
let cert_dir = data_dir.join("certs");
|
||||
let cert_path = cert_dir.join("server.crt");
|
||||
let key_path = cert_dir.join("server.key");
|
||||
|
||||
// Check if certificate already exists, only generate if missing
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
tracing::info!("Generating new self-signed TLS certificate");
|
||||
let cert = generate_self_signed_cert()?;
|
||||
tokio::fs::create_dir_all(&cert_dir).await?;
|
||||
tokio::fs::write(&cert_path, cert.cert.pem()).await?;
|
||||
tokio::fs::write(&key_path, cert.key_pair.serialize_pem()).await?;
|
||||
} else {
|
||||
tracing::info!("Using existing TLS certificate from {}", cert_dir.display());
|
||||
}
|
||||
|
||||
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
|
||||
};
|
||||
|
||||
tracing::info!("Starting HTTPS server on {}", bind_addr);
|
||||
|
||||
let server = axum_server::bind_rustls(bind_addr, tls_config)
|
||||
.serve(app.into_make_service());
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_signal => {
|
||||
cleanup(&state).await;
|
||||
}
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
tracing::error!("HTTPS server error: {}", e);
|
||||
}
|
||||
cleanup(&state).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::info!("Starting HTTP server on {}", bind_addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_signal => {
|
||||
cleanup(&state).await;
|
||||
}
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
tracing::error!("HTTP server error: {}", e);
|
||||
}
|
||||
cleanup(&state).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Server shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize logging with tracing
|
||||
fn init_logging(level: LogLevel, verbose_count: u8) {
|
||||
// Verbose count overrides log level
|
||||
let effective_level = match verbose_count {
|
||||
0 => level,
|
||||
1 => LogLevel::Verbose,
|
||||
2 => LogLevel::Debug,
|
||||
_ => LogLevel::Trace,
|
||||
};
|
||||
|
||||
// Build filter string based on effective level
|
||||
let filter = match effective_level {
|
||||
LogLevel::Error => "one_kvm=error,tower_http=error",
|
||||
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
|
||||
LogLevel::Info => "one_kvm=info,tower_http=info",
|
||||
LogLevel::Verbose => "one_kvm=debug,tower_http=info",
|
||||
LogLevel::Debug => "one_kvm=debug,tower_http=debug",
|
||||
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
|
||||
};
|
||||
|
||||
// Environment variable takes highest priority
|
||||
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| filter.into());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
}
|
||||
|
||||
/// Get the application data directory
|
||||
fn get_data_dir() -> PathBuf {
|
||||
// Check environment variable first
|
||||
if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") {
|
||||
return PathBuf::from(path);
|
||||
}
|
||||
|
||||
// Default to system configuration directory
|
||||
PathBuf::from("/etc/one-kvm")
|
||||
}
|
||||
|
||||
/// Parse video format and resolution from config (avoids code duplication)
|
||||
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
|
||||
let format = config
|
||||
.video
|
||||
.format
|
||||
.as_ref()
|
||||
.and_then(|f: &String| f.parse::<PixelFormat>().ok())
|
||||
.unwrap_or(PixelFormat::Mjpeg);
|
||||
let resolution = Resolution::new(config.video.width, config.video.height);
|
||||
(format, resolution)
|
||||
}
|
||||
|
||||
/// Generate a self-signed TLS certificate
|
||||
fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey> {
|
||||
use rcgen::generate_simple_self_signed;
|
||||
|
||||
let subject_alt_names = vec![
|
||||
"localhost".to_string(),
|
||||
"127.0.0.1".to_string(),
|
||||
"::1".to_string(),
|
||||
];
|
||||
|
||||
let certified_key = generate_simple_self_signed(subject_alt_names)?;
|
||||
Ok(certified_key)
|
||||
}
|
||||
|
||||
/// Spawn a background task that monitors state change events
|
||||
/// and broadcasts DeviceInfo to all WebSocket clients with debouncing
|
||||
fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
|
||||
use one_kvm::events::SystemEvent;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
let mut rx = events.subscribe();
|
||||
const DEBOUNCE_MS: u64 = 100;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut last_broadcast = Instant::now() - Duration::from_millis(DEBOUNCE_MS);
|
||||
let mut pending_broadcast = false;
|
||||
|
||||
loop {
|
||||
// Use timeout to handle pending broadcasts
|
||||
let recv_result = if pending_broadcast {
|
||||
let remaining = DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64);
|
||||
tokio::time::timeout(Duration::from_millis(remaining), rx.recv()).await
|
||||
} else {
|
||||
Ok(rx.recv().await)
|
||||
};
|
||||
|
||||
match recv_result {
|
||||
Ok(Ok(event)) => {
|
||||
let should_broadcast = matches!(
|
||||
event,
|
||||
SystemEvent::StreamStateChanged { .. }
|
||||
| SystemEvent::StreamConfigApplied { .. }
|
||||
| SystemEvent::HidStateChanged { .. }
|
||||
| SystemEvent::MsdStateChanged { .. }
|
||||
| SystemEvent::AtxStateChanged { .. }
|
||||
| SystemEvent::AudioStateChanged { .. }
|
||||
);
|
||||
if should_broadcast {
|
||||
pending_broadcast = true;
|
||||
}
|
||||
}
|
||||
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(n))) => {
|
||||
tracing::warn!("DeviceInfo broadcaster lagged by {} events", n);
|
||||
pending_broadcast = true;
|
||||
}
|
||||
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
|
||||
tracing::info!("Event bus closed, stopping DeviceInfo broadcaster");
|
||||
break;
|
||||
}
|
||||
Err(_timeout) => {
|
||||
// Debounce timeout reached, broadcast now
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast if pending and debounce time has passed
|
||||
if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) {
|
||||
state.publish_device_info().await;
|
||||
tracing::trace!("Broadcasted DeviceInfo (debounced)");
|
||||
last_broadcast = Instant::now();
|
||||
pending_broadcast = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tracing::info!("DeviceInfo broadcaster task started (debounce: {}ms)", DEBOUNCE_MS);
|
||||
}
|
||||
|
||||
/// Clean up subsystems on shutdown
|
||||
async fn cleanup(state: &Arc<AppState>) {
|
||||
// Stop all extensions
|
||||
state.extensions.stop_all().await;
|
||||
tracing::info!("Extensions stopped");
|
||||
|
||||
// Stop video
|
||||
if let Err(e) = state.stream_manager.stop().await {
|
||||
tracing::warn!("Failed to stop streamer: {}", e);
|
||||
}
|
||||
|
||||
// Shutdown HID
|
||||
if let Err(e) = state.hid.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown HID: {}", e);
|
||||
}
|
||||
|
||||
// Shutdown MSD
|
||||
if let Some(msd) = state.msd.write().await.as_mut() {
|
||||
if let Err(e) = msd.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown MSD: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown ATX
|
||||
if let Some(atx) = state.atx.write().await.as_mut() {
|
||||
if let Err(e) = atx.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown ATX: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown Audio
|
||||
if let Err(e) = state.audio.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown audio: {}", e);
|
||||
}
|
||||
}
|
||||
49
src/modules/mod.rs
Normal file
49
src/modules/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Module management for One-KVM
|
||||
//!
|
||||
//! This module provides infrastructure for managing feature modules
|
||||
//! (video streaming, HID control, MSD, ATX) as independent async tasks.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Module status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ModuleStatus {
|
||||
Stopped,
|
||||
Starting,
|
||||
Running,
|
||||
Stopping,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Trait for feature modules
|
||||
pub trait Module: Send + Sync {
|
||||
/// Module name
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Current status
|
||||
fn status(&self) -> ModuleStatus;
|
||||
|
||||
/// Start the module
|
||||
fn start(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
|
||||
|
||||
/// Stop the module
|
||||
fn stop(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
|
||||
}
|
||||
|
||||
/// Module manager for coordinating feature modules
|
||||
pub struct ModuleManager {
|
||||
shutdown_rx: broadcast::Receiver<()>,
|
||||
}
|
||||
|
||||
impl ModuleManager {
|
||||
pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self {
|
||||
Self { shutdown_rx }
|
||||
}
|
||||
|
||||
/// Wait for shutdown signal
|
||||
pub async fn wait_for_shutdown(&mut self) {
|
||||
let _ = self.shutdown_rx.recv().await;
|
||||
}
|
||||
}
|
||||
597
src/msd/controller.rs
Normal file
597
src/msd/controller.rs
Normal file
@@ -0,0 +1,597 @@
|
||||
//! MSD Controller
|
||||
//!
|
||||
//! Manages the mass storage device lifecycle including:
|
||||
//! - Image mounting and unmounting
|
||||
//! - Virtual drive management
|
||||
//! - State tracking
|
||||
//! - Image downloads from URL
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::image::ImageManager;
|
||||
use super::monitor::{MsdHealthMonitor, MsdHealthStatus};
|
||||
use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::otg::{MsdFunction, MsdLunConfig, OtgService};
|
||||
|
||||
/// USB Gadget path (system constant)
|
||||
const GADGET_PATH: &str = "/sys/kernel/config/usb_gadget/one-kvm";
|
||||
|
||||
/// MSD Controller
|
||||
pub struct MsdController {
|
||||
/// OTG Service reference
|
||||
otg_service: Arc<OtgService>,
|
||||
/// MSD function manager (provided by OtgService)
|
||||
msd_function: RwLock<Option<MsdFunction>>,
|
||||
/// Current state
|
||||
state: RwLock<MsdState>,
|
||||
/// Images storage path
|
||||
images_path: PathBuf,
|
||||
/// Virtual drive path
|
||||
drive_path: PathBuf,
|
||||
/// Event bus for broadcasting state changes (optional)
|
||||
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
|
||||
/// Active downloads (download_id -> CancellationToken)
|
||||
downloads: Arc<RwLock<HashMap<String, CancellationToken>>>,
|
||||
/// Operation mutex lock (prevents concurrent operations)
|
||||
operation_lock: Arc<RwLock<()>>,
|
||||
/// Health monitor for error tracking and recovery
|
||||
monitor: Arc<MsdHealthMonitor>,
|
||||
}
|
||||
|
||||
impl MsdController {
|
||||
/// Create new MSD controller
|
||||
///
|
||||
/// # Parameters
|
||||
/// * `otg_service` - OTG service for gadget management
|
||||
/// * `images_path` - Directory path for storing ISO/IMG files
|
||||
/// * `drive_path` - File path for the virtual FAT32 drive
|
||||
pub fn new(
|
||||
otg_service: Arc<OtgService>,
|
||||
images_path: impl Into<PathBuf>,
|
||||
drive_path: impl Into<PathBuf>,
|
||||
) -> Self {
|
||||
Self {
|
||||
otg_service,
|
||||
msd_function: RwLock::new(None),
|
||||
state: RwLock::new(MsdState::default()),
|
||||
images_path: images_path.into(),
|
||||
drive_path: drive_path.into(),
|
||||
events: tokio::sync::RwLock::new(None),
|
||||
downloads: Arc::new(RwLock::new(HashMap::new())),
|
||||
operation_lock: Arc::new(RwLock::new(())),
|
||||
monitor: Arc::new(MsdHealthMonitor::with_defaults()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the MSD controller
|
||||
pub async fn init(&self) -> Result<()> {
|
||||
info!("Initializing MSD controller");
|
||||
|
||||
// 1. Ensure images directory exists
|
||||
if let Err(e) = std::fs::create_dir_all(&self.images_path) {
|
||||
warn!("Failed to create images directory: {}", e);
|
||||
}
|
||||
|
||||
// 2. Request MSD function from OtgService
|
||||
info!("Requesting MSD function from OtgService");
|
||||
let msd_func = self.otg_service.enable_msd().await?;
|
||||
|
||||
// 3. Store function handle
|
||||
*self.msd_function.write().await = Some(msd_func);
|
||||
|
||||
// 4. Update state
|
||||
let mut state = self.state.write().await;
|
||||
state.available = true;
|
||||
|
||||
// 5. Check for existing virtual drive
|
||||
if self.drive_path.exists() {
|
||||
if let Ok(metadata) = std::fs::metadata(&self.drive_path) {
|
||||
state.drive_info = Some(DriveInfo {
|
||||
size: metadata.len(),
|
||||
used: 0,
|
||||
free: metadata.len(),
|
||||
initialized: true,
|
||||
path: self.drive_path.clone(),
|
||||
});
|
||||
debug!("Found existing virtual drive: {}", self.drive_path.display());
|
||||
}
|
||||
}
|
||||
|
||||
info!("MSD controller initialized");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current state as SystemEvent
|
||||
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
|
||||
let state = self.state.read().await;
|
||||
crate::events::SystemEvent::MsdStateChanged {
|
||||
mode: state.mode.clone(),
|
||||
connected: state.connected,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current MSD state
|
||||
pub async fn state(&self) -> MsdState {
|
||||
self.state.read().await.clone()
|
||||
}
|
||||
|
||||
/// Set event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: std::sync::Arc<crate::events::EventBus>) {
|
||||
*self.events.write().await = Some(events.clone());
|
||||
// Also set event bus on the monitor for health notifications
|
||||
self.monitor.set_event_bus(events).await;
|
||||
}
|
||||
|
||||
/// Publish an event to the event bus
|
||||
async fn publish_event(&self, event: crate::events::SystemEvent) {
|
||||
if let Some(ref bus) = *self.events.read().await {
|
||||
bus.publish(event);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if MSD is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
self.state.read().await.available
|
||||
}
|
||||
|
||||
/// Connect an image file
|
||||
///
|
||||
/// # Parameters
|
||||
/// * `image` - Image info to mount
|
||||
/// * `cdrom` - Mount as CD-ROM (read-only, removable)
|
||||
/// * `read_only` - Mount as read-only
|
||||
pub async fn connect_image(&self, image: &ImageInfo, cdrom: bool, read_only: bool) -> Result<()> {
|
||||
// Acquire operation lock to prevent concurrent operations
|
||||
let _op_guard = self.operation_lock.write().await;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
if !state.available {
|
||||
let err = AppError::Internal("MSD not available".to_string());
|
||||
self.monitor.report_error("MSD not available", "not_available").await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
if state.connected {
|
||||
return Err(AppError::Internal(
|
||||
"Already connected. Disconnect first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Verify image exists
|
||||
if !image.path.exists() {
|
||||
let error_msg = format!("Image file not found: {}", image.path.display());
|
||||
self.monitor.report_error(&error_msg, "image_not_found").await;
|
||||
return Err(AppError::Internal(error_msg));
|
||||
}
|
||||
|
||||
// Configure LUN
|
||||
let config = if cdrom {
|
||||
MsdLunConfig::cdrom(image.path.clone())
|
||||
} else {
|
||||
MsdLunConfig::disk(image.path.clone(), read_only)
|
||||
};
|
||||
|
||||
let gadget_path = PathBuf::from(GADGET_PATH);
|
||||
if let Some(ref msd) = *self.msd_function.read().await {
|
||||
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
|
||||
let error_msg = format!("Failed to configure LUN: {}", e);
|
||||
self.monitor.report_error(&error_msg, "configfs_error").await;
|
||||
return Err(e);
|
||||
}
|
||||
} else {
|
||||
let err = AppError::Internal("MSD function not initialized".to_string());
|
||||
self.monitor.report_error("MSD function not initialized", "not_initialized").await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
state.connected = true;
|
||||
state.mode = MsdMode::Image;
|
||||
state.current_image = Some(image.clone());
|
||||
|
||||
info!(
|
||||
"Connected image: {} (cdrom={}, ro={})",
|
||||
image.name, cdrom, read_only
|
||||
);
|
||||
|
||||
// Release the lock before publishing events
|
||||
drop(state);
|
||||
drop(_op_guard);
|
||||
|
||||
// Report recovery if we were in an error state
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered().await;
|
||||
}
|
||||
|
||||
// Publish events
|
||||
self.publish_event(crate::events::SystemEvent::MsdImageMounted {
|
||||
image_id: image.id.clone(),
|
||||
image_name: image.name.clone(),
|
||||
size: image.size,
|
||||
cdrom,
|
||||
})
|
||||
.await;
|
||||
|
||||
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
|
||||
mode: MsdMode::Image,
|
||||
connected: true,
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect the virtual drive
|
||||
pub async fn connect_drive(&self) -> Result<()> {
|
||||
// Acquire operation lock to prevent concurrent operations
|
||||
let _op_guard = self.operation_lock.write().await;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
if !state.available {
|
||||
let err = AppError::Internal("MSD not available".to_string());
|
||||
self.monitor.report_error("MSD not available", "not_available").await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
if state.connected {
|
||||
return Err(AppError::Internal(
|
||||
"Already connected. Disconnect first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check drive exists
|
||||
if !self.drive_path.exists() {
|
||||
let err = AppError::Internal(
|
||||
"Virtual drive not initialized. Call init first.".to_string(),
|
||||
);
|
||||
self.monitor.report_error("Virtual drive not initialized", "drive_not_found").await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Configure LUN as read-write disk
|
||||
let config = MsdLunConfig::disk(self.drive_path.clone(), false);
|
||||
|
||||
let gadget_path = PathBuf::from(GADGET_PATH);
|
||||
if let Some(ref msd) = *self.msd_function.read().await {
|
||||
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
|
||||
let error_msg = format!("Failed to configure LUN: {}", e);
|
||||
self.monitor.report_error(&error_msg, "configfs_error").await;
|
||||
return Err(e);
|
||||
}
|
||||
} else {
|
||||
let err = AppError::Internal("MSD function not initialized".to_string());
|
||||
self.monitor.report_error("MSD function not initialized", "not_initialized").await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
state.connected = true;
|
||||
state.mode = MsdMode::Drive;
|
||||
state.current_image = None;
|
||||
|
||||
info!("Connected virtual drive: {}", self.drive_path.display());
|
||||
|
||||
// Release the lock before publishing event
|
||||
drop(state);
|
||||
drop(_op_guard);
|
||||
|
||||
// Report recovery if we were in an error state
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered().await;
|
||||
}
|
||||
|
||||
// Publish event
|
||||
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
|
||||
mode: MsdMode::Drive,
|
||||
connected: true,
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect current storage
|
||||
pub async fn disconnect(&self) -> Result<()> {
|
||||
// Acquire operation lock to prevent concurrent operations
|
||||
let _op_guard = self.operation_lock.write().await;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
if !state.connected {
|
||||
debug!("Nothing connected, skipping disconnect");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let gadget_path = PathBuf::from(GADGET_PATH);
|
||||
if let Some(ref msd) = *self.msd_function.read().await {
|
||||
msd.disconnect_lun_async(&gadget_path, 0).await?;
|
||||
}
|
||||
|
||||
state.connected = false;
|
||||
state.mode = MsdMode::None;
|
||||
state.current_image = None;
|
||||
|
||||
info!("Disconnected storage");
|
||||
|
||||
// Release the lock before publishing events
|
||||
drop(state);
|
||||
drop(_op_guard);
|
||||
|
||||
// Publish events
|
||||
self.publish_event(crate::events::SystemEvent::MsdImageUnmounted)
|
||||
.await;
|
||||
|
||||
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
|
||||
mode: MsdMode::None,
|
||||
connected: false,
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get images storage path
|
||||
pub fn images_path(&self) -> &PathBuf {
|
||||
&self.images_path
|
||||
}
|
||||
|
||||
/// Get virtual drive path
|
||||
pub fn drive_path(&self) -> &PathBuf {
|
||||
&self.drive_path
|
||||
}
|
||||
|
||||
/// Check if currently connected
|
||||
pub async fn is_connected(&self) -> bool {
|
||||
self.state.read().await.connected
|
||||
}
|
||||
|
||||
/// Get current mode
|
||||
pub async fn mode(&self) -> MsdMode {
|
||||
self.state.read().await.mode.clone()
|
||||
}
|
||||
|
||||
/// Update drive info
|
||||
pub async fn update_drive_info(&self, info: DriveInfo) {
|
||||
let mut state = self.state.write().await;
|
||||
state.drive_info = Some(info);
|
||||
}
|
||||
|
||||
/// Start downloading an image from URL
|
||||
///
|
||||
/// Returns the download_id that can be used to track or cancel the download.
|
||||
/// Progress is reported via MsdDownloadProgress events.
|
||||
pub async fn download_image(
|
||||
&self,
|
||||
url: String,
|
||||
filename: Option<String>,
|
||||
) -> Result<DownloadProgress> {
|
||||
let download_id = uuid::Uuid::new_v4().to_string();
|
||||
let cancel_token = CancellationToken::new();
|
||||
|
||||
// Register download
|
||||
{
|
||||
let mut downloads = self.downloads.write().await;
|
||||
downloads.insert(download_id.clone(), cancel_token.clone());
|
||||
}
|
||||
|
||||
// Extract filename for initial response
|
||||
let display_filename = filename.clone().unwrap_or_else(|| {
|
||||
url.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("download")
|
||||
.to_string()
|
||||
});
|
||||
|
||||
// Create initial progress
|
||||
let initial_progress = DownloadProgress {
|
||||
download_id: download_id.clone(),
|
||||
url: url.clone(),
|
||||
filename: display_filename.clone(),
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: None,
|
||||
progress_pct: None,
|
||||
status: DownloadStatus::Started,
|
||||
error: None,
|
||||
};
|
||||
|
||||
// Publish started event
|
||||
self.publish_event(crate::events::SystemEvent::MsdDownloadProgress {
|
||||
download_id: download_id.clone(),
|
||||
url: url.clone(),
|
||||
filename: display_filename.clone(),
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: None,
|
||||
progress_pct: None,
|
||||
status: "started".to_string(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// Clone what we need for the spawned task
|
||||
let images_path = self.images_path.clone();
|
||||
let events = self.events.read().await.clone();
|
||||
let downloads = self.downloads.clone();
|
||||
let download_id_clone = download_id.clone();
|
||||
let url_clone = url.clone();
|
||||
|
||||
// Spawn download task
|
||||
tokio::spawn(async move {
|
||||
let manager = ImageManager::new(images_path);
|
||||
|
||||
// Create progress callback
|
||||
let events_for_callback = events.clone();
|
||||
let download_id_for_callback = download_id_clone.clone();
|
||||
let url_for_callback = url_clone.clone();
|
||||
let filename_for_callback = display_filename.clone();
|
||||
|
||||
let progress_callback = move |downloaded: u64, total: Option<u64>| {
|
||||
let progress_pct = total.map(|t| (downloaded as f32 / t as f32) * 100.0);
|
||||
|
||||
if let Some(ref bus) = events_for_callback {
|
||||
bus.publish(crate::events::SystemEvent::MsdDownloadProgress {
|
||||
download_id: download_id_for_callback.clone(),
|
||||
url: url_for_callback.clone(),
|
||||
filename: filename_for_callback.clone(),
|
||||
bytes_downloaded: downloaded,
|
||||
total_bytes: total,
|
||||
progress_pct,
|
||||
status: "in_progress".to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Run download
|
||||
let result = manager
|
||||
.download_from_url(&url_clone, filename, progress_callback)
|
||||
.await;
|
||||
|
||||
// Remove from active downloads
|
||||
{
|
||||
let mut downloads_guard = downloads.write().await;
|
||||
downloads_guard.remove(&download_id_clone);
|
||||
}
|
||||
|
||||
// Publish completion event
|
||||
match result {
|
||||
Ok(image_info) => {
|
||||
if let Some(ref bus) = events {
|
||||
bus.publish(crate::events::SystemEvent::MsdDownloadProgress {
|
||||
download_id: download_id_clone,
|
||||
url: url_clone,
|
||||
filename: image_info.name,
|
||||
bytes_downloaded: image_info.size,
|
||||
total_bytes: Some(image_info.size),
|
||||
progress_pct: Some(100.0),
|
||||
status: "completed".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Download failed: {}", e);
|
||||
if let Some(ref bus) = events {
|
||||
bus.publish(crate::events::SystemEvent::MsdDownloadProgress {
|
||||
download_id: download_id_clone,
|
||||
url: url_clone,
|
||||
filename: display_filename,
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: None,
|
||||
progress_pct: None,
|
||||
status: format!("failed: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(initial_progress)
|
||||
}
|
||||
|
||||
/// Cancel an active download
|
||||
pub async fn cancel_download(&self, download_id: &str) -> Result<()> {
|
||||
let mut downloads = self.downloads.write().await;
|
||||
|
||||
if let Some(token) = downloads.remove(download_id) {
|
||||
token.cancel();
|
||||
info!("Download cancelled: {}", download_id);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(AppError::NotFound(format!(
|
||||
"Download not found: {}",
|
||||
download_id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of active download IDs
|
||||
pub async fn active_downloads(&self) -> Vec<String> {
|
||||
let downloads = self.downloads.read().await;
|
||||
downloads.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Shutdown the controller
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down MSD controller");
|
||||
|
||||
// 1. Disconnect if connected
|
||||
if let Err(e) = self.disconnect().await {
|
||||
warn!("Error disconnecting during shutdown: {}", e);
|
||||
}
|
||||
|
||||
// 2. Notify OtgService to disable MSD
|
||||
info!("Disabling MSD function in OtgService");
|
||||
self.otg_service.disable_msd().await?;
|
||||
|
||||
// 3. Clear local state
|
||||
*self.msd_function.write().await = None;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
state.available = false;
|
||||
|
||||
info!("MSD controller shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the health monitor reference
|
||||
pub fn monitor(&self) -> &Arc<MsdHealthMonitor> {
|
||||
&self.monitor
|
||||
}
|
||||
|
||||
/// Get current health status
|
||||
pub async fn health_status(&self) -> MsdHealthStatus {
|
||||
self.monitor.status().await
|
||||
}
|
||||
|
||||
/// Check if the MSD is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.monitor.is_healthy().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MsdController {
|
||||
fn drop(&mut self) {
|
||||
// Cleanup is handled by OtgGadgetManager when the gadget is torn down
|
||||
// Individual controllers don't need to cleanup the ConfigFS
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_controller_creation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let otg_service = Arc::new(OtgService::new());
|
||||
let images_path = temp_dir.path().join("images");
|
||||
let drive_path = temp_dir.path().join("ventoy.img");
|
||||
|
||||
let controller = MsdController::new(otg_service, &images_path, &drive_path);
|
||||
|
||||
// Check that MSD is not initialized (msd_function is None)
|
||||
let state = controller.state().await;
|
||||
assert!(!state.available);
|
||||
assert!(controller.images_path.ends_with("images"));
|
||||
assert!(controller.drive_path.ends_with("ventoy.img"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_state_default() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let otg_service = Arc::new(OtgService::new());
|
||||
let images_path = temp_dir.path().join("images");
|
||||
let drive_path = temp_dir.path().join("ventoy.img");
|
||||
|
||||
let controller = MsdController::new(otg_service, &images_path, &drive_path);
|
||||
|
||||
let state = controller.state().await;
|
||||
assert!(!state.available);
|
||||
assert!(!state.connected);
|
||||
assert_eq!(state.mode, MsdMode::None);
|
||||
}
|
||||
}
|
||||
654
src/msd/image.rs
Normal file
654
src/msd/image.rs
Normal file
@@ -0,0 +1,654 @@
|
||||
//! Image file manager
|
||||
//!
|
||||
//! Handles ISO/IMG image file operations:
|
||||
//! - List available images
|
||||
//! - Upload new images
|
||||
//! - Delete images
|
||||
//! - Metadata management
|
||||
//! - Download from URL
|
||||
|
||||
use chrono::Utc;
|
||||
use futures::StreamExt;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::info;
|
||||
|
||||
use super::types::ImageInfo;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Maximum image size (32 GB)
|
||||
const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024;
|
||||
|
||||
/// Progress report throttle interval (milliseconds)
|
||||
const PROGRESS_THROTTLE_MS: u64 = 200;
|
||||
|
||||
/// Progress report throttle bytes threshold (512 KB)
|
||||
const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024;
|
||||
|
||||
/// Image Manager
|
||||
pub struct ImageManager {
|
||||
/// Images storage directory
|
||||
images_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ImageManager {
|
||||
/// Create a new image manager
|
||||
pub fn new(images_path: PathBuf) -> Self {
|
||||
Self { images_path }
|
||||
}
|
||||
|
||||
/// Ensure images directory exists
|
||||
pub fn ensure_dir(&self) -> Result<()> {
|
||||
fs::create_dir_all(&self.images_path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to create images directory: {}", e))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all available images
|
||||
pub fn list(&self) -> Result<Vec<ImageInfo>> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
let mut images = Vec::new();
|
||||
|
||||
for entry in fs::read_dir(&self.images_path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to read images directory: {}", e))
|
||||
})? {
|
||||
let entry = entry.map_err(|e| {
|
||||
AppError::Internal(format!("Failed to read directory entry: {}", e))
|
||||
})?;
|
||||
|
||||
let path = entry.path();
|
||||
if path.is_file() {
|
||||
if let Some(info) = self.get_image_info(&path) {
|
||||
images.push(info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by creation time (newest first)
|
||||
images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
/// Get image info from path
|
||||
fn get_image_info(&self, path: &Path) -> Option<ImageInfo> {
|
||||
let metadata = fs::metadata(path).ok()?;
|
||||
let name = path.file_name()?.to_string_lossy().to_string();
|
||||
|
||||
// Use filename hash as ID (stable across restarts)
|
||||
let id = format!("{:x}", md5_hash(&name));
|
||||
|
||||
let created_at = metadata
|
||||
.created()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
|
||||
.map(|d| {
|
||||
chrono::DateTime::from_timestamp(d.as_secs() as i64, 0)
|
||||
.unwrap_or_else(|| Utc::now().into())
|
||||
})
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
Some(ImageInfo {
|
||||
id,
|
||||
name,
|
||||
path: path.to_path_buf(),
|
||||
size: metadata.len(),
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get image by ID
|
||||
pub fn get(&self, id: &str) -> Result<ImageInfo> {
|
||||
for image in self.list()? {
|
||||
if image.id == id {
|
||||
return Ok(image);
|
||||
}
|
||||
}
|
||||
Err(AppError::NotFound(format!("Image not found: {}", id)))
|
||||
}
|
||||
|
||||
/// Get image by name
|
||||
pub fn get_by_name(&self, name: &str) -> Result<ImageInfo> {
|
||||
let path = self.images_path.join(name);
|
||||
self.get_image_info(&path)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name)))
|
||||
}
|
||||
|
||||
/// Create a new image from bytes
|
||||
pub fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
// Validate name
|
||||
let name = sanitize_filename(name);
|
||||
if name.is_empty() {
|
||||
return Err(AppError::Internal("Invalid filename".to_string()));
|
||||
}
|
||||
|
||||
// Check size
|
||||
if data.len() as u64 > MAX_IMAGE_SIZE {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image too large. Maximum size: {} GB",
|
||||
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
|
||||
)));
|
||||
}
|
||||
|
||||
// Write file
|
||||
let path = self.images_path.join(&name);
|
||||
if path.exists() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image already exists: {}",
|
||||
name
|
||||
)));
|
||||
}
|
||||
|
||||
let mut file = File::create(&path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to create image file: {}", e))
|
||||
})?;
|
||||
|
||||
file.write_all(data).map_err(|e| {
|
||||
// Try to clean up on error
|
||||
let _ = fs::remove_file(&path);
|
||||
AppError::Internal(format!("Failed to write image data: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Created image: {} ({} bytes)", name, data.len());
|
||||
|
||||
self.get_by_name(&name)
|
||||
}
|
||||
|
||||
/// Create a new image from a file stream (for chunked uploads)
|
||||
pub fn create_from_stream<R: Read>(
|
||||
&self,
|
||||
name: &str,
|
||||
reader: &mut R,
|
||||
expected_size: Option<u64>,
|
||||
) -> Result<ImageInfo> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
let name = sanitize_filename(name);
|
||||
if name.is_empty() {
|
||||
return Err(AppError::Internal("Invalid filename".to_string()));
|
||||
}
|
||||
|
||||
if let Some(size) = expected_size {
|
||||
if size > MAX_IMAGE_SIZE {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image too large. Maximum size: {} GB",
|
||||
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let path = self.images_path.join(&name);
|
||||
if path.exists() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image already exists: {}",
|
||||
name
|
||||
)));
|
||||
}
|
||||
|
||||
// Create file and copy data
|
||||
let mut file = File::create(&path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to create image file: {}", e))
|
||||
})?;
|
||||
|
||||
let bytes_written = io::copy(reader, &mut file).map_err(|e| {
|
||||
let _ = fs::remove_file(&path);
|
||||
AppError::Internal(format!("Failed to write image data: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Created image: {} ({} bytes)", name, bytes_written);
|
||||
|
||||
self.get_by_name(&name)
|
||||
}
|
||||
|
||||
/// Create a new image from an async multipart field (streaming, memory-efficient)
|
||||
///
|
||||
/// This method streams data directly to disk without buffering the entire file in memory,
|
||||
/// making it suitable for large files (multi-GB ISOs).
|
||||
pub async fn create_from_multipart_field(
|
||||
&self,
|
||||
name: &str,
|
||||
mut field: axum::extract::multipart::Field<'_>,
|
||||
) -> Result<ImageInfo> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
let name = sanitize_filename(name);
|
||||
if name.is_empty() {
|
||||
return Err(AppError::Internal("Invalid filename".to_string()));
|
||||
}
|
||||
|
||||
// Use a temporary file during upload
|
||||
let temp_name = format!(".upload_{}", uuid::Uuid::new_v4());
|
||||
let temp_path = self.images_path.join(&temp_name);
|
||||
let final_path = self.images_path.join(&name);
|
||||
|
||||
// Check if final file already exists
|
||||
if final_path.exists() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image already exists: {}",
|
||||
name
|
||||
)));
|
||||
}
|
||||
|
||||
// Create temp file
|
||||
let mut file = tokio::fs::File::create(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
|
||||
|
||||
let mut bytes_written: u64 = 0;
|
||||
|
||||
// Stream chunks directly to disk
|
||||
while let Some(chunk) = field.chunk().await.map_err(|e| {
|
||||
AppError::Internal(format!("Failed to read upload chunk: {}", e))
|
||||
})? {
|
||||
// Check size limit
|
||||
bytes_written += chunk.len() as u64;
|
||||
if bytes_written > MAX_IMAGE_SIZE {
|
||||
// Cleanup and return error
|
||||
drop(file);
|
||||
let _ = tokio::fs::remove_file(&temp_path).await;
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image too large. Maximum size: {} GB",
|
||||
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
|
||||
)));
|
||||
}
|
||||
|
||||
// Write chunk to file
|
||||
file.write_all(&chunk).await.map_err(|e| {
|
||||
AppError::Internal(format!("Failed to write chunk: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Flush and close file
|
||||
file.flush().await.map_err(|e| {
|
||||
AppError::Internal(format!("Failed to flush file: {}", e))
|
||||
})?;
|
||||
drop(file);
|
||||
|
||||
// Move temp file to final location
|
||||
tokio::fs::rename(&temp_path, &final_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let _ = std::fs::remove_file(&temp_path);
|
||||
AppError::Internal(format!("Failed to rename temp file: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Created image (streaming): {} ({} bytes)", name, bytes_written);
|
||||
|
||||
self.get_by_name(&name)
|
||||
}
|
||||
|
||||
/// Delete an image by ID
|
||||
pub fn delete(&self, id: &str) -> Result<()> {
|
||||
let image = self.get(id)?;
|
||||
|
||||
fs::remove_file(&image.path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to delete image: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Deleted image: {}", image.name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete an image by name
|
||||
pub fn delete_by_name(&self, name: &str) -> Result<()> {
|
||||
let path = self.images_path.join(name);
|
||||
|
||||
if !path.exists() {
|
||||
return Err(AppError::NotFound(format!("Image not found: {}", name)));
|
||||
}
|
||||
|
||||
fs::remove_file(&path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to delete image: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Deleted image: {}", name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get total storage used
|
||||
pub fn used_space(&self) -> u64 {
|
||||
self.list()
|
||||
.map(|images| images.iter().map(|i| i.size).sum())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Check if storage has space for new image
|
||||
pub fn has_space(&self, size: u64) -> bool {
|
||||
// For now, just check against max size
|
||||
// In the future, could check disk space
|
||||
size <= MAX_IMAGE_SIZE
|
||||
}
|
||||
|
||||
/// Download image from URL with progress callback
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `url` - The URL to download from
|
||||
/// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided)
|
||||
/// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(ImageInfo)` - The downloaded image info
|
||||
/// * `Err(AppError)` - If download fails
|
||||
pub async fn download_from_url<F>(
|
||||
&self,
|
||||
url: &str,
|
||||
filename: Option<String>,
|
||||
progress_callback: F,
|
||||
) -> Result<ImageInfo>
|
||||
where
|
||||
F: Fn(u64, Option<u64>) + Send + 'static,
|
||||
{
|
||||
self.ensure_dir()?;
|
||||
|
||||
// Validate URL
|
||||
let parsed_url = reqwest::Url::parse(url)
|
||||
.map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?;
|
||||
|
||||
info!("Starting download from: {}", url);
|
||||
|
||||
// Create HTTP client with timeout
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?;
|
||||
|
||||
// Send HEAD request first to get content info
|
||||
let head_response = client
|
||||
.head(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to connect: {}", e)))?;
|
||||
|
||||
if !head_response.status().is_success() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Server returned error: {}",
|
||||
head_response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let total_size = head_response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
// Check file size
|
||||
if let Some(size) = total_size {
|
||||
if size > MAX_IMAGE_SIZE {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"File too large: {} bytes (max {} GB)",
|
||||
size,
|
||||
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Determine filename
|
||||
let final_filename = if let Some(name) = filename {
|
||||
sanitize_filename(&name)
|
||||
} else {
|
||||
// Try Content-Disposition header first
|
||||
let from_header = head_response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_DISPOSITION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| extract_filename_from_content_disposition(s));
|
||||
|
||||
if let Some(name) = from_header {
|
||||
sanitize_filename(&name)
|
||||
} else {
|
||||
// Fall back to URL path
|
||||
let path = parsed_url.path();
|
||||
let name = path.rsplit('/').next().unwrap_or("download");
|
||||
let name = urlencoding::decode(name).unwrap_or_else(|_| name.into());
|
||||
sanitize_filename(&name)
|
||||
}
|
||||
};
|
||||
|
||||
if final_filename.is_empty() {
|
||||
return Err(AppError::BadRequest("Could not determine filename".to_string()));
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
let final_path = self.images_path.join(&final_filename);
|
||||
if final_path.exists() {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"Image already exists: {}",
|
||||
final_filename
|
||||
)));
|
||||
}
|
||||
|
||||
// Create temporary file for download
|
||||
let temp_filename = format!(".download_{}", uuid::Uuid::new_v4());
|
||||
let temp_path = self.images_path.join(&temp_filename);
|
||||
|
||||
// Start actual download
|
||||
let response = client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Download failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Download failed: HTTP {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
// Get actual content length from response (may differ from HEAD)
|
||||
let content_length = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_LENGTH)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.or(total_size);
|
||||
|
||||
// Create temp file
|
||||
let mut file = tokio::fs::File::create(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
|
||||
|
||||
// Stream download with progress (throttled)
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut last_report_time = Instant::now();
|
||||
let mut last_reported_bytes: u64 = 0;
|
||||
let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS);
|
||||
|
||||
// Report initial progress
|
||||
progress_callback(0, content_length);
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result
|
||||
.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?;
|
||||
|
||||
file.write_all(&chunk)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// Cleanup on error
|
||||
let _ = std::fs::remove_file(&temp_path);
|
||||
AppError::Internal(format!("Failed to write data: {}", e))
|
||||
})?;
|
||||
|
||||
downloaded += chunk.len() as u64;
|
||||
|
||||
// Throttled progress reporting: report if enough time or bytes have passed
|
||||
let now = Instant::now();
|
||||
let time_elapsed = now.duration_since(last_report_time) >= throttle_interval;
|
||||
let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES;
|
||||
|
||||
if time_elapsed || bytes_elapsed {
|
||||
progress_callback(downloaded, content_length);
|
||||
last_report_time = now;
|
||||
last_reported_bytes = downloaded;
|
||||
}
|
||||
}
|
||||
|
||||
// Always report final progress
|
||||
if downloaded != last_reported_bytes {
|
||||
progress_callback(downloaded, content_length);
|
||||
}
|
||||
|
||||
// Ensure all data is flushed
|
||||
file.flush()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
|
||||
drop(file);
|
||||
|
||||
// Verify downloaded size
|
||||
let metadata = tokio::fs::metadata(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?;
|
||||
|
||||
if let Some(expected) = content_length {
|
||||
if metadata.len() != expected {
|
||||
let _ = tokio::fs::remove_file(&temp_path).await;
|
||||
return Err(AppError::Internal(format!(
|
||||
"Download incomplete: got {} bytes, expected {}",
|
||||
metadata.len(),
|
||||
expected
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Move temp file to final location
|
||||
tokio::fs::rename(&temp_path, &final_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let _ = std::fs::remove_file(&temp_path);
|
||||
AppError::Internal(format!("Failed to move file: {}", e))
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"Download complete: {} ({} bytes)",
|
||||
final_filename,
|
||||
metadata.len()
|
||||
);
|
||||
|
||||
// Return image info
|
||||
self.get_by_name(&final_filename)
|
||||
}
|
||||
|
||||
/// Get images storage path
|
||||
pub fn images_path(&self) -> &PathBuf {
|
||||
&self.images_path
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple hash function for generating stable IDs
|
||||
fn md5_hash(s: &str) -> u64 {
|
||||
let mut hash: u64 = 0;
|
||||
for (i, byte) in s.bytes().enumerate() {
|
||||
hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1)));
|
||||
hash = hash.wrapping_mul(31);
|
||||
}
|
||||
hash
|
||||
}
|
||||
|
||||
/// Sanitize filename to prevent path traversal
|
||||
fn sanitize_filename(name: &str) -> String {
|
||||
let name = name.trim();
|
||||
let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_");
|
||||
|
||||
// Remove leading dots (hidden files)
|
||||
let name = name.trim_start_matches('.');
|
||||
|
||||
// Limit length
|
||||
if name.len() > 255 {
|
||||
name[..255].to_string()
|
||||
} else {
|
||||
name.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract filename from Content-Disposition header
|
||||
fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
|
||||
// Handle both:
|
||||
// Content-Disposition: attachment; filename="example.iso"
|
||||
// Content-Disposition: attachment; filename*=UTF-8''example.iso
|
||||
|
||||
// Try filename* first (RFC 5987)
|
||||
if let Some(pos) = header.find("filename*=") {
|
||||
let start = pos + 10;
|
||||
let value = &header[start..];
|
||||
// Format: charset'language'value
|
||||
if let Some(quote_start) = value.find("''") {
|
||||
let encoded = value[quote_start + 2..].split(';').next()?;
|
||||
let decoded = urlencoding::decode(encoded.trim()).ok()?;
|
||||
let name = decoded.trim_matches('"').to_string();
|
||||
if !name.is_empty() {
|
||||
return Some(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try filename next
|
||||
if let Some(pos) = header.find("filename=") {
|
||||
let start = pos + 9;
|
||||
let value = &header[start..];
|
||||
let name = value.split(';').next()?;
|
||||
let name = name.trim().trim_matches('"').to_string();
|
||||
if !name.is_empty() {
|
||||
return Some(name);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_filename() {
|
||||
assert_eq!(sanitize_filename("test.iso"), "test.iso");
|
||||
assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.')
|
||||
assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso");
|
||||
assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_manager_list_empty() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = ImageManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let images = manager.list().unwrap();
|
||||
assert!(images.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_manager_create() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = ImageManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let data = vec![0u8; 1024];
|
||||
let image = manager.create("test.iso", &data).unwrap();
|
||||
|
||||
assert_eq!(image.name, "test.iso");
|
||||
assert_eq!(image.size, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_manager_delete() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = ImageManager::new(temp_dir.path().to_path_buf());
|
||||
|
||||
let data = vec![0u8; 1024];
|
||||
let image = manager.create("test.iso", &data).unwrap();
|
||||
|
||||
manager.delete(&image.id).unwrap();
|
||||
|
||||
assert!(manager.list().unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
33
src/msd/mod.rs
Normal file
33
src/msd/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! MSD (Mass Storage Device) module
|
||||
//!
|
||||
//! Provides virtual USB storage functionality with two modes:
|
||||
//! - Image mounting: Mount ISO/IMG files for system installation
|
||||
//! - Ventoy drive: Bootable exFAT drive for multiple ISO files
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC
|
||||
//! |
|
||||
//! ┌──────┴──────┐
|
||||
//! │ │
|
||||
//! Image Manager Ventoy Drive
|
||||
//! (ISO/IMG) (Bootable exFAT)
|
||||
//! ```
|
||||
|
||||
pub mod controller;
|
||||
pub mod ventoy_drive;
|
||||
pub mod image;
|
||||
pub mod monitor;
|
||||
pub mod types;
|
||||
|
||||
pub use controller::MsdController;
|
||||
pub use ventoy_drive::VentoyDrive;
|
||||
pub use image::ImageManager;
|
||||
pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig};
|
||||
pub use types::{
|
||||
DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest,
|
||||
ImageInfo, MsdConnectRequest, MsdMode, MsdState,
|
||||
};
|
||||
|
||||
// Re-export from otg module for backward compatibility
|
||||
pub use crate::otg::{MsdFunction, MsdLunConfig};
|
||||
284
src/msd/monitor.rs
Normal file
284
src/msd/monitor.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
//! MSD (Mass Storage Device) health monitoring
|
||||
//!
|
||||
//! This module provides health monitoring for MSD operations, including:
|
||||
//! - ConfigFS operation error tracking
|
||||
//! - Image mount/unmount error tracking
|
||||
//! - Error notification
|
||||
//! - Log throttling to prevent log flooding
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// MSD health status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum MsdHealthStatus {
|
||||
/// Device is healthy and operational
|
||||
Healthy,
|
||||
/// Device has an error
|
||||
Error {
|
||||
/// Human-readable error reason
|
||||
reason: String,
|
||||
/// Error code for programmatic handling
|
||||
error_code: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for MsdHealthStatus {
|
||||
fn default() -> Self {
|
||||
Self::Healthy
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD health monitor configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MsdMonitorConfig {
|
||||
/// Log throttle interval in seconds
|
||||
pub log_throttle_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for MsdMonitorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
log_throttle_secs: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD health monitor
|
||||
///
|
||||
/// Monitors MSD operation health and manages error notifications.
|
||||
/// Publishes WebSocket events when operation status changes.
|
||||
pub struct MsdHealthMonitor {
|
||||
/// Current health status
|
||||
status: RwLock<MsdHealthStatus>,
|
||||
/// Event bus for notifications
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
/// Log throttler to prevent log flooding
|
||||
throttler: LogThrottler,
|
||||
/// Configuration
|
||||
#[allow(dead_code)]
|
||||
config: MsdMonitorConfig,
|
||||
/// Whether monitoring is active (reserved for future use)
|
||||
#[allow(dead_code)]
|
||||
running: AtomicBool,
|
||||
/// Error count (for tracking)
|
||||
error_count: AtomicU32,
|
||||
/// Last error code (for change detection)
|
||||
last_error_code: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl MsdHealthMonitor {
|
||||
/// Create a new MSD health monitor with the specified configuration
|
||||
pub fn new(config: MsdMonitorConfig) -> Self {
|
||||
let throttle_secs = config.log_throttle_secs;
|
||||
Self {
|
||||
status: RwLock::new(MsdHealthStatus::Healthy),
|
||||
events: RwLock::new(None),
|
||||
throttler: LogThrottler::with_secs(throttle_secs),
|
||||
config,
|
||||
running: AtomicBool::new(false),
|
||||
error_count: AtomicU32::new(0),
|
||||
last_error_code: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new MSD health monitor with default configuration
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(MsdMonitorConfig::default())
|
||||
}
|
||||
|
||||
/// Set the event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Report an error from MSD operations
|
||||
///
|
||||
/// This method is called when an MSD operation fails. It:
|
||||
/// 1. Updates the health status
|
||||
/// 2. Logs the error (with throttling)
|
||||
/// 3. Publishes a WebSocket event if the error is new or changed
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `reason` - Human-readable error description
|
||||
/// * `error_code` - Error code for programmatic handling
|
||||
pub async fn report_error(&self, reason: &str, error_code: &str) {
|
||||
let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if error code changed
|
||||
let error_changed = {
|
||||
let last = self.last_error_code.read().await;
|
||||
last.as_ref().map(|s| s.as_str()) != Some(error_code)
|
||||
};
|
||||
|
||||
// Log with throttling (always log if error type changed)
|
||||
let throttle_key = format!("msd_{}", error_code);
|
||||
if error_changed || self.throttler.should_log(&throttle_key) {
|
||||
warn!("MSD error: {} (code: {}, count: {})", reason, error_code, count);
|
||||
}
|
||||
|
||||
// Update last error code
|
||||
*self.last_error_code.write().await = Some(error_code.to_string());
|
||||
|
||||
// Update status
|
||||
*self.status.write().await = MsdHealthStatus::Error {
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
};
|
||||
|
||||
// Publish event (only if error changed or first occurrence)
|
||||
if error_changed || count == 1 {
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::MsdError {
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Report that the MSD has recovered from error
|
||||
///
|
||||
/// This method is called when an MSD operation succeeds after errors.
|
||||
/// It resets the error state and publishes a recovery event.
|
||||
pub async fn report_recovered(&self) {
|
||||
let prev_status = self.status.read().await.clone();
|
||||
|
||||
// Only report recovery if we were in an error state
|
||||
if prev_status != MsdHealthStatus::Healthy {
|
||||
let error_count = self.error_count.load(Ordering::Relaxed);
|
||||
info!("MSD recovered after {} errors", error_count);
|
||||
|
||||
// Reset state
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
self.throttler.clear_all();
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = MsdHealthStatus::Healthy;
|
||||
|
||||
// Publish recovery event
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(SystemEvent::MsdRecovered);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current health status
|
||||
pub async fn status(&self) -> MsdHealthStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get the current error count
|
||||
pub fn error_count(&self) -> u32 {
|
||||
self.error_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check if the monitor is in an error state
|
||||
pub async fn is_error(&self) -> bool {
|
||||
matches!(*self.status.read().await, MsdHealthStatus::Error { .. })
|
||||
}
|
||||
|
||||
/// Check if the monitor is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
matches!(*self.status.read().await, MsdHealthStatus::Healthy)
|
||||
}
|
||||
|
||||
/// Reset the monitor to healthy state without publishing events
|
||||
///
|
||||
/// This is useful during initialization.
|
||||
pub async fn reset(&self) {
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = MsdHealthStatus::Healthy;
|
||||
self.throttler.clear_all();
|
||||
}
|
||||
|
||||
/// Get the current error message if in error state
|
||||
pub async fn error_message(&self) -> Option<String> {
|
||||
match &*self.status.read().await {
|
||||
MsdHealthStatus::Error { reason, .. } => Some(reason.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MsdHealthMonitor {
|
||||
fn default() -> Self {
|
||||
Self::with_defaults()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initial_status() {
|
||||
let monitor = MsdHealthMonitor::with_defaults();
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert!(!monitor.is_error().await);
|
||||
assert_eq!(monitor.error_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_error() {
|
||||
let monitor = MsdHealthMonitor::with_defaults();
|
||||
|
||||
monitor
|
||||
.report_error("ConfigFS write failed", "configfs_error")
|
||||
.await;
|
||||
|
||||
assert!(monitor.is_error().await);
|
||||
assert_eq!(monitor.error_count(), 1);
|
||||
|
||||
if let MsdHealthStatus::Error { reason, error_code } = monitor.status().await {
|
||||
assert_eq!(reason, "ConfigFS write failed");
|
||||
assert_eq!(error_code, "configfs_error");
|
||||
} else {
|
||||
panic!("Expected Error status");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_recovered() {
|
||||
let monitor = MsdHealthMonitor::with_defaults();
|
||||
|
||||
// First report an error
|
||||
monitor
|
||||
.report_error("Image not found", "image_not_found")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
// Then report recovery
|
||||
monitor.report_recovered().await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.error_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_count_increments() {
|
||||
let monitor = MsdHealthMonitor::with_defaults();
|
||||
|
||||
for i in 1..=5 {
|
||||
monitor.report_error("Error", "io_error").await;
|
||||
assert_eq!(monitor.error_count(), i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset() {
|
||||
let monitor = MsdHealthMonitor::with_defaults();
|
||||
|
||||
monitor.report_error("Error", "io_error").await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
monitor.reset().await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.error_count(), 0);
|
||||
}
|
||||
}
|
||||
229
src/msd/types.rs
Normal file
229
src/msd/types.rs
Normal file
@@ -0,0 +1,229 @@
|
||||
//! MSD data types and structures
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// MSD operating mode
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MsdMode {
|
||||
/// No storage connected
|
||||
None,
|
||||
/// Image file mounted (ISO/IMG)
|
||||
Image,
|
||||
/// Virtual drive (FAT32) connected
|
||||
Drive,
|
||||
}
|
||||
|
||||
impl Default for MsdMode {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Image file metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageInfo {
|
||||
/// Unique image ID
|
||||
pub id: String,
|
||||
/// Display name
|
||||
pub name: String,
|
||||
/// File path on disk
|
||||
#[serde(skip_serializing)]
|
||||
pub path: PathBuf,
|
||||
/// File size in bytes
|
||||
pub size: u64,
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl ImageInfo {
|
||||
/// Create new image info
|
||||
pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self {
|
||||
Self {
|
||||
id,
|
||||
name,
|
||||
path,
|
||||
size,
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format size for display
|
||||
pub fn size_display(&self) -> String {
|
||||
const KB: u64 = 1024;
|
||||
const MB: u64 = KB * 1024;
|
||||
const GB: u64 = MB * 1024;
|
||||
|
||||
if self.size >= GB {
|
||||
format!("{:.2} GB", self.size as f64 / GB as f64)
|
||||
} else if self.size >= MB {
|
||||
format!("{:.2} MB", self.size as f64 / MB as f64)
|
||||
} else if self.size >= KB {
|
||||
format!("{:.2} KB", self.size as f64 / KB as f64)
|
||||
} else {
|
||||
format!("{} B", self.size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD state information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MsdState {
|
||||
/// Whether MSD feature is available
|
||||
pub available: bool,
|
||||
/// Current mode
|
||||
pub mode: MsdMode,
|
||||
/// Whether storage is connected to target
|
||||
pub connected: bool,
|
||||
/// Currently mounted image (if mode is Image)
|
||||
pub current_image: Option<ImageInfo>,
|
||||
/// Virtual drive info (if mode is Drive)
|
||||
pub drive_info: Option<DriveInfo>,
|
||||
}
|
||||
|
||||
impl Default for MsdState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
available: false,
|
||||
mode: MsdMode::None,
|
||||
connected: false,
|
||||
current_image: None,
|
||||
drive_info: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Virtual drive information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DriveInfo {
|
||||
/// Drive size in bytes
|
||||
pub size: u64,
|
||||
/// Used space in bytes
|
||||
pub used: u64,
|
||||
/// Free space in bytes
|
||||
pub free: u64,
|
||||
/// Whether drive is initialized
|
||||
pub initialized: bool,
|
||||
/// Drive file path
|
||||
#[serde(skip_serializing)]
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
impl DriveInfo {
|
||||
/// Create new drive info
|
||||
pub fn new(path: PathBuf, size: u64) -> Self {
|
||||
Self {
|
||||
size,
|
||||
used: 0,
|
||||
free: size,
|
||||
initialized: false,
|
||||
path,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// File entry in virtual drive
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DriveFile {
|
||||
/// File name
|
||||
pub name: String,
|
||||
/// Relative path from drive root
|
||||
pub path: String,
|
||||
/// File size in bytes (0 for directories)
|
||||
pub size: u64,
|
||||
/// Whether this is a directory
|
||||
pub is_dir: bool,
|
||||
/// Last modified timestamp
|
||||
pub modified: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// MSD connect request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MsdConnectRequest {
|
||||
/// Connection mode: "image" or "drive"
|
||||
pub mode: MsdMode,
|
||||
/// Image ID to mount (required for image mode)
|
||||
pub image_id: Option<String>,
|
||||
/// Mount as CD-ROM (optional, defaults based on image type)
|
||||
#[serde(default)]
|
||||
pub cdrom: Option<bool>,
|
||||
/// Mount as read-only
|
||||
#[serde(default)]
|
||||
pub read_only: Option<bool>,
|
||||
}
|
||||
|
||||
/// Virtual drive init request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DriveInitRequest {
|
||||
/// Drive size in megabytes (defaults to 16GB)
|
||||
#[serde(default = "default_drive_size")]
|
||||
pub size_mb: u32,
|
||||
/// Optional custom path for Ventoy installation
|
||||
pub ventoy_path: Option<String>,
|
||||
}
|
||||
|
||||
fn default_drive_size() -> u32 {
|
||||
16 * 1024 // 16GB
|
||||
}
|
||||
|
||||
/// Image download request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ImageDownloadRequest {
|
||||
/// URL to download from
|
||||
pub url: String,
|
||||
/// Optional custom filename
|
||||
pub filename: Option<String>,
|
||||
}
|
||||
|
||||
/// Download status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DownloadStatus {
|
||||
/// Download has started
|
||||
Started,
|
||||
/// Download is in progress
|
||||
InProgress,
|
||||
/// Download completed successfully
|
||||
Completed,
|
||||
/// Download failed
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Download progress information
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct DownloadProgress {
|
||||
/// Unique download ID
|
||||
pub download_id: String,
|
||||
/// Source URL
|
||||
pub url: String,
|
||||
/// Target filename
|
||||
pub filename: String,
|
||||
/// Bytes downloaded so far
|
||||
pub bytes_downloaded: u64,
|
||||
/// Total file size (None if unknown)
|
||||
pub total_bytes: Option<u64>,
|
||||
/// Progress percentage (0.0 - 100.0, None if total unknown)
|
||||
pub progress_pct: Option<f32>,
|
||||
/// Download status
|
||||
pub status: DownloadStatus,
|
||||
/// Error message if failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_size_display() {
|
||||
let info = ImageInfo::new(
|
||||
"test".into(),
|
||||
"test.iso".into(),
|
||||
PathBuf::from("/tmp/test.iso"),
|
||||
1024 * 1024 * 1024 * 2, // 2 GB
|
||||
);
|
||||
assert!(info.size_display().contains("GB"));
|
||||
}
|
||||
}
|
||||
690
src/msd/ventoy_drive.rs
Normal file
690
src/msd/ventoy_drive.rs
Normal file
@@ -0,0 +1,690 @@
|
||||
//! Ventoy Virtual Drive
|
||||
//!
|
||||
//! Replaces FAT32 VirtualDrive with a Ventoy bootable image.
|
||||
//! Provides a bootable USB with exFAT data partition for ISO files.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage};
|
||||
|
||||
use super::types::{DriveFile, DriveInfo};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Chunk size for streaming reads (64 KB)
|
||||
const STREAM_CHUNK_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// Minimum drive size (1 GB) - Ventoy requires space for boot partition
|
||||
const MIN_DRIVE_SIZE_MB: u32 = 1024;
|
||||
|
||||
/// Maximum drive size (128 GB)
|
||||
const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024;
|
||||
|
||||
/// Default drive label
|
||||
const DEFAULT_LABEL: &str = "ONE-KVM";
|
||||
|
||||
/// Ventoy Drive Manager
|
||||
///
|
||||
/// Thread-safe wrapper around VentoyImage providing async file operations.
|
||||
/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous.
|
||||
/// Uses RwLock to allow concurrent read operations while serializing writes.
|
||||
pub struct VentoyDrive {
|
||||
/// Drive image path
|
||||
path: PathBuf,
|
||||
/// RwLock for concurrent reads, exclusive writes
|
||||
/// (ventoy-img-rs operations are synchronous and not thread-safe)
|
||||
lock: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl VentoyDrive {
|
||||
/// Create new Ventoy drive manager
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self {
|
||||
path,
|
||||
lock: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if drive image exists
|
||||
pub fn exists(&self) -> bool {
|
||||
self.path.exists()
|
||||
}
|
||||
|
||||
/// Get drive path
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Initialize a new Ventoy drive image
|
||||
///
|
||||
/// Creates a bootable Ventoy image with the specified size.
|
||||
/// The image includes boot partitions and an exFAT data partition.
|
||||
pub async fn init(&self, size_mb: u32) -> Result<DriveInfo> {
|
||||
let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB);
|
||||
let size_str = format!("{}M", size_mb);
|
||||
let path = self.path.clone();
|
||||
let _lock = self.lock.write().await; // Write lock for initialization
|
||||
|
||||
info!("Creating {} MB Ventoy drive at {}", size_mb, path.display());
|
||||
|
||||
// Run Ventoy creation in blocking task
|
||||
let info = tokio::task::spawn_blocking(move || {
|
||||
VentoyImage::create(&path, &size_str, DEFAULT_LABEL)
|
||||
.map_err(ventoy_to_app_error)?;
|
||||
|
||||
// Get file metadata for DriveInfo
|
||||
let metadata = std::fs::metadata(&path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to read drive metadata: {}", e))
|
||||
})?;
|
||||
|
||||
Ok::<DriveInfo, AppError>(DriveInfo {
|
||||
size: metadata.len(),
|
||||
used: 0,
|
||||
free: metadata.len(), // Approximate - exFAT overhead not calculated
|
||||
initialized: true,
|
||||
path,
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))??;
|
||||
|
||||
info!("Ventoy drive created successfully");
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
/// Get drive information
|
||||
pub async fn info(&self) -> Result<DriveInfo> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let _lock = self.lock.read().await; // Read lock for info query
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let metadata = std::fs::metadata(&path).map_err(|e| {
|
||||
AppError::Internal(format!("Failed to read drive metadata: {}", e))
|
||||
})?;
|
||||
|
||||
// Open image to get file list and calculate used space
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
let files = image
|
||||
.list_files_recursive()
|
||||
.map_err(ventoy_to_app_error)?;
|
||||
|
||||
let used: u64 = files
|
||||
.iter()
|
||||
.filter(|f| !f.is_directory)
|
||||
.map(|f| f.size)
|
||||
.sum();
|
||||
|
||||
// Note: This is approximate since we don't have exact exFAT overhead
|
||||
let size = metadata.len();
|
||||
let free = size.saturating_sub(used);
|
||||
|
||||
Ok(DriveInfo {
|
||||
size,
|
||||
used,
|
||||
free,
|
||||
initialized: true,
|
||||
path,
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// List files at a given path (or root if empty/"/")
|
||||
pub async fn list_files(&self, dir_path: &str) -> Result<Vec<DriveFile>> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let dir_path = dir_path.to_string();
|
||||
let _lock = self.lock.read().await; // Read lock for listing
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
let files = if dir_path.is_empty() || dir_path == "/" {
|
||||
image.list_files()
|
||||
} else {
|
||||
image.list_files_at(&dir_path)
|
||||
}
|
||||
.map_err(ventoy_to_app_error)?;
|
||||
|
||||
Ok(files
|
||||
.into_iter()
|
||||
.map(|f| ventoy_file_to_drive_file(f, &dir_path))
|
||||
.collect())
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Write a file to the drive from multipart upload (streaming)
|
||||
///
|
||||
/// Streams the file directly into the Ventoy image's exFAT partition.
|
||||
pub async fn write_file_from_multipart_field(
|
||||
&self,
|
||||
file_path: &str,
|
||||
mut field: axum::extract::multipart::Field<'_>,
|
||||
) -> Result<u64> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
// First, stream to a temporary file (to get the size)
|
||||
let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp"));
|
||||
let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4());
|
||||
let temp_path = temp_dir.join(&temp_name);
|
||||
|
||||
// Stream upload to temp file
|
||||
let mut temp_file = tokio::fs::File::create(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
|
||||
|
||||
let mut bytes_written: u64 = 0;
|
||||
|
||||
while let Some(chunk) = field.chunk().await.map_err(|e| {
|
||||
AppError::Internal(format!("Failed to read upload chunk: {}", e))
|
||||
})? {
|
||||
bytes_written += chunk.len() as u64;
|
||||
tokio::io::AsyncWriteExt::write_all(&mut temp_file, &chunk)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?;
|
||||
}
|
||||
|
||||
tokio::io::AsyncWriteExt::flush(&mut temp_file)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?;
|
||||
drop(temp_file);
|
||||
|
||||
// Now copy from temp file to Ventoy image
|
||||
let path = self.path.clone();
|
||||
let file_path = file_path.to_string();
|
||||
let temp_path_clone = temp_path.clone();
|
||||
let _lock = self.lock.write().await; // Write lock for file write
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
// Use add_file_to_path which handles streaming internally
|
||||
image
|
||||
.add_file_to_path(
|
||||
&temp_path_clone,
|
||||
&file_path,
|
||||
true, // create_parents
|
||||
true, // overwrite
|
||||
)
|
||||
.map_err(ventoy_to_app_error)?;
|
||||
|
||||
Ok::<(), AppError>(())
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?;
|
||||
|
||||
// Cleanup temp file
|
||||
let _ = tokio::fs::remove_file(&temp_path).await;
|
||||
|
||||
result?;
|
||||
Ok(bytes_written)
|
||||
}
|
||||
|
||||
/// Read a file from the drive (for download)
|
||||
pub async fn read_file(&self, file_path: &str) -> Result<Vec<u8>> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let file_path = file_path.to_string();
|
||||
let _lock = self.lock.read().await; // Read lock for file read
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
image
|
||||
.read_file(&file_path)
|
||||
.map_err(ventoy_to_app_error)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Get file information without reading content
|
||||
///
|
||||
/// Returns file size, name, and other metadata.
|
||||
/// Returns None if the file doesn't exist.
|
||||
pub async fn get_file_info(&self, file_path: &str) -> Result<Option<DriveFile>> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let file_path_owned = file_path.to_string();
|
||||
let _lock = self.lock.read().await; // Read lock for file info
|
||||
|
||||
let info = tokio::task::spawn_blocking(move || {
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
image
|
||||
.get_file_info(&file_path_owned)
|
||||
.map_err(ventoy_to_app_error)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))??;
|
||||
|
||||
Ok(info.map(|f| DriveFile {
|
||||
name: f.name,
|
||||
path: f.path,
|
||||
size: f.size,
|
||||
is_dir: f.is_directory,
|
||||
modified: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Read a file from the drive as a stream (for large file downloads)
|
||||
///
|
||||
/// Returns an async channel receiver that yields chunks of file data.
|
||||
/// This avoids loading the entire file into memory.
|
||||
pub async fn read_file_stream(
|
||||
&self,
|
||||
file_path: &str,
|
||||
) -> Result<(
|
||||
u64,
|
||||
tokio::sync::mpsc::Receiver<std::result::Result<bytes::Bytes, std::io::Error>>,
|
||||
)> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
// First, get the file size
|
||||
let file_info = self
|
||||
.get_file_info(file_path)
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound(format!("File not found: {}", file_path)))?;
|
||||
|
||||
if file_info.is_dir {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"'{}' is a directory",
|
||||
file_path
|
||||
)));
|
||||
}
|
||||
|
||||
let file_size = file_info.size;
|
||||
let path = self.path.clone();
|
||||
let file_path_owned = file_path.to_string();
|
||||
let lock = self.lock.clone();
|
||||
|
||||
// Create a channel for streaming data
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, std::io::Error>>(8);
|
||||
|
||||
// Spawn blocking task to read and send chunks
|
||||
tokio::task::spawn_blocking(move || {
|
||||
// Hold read lock for the entire read operation
|
||||
let rt = tokio::runtime::Handle::current();
|
||||
let _lock = rt.block_on(lock.read()); // Read lock for streaming
|
||||
|
||||
let image = match VentoyImage::open(&path) {
|
||||
Ok(img) => img,
|
||||
Err(e) => {
|
||||
let _ = rt.block_on(tx.send(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
))));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a channel writer that sends chunks
|
||||
let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone());
|
||||
|
||||
// Stream the file through the writer
|
||||
if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) {
|
||||
let _ = rt.block_on(tx.send(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
))));
|
||||
}
|
||||
});
|
||||
|
||||
Ok((file_size, rx))
|
||||
}
|
||||
|
||||
/// Create a directory
|
||||
pub async fn mkdir(&self, dir_path: &str) -> Result<()> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let dir_path = dir_path.to_string();
|
||||
let _lock = self.lock.write().await; // Write lock for mkdir
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
image
|
||||
.create_directory(&dir_path, true)
|
||||
.map_err(ventoy_to_app_error)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Delete a file or directory
|
||||
pub async fn delete(&self, path_to_delete: &str) -> Result<()> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let path_to_delete = path_to_delete.to_string();
|
||||
let _lock = self.lock.write().await; // Write lock for delete
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
// Use recursive delete to handle directories
|
||||
image
|
||||
.remove_recursive(&path_to_delete)
|
||||
.map_err(ventoy_to_app_error)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert VentoyError to AppError
|
||||
fn ventoy_to_app_error(err: VentoyError) -> AppError {
|
||||
match err {
|
||||
VentoyError::Io(e) => AppError::Io(e),
|
||||
VentoyError::InvalidSize(s) => AppError::BadRequest(format!("Invalid size: {}", s)),
|
||||
VentoyError::SizeParseError(s) => {
|
||||
AppError::BadRequest(format!("Size parse error: {}", s))
|
||||
}
|
||||
VentoyError::FilesystemError(s) => {
|
||||
AppError::Internal(format!("Filesystem error: {}", s))
|
||||
}
|
||||
VentoyError::ImageError(s) => AppError::Internal(format!("Image error: {}", s)),
|
||||
VentoyError::FileNotFound(s) => AppError::NotFound(format!("File not found: {}", s)),
|
||||
VentoyError::ResourceNotFound(s) => {
|
||||
AppError::Internal(format!("Resource not found: {}", s))
|
||||
}
|
||||
VentoyError::PartitionError(s) => {
|
||||
AppError::Internal(format!("Partition error: {}", s))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert VentoyFileInfo to DriveFile
|
||||
fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile {
|
||||
let full_path = if parent_path.is_empty() || parent_path == "/" {
|
||||
format!("/{}", info.name)
|
||||
} else {
|
||||
format!("{}/{}", parent_path.trim_end_matches('/'), info.name)
|
||||
};
|
||||
|
||||
DriveFile {
|
||||
name: info.name,
|
||||
path: full_path,
|
||||
size: info.size,
|
||||
is_dir: info.is_directory,
|
||||
modified: None, // Ventoy FileInfo doesn't include timestamps
|
||||
}
|
||||
}
|
||||
|
||||
/// A writer that sends chunks to an async channel
|
||||
///
|
||||
/// This bridges the sync Write trait with async channels for streaming.
|
||||
struct ChannelWriter {
|
||||
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
|
||||
rt: tokio::runtime::Handle,
|
||||
buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ChannelWriter {
|
||||
fn new(
|
||||
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
|
||||
rt: tokio::runtime::Handle,
|
||||
) -> Self {
|
||||
Self {
|
||||
tx,
|
||||
rt,
|
||||
buffer: Vec::with_capacity(STREAM_CHUNK_SIZE),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_buffer(&mut self) -> std::io::Result<()> {
|
||||
if self.buffer.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunk = bytes::Bytes::copy_from_slice(&self.buffer);
|
||||
self.buffer.clear();
|
||||
|
||||
self.rt
|
||||
.block_on(self.tx.send(Ok(chunk)))
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel closed"))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for ChannelWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
let mut written = 0;
|
||||
|
||||
while written < buf.len() {
|
||||
let space = STREAM_CHUNK_SIZE - self.buffer.len();
|
||||
let to_copy = std::cmp::min(space, buf.len() - written);
|
||||
|
||||
self.buffer.extend_from_slice(&buf[written..written + to_copy]);
|
||||
written += to_copy;
|
||||
|
||||
if self.buffer.len() >= STREAM_CHUNK_SIZE {
|
||||
self.flush_buffer()?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(written)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.flush_buffer()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ChannelWriter {
|
||||
fn drop(&mut self) {
|
||||
// Flush any remaining data when the writer is dropped
|
||||
let _ = self.flush_buffer();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drive_init() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path);
|
||||
|
||||
let info = drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
assert!(info.initialized);
|
||||
assert!(drive.exists());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drive_mkdir() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path);
|
||||
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
drive.mkdir("/isos").await.unwrap();
|
||||
|
||||
let files = drive.list_files("/").await.unwrap();
|
||||
assert_eq!(files.len(), 1);
|
||||
assert!(files[0].is_dir);
|
||||
assert_eq!(files[0].name, "isos");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drive_file_write_and_read() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Write a test file
|
||||
let test_content = b"Hello, Ventoy!";
|
||||
let test_file_path = temp_dir.path().join("test.txt");
|
||||
std::fs::write(&test_file_path, test_content).unwrap();
|
||||
|
||||
// Add file to drive using ventoy-img directly
|
||||
let path = drive.path().clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
image.add_file(&test_file_path).unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Read file from drive
|
||||
let read_data = drive.read_file("/test.txt").await.unwrap();
|
||||
assert_eq!(read_data, test_content);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drive_get_file_info() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Create a directory
|
||||
drive.mkdir("/mydir").await.unwrap();
|
||||
|
||||
// Write a test file
|
||||
let test_content = b"Test file content for info check";
|
||||
let test_file_path = temp_dir.path().join("info_test.txt");
|
||||
std::fs::write(&test_file_path, test_content).unwrap();
|
||||
|
||||
// Add file to drive
|
||||
let path = drive.path().clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
image.add_file(&test_file_path).unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test get_file_info for file
|
||||
let file_info = drive.get_file_info("/info_test.txt").await.unwrap();
|
||||
assert!(file_info.is_some());
|
||||
let file_info = file_info.unwrap();
|
||||
assert_eq!(file_info.name, "info_test.txt");
|
||||
assert_eq!(file_info.size, test_content.len() as u64);
|
||||
assert!(!file_info.is_dir);
|
||||
|
||||
// Test get_file_info for directory
|
||||
let dir_info = drive.get_file_info("/mydir").await.unwrap();
|
||||
assert!(dir_info.is_some());
|
||||
let dir_info = dir_info.unwrap();
|
||||
assert_eq!(dir_info.name, "mydir");
|
||||
assert!(dir_info.is_dir);
|
||||
|
||||
// Test get_file_info for non-existent file
|
||||
let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap();
|
||||
assert!(not_found.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drive_stream_read() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Create test data that spans multiple chunks (>64KB)
|
||||
let test_size = 200 * 1024; // 200 KB
|
||||
let test_content: Vec<u8> = (0..test_size).map(|i| (i % 256) as u8).collect();
|
||||
let test_file_path = temp_dir.path().join("large_file.bin");
|
||||
std::fs::write(&test_file_path, &test_content).unwrap();
|
||||
|
||||
// Add file to drive
|
||||
let path = drive.path().clone();
|
||||
let file_path_clone = test_file_path.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
image.add_file(&file_path_clone).unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Stream read the file
|
||||
let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap();
|
||||
assert_eq!(file_size, test_size as u64);
|
||||
|
||||
// Collect all chunks
|
||||
let mut received_data = Vec::new();
|
||||
while let Some(chunk_result) = rx.recv().await {
|
||||
let chunk = chunk_result.expect("Chunk should not be an error");
|
||||
received_data.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
// Verify data matches
|
||||
assert_eq!(received_data.len(), test_content.len());
|
||||
assert_eq!(received_data, test_content);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drive_stream_read_small_file() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Create a small test file
|
||||
let test_content = b"Small file for streaming test";
|
||||
let test_file_path = temp_dir.path().join("small.txt");
|
||||
std::fs::write(&test_file_path, test_content).unwrap();
|
||||
|
||||
// Add file to drive
|
||||
let path = drive.path().clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
image.add_file(&test_file_path).unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Stream read the file
|
||||
let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap();
|
||||
assert_eq!(file_size, test_content.len() as u64);
|
||||
|
||||
// Collect all chunks
|
||||
let mut received_data = Vec::new();
|
||||
while let Some(chunk_result) = rx.recv().await {
|
||||
let chunk = chunk_result.expect("Chunk should not be an error");
|
||||
received_data.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
// Verify data matches
|
||||
assert_eq!(received_data.as_slice(), test_content);
|
||||
}
|
||||
}
|
||||
138
src/otg/configfs.rs
Normal file
138
src/otg/configfs.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
//! ConfigFS file operations for USB Gadget
|
||||
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// ConfigFS base path for USB gadgets
|
||||
pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget";
|
||||
|
||||
/// Default gadget name
|
||||
pub const DEFAULT_GADGET_NAME: &str = "one-kvm";
|
||||
|
||||
/// USB Vendor ID (Linux Foundation)
|
||||
pub const USB_VENDOR_ID: u16 = 0x1d6b;
|
||||
|
||||
/// USB Product ID (Multifunction Composite Gadget)
|
||||
pub const USB_PRODUCT_ID: u16 = 0x0104;
|
||||
|
||||
/// USB device version
|
||||
pub const USB_BCD_DEVICE: u16 = 0x0100;
|
||||
|
||||
/// USB spec version (USB 2.0)
|
||||
pub const USB_BCD_USB: u16 = 0x0200;
|
||||
|
||||
/// Check if ConfigFS is available
|
||||
pub fn is_configfs_available() -> bool {
|
||||
Path::new(CONFIGFS_PATH).exists()
|
||||
}
|
||||
|
||||
/// Find available UDC (USB Device Controller)
|
||||
pub fn find_udc() -> Option<String> {
|
||||
let udc_path = Path::new("/sys/class/udc");
|
||||
if !udc_path.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
fs::read_dir(udc_path)
|
||||
.ok()?
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.file_name().to_string_lossy().to_string())
|
||||
.next()
|
||||
}
|
||||
|
||||
/// Write string content to a file
|
||||
///
|
||||
/// For sysfs files, this function appends a newline and flushes
|
||||
/// to ensure the kernel processes the write immediately.
|
||||
///
|
||||
/// IMPORTANT: sysfs attributes require a single atomic write() syscall.
|
||||
/// The kernel processes the value on the first write(), so we must
|
||||
/// build the complete buffer (including newline) before writing.
|
||||
pub fn write_file(path: &Path, content: &str) -> Result<()> {
|
||||
// For sysfs files (especially write-only ones like forced_eject),
|
||||
// we need to use simple O_WRONLY without O_TRUNC
|
||||
// O_TRUNC may fail on special files or require read permission
|
||||
let mut file = OpenOptions::new()
|
||||
.write(true)
|
||||
.open(path)
|
||||
.or_else(|e| {
|
||||
// If open fails, try create (for regular files)
|
||||
if path.exists() {
|
||||
Err(e)
|
||||
} else {
|
||||
File::create(path)
|
||||
}
|
||||
})
|
||||
.map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?;
|
||||
|
||||
// Build complete buffer with newline, then write in single syscall.
|
||||
// This is critical for sysfs - multiple write() calls may cause
|
||||
// the kernel to only process partial data or return EINVAL.
|
||||
let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') {
|
||||
content.as_bytes().into()
|
||||
} else {
|
||||
let mut buf = content.as_bytes().to_vec();
|
||||
buf.push(b'\n');
|
||||
buf.into()
|
||||
};
|
||||
|
||||
file.write_all(&data)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
|
||||
|
||||
// Explicitly flush to ensure sysfs processes the write
|
||||
file.flush()
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write binary content to a file
|
||||
pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
|
||||
let mut file = File::create(path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?;
|
||||
|
||||
file.write_all(data)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read string content from a file
|
||||
pub fn read_file(path: &Path) -> Result<String> {
|
||||
fs::read_to_string(path)
|
||||
.map(|s| s.trim().to_string())
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
/// Create directory if not exists
|
||||
pub fn create_dir(path: &Path) -> Result<()> {
|
||||
fs::create_dir_all(path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create directory {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
/// Remove directory
|
||||
pub fn remove_dir(path: &Path) -> Result<()> {
|
||||
if path.exists() {
|
||||
fs::remove_dir(path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to remove directory {}: {}", path.display(), e)))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove file
|
||||
pub fn remove_file(path: &Path) -> Result<()> {
|
||||
if path.exists() {
|
||||
fs::remove_file(path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to remove file {}: {}", path.display(), e)))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create symlink
|
||||
pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> {
|
||||
std::os::unix::fs::symlink(src, dest)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create symlink {} -> {}: {}", dest.display(), src.display(), e)))
|
||||
}
|
||||
91
src/otg/endpoint.rs
Normal file
91
src/otg/endpoint.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
//! USB Endpoint allocation management
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Default maximum endpoints for typical UDC
|
||||
pub const DEFAULT_MAX_ENDPOINTS: u8 = 16;
|
||||
|
||||
/// Endpoint allocator - manages UDC endpoint resources
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EndpointAllocator {
|
||||
max_endpoints: u8,
|
||||
used_endpoints: u8,
|
||||
}
|
||||
|
||||
impl EndpointAllocator {
|
||||
/// Create a new endpoint allocator
|
||||
pub fn new(max_endpoints: u8) -> Self {
|
||||
Self {
|
||||
max_endpoints,
|
||||
used_endpoints: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate endpoints for a function
|
||||
pub fn allocate(&mut self, count: u8) -> Result<()> {
|
||||
if self.used_endpoints + count > self.max_endpoints {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Not enough endpoints: need {}, available {}",
|
||||
count,
|
||||
self.available()
|
||||
)));
|
||||
}
|
||||
self.used_endpoints += count;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Release endpoints
|
||||
pub fn release(&mut self, count: u8) {
|
||||
self.used_endpoints = self.used_endpoints.saturating_sub(count);
|
||||
}
|
||||
|
||||
/// Get available endpoint count
|
||||
pub fn available(&self) -> u8 {
|
||||
self.max_endpoints.saturating_sub(self.used_endpoints)
|
||||
}
|
||||
|
||||
/// Get used endpoint count
|
||||
pub fn used(&self) -> u8 {
|
||||
self.used_endpoints
|
||||
}
|
||||
|
||||
/// Get maximum endpoint count
|
||||
pub fn max(&self) -> u8 {
|
||||
self.max_endpoints
|
||||
}
|
||||
|
||||
/// Check if can allocate
|
||||
pub fn can_allocate(&self, count: u8) -> bool {
|
||||
self.available() >= count
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EndpointAllocator {
|
||||
fn default() -> Self {
|
||||
Self::new(DEFAULT_MAX_ENDPOINTS)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_allocator() {
|
||||
let mut alloc = EndpointAllocator::new(8);
|
||||
assert_eq!(alloc.available(), 8);
|
||||
|
||||
alloc.allocate(2).unwrap();
|
||||
assert_eq!(alloc.available(), 6);
|
||||
assert_eq!(alloc.used(), 2);
|
||||
|
||||
alloc.allocate(4).unwrap();
|
||||
assert_eq!(alloc.available(), 2);
|
||||
|
||||
// Should fail - not enough endpoints
|
||||
assert!(alloc.allocate(3).is_err());
|
||||
|
||||
alloc.release(2);
|
||||
assert_eq!(alloc.available(), 4);
|
||||
}
|
||||
}
|
||||
42
src/otg/function.rs
Normal file
42
src/otg/function.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
//! USB Gadget Function trait definition
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// Function metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunctionMeta {
|
||||
/// Function name (e.g., "hid.usb0")
|
||||
pub name: String,
|
||||
/// Human-readable description
|
||||
pub description: String,
|
||||
/// Number of endpoints used
|
||||
pub endpoints: u8,
|
||||
/// Whether the function is enabled
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// USB Gadget Function trait
|
||||
pub trait GadgetFunction: Send + Sync {
|
||||
/// Get function name (e.g., "hid.usb0", "mass_storage.usb0")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get number of endpoints required
|
||||
fn endpoints_required(&self) -> u8;
|
||||
|
||||
/// Get function metadata
|
||||
fn meta(&self) -> FunctionMeta;
|
||||
|
||||
/// Create function directory and configuration in ConfigFS
|
||||
fn create(&self, gadget_path: &Path) -> Result<()>;
|
||||
|
||||
/// Link function to configuration
|
||||
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()>;
|
||||
|
||||
/// Unlink function from configuration
|
||||
fn unlink(&self, config_path: &Path) -> Result<()>;
|
||||
|
||||
/// Cleanup function directory
|
||||
fn cleanup(&self, gadget_path: &Path) -> Result<()>;
|
||||
}
|
||||
226
src/otg/hid.rs
Normal file
226
src/otg/hid.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! HID Function implementation for USB Gadget
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::debug;
|
||||
|
||||
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file};
|
||||
use super::function::{FunctionMeta, GadgetFunction};
|
||||
use super::report_desc::{KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
|
||||
use crate::error::Result;
|
||||
|
||||
/// HID function type
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HidFunctionType {
|
||||
/// Keyboard with LED feedback support
|
||||
/// Uses 2 endpoints: IN (reports) + OUT (LED status)
|
||||
Keyboard,
|
||||
/// Relative mouse (traditional mouse movement)
|
||||
/// Uses 1 endpoint: IN
|
||||
MouseRelative,
|
||||
/// Absolute mouse (touchscreen-like positioning)
|
||||
/// Uses 1 endpoint: IN
|
||||
MouseAbsolute,
|
||||
}
|
||||
|
||||
impl HidFunctionType {
|
||||
/// Get endpoints required for this function type
|
||||
pub fn endpoints(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 2, // IN + OUT for LED
|
||||
HidFunctionType::MouseRelative => 1,
|
||||
HidFunctionType::MouseAbsolute => 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get HID protocol
|
||||
pub fn protocol(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 1, // Keyboard
|
||||
HidFunctionType::MouseRelative => 2, // Mouse
|
||||
HidFunctionType::MouseAbsolute => 2, // Mouse
|
||||
}
|
||||
}
|
||||
|
||||
/// Get HID subclass
|
||||
pub fn subclass(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 1, // Boot interface
|
||||
HidFunctionType::MouseRelative => 1, // Boot interface
|
||||
HidFunctionType::MouseAbsolute => 0, // No boot interface (absolute not in boot protocol)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get report length in bytes
|
||||
pub fn report_length(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 8,
|
||||
HidFunctionType::MouseRelative => 4,
|
||||
HidFunctionType::MouseAbsolute => 6,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get report descriptor
|
||||
pub fn report_desc(&self) -> &'static [u8] {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => KEYBOARD_WITH_LED,
|
||||
HidFunctionType::MouseRelative => MOUSE_RELATIVE,
|
||||
HidFunctionType::MouseAbsolute => MOUSE_ABSOLUTE,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get description
|
||||
pub fn description(&self) -> &'static str {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => "Keyboard",
|
||||
HidFunctionType::MouseRelative => "Relative Mouse",
|
||||
HidFunctionType::MouseAbsolute => "Absolute Mouse",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HID Function for USB Gadget
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HidFunction {
|
||||
/// Instance number (usb0, usb1, ...)
|
||||
instance: u8,
|
||||
/// Function type
|
||||
func_type: HidFunctionType,
|
||||
/// Cached function name (avoids repeated allocation)
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl HidFunction {
|
||||
/// Create a keyboard function
|
||||
pub fn keyboard(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
func_type: HidFunctionType::Keyboard,
|
||||
name: format!("hid.usb{}", instance),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a relative mouse function
|
||||
pub fn mouse_relative(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
func_type: HidFunctionType::MouseRelative,
|
||||
name: format!("hid.usb{}", instance),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an absolute mouse function
|
||||
pub fn mouse_absolute(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
func_type: HidFunctionType::MouseAbsolute,
|
||||
name: format!("hid.usb{}", instance),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get function path in gadget
|
||||
fn function_path(&self, gadget_path: &Path) -> PathBuf {
|
||||
gadget_path.join("functions").join(self.name())
|
||||
}
|
||||
|
||||
/// Get expected device path (e.g., /dev/hidg0)
|
||||
pub fn device_path(&self) -> PathBuf {
|
||||
PathBuf::from(format!("/dev/hidg{}", self.instance))
|
||||
}
|
||||
}
|
||||
|
||||
impl GadgetFunction for HidFunction {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn endpoints_required(&self) -> u8 {
|
||||
self.func_type.endpoints()
|
||||
}
|
||||
|
||||
fn meta(&self) -> FunctionMeta {
|
||||
FunctionMeta {
|
||||
name: self.name().to_string(),
|
||||
description: self.func_type.description().to_string(),
|
||||
endpoints: self.endpoints_required(),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn create(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
create_dir(&func_path)?;
|
||||
|
||||
// Set HID parameters
|
||||
write_file(&func_path.join("protocol"), &self.func_type.protocol().to_string())?;
|
||||
write_file(&func_path.join("subclass"), &self.func_type.subclass().to_string())?;
|
||||
write_file(&func_path.join("report_length"), &self.func_type.report_length().to_string())?;
|
||||
|
||||
// For keyboard, enable OUT endpoint for LED feedback
|
||||
// no_out_endpoint: 0 = enable OUT endpoint, 1 = disable
|
||||
if matches!(self.func_type, HidFunctionType::Keyboard) {
|
||||
let no_out_path = func_path.join("no_out_endpoint");
|
||||
if no_out_path.exists() || func_path.exists() {
|
||||
// Try to write, ignore error if file doesn't exist yet
|
||||
let _ = write_file(&no_out_path, "0");
|
||||
}
|
||||
}
|
||||
|
||||
// Write report descriptor
|
||||
write_bytes(&func_path.join("report_desc"), self.func_type.report_desc())?;
|
||||
|
||||
debug!("Created HID function: {} at {}", self.name(), func_path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
let link_path = config_path.join(self.name());
|
||||
|
||||
if !link_path.exists() {
|
||||
create_symlink(&func_path, &link_path)?;
|
||||
debug!("Linked HID function {} to config", self.name());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unlink(&self, config_path: &Path) -> Result<()> {
|
||||
let link_path = config_path.join(self.name());
|
||||
remove_file(&link_path)?;
|
||||
debug!("Unlinked HID function {}", self.name());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
remove_dir(&func_path)?;
|
||||
debug!("Cleaned up HID function {}", self.name());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hid_function_types() {
|
||||
assert_eq!(HidFunctionType::Keyboard.endpoints(), 2);
|
||||
assert_eq!(HidFunctionType::MouseRelative.endpoints(), 1);
|
||||
assert_eq!(HidFunctionType::MouseAbsolute.endpoints(), 1);
|
||||
|
||||
assert_eq!(HidFunctionType::Keyboard.report_length(), 8);
|
||||
assert_eq!(HidFunctionType::MouseRelative.report_length(), 4);
|
||||
assert_eq!(HidFunctionType::MouseAbsolute.report_length(), 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hid_function_names() {
|
||||
let kb = HidFunction::keyboard(0);
|
||||
assert_eq!(kb.name(), "hid.usb0");
|
||||
assert_eq!(kb.device_path(), PathBuf::from("/dev/hidg0"));
|
||||
|
||||
let mouse = HidFunction::mouse_relative(1);
|
||||
assert_eq!(mouse.name(), "hid.usb1");
|
||||
}
|
||||
}
|
||||
394
src/otg/manager.rs
Normal file
394
src/otg/manager.rs
Normal file
@@ -0,0 +1,394 @@
|
||||
//! OTG Gadget Manager - unified management for USB Gadget functions
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::configfs::{
|
||||
create_dir, find_udc, is_configfs_available, remove_dir, write_file, CONFIGFS_PATH,
|
||||
DEFAULT_GADGET_NAME, USB_BCD_DEVICE, USB_BCD_USB, USB_PRODUCT_ID, USB_VENDOR_ID,
|
||||
};
|
||||
use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS};
|
||||
use super::function::{FunctionMeta, GadgetFunction};
|
||||
use super::hid::HidFunction;
|
||||
use super::msd::MsdFunction;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// OTG Gadget Manager - unified management for HID and MSD
|
||||
pub struct OtgGadgetManager {
|
||||
/// Gadget name
|
||||
gadget_name: String,
|
||||
/// Gadget path in ConfigFS
|
||||
gadget_path: PathBuf,
|
||||
/// Configuration path
|
||||
config_path: PathBuf,
|
||||
/// Endpoint allocator
|
||||
endpoint_allocator: EndpointAllocator,
|
||||
/// HID instance counter
|
||||
hid_instance: u8,
|
||||
/// MSD instance counter
|
||||
msd_instance: u8,
|
||||
/// Registered functions
|
||||
functions: Vec<Box<dyn GadgetFunction>>,
|
||||
/// Function metadata
|
||||
meta: HashMap<String, FunctionMeta>,
|
||||
/// Bound UDC name
|
||||
bound_udc: Option<String>,
|
||||
/// Whether gadget was created by us
|
||||
created_by_us: bool,
|
||||
}
|
||||
|
||||
impl OtgGadgetManager {
|
||||
/// Create a new gadget manager with default settings
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(DEFAULT_GADGET_NAME, DEFAULT_MAX_ENDPOINTS)
|
||||
}
|
||||
|
||||
/// Create a new gadget manager with custom configuration
|
||||
pub fn with_config(gadget_name: &str, max_endpoints: u8) -> Self {
|
||||
let gadget_path = PathBuf::from(CONFIGFS_PATH).join(gadget_name);
|
||||
let config_path = gadget_path.join("configs/c.1");
|
||||
|
||||
Self {
|
||||
gadget_name: gadget_name.to_string(),
|
||||
gadget_path,
|
||||
config_path,
|
||||
endpoint_allocator: EndpointAllocator::new(max_endpoints),
|
||||
hid_instance: 0,
|
||||
msd_instance: 0,
|
||||
// Pre-allocate for typical use: 3 HID (keyboard, rel mouse, abs mouse) + 1 MSD
|
||||
functions: Vec::with_capacity(4),
|
||||
meta: HashMap::with_capacity(4),
|
||||
bound_udc: None,
|
||||
created_by_us: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if ConfigFS is available
|
||||
pub fn is_available() -> bool {
|
||||
is_configfs_available()
|
||||
}
|
||||
|
||||
/// Find available UDC
|
||||
pub fn find_udc() -> Option<String> {
|
||||
find_udc()
|
||||
}
|
||||
|
||||
/// Check if gadget exists
|
||||
pub fn gadget_exists(&self) -> bool {
|
||||
self.gadget_path.exists()
|
||||
}
|
||||
|
||||
/// Check if gadget is bound to UDC
|
||||
pub fn is_bound(&self) -> bool {
|
||||
let udc_file = self.gadget_path.join("UDC");
|
||||
if let Ok(content) = fs::read_to_string(&udc_file) {
|
||||
!content.trim().is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Add keyboard function
|
||||
/// Returns the expected device path (e.g., /dev/hidg0)
|
||||
pub fn add_keyboard(&mut self) -> Result<PathBuf> {
|
||||
let func = HidFunction::keyboard(self.hid_instance);
|
||||
let device_path = func.device_path();
|
||||
self.add_function(Box::new(func))?;
|
||||
self.hid_instance += 1;
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add relative mouse function
|
||||
pub fn add_mouse_relative(&mut self) -> Result<PathBuf> {
|
||||
let func = HidFunction::mouse_relative(self.hid_instance);
|
||||
let device_path = func.device_path();
|
||||
self.add_function(Box::new(func))?;
|
||||
self.hid_instance += 1;
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add absolute mouse function
|
||||
pub fn add_mouse_absolute(&mut self) -> Result<PathBuf> {
|
||||
let func = HidFunction::mouse_absolute(self.hid_instance);
|
||||
let device_path = func.device_path();
|
||||
self.add_function(Box::new(func))?;
|
||||
self.hid_instance += 1;
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add MSD function (returns MsdFunction handle for LUN configuration)
|
||||
pub fn add_msd(&mut self) -> Result<MsdFunction> {
|
||||
let func = MsdFunction::new(self.msd_instance);
|
||||
let func_clone = func.clone();
|
||||
self.add_function(Box::new(func))?;
|
||||
self.msd_instance += 1;
|
||||
Ok(func_clone)
|
||||
}
|
||||
|
||||
/// Add a generic function
|
||||
fn add_function(&mut self, func: Box<dyn GadgetFunction>) -> Result<()> {
|
||||
let endpoints = func.endpoints_required();
|
||||
|
||||
// Check endpoint availability
|
||||
if !self.endpoint_allocator.can_allocate(endpoints) {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Not enough endpoints for function {}: need {}, available {}",
|
||||
func.name(),
|
||||
endpoints,
|
||||
self.endpoint_allocator.available()
|
||||
)));
|
||||
}
|
||||
|
||||
// Allocate endpoints
|
||||
self.endpoint_allocator.allocate(endpoints)?;
|
||||
|
||||
// Store metadata
|
||||
self.meta.insert(func.name().to_string(), func.meta());
|
||||
|
||||
// Store function
|
||||
self.functions.push(func);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Setup the gadget (create directories and configure)
|
||||
pub fn setup(&mut self) -> Result<()> {
|
||||
info!("Setting up OTG USB Gadget: {}", self.gadget_name);
|
||||
|
||||
// Check ConfigFS availability
|
||||
if !Self::is_available() {
|
||||
return Err(AppError::Internal(
|
||||
"ConfigFS not available. Is it mounted at /sys/kernel/config?".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check if gadget already exists and is bound
|
||||
if self.gadget_exists() {
|
||||
if self.is_bound() {
|
||||
info!("Gadget already exists and is bound, skipping setup");
|
||||
return Ok(());
|
||||
}
|
||||
warn!("Gadget exists but not bound, will reconfigure");
|
||||
self.cleanup()?;
|
||||
}
|
||||
|
||||
// Create gadget directory
|
||||
create_dir(&self.gadget_path)?;
|
||||
self.created_by_us = true;
|
||||
|
||||
// Set device descriptors
|
||||
self.set_device_descriptors()?;
|
||||
|
||||
// Create strings
|
||||
self.create_strings()?;
|
||||
|
||||
// Create configuration
|
||||
self.create_configuration()?;
|
||||
|
||||
// Create and link all functions
|
||||
for func in &self.functions {
|
||||
func.create(&self.gadget_path)?;
|
||||
func.link(&self.config_path, &self.gadget_path)?;
|
||||
}
|
||||
|
||||
info!("OTG USB Gadget setup complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bind gadget to UDC
|
||||
pub fn bind(&mut self) -> Result<()> {
|
||||
let udc = Self::find_udc().ok_or_else(|| {
|
||||
AppError::Internal("No USB Device Controller (UDC) found".to_string())
|
||||
})?;
|
||||
|
||||
info!("Binding gadget to UDC: {}", udc);
|
||||
write_file(&self.gadget_path.join("UDC"), &udc)?;
|
||||
self.bound_udc = Some(udc);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Unbind gadget from UDC
|
||||
pub fn unbind(&mut self) -> Result<()> {
|
||||
if self.is_bound() {
|
||||
write_file(&self.gadget_path.join("UDC"), "")?;
|
||||
self.bound_udc = None;
|
||||
info!("Unbound gadget from UDC");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cleanup all resources
|
||||
pub fn cleanup(&mut self) -> Result<()> {
|
||||
if !self.gadget_exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Cleaning up OTG USB Gadget: {}", self.gadget_name);
|
||||
|
||||
// Unbind from UDC first
|
||||
let _ = self.unbind();
|
||||
|
||||
// Unlink and cleanup functions
|
||||
for func in self.functions.iter().rev() {
|
||||
let _ = func.unlink(&self.config_path);
|
||||
}
|
||||
|
||||
// Remove config strings
|
||||
let config_strings = self.config_path.join("strings/0x409");
|
||||
let _ = remove_dir(&config_strings);
|
||||
let _ = remove_dir(&self.config_path);
|
||||
|
||||
// Cleanup functions
|
||||
for func in self.functions.iter().rev() {
|
||||
let _ = func.cleanup(&self.gadget_path);
|
||||
}
|
||||
|
||||
// Remove gadget strings
|
||||
let gadget_strings = self.gadget_path.join("strings/0x409");
|
||||
let _ = remove_dir(&gadget_strings);
|
||||
|
||||
// Remove gadget directory
|
||||
if let Err(e) = remove_dir(&self.gadget_path) {
|
||||
warn!("Could not remove gadget directory: {}", e);
|
||||
}
|
||||
|
||||
self.created_by_us = false;
|
||||
info!("OTG USB Gadget cleanup complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set USB device descriptors
|
||||
fn set_device_descriptors(&self) -> Result<()> {
|
||||
write_file(&self.gadget_path.join("idVendor"), &format!("0x{:04x}", USB_VENDOR_ID))?;
|
||||
write_file(&self.gadget_path.join("idProduct"), &format!("0x{:04x}", USB_PRODUCT_ID))?;
|
||||
write_file(&self.gadget_path.join("bcdDevice"), &format!("0x{:04x}", USB_BCD_DEVICE))?;
|
||||
write_file(&self.gadget_path.join("bcdUSB"), &format!("0x{:04x}", USB_BCD_USB))?;
|
||||
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device
|
||||
write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?;
|
||||
write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?;
|
||||
debug!("Set device descriptors");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create USB strings
|
||||
fn create_strings(&self) -> Result<()> {
|
||||
let strings_path = self.gadget_path.join("strings/0x409");
|
||||
create_dir(&strings_path)?;
|
||||
|
||||
write_file(&strings_path.join("serialnumber"), "0123456789")?;
|
||||
write_file(&strings_path.join("manufacturer"), "One-KVM")?;
|
||||
write_file(&strings_path.join("product"), "One-KVM HID Device")?;
|
||||
debug!("Created USB strings");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create configuration
|
||||
fn create_configuration(&self) -> Result<()> {
|
||||
create_dir(&self.config_path)?;
|
||||
|
||||
// Create config strings
|
||||
let strings_path = self.config_path.join("strings/0x409");
|
||||
create_dir(&strings_path)?;
|
||||
write_file(&strings_path.join("configuration"), "Config 1: HID + MSD")?;
|
||||
|
||||
// Set max power (500mA)
|
||||
write_file(&self.config_path.join("MaxPower"), "500")?;
|
||||
|
||||
debug!("Created configuration c.1");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get function metadata
|
||||
pub fn get_meta(&self) -> &HashMap<String, FunctionMeta> {
|
||||
&self.meta
|
||||
}
|
||||
|
||||
/// Get endpoint usage info
|
||||
pub fn endpoint_info(&self) -> (u8, u8) {
|
||||
(self.endpoint_allocator.used(), self.endpoint_allocator.max())
|
||||
}
|
||||
|
||||
/// Get gadget path
|
||||
pub fn gadget_path(&self) -> &PathBuf {
|
||||
&self.gadget_path
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OtgGadgetManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OtgGadgetManager {
|
||||
fn drop(&mut self) {
|
||||
if self.created_by_us {
|
||||
if let Err(e) = self.cleanup() {
|
||||
error!("Failed to cleanup OTG gadget on drop: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for HID devices to become available
|
||||
///
|
||||
/// Uses exponential backoff starting from 10ms, capped at 100ms,
|
||||
/// to reduce CPU usage while still providing fast response.
|
||||
pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> bool {
|
||||
let start = std::time::Instant::now();
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
|
||||
// Exponential backoff: start at 10ms, double each time, cap at 100ms
|
||||
let mut delay_ms = 10u64;
|
||||
const MAX_DELAY_MS: u64 = 100;
|
||||
|
||||
while start.elapsed() < timeout {
|
||||
if device_paths.iter().all(|p| p.exists()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculate remaining time to avoid overshooting timeout
|
||||
let remaining = timeout.saturating_sub(start.elapsed());
|
||||
let sleep_duration = std::time::Duration::from_millis(delay_ms).min(remaining);
|
||||
|
||||
if sleep_duration.is_zero() {
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
|
||||
// Exponential backoff with cap
|
||||
delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_manager_creation() {
|
||||
let manager = OtgGadgetManager::new();
|
||||
assert_eq!(manager.gadget_name, DEFAULT_GADGET_NAME);
|
||||
assert!(!manager.gadget_exists()); // Won't exist in test environment
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_endpoint_tracking() {
|
||||
let mut manager = OtgGadgetManager::with_config("test", 8);
|
||||
|
||||
// Keyboard uses 2 endpoints
|
||||
let _ = manager.add_keyboard();
|
||||
assert_eq!(manager.endpoint_allocator.used(), 2);
|
||||
|
||||
// Mouse uses 1 endpoint each
|
||||
let _ = manager.add_mouse_relative();
|
||||
let _ = manager.add_mouse_absolute();
|
||||
assert_eq!(manager.endpoint_allocator.used(), 4);
|
||||
}
|
||||
}
|
||||
35
src/otg/mod.rs
Normal file
35
src/otg/mod.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
//! OTG USB Gadget unified management module
|
||||
//!
|
||||
//! This module provides unified management for USB Gadget functions:
|
||||
//! - HID (Keyboard, Mouse)
|
||||
//! - MSD (Mass Storage Device)
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! OtgService (high-level coordination)
|
||||
//! └── OtgGadgetManager (gadget lifecycle)
|
||||
//! ├── EndpointAllocator (manages UDC endpoints)
|
||||
//! ├── HidFunction (keyboard, mouse_rel, mouse_abs)
|
||||
//! └── MsdFunction (mass storage)
|
||||
//! ```
|
||||
//!
|
||||
//! The recommended way to use this module is through `OtgService`, which provides
|
||||
//! a high-level interface for enabling/disabling HID and MSD functions independently.
|
||||
//! Both `HidController` and `MsdController` should share the same `OtgService` instance.
|
||||
|
||||
pub mod configfs;
|
||||
pub mod endpoint;
|
||||
pub mod function;
|
||||
pub mod hid;
|
||||
pub mod manager;
|
||||
pub mod msd;
|
||||
pub mod report_desc;
|
||||
pub mod service;
|
||||
|
||||
pub use endpoint::EndpointAllocator;
|
||||
pub use function::{FunctionMeta, GadgetFunction};
|
||||
pub use hid::{HidFunction, HidFunctionType};
|
||||
pub use manager::{wait_for_hid_devices, OtgGadgetManager};
|
||||
pub use msd::{MsdFunction, MsdLunConfig};
|
||||
pub use report_desc::{KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
|
||||
pub use service::{HidDevicePaths, OtgService, OtgServiceState};
|
||||
411
src/otg/msd.rs
Normal file
411
src/otg/msd.rs
Normal file
@@ -0,0 +1,411 @@
|
||||
//! MSD (Mass Storage Device) Function implementation for USB Gadget
|
||||
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_file};
|
||||
use super::function::{FunctionMeta, GadgetFunction};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// MSD LUN configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MsdLunConfig {
|
||||
/// File/image path to expose
|
||||
pub file: PathBuf,
|
||||
/// Mount as CD-ROM
|
||||
pub cdrom: bool,
|
||||
/// Read-only mode
|
||||
pub ro: bool,
|
||||
/// Removable media
|
||||
pub removable: bool,
|
||||
/// Disable Force Unit Access
|
||||
pub nofua: bool,
|
||||
}
|
||||
|
||||
impl Default for MsdLunConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
file: PathBuf::new(),
|
||||
cdrom: false,
|
||||
ro: false,
|
||||
removable: true,
|
||||
nofua: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MsdLunConfig {
|
||||
/// Create CD-ROM configuration
|
||||
pub fn cdrom(file: PathBuf) -> Self {
|
||||
Self {
|
||||
file,
|
||||
cdrom: true,
|
||||
ro: true,
|
||||
removable: true,
|
||||
nofua: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create disk configuration
|
||||
pub fn disk(file: PathBuf, read_only: bool) -> Self {
|
||||
Self {
|
||||
file,
|
||||
cdrom: false,
|
||||
ro: read_only,
|
||||
removable: true,
|
||||
nofua: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD Function for USB Gadget
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MsdFunction {
|
||||
/// Instance number (usb0, usb1, ...)
|
||||
instance: u8,
|
||||
/// Cached function name (avoids repeated allocation)
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl MsdFunction {
|
||||
/// Create a new MSD function
|
||||
pub fn new(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
name: format!("mass_storage.usb{}", instance),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get function path in gadget
|
||||
fn function_path(&self, gadget_path: &Path) -> PathBuf {
|
||||
gadget_path.join("functions").join(self.name())
|
||||
}
|
||||
|
||||
/// Get LUN path
|
||||
fn lun_path(&self, gadget_path: &Path, lun: u8) -> PathBuf {
|
||||
self.function_path(gadget_path).join(format!("lun.{}", lun))
|
||||
}
|
||||
|
||||
/// Configure a LUN with specified settings (async version)
|
||||
///
|
||||
/// This is the preferred method for async contexts. It runs the blocking
|
||||
/// file I/O and USB timing operations in a separate thread pool.
|
||||
pub async fn configure_lun_async(
|
||||
&self,
|
||||
gadget_path: &Path,
|
||||
lun: u8,
|
||||
config: &MsdLunConfig,
|
||||
) -> Result<()> {
|
||||
let gadget_path = gadget_path.to_path_buf();
|
||||
let config = config.clone();
|
||||
let this = self.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || this.configure_lun(&gadget_path, lun, &config))
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Configure a LUN with specified settings
|
||||
/// Note: This should be called after the gadget is set up
|
||||
///
|
||||
/// This implementation is based on PiKVM's MSD drive configuration.
|
||||
/// Key improvements:
|
||||
/// - Uses forced_eject when available (safer than clearing file directly)
|
||||
/// - Reduced sleep times to minimize HID interference
|
||||
/// - Better retry logic for EBUSY errors
|
||||
///
|
||||
/// **Note**: This is a blocking function. In async contexts, prefer
|
||||
/// `configure_lun_async` to avoid blocking the runtime.
|
||||
pub fn configure_lun(&self, gadget_path: &Path, lun: u8, config: &MsdLunConfig) -> Result<()> {
|
||||
let lun_path = self.lun_path(gadget_path, lun);
|
||||
|
||||
if !lun_path.exists() {
|
||||
create_dir(&lun_path)?;
|
||||
}
|
||||
|
||||
// Batch read all current values to minimize syscalls
|
||||
let read_attr = |attr: &str| -> String {
|
||||
fs::read_to_string(lun_path.join(attr))
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_string()
|
||||
};
|
||||
|
||||
let current_cdrom = read_attr("cdrom");
|
||||
let current_ro = read_attr("ro");
|
||||
let current_removable = read_attr("removable");
|
||||
let current_nofua = read_attr("nofua");
|
||||
|
||||
// Prepare new values
|
||||
let new_cdrom = if config.cdrom { "1" } else { "0" };
|
||||
let new_ro = if config.ro { "1" } else { "0" };
|
||||
let new_removable = if config.removable { "1" } else { "0" };
|
||||
let new_nofua = if config.nofua { "1" } else { "0" };
|
||||
|
||||
// Disconnect current file first using forced_eject if available (PiKVM approach)
|
||||
let forced_eject_path = lun_path.join("forced_eject");
|
||||
if forced_eject_path.exists() {
|
||||
// forced_eject is safer - it forcibly detaches regardless of host state
|
||||
debug!("Using forced_eject to clear LUN {}", lun);
|
||||
let _ = write_file(&forced_eject_path, "1");
|
||||
} else {
|
||||
// Fallback to clearing file directly
|
||||
let _ = write_file(&lun_path.join("file"), "");
|
||||
}
|
||||
|
||||
// Brief yield to allow USB stack to process the disconnect
|
||||
// Reduced from 200ms to 50ms - let USB protocol handle timing
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Write only changed attributes
|
||||
let cdrom_changed = current_cdrom != new_cdrom;
|
||||
if cdrom_changed {
|
||||
debug!("Updating LUN {} cdrom: {} -> {}", lun, current_cdrom, new_cdrom);
|
||||
write_file(&lun_path.join("cdrom"), new_cdrom)?;
|
||||
}
|
||||
if current_ro != new_ro {
|
||||
debug!("Updating LUN {} ro: {} -> {}", lun, current_ro, new_ro);
|
||||
write_file(&lun_path.join("ro"), new_ro)?;
|
||||
}
|
||||
if current_removable != new_removable {
|
||||
debug!("Updating LUN {} removable: {} -> {}", lun, current_removable, new_removable);
|
||||
write_file(&lun_path.join("removable"), new_removable)?;
|
||||
}
|
||||
if current_nofua != new_nofua {
|
||||
debug!("Updating LUN {} nofua: {} -> {}", lun, current_nofua, new_nofua);
|
||||
write_file(&lun_path.join("nofua"), new_nofua)?;
|
||||
}
|
||||
|
||||
// If cdrom mode changed, brief yield for USB host
|
||||
if cdrom_changed {
|
||||
debug!("CDROM mode changed, brief yield for USB host");
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
}
|
||||
|
||||
// Set file path (this triggers the actual mount) - with retry on EBUSY
|
||||
if config.file.exists() {
|
||||
let file_path = config.file.to_string_lossy();
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 0..5 {
|
||||
match write_file(&lun_path.join("file"), file_path.as_ref()) {
|
||||
Ok(_) => {
|
||||
info!(
|
||||
"LUN {} configured with file: {} (cdrom={}, ro={})",
|
||||
lun,
|
||||
config.file.display(),
|
||||
config.cdrom,
|
||||
config.ro
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
// Check if it's EBUSY (error code 16)
|
||||
let is_busy = e.to_string().contains("Device or resource busy")
|
||||
|| e.to_string().contains("os error 16");
|
||||
|
||||
if is_busy && attempt < 4 {
|
||||
warn!(
|
||||
"LUN {} file write busy, retrying (attempt {}/5)",
|
||||
lun,
|
||||
attempt + 1
|
||||
);
|
||||
// Exponential backoff: 50, 100, 200, 400ms
|
||||
std::thread::sleep(std::time::Duration::from_millis(50 << attempt));
|
||||
last_error = Some(e);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, all retries failed
|
||||
if let Some(e) = last_error {
|
||||
return Err(e);
|
||||
}
|
||||
} else if !config.file.as_os_str().is_empty() {
|
||||
warn!("LUN {} file does not exist: {}", lun, config.file.display());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect LUN (async version)
|
||||
///
|
||||
/// Preferred for async contexts.
|
||||
pub async fn disconnect_lun_async(&self, gadget_path: &Path, lun: u8) -> Result<()> {
|
||||
let gadget_path = gadget_path.to_path_buf();
|
||||
let this = self.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || this.disconnect_lun(&gadget_path, lun))
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Disconnect LUN (clear file)
|
||||
///
|
||||
/// This method uses forced_eject when available, which is safer than
|
||||
/// directly clearing the file path. Based on PiKVM's implementation.
|
||||
/// See: https://docs.kernel.org/usb/mass-storage.html
|
||||
pub fn disconnect_lun(&self, gadget_path: &Path, lun: u8) -> Result<()> {
|
||||
let lun_path = self.lun_path(gadget_path, lun);
|
||||
|
||||
if lun_path.exists() {
|
||||
// Prefer forced_eject if available (PiKVM approach)
|
||||
// forced_eject forcibly detaches the backing file regardless of host state
|
||||
let forced_eject_path = lun_path.join("forced_eject");
|
||||
if forced_eject_path.exists() {
|
||||
debug!("Using forced_eject to disconnect LUN {} at {:?}", lun, forced_eject_path);
|
||||
match write_file(&forced_eject_path, "1") {
|
||||
Ok(_) => debug!("forced_eject write succeeded"),
|
||||
Err(e) => {
|
||||
warn!("forced_eject write failed: {}, falling back to clearing file", e);
|
||||
write_file(&lun_path.join("file"), "")?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to clearing file directly
|
||||
write_file(&lun_path.join("file"), "")?;
|
||||
}
|
||||
info!("LUN {} disconnected", lun);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current LUN file path
|
||||
pub fn get_lun_file(&self, gadget_path: &Path, lun: u8) -> Option<PathBuf> {
|
||||
let lun_path = self.lun_path(gadget_path, lun);
|
||||
let file_path = lun_path.join("file");
|
||||
|
||||
if let Ok(content) = fs::read_to_string(&file_path) {
|
||||
let content = content.trim();
|
||||
if !content.is_empty() {
|
||||
return Some(PathBuf::from(content));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if LUN is connected
|
||||
pub fn is_lun_connected(&self, gadget_path: &Path, lun: u8) -> bool {
|
||||
self.get_lun_file(gadget_path, lun).is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl GadgetFunction for MsdFunction {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn endpoints_required(&self) -> u8 {
|
||||
2 // IN + OUT for bulk transfers
|
||||
}
|
||||
|
||||
fn meta(&self) -> FunctionMeta {
|
||||
FunctionMeta {
|
||||
name: self.name().to_string(),
|
||||
description: if self.instance == 0 {
|
||||
"Mass Storage Drive".to_string()
|
||||
} else {
|
||||
format!("Extra Drive #{}", self.instance)
|
||||
},
|
||||
endpoints: self.endpoints_required(),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn create(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
create_dir(&func_path)?;
|
||||
|
||||
// Set stall to 0 (workaround for some hosts)
|
||||
let stall_path = func_path.join("stall");
|
||||
if stall_path.exists() {
|
||||
let _ = write_file(&stall_path, "0");
|
||||
}
|
||||
|
||||
// LUN 0 is created automatically, but ensure it exists
|
||||
let lun0_path = func_path.join("lun.0");
|
||||
if !lun0_path.exists() {
|
||||
create_dir(&lun0_path)?;
|
||||
}
|
||||
|
||||
// Set default LUN 0 parameters
|
||||
let _ = write_file(&lun0_path.join("cdrom"), "0");
|
||||
let _ = write_file(&lun0_path.join("ro"), "0");
|
||||
let _ = write_file(&lun0_path.join("removable"), "1");
|
||||
let _ = write_file(&lun0_path.join("nofua"), "1");
|
||||
|
||||
debug!("Created MSD function: {}", self.name());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
let link_path = config_path.join(self.name());
|
||||
|
||||
if !link_path.exists() {
|
||||
create_symlink(&func_path, &link_path)?;
|
||||
debug!("Linked MSD function {} to config", self.name());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unlink(&self, config_path: &Path) -> Result<()> {
|
||||
let link_path = config_path.join(self.name());
|
||||
remove_file(&link_path)?;
|
||||
debug!("Unlinked MSD function {}", self.name());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
|
||||
// Disconnect all LUNs first
|
||||
for lun in 0..8 {
|
||||
let _ = self.disconnect_lun(gadget_path, lun);
|
||||
}
|
||||
|
||||
// Remove function directory
|
||||
if let Err(e) = remove_dir(&func_path) {
|
||||
warn!("Could not remove MSD function directory: {}", e);
|
||||
}
|
||||
|
||||
debug!("Cleaned up MSD function {}", self.name());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_lun_config_cdrom() {
|
||||
let config = MsdLunConfig::cdrom(PathBuf::from("/tmp/test.iso"));
|
||||
assert!(config.cdrom);
|
||||
assert!(config.ro);
|
||||
assert!(config.removable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lun_config_disk() {
|
||||
let config = MsdLunConfig::disk(PathBuf::from("/tmp/test.img"), false);
|
||||
assert!(!config.cdrom);
|
||||
assert!(!config.ro);
|
||||
assert!(config.removable);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_msd_function_name() {
|
||||
let msd = MsdFunction::new(0);
|
||||
assert_eq!(msd.name(), "mass_storage.usb0");
|
||||
assert_eq!(msd.endpoints_required(), 2);
|
||||
}
|
||||
}
|
||||
160
src/otg/report_desc.rs
Normal file
160
src/otg/report_desc.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
//! HID Report Descriptors
|
||||
|
||||
/// Keyboard HID Report Descriptor with LED output support
|
||||
/// Report format (8 bytes input):
|
||||
/// [0] Modifier keys (8 bits)
|
||||
/// [1] Reserved
|
||||
/// [2-7] Key codes (6 keys)
|
||||
/// LED output (1 byte):
|
||||
/// Bit 0: Num Lock
|
||||
/// Bit 1: Caps Lock
|
||||
/// Bit 2: Scroll Lock
|
||||
/// Bit 3: Compose
|
||||
/// Bit 4: Kana
|
||||
pub const KEYBOARD_WITH_LED: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x06, // Usage (Keyboard)
|
||||
0xA1, 0x01, // Collection (Application)
|
||||
// Modifier keys input (8 bits)
|
||||
0x05, 0x07, // Usage Page (Key Codes)
|
||||
0x19, 0xE0, // Usage Minimum (224) - Left Control
|
||||
0x29, 0xE7, // Usage Maximum (231) - Right GUI
|
||||
0x15, 0x00, // Logical Minimum (0)
|
||||
0x25, 0x01, // Logical Maximum (1)
|
||||
0x75, 0x01, // Report Size (1)
|
||||
0x95, 0x08, // Report Count (8)
|
||||
0x81, 0x02, // Input (Data, Variable, Absolute) - Modifier byte
|
||||
// Reserved byte
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x75, 0x08, // Report Size (8)
|
||||
0x81, 0x01, // Input (Constant) - Reserved byte
|
||||
// LED output (5 bits)
|
||||
0x95, 0x05, // Report Count (5)
|
||||
0x75, 0x01, // Report Size (1)
|
||||
0x05, 0x08, // Usage Page (LEDs)
|
||||
0x19, 0x01, // Usage Minimum (1) - Num Lock
|
||||
0x29, 0x05, // Usage Maximum (5) - Kana
|
||||
0x91, 0x02, // Output (Data, Variable, Absolute) - LED bits
|
||||
// LED padding (3 bits)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x75, 0x03, // Report Size (3)
|
||||
0x91, 0x01, // Output (Constant) - Padding
|
||||
// Key array (6 bytes)
|
||||
0x95, 0x06, // Report Count (6)
|
||||
0x75, 0x08, // Report Size (8)
|
||||
0x15, 0x00, // Logical Minimum (0)
|
||||
0x26, 0xFF, 0x00, // Logical Maximum (255)
|
||||
0x05, 0x07, // Usage Page (Key Codes)
|
||||
0x19, 0x00, // Usage Minimum (0)
|
||||
0x2A, 0xFF, 0x00, // Usage Maximum (255)
|
||||
0x81, 0x00, // Input (Data, Array) - Key array (6 keys)
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
/// Relative Mouse HID Report Descriptor (4 bytes report)
|
||||
/// Report format:
|
||||
/// [0] Buttons (5 bits) + padding (3 bits)
|
||||
/// [1] X movement (signed 8-bit)
|
||||
/// [2] Y movement (signed 8-bit)
|
||||
/// [3] Wheel (signed 8-bit)
|
||||
pub const MOUSE_RELATIVE: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x02, // Usage (Mouse)
|
||||
0xA1, 0x01, // Collection (Application)
|
||||
0x09, 0x01, // Usage (Pointer)
|
||||
0xA1, 0x00, // Collection (Physical)
|
||||
// Buttons (5 bits)
|
||||
0x05, 0x09, // Usage Page (Button)
|
||||
0x19, 0x01, // Usage Minimum (1)
|
||||
0x29, 0x05, // Usage Maximum (5) - 5 buttons
|
||||
0x15, 0x00, // Logical Minimum (0)
|
||||
0x25, 0x01, // Logical Maximum (1)
|
||||
0x95, 0x05, // Report Count (5)
|
||||
0x75, 0x01, // Report Size (1)
|
||||
0x81, 0x02, // Input (Data, Variable, Absolute) - Button bits
|
||||
// Padding (3 bits)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x75, 0x03, // Report Size (3)
|
||||
0x81, 0x01, // Input (Constant) - Padding
|
||||
// X, Y movement
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x30, // Usage (X)
|
||||
0x09, 0x31, // Usage (Y)
|
||||
0x15, 0x81, // Logical Minimum (-127)
|
||||
0x25, 0x7F, // Logical Maximum (127)
|
||||
0x75, 0x08, // Report Size (8)
|
||||
0x95, 0x02, // Report Count (2)
|
||||
0x81, 0x06, // Input (Data, Variable, Relative) - X, Y
|
||||
// Wheel
|
||||
0x09, 0x38, // Usage (Wheel)
|
||||
0x15, 0x81, // Logical Minimum (-127)
|
||||
0x25, 0x7F, // Logical Maximum (127)
|
||||
0x75, 0x08, // Report Size (8)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x81, 0x06, // Input (Data, Variable, Relative) - Wheel
|
||||
0xC0, // End Collection
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
/// Absolute Mouse HID Report Descriptor (6 bytes report)
|
||||
/// Report format:
|
||||
/// [0] Buttons (5 bits) + padding (3 bits)
|
||||
/// [1-2] X position (16-bit, 0-32767)
|
||||
/// [3-4] Y position (16-bit, 0-32767)
|
||||
/// [5] Wheel (signed 8-bit)
|
||||
pub const MOUSE_ABSOLUTE: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x02, // Usage (Mouse)
|
||||
0xA1, 0x01, // Collection (Application)
|
||||
0x09, 0x01, // Usage (Pointer)
|
||||
0xA1, 0x00, // Collection (Physical)
|
||||
// Buttons (5 bits)
|
||||
0x05, 0x09, // Usage Page (Button)
|
||||
0x19, 0x01, // Usage Minimum (1)
|
||||
0x29, 0x05, // Usage Maximum (5) - 5 buttons
|
||||
0x15, 0x00, // Logical Minimum (0)
|
||||
0x25, 0x01, // Logical Maximum (1)
|
||||
0x95, 0x05, // Report Count (5)
|
||||
0x75, 0x01, // Report Size (1)
|
||||
0x81, 0x02, // Input (Data, Variable, Absolute) - Button bits
|
||||
// Padding (3 bits)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x75, 0x03, // Report Size (3)
|
||||
0x81, 0x01, // Input (Constant) - Padding
|
||||
// X position (16-bit absolute)
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x30, // Usage (X)
|
||||
0x16, 0x00, 0x00, // Logical Minimum (0)
|
||||
0x26, 0xFF, 0x7F, // Logical Maximum (32767)
|
||||
0x75, 0x10, // Report Size (16)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x81, 0x02, // Input (Data, Variable, Absolute) - X
|
||||
// Y position (16-bit absolute)
|
||||
0x09, 0x31, // Usage (Y)
|
||||
0x16, 0x00, 0x00, // Logical Minimum (0)
|
||||
0x26, 0xFF, 0x7F, // Logical Maximum (32767)
|
||||
0x75, 0x10, // Report Size (16)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x81, 0x02, // Input (Data, Variable, Absolute) - Y
|
||||
// Wheel
|
||||
0x09, 0x38, // Usage (Wheel)
|
||||
0x15, 0x81, // Logical Minimum (-127)
|
||||
0x25, 0x7F, // Logical Maximum (127)
|
||||
0x75, 0x08, // Report Size (8)
|
||||
0x95, 0x01, // Report Count (1)
|
||||
0x81, 0x06, // Input (Data, Variable, Relative) - Wheel
|
||||
0xC0, // End Collection
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_report_descriptor_sizes() {
|
||||
assert!(!KEYBOARD_WITH_LED.is_empty());
|
||||
assert!(!MOUSE_RELATIVE.is_empty());
|
||||
assert!(!MOUSE_ABSOLUTE.is_empty());
|
||||
}
|
||||
}
|
||||
503
src/otg/service.rs
Normal file
503
src/otg/service.rs
Normal file
@@ -0,0 +1,503 @@
|
||||
//! OTG Service - unified gadget lifecycle management
|
||||
//!
|
||||
//! This module provides centralized management for USB OTG gadget functions.
|
||||
//! It solves the ownership problem where both HID and MSD need access to the
|
||||
//! same USB gadget but should be independently configurable.
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! ┌─────────────────────────┐
|
||||
//! │ OtgService │
|
||||
//! │ ┌───────────────────┐ │
|
||||
//! │ │ OtgGadgetManager │ │
|
||||
//! │ └───────────────────┘ │
|
||||
//! │ ↓ ↓ │
|
||||
//! │ ┌─────┐ ┌─────┐ │
|
||||
//! │ │ HID │ │ MSD │ │
|
||||
//! │ └─────┘ └─────┘ │
|
||||
//! └─────────────────────────┘
|
||||
//! ↑ ↑
|
||||
//! HidController MsdController
|
||||
//! ```
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicU8, Ordering};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::manager::{wait_for_hid_devices, OtgGadgetManager};
|
||||
use super::msd::MsdFunction;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Bitflags for requested functions (lock-free)
|
||||
const FLAG_HID: u8 = 0b01;
|
||||
const FLAG_MSD: u8 = 0b10;
|
||||
|
||||
/// HID device paths
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HidDevicePaths {
|
||||
pub keyboard: PathBuf,
|
||||
pub mouse_relative: PathBuf,
|
||||
pub mouse_absolute: PathBuf,
|
||||
}
|
||||
|
||||
impl Default for HidDevicePaths {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
keyboard: PathBuf::from("/dev/hidg0"),
|
||||
mouse_relative: PathBuf::from("/dev/hidg1"),
|
||||
mouse_absolute: PathBuf::from("/dev/hidg2"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OTG Service state
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OtgServiceState {
|
||||
/// Whether the gadget is created and bound
|
||||
pub gadget_active: bool,
|
||||
/// Whether HID functions are enabled
|
||||
pub hid_enabled: bool,
|
||||
/// Whether MSD function is enabled
|
||||
pub msd_enabled: bool,
|
||||
/// HID device paths (set after gadget setup)
|
||||
pub hid_paths: Option<HidDevicePaths>,
|
||||
/// Error message if setup failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// OTG Service - unified gadget lifecycle management
|
||||
///
|
||||
/// This service owns the OtgGadgetManager and provides a high-level interface
|
||||
/// for enabling/disabling HID and MSD functions. It ensures proper coordination
|
||||
/// between the two subsystems and handles gadget lifecycle management.
|
||||
pub struct OtgService {
|
||||
/// The underlying gadget manager
|
||||
manager: Mutex<Option<OtgGadgetManager>>,
|
||||
/// Current state
|
||||
state: RwLock<OtgServiceState>,
|
||||
/// MSD function handle (for runtime LUN configuration)
|
||||
msd_function: RwLock<Option<MsdFunction>>,
|
||||
/// Requested functions flags (atomic, lock-free read/write)
|
||||
requested_flags: AtomicU8,
|
||||
}
|
||||
|
||||
impl OtgService {
|
||||
/// Create a new OTG service
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
manager: Mutex::new(None),
|
||||
state: RwLock::new(OtgServiceState::default()),
|
||||
msd_function: RwLock::new(None),
|
||||
requested_flags: AtomicU8::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if HID is requested (lock-free)
|
||||
#[inline]
|
||||
fn is_hid_requested(&self) -> bool {
|
||||
self.requested_flags.load(Ordering::Acquire) & FLAG_HID != 0
|
||||
}
|
||||
|
||||
/// Check if MSD is requested (lock-free)
|
||||
#[inline]
|
||||
fn is_msd_requested(&self) -> bool {
|
||||
self.requested_flags.load(Ordering::Acquire) & FLAG_MSD != 0
|
||||
}
|
||||
|
||||
/// Set HID requested flag (lock-free)
|
||||
#[inline]
|
||||
fn set_hid_requested(&self, requested: bool) {
|
||||
if requested {
|
||||
self.requested_flags.fetch_or(FLAG_HID, Ordering::Release);
|
||||
} else {
|
||||
self.requested_flags.fetch_and(!FLAG_HID, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
/// Set MSD requested flag (lock-free)
|
||||
#[inline]
|
||||
fn set_msd_requested(&self, requested: bool) {
|
||||
if requested {
|
||||
self.requested_flags.fetch_or(FLAG_MSD, Ordering::Release);
|
||||
} else {
|
||||
self.requested_flags.fetch_and(!FLAG_MSD, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if OTG is available on this system
|
||||
pub fn is_available() -> bool {
|
||||
OtgGadgetManager::is_available() && OtgGadgetManager::find_udc().is_some()
|
||||
}
|
||||
|
||||
/// Get current service state
|
||||
pub async fn state(&self) -> OtgServiceState {
|
||||
self.state.read().await.clone()
|
||||
}
|
||||
|
||||
/// Check if gadget is active
|
||||
pub async fn is_gadget_active(&self) -> bool {
|
||||
self.state.read().await.gadget_active
|
||||
}
|
||||
|
||||
/// Check if HID is enabled
|
||||
pub async fn is_hid_enabled(&self) -> bool {
|
||||
self.state.read().await.hid_enabled
|
||||
}
|
||||
|
||||
/// Check if MSD is enabled
|
||||
pub async fn is_msd_enabled(&self) -> bool {
|
||||
self.state.read().await.msd_enabled
|
||||
}
|
||||
|
||||
/// Get gadget path (for MSD LUN configuration)
|
||||
pub async fn gadget_path(&self) -> Option<PathBuf> {
|
||||
let manager = self.manager.lock().await;
|
||||
manager.as_ref().map(|m| m.gadget_path().clone())
|
||||
}
|
||||
|
||||
/// Get HID device paths
|
||||
pub async fn hid_device_paths(&self) -> Option<HidDevicePaths> {
|
||||
self.state.read().await.hid_paths.clone()
|
||||
}
|
||||
|
||||
/// Get MSD function handle (for LUN configuration)
|
||||
pub async fn msd_function(&self) -> Option<MsdFunction> {
|
||||
self.msd_function.read().await.clone()
|
||||
}
|
||||
|
||||
/// Enable HID functions
|
||||
///
|
||||
/// This will create the gadget if not already created, add HID functions,
|
||||
/// and bind the gadget to UDC.
|
||||
pub async fn enable_hid(&self) -> Result<HidDevicePaths> {
|
||||
info!("Enabling HID functions via OtgService");
|
||||
|
||||
// Mark HID as requested (lock-free)
|
||||
self.set_hid_requested(true);
|
||||
|
||||
// Check if already enabled
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if state.hid_enabled {
|
||||
if let Some(ref paths) = state.hid_paths {
|
||||
info!("HID already enabled, returning existing paths");
|
||||
return Ok(paths.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate gadget with both HID and MSD if needed
|
||||
self.recreate_gadget().await?;
|
||||
|
||||
// Get HID paths from state
|
||||
let state = self.state.read().await;
|
||||
state
|
||||
.hid_paths
|
||||
.clone()
|
||||
.ok_or_else(|| AppError::Internal("HID paths not set after gadget setup".to_string()))
|
||||
}
|
||||
|
||||
/// Disable HID functions
|
||||
///
|
||||
/// This will unbind the gadget, remove HID functions, and optionally
|
||||
/// recreate the gadget with only MSD if MSD is still enabled.
|
||||
pub async fn disable_hid(&self) -> Result<()> {
|
||||
info!("Disabling HID functions via OtgService");
|
||||
|
||||
// Mark HID as not requested (lock-free)
|
||||
self.set_hid_requested(false);
|
||||
|
||||
// Check if HID is enabled
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if !state.hid_enabled {
|
||||
info!("HID already disabled");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate gadget without HID (or destroy if MSD also disabled)
|
||||
self.recreate_gadget().await
|
||||
}
|
||||
|
||||
/// Enable MSD function
|
||||
///
|
||||
/// This will create the gadget if not already created, add MSD function,
|
||||
/// and bind the gadget to UDC.
|
||||
pub async fn enable_msd(&self) -> Result<MsdFunction> {
|
||||
info!("Enabling MSD function via OtgService");
|
||||
|
||||
// Mark MSD as requested (lock-free)
|
||||
self.set_msd_requested(true);
|
||||
|
||||
// Check if already enabled
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if state.msd_enabled {
|
||||
let msd = self.msd_function.read().await;
|
||||
if let Some(ref func) = *msd {
|
||||
info!("MSD already enabled, returning existing function");
|
||||
return Ok(func.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate gadget with both HID and MSD if needed
|
||||
self.recreate_gadget().await?;
|
||||
|
||||
// Get MSD function
|
||||
let msd = self.msd_function.read().await;
|
||||
msd.clone()
|
||||
.ok_or_else(|| AppError::Internal("MSD function not set after gadget setup".to_string()))
|
||||
}
|
||||
|
||||
/// Disable MSD function
|
||||
///
|
||||
/// This will unbind the gadget, remove MSD function, and optionally
|
||||
/// recreate the gadget with only HID if HID is still enabled.
|
||||
pub async fn disable_msd(&self) -> Result<()> {
|
||||
info!("Disabling MSD function via OtgService");
|
||||
|
||||
// Mark MSD as not requested (lock-free)
|
||||
self.set_msd_requested(false);
|
||||
|
||||
// Check if MSD is enabled
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if !state.msd_enabled {
|
||||
info!("MSD already disabled");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Recreate gadget without MSD (or destroy if HID also disabled)
|
||||
self.recreate_gadget().await
|
||||
}
|
||||
|
||||
/// Recreate the gadget with currently requested functions
|
||||
///
|
||||
/// This is called whenever the set of enabled functions changes.
|
||||
/// It will:
|
||||
/// 1. Check if recreation is needed (function set changed)
|
||||
/// 2. If needed: cleanup existing gadget
|
||||
/// 3. Create new gadget with requested functions
|
||||
/// 4. Setup and bind
|
||||
async fn recreate_gadget(&self) -> Result<()> {
|
||||
// Read requested flags atomically (lock-free)
|
||||
let hid_requested = self.is_hid_requested();
|
||||
let msd_requested = self.is_msd_requested();
|
||||
|
||||
info!(
|
||||
"Recreating gadget with: HID={}, MSD={}",
|
||||
hid_requested, msd_requested
|
||||
);
|
||||
|
||||
// Check if gadget already matches requested state
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if state.gadget_active
|
||||
&& state.hid_enabled == hid_requested
|
||||
&& state.msd_enabled == msd_requested
|
||||
{
|
||||
info!("Gadget already has requested functions, skipping recreate");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup existing gadget
|
||||
{
|
||||
let mut manager = self.manager.lock().await;
|
||||
if let Some(mut m) = manager.take() {
|
||||
info!("Cleaning up existing gadget before recreate");
|
||||
if let Err(e) = m.cleanup() {
|
||||
warn!("Error cleaning up existing gadget: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear MSD function
|
||||
*self.msd_function.write().await = None;
|
||||
|
||||
// Update state to inactive
|
||||
{
|
||||
let mut state = self.state.write().await;
|
||||
state.gadget_active = false;
|
||||
state.hid_enabled = false;
|
||||
state.msd_enabled = false;
|
||||
state.hid_paths = None;
|
||||
state.error = None;
|
||||
}
|
||||
|
||||
// If nothing requested, we're done
|
||||
if !hid_requested && !msd_requested {
|
||||
info!("No functions requested, gadget destroyed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check if OTG is available
|
||||
if !Self::is_available() {
|
||||
let error = "OTG not available: ConfigFS not mounted or no UDC found".to_string();
|
||||
let mut state = self.state.write().await;
|
||||
state.error = Some(error.clone());
|
||||
return Err(AppError::Internal(error));
|
||||
}
|
||||
|
||||
// Create new gadget manager
|
||||
let mut manager = OtgGadgetManager::new();
|
||||
let mut hid_paths = None;
|
||||
|
||||
// Add HID functions if requested
|
||||
if hid_requested {
|
||||
match (
|
||||
manager.add_keyboard(),
|
||||
manager.add_mouse_relative(),
|
||||
manager.add_mouse_absolute(),
|
||||
) {
|
||||
(Ok(kb), Ok(rel), Ok(abs)) => {
|
||||
hid_paths = Some(HidDevicePaths {
|
||||
keyboard: kb,
|
||||
mouse_relative: rel,
|
||||
mouse_absolute: abs,
|
||||
});
|
||||
debug!("HID functions added to gadget");
|
||||
}
|
||||
(Err(e), _, _) | (_, Err(e), _) | (_, _, Err(e)) => {
|
||||
let error = format!("Failed to add HID functions: {}", e);
|
||||
let mut state = self.state.write().await;
|
||||
state.error = Some(error.clone());
|
||||
return Err(AppError::Internal(error));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add MSD function if requested
|
||||
let msd_func = if msd_requested {
|
||||
match manager.add_msd() {
|
||||
Ok(func) => {
|
||||
debug!("MSD function added to gadget");
|
||||
Some(func)
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("Failed to add MSD function: {}", e);
|
||||
let mut state = self.state.write().await;
|
||||
state.error = Some(error.clone());
|
||||
return Err(AppError::Internal(error));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Setup gadget
|
||||
if let Err(e) = manager.setup() {
|
||||
let error = format!("Failed to setup gadget: {}", e);
|
||||
let mut state = self.state.write().await;
|
||||
state.error = Some(error.clone());
|
||||
return Err(AppError::Internal(error));
|
||||
}
|
||||
|
||||
// Bind to UDC
|
||||
if let Err(e) = manager.bind() {
|
||||
let error = format!("Failed to bind gadget to UDC: {}", e);
|
||||
let mut state = self.state.write().await;
|
||||
state.error = Some(error.clone());
|
||||
// Cleanup on failure
|
||||
let _ = manager.cleanup();
|
||||
return Err(AppError::Internal(error));
|
||||
}
|
||||
|
||||
// Wait for HID devices to appear
|
||||
if let Some(ref paths) = hid_paths {
|
||||
let device_paths = vec![
|
||||
paths.keyboard.clone(),
|
||||
paths.mouse_relative.clone(),
|
||||
paths.mouse_absolute.clone(),
|
||||
];
|
||||
if !wait_for_hid_devices(&device_paths, 2000).await {
|
||||
warn!("HID devices did not appear after gadget setup");
|
||||
}
|
||||
}
|
||||
|
||||
// Store manager and update state
|
||||
{
|
||||
*self.manager.lock().await = Some(manager);
|
||||
}
|
||||
|
||||
{
|
||||
*self.msd_function.write().await = msd_func;
|
||||
}
|
||||
|
||||
{
|
||||
let mut state = self.state.write().await;
|
||||
state.gadget_active = true;
|
||||
state.hid_enabled = hid_requested;
|
||||
state.msd_enabled = msd_requested;
|
||||
state.hid_paths = hid_paths;
|
||||
state.error = None;
|
||||
}
|
||||
|
||||
info!("Gadget created successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the OTG service and cleanup all resources
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down OTG service");
|
||||
|
||||
// Mark nothing as requested (lock-free)
|
||||
self.requested_flags.store(0, Ordering::Release);
|
||||
|
||||
// Cleanup gadget
|
||||
let mut manager = self.manager.lock().await;
|
||||
if let Some(mut m) = manager.take() {
|
||||
if let Err(e) = m.cleanup() {
|
||||
warn!("Error cleaning up gadget during shutdown: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear state
|
||||
*self.msd_function.write().await = None;
|
||||
{
|
||||
let mut state = self.state.write().await;
|
||||
*state = OtgServiceState::default();
|
||||
}
|
||||
|
||||
info!("OTG service shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OtgService {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OtgService {
|
||||
fn drop(&mut self) {
|
||||
// Gadget cleanup is handled by OtgGadgetManager's Drop
|
||||
debug!("OtgService dropping");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_service_creation() {
|
||||
let service = OtgService::new();
|
||||
// Just test that creation doesn't panic
|
||||
assert!(!OtgService::is_available() || true); // Depends on environment
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initial_state() {
|
||||
let service = OtgService::new();
|
||||
let state = service.state().await;
|
||||
assert!(!state.gadget_active);
|
||||
assert!(!state.hid_enabled);
|
||||
assert!(!state.msd_enabled);
|
||||
}
|
||||
}
|
||||
227
src/state.rs
Normal file
227
src/state.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
|
||||
use crate::atx::AtxController;
|
||||
use crate::audio::AudioController;
|
||||
use crate::auth::{SessionStore, UserStore};
|
||||
use crate::config::ConfigStore;
|
||||
use crate::events::{AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent, VideoDeviceInfo};
|
||||
use crate::extensions::ExtensionManager;
|
||||
use crate::hid::HidController;
|
||||
use crate::msd::MsdController;
|
||||
use crate::otg::OtgService;
|
||||
use crate::video::VideoStreamManager;
|
||||
|
||||
/// Application-wide state shared across handlers
|
||||
///
|
||||
/// # Video Streaming
|
||||
///
|
||||
/// All video operations should go through `stream_manager`:
|
||||
/// - `stream_manager.webrtc_streamer()` - WebRTC streaming (H264, extensible to VP8/VP9/H265)
|
||||
/// - `stream_manager.mjpeg_handler()` - MJPEG stream handler
|
||||
/// - `stream_manager.streamer()` - Low-level video capture
|
||||
/// - `stream_manager.start()` / `stream_manager.stop()` - Stream control
|
||||
/// - `stream_manager.stats()` - Stream statistics
|
||||
/// - `stream_manager.list_devices()` - List video devices
|
||||
pub struct AppState {
|
||||
/// Configuration store
|
||||
pub config: ConfigStore,
|
||||
/// Session store
|
||||
pub sessions: SessionStore,
|
||||
/// User store
|
||||
pub users: UserStore,
|
||||
/// OTG Service - centralized USB gadget lifecycle management
|
||||
/// This is the single owner of OtgGadgetManager, coordinating HID and MSD functions
|
||||
pub otg_service: Arc<OtgService>,
|
||||
/// Video stream manager (unified MJPEG/WebRTC management)
|
||||
/// This is the single entry point for all video operations.
|
||||
pub stream_manager: Arc<VideoStreamManager>,
|
||||
/// HID controller
|
||||
pub hid: Arc<HidController>,
|
||||
/// MSD controller (optional, may not be initialized)
|
||||
pub msd: Arc<RwLock<Option<MsdController>>>,
|
||||
/// ATX controller (optional, may not be initialized)
|
||||
pub atx: Arc<RwLock<Option<AtxController>>>,
|
||||
/// Audio controller
|
||||
pub audio: Arc<AudioController>,
|
||||
/// Extension manager (ttyd, gostc, easytier)
|
||||
pub extensions: Arc<ExtensionManager>,
|
||||
/// Event bus for real-time notifications
|
||||
pub events: Arc<EventBus>,
|
||||
/// Shutdown signal sender
|
||||
pub shutdown_tx: broadcast::Sender<()>,
|
||||
/// Data directory path
|
||||
data_dir: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
/// Create new application state
|
||||
pub fn new(
|
||||
config: ConfigStore,
|
||||
sessions: SessionStore,
|
||||
users: UserStore,
|
||||
otg_service: Arc<OtgService>,
|
||||
stream_manager: Arc<VideoStreamManager>,
|
||||
hid: Arc<HidController>,
|
||||
msd: Option<MsdController>,
|
||||
atx: Option<AtxController>,
|
||||
audio: Arc<AudioController>,
|
||||
extensions: Arc<ExtensionManager>,
|
||||
events: Arc<EventBus>,
|
||||
shutdown_tx: broadcast::Sender<()>,
|
||||
data_dir: std::path::PathBuf,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
config,
|
||||
sessions,
|
||||
users,
|
||||
otg_service,
|
||||
stream_manager,
|
||||
hid,
|
||||
msd: Arc::new(RwLock::new(msd)),
|
||||
atx: Arc::new(RwLock::new(atx)),
|
||||
audio,
|
||||
extensions,
|
||||
events,
|
||||
shutdown_tx,
|
||||
data_dir,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get data directory path
|
||||
pub fn data_dir(&self) -> &std::path::PathBuf {
|
||||
&self.data_dir
|
||||
}
|
||||
|
||||
/// Subscribe to shutdown signal
|
||||
pub fn shutdown_signal(&self) -> broadcast::Receiver<()> {
|
||||
self.shutdown_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get complete device information for WebSocket clients
|
||||
///
|
||||
/// This method collects the current state of all devices (video, HID, MSD, ATX, Audio)
|
||||
/// and returns a DeviceInfo event that can be sent to clients.
|
||||
/// Uses tokio::join! to collect all device info in parallel for better performance.
|
||||
pub async fn get_device_info(&self) -> SystemEvent {
|
||||
// Collect all device info in parallel
|
||||
let (video, hid, msd, atx, audio) = tokio::join!(
|
||||
self.collect_video_info(),
|
||||
self.collect_hid_info(),
|
||||
self.collect_msd_info(),
|
||||
self.collect_atx_info(),
|
||||
self.collect_audio_info(),
|
||||
);
|
||||
|
||||
SystemEvent::DeviceInfo {
|
||||
video,
|
||||
hid,
|
||||
msd,
|
||||
atx,
|
||||
audio,
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish DeviceInfo event to all connected WebSocket clients
|
||||
pub async fn publish_device_info(&self) {
|
||||
let device_info = self.get_device_info().await;
|
||||
self.events.publish(device_info);
|
||||
}
|
||||
|
||||
/// Collect video device information
|
||||
async fn collect_video_info(&self) -> VideoDeviceInfo {
|
||||
// Use stream_manager to get video info (includes stream_mode)
|
||||
self.stream_manager.get_video_info().await
|
||||
}
|
||||
|
||||
/// Collect HID device information
|
||||
async fn collect_hid_info(&self) -> HidDeviceInfo {
|
||||
let info = self.hid.info().await;
|
||||
let backend_type = self.hid.backend_type().await;
|
||||
|
||||
match info {
|
||||
Some(hid_info) => HidDeviceInfo {
|
||||
available: true,
|
||||
backend: hid_info.name.to_string(),
|
||||
initialized: hid_info.initialized,
|
||||
supports_absolute_mouse: hid_info.supports_absolute_mouse,
|
||||
device: match backend_type {
|
||||
crate::hid::HidBackendType::Ch9329 { ref port, .. } => Some(port.clone()),
|
||||
_ => None,
|
||||
},
|
||||
error: None,
|
||||
},
|
||||
None => HidDeviceInfo {
|
||||
available: false,
|
||||
backend: backend_type.name_str().to_string(),
|
||||
initialized: false,
|
||||
supports_absolute_mouse: false,
|
||||
device: match backend_type {
|
||||
crate::hid::HidBackendType::Ch9329 { ref port, .. } => Some(port.clone()),
|
||||
_ => None,
|
||||
},
|
||||
error: Some("HID backend not available".to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect MSD device information (optional)
|
||||
async fn collect_msd_info(&self) -> Option<MsdDeviceInfo> {
|
||||
let msd_guard = self.msd.read().await;
|
||||
let msd = msd_guard.as_ref()?;
|
||||
|
||||
let state = msd.state().await;
|
||||
Some(MsdDeviceInfo {
|
||||
available: state.available,
|
||||
mode: match state.mode {
|
||||
crate::msd::MsdMode::None => "none",
|
||||
crate::msd::MsdMode::Image => "image",
|
||||
crate::msd::MsdMode::Drive => "drive",
|
||||
}
|
||||
.to_string(),
|
||||
connected: state.connected,
|
||||
image_id: state.current_image.map(|img| img.id),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Collect ATX device information (optional)
|
||||
async fn collect_atx_info(&self) -> Option<AtxDeviceInfo> {
|
||||
// Predefined backend strings to avoid repeated allocations
|
||||
const BACKEND_POWER_ONLY: &str = "power: configured, reset: none";
|
||||
const BACKEND_RESET_ONLY: &str = "power: none, reset: configured";
|
||||
const BACKEND_BOTH: &str = "power: configured, reset: configured";
|
||||
const BACKEND_NONE: &str = "none";
|
||||
|
||||
let atx_guard = self.atx.read().await;
|
||||
let atx = atx_guard.as_ref()?;
|
||||
|
||||
let state = atx.state().await;
|
||||
Some(AtxDeviceInfo {
|
||||
available: state.available,
|
||||
backend: match (state.power_configured, state.reset_configured) {
|
||||
(true, true) => BACKEND_BOTH,
|
||||
(true, false) => BACKEND_POWER_ONLY,
|
||||
(false, true) => BACKEND_RESET_ONLY,
|
||||
(false, false) => BACKEND_NONE,
|
||||
}
|
||||
.to_string(),
|
||||
initialized: state.power_configured || state.reset_configured,
|
||||
power_on: state.power_status == crate::atx::PowerStatus::On,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Collect Audio device information (optional)
|
||||
async fn collect_audio_info(&self) -> Option<AudioDeviceInfo> {
|
||||
let status = self.audio.status().await;
|
||||
|
||||
Some(AudioDeviceInfo {
|
||||
available: status.enabled,
|
||||
streaming: status.streaming,
|
||||
device: status.device,
|
||||
quality: status.quality.to_string(),
|
||||
error: status.error,
|
||||
})
|
||||
}
|
||||
}
|
||||
564
src/stream/mjpeg.rs
Normal file
564
src/stream/mjpeg.rs
Normal file
@@ -0,0 +1,564 @@
|
||||
//! MJPEG stream handler
|
||||
//!
|
||||
//! Manages video frame distribution and per-client statistics.
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use parking_lot::Mutex as ParkingMutex;
|
||||
use parking_lot::RwLock as ParkingRwLock;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::video::encoder::JpegEncoder;
|
||||
use crate::video::encoder::traits::{Encoder, EncoderConfig};
|
||||
use crate::video::format::PixelFormat;
|
||||
use crate::video::VideoFrame;
|
||||
|
||||
/// Client ID type (UUID string)
|
||||
pub type ClientId = String;
|
||||
|
||||
/// Per-client session information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientSession {
|
||||
/// Unique client ID
|
||||
pub id: ClientId,
|
||||
/// Connection timestamp
|
||||
pub connected_at: Instant,
|
||||
/// Last activity timestamp (frame sent)
|
||||
pub last_activity: Instant,
|
||||
/// Frames sent to this client
|
||||
pub frames_sent: u64,
|
||||
/// FPS calculator (1-second rolling window)
|
||||
pub fps_calculator: FpsCalculator,
|
||||
}
|
||||
|
||||
impl ClientSession {
|
||||
/// Create a new client session
|
||||
pub fn new(id: ClientId) -> Self {
|
||||
let now = Instant::now();
|
||||
Self {
|
||||
id,
|
||||
connected_at: now,
|
||||
last_activity: now,
|
||||
frames_sent: 0,
|
||||
fps_calculator: FpsCalculator::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get connection duration
|
||||
pub fn connected_duration(&self) -> Duration {
|
||||
self.last_activity.duration_since(self.connected_at)
|
||||
}
|
||||
|
||||
/// Get idle duration
|
||||
pub fn idle_duration(&self) -> Duration {
|
||||
Instant::now().duration_since(self.last_activity)
|
||||
}
|
||||
}
|
||||
|
||||
/// Rolling window FPS calculator
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FpsCalculator {
|
||||
/// Frame timestamps in last window
|
||||
frame_times: VecDeque<Instant>,
|
||||
/// Window duration (default 1 second)
|
||||
window: Duration,
|
||||
/// Cached count of frames in current window (optimization to avoid O(n) filtering)
|
||||
count_in_window: usize,
|
||||
}
|
||||
|
||||
impl FpsCalculator {
|
||||
/// Create a new FPS calculator with 1-second window
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
frame_times: VecDeque::with_capacity(120), // Max 120fps tracking
|
||||
window: Duration::from_secs(1),
|
||||
count_in_window: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a frame sent
|
||||
pub fn record_frame(&mut self) {
|
||||
let now = Instant::now();
|
||||
self.frame_times.push_back(now);
|
||||
|
||||
// Remove frames outside window and maintain count
|
||||
let cutoff = now - self.window;
|
||||
while let Some(&oldest) = self.frame_times.front() {
|
||||
if oldest < cutoff {
|
||||
self.frame_times.pop_front();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Update cached count
|
||||
self.count_in_window = self.frame_times.len();
|
||||
}
|
||||
|
||||
/// Calculate current FPS (frames in last 1 second window)
|
||||
pub fn current_fps(&self) -> u32 {
|
||||
// Return cached count instead of filtering entire deque (O(1) instead of O(n))
|
||||
self.count_in_window as u32
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FpsCalculator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Auto-pause configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoPauseConfig {
|
||||
/// Enable auto-pause when no clients
|
||||
pub enabled: bool,
|
||||
/// Delay before pausing (default 10s)
|
||||
pub shutdown_delay_secs: u64,
|
||||
/// Client timeout for cleanup (default 30s)
|
||||
pub client_timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for AutoPauseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
shutdown_delay_secs: 10,
|
||||
client_timeout_secs: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MJPEG stream handler
|
||||
/// Manages video frame distribution to HTTP clients
|
||||
pub struct MjpegStreamHandler {
|
||||
/// Current frame (latest) - using ArcSwap for lock-free reads
|
||||
current_frame: ArcSwap<Option<VideoFrame>>,
|
||||
/// Frame update notification
|
||||
frame_notify: broadcast::Sender<()>,
|
||||
/// Whether stream is online
|
||||
online: AtomicBool,
|
||||
/// Frame sequence counter
|
||||
sequence: AtomicU64,
|
||||
/// Per-client sessions (ClientId -> ClientSession)
|
||||
/// Use parking_lot::RwLock for better performance
|
||||
clients: ParkingRwLock<HashMap<ClientId, ClientSession>>,
|
||||
/// Auto-pause configuration
|
||||
auto_pause_config: ParkingRwLock<AutoPauseConfig>,
|
||||
/// Last frame timestamp
|
||||
last_frame_ts: ParkingRwLock<Option<Instant>>,
|
||||
/// Dropped same frames count
|
||||
dropped_same_frames: AtomicU64,
|
||||
/// Maximum consecutive same frames to drop (0 = disabled)
|
||||
max_drop_same_frames: AtomicU64,
|
||||
/// JPEG encoder for non-JPEG input formats
|
||||
jpeg_encoder: ParkingMutex<Option<JpegEncoder>>,
|
||||
}
|
||||
|
||||
impl MjpegStreamHandler {
|
||||
/// Create a new MJPEG stream handler
|
||||
pub fn new() -> Self {
|
||||
Self::with_drop_limit(100) // Default: drop up to 100 same frames
|
||||
}
|
||||
|
||||
/// Create handler with custom drop limit
|
||||
pub fn with_drop_limit(max_drop: u64) -> Self {
|
||||
let (frame_notify, _) = broadcast::channel(4); // Reduced from 16 for lower latency
|
||||
Self {
|
||||
current_frame: ArcSwap::from_pointee(None),
|
||||
frame_notify,
|
||||
online: AtomicBool::new(false),
|
||||
sequence: AtomicU64::new(0),
|
||||
clients: ParkingRwLock::new(HashMap::new()),
|
||||
jpeg_encoder: ParkingMutex::new(None),
|
||||
auto_pause_config: ParkingRwLock::new(AutoPauseConfig::default()),
|
||||
last_frame_ts: ParkingRwLock::new(None),
|
||||
dropped_same_frames: AtomicU64::new(0),
|
||||
max_drop_same_frames: AtomicU64::new(max_drop),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update current frame
|
||||
pub fn update_frame(&self, frame: VideoFrame) {
|
||||
// If frame is not JPEG, encode it
|
||||
let frame = if !frame.format.is_compressed() {
|
||||
match self.encode_to_jpeg(&frame) {
|
||||
Ok(jpeg_frame) => jpeg_frame,
|
||||
Err(e) => {
|
||||
warn!("Failed to encode frame to JPEG: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
frame
|
||||
};
|
||||
|
||||
// Frame deduplication (ustreamer-style)
|
||||
// Check if this frame is identical to the previous one
|
||||
let max_drop = self.max_drop_same_frames.load(Ordering::Relaxed);
|
||||
if max_drop > 0 && frame.online {
|
||||
let current = self.current_frame.load();
|
||||
if let Some(ref prev_frame) = **current {
|
||||
let dropped_count = self.dropped_same_frames.load(Ordering::Relaxed);
|
||||
|
||||
// Check if we should drop this frame
|
||||
if dropped_count < max_drop && frames_are_identical(prev_frame, &frame) {
|
||||
// Check last frame timestamp to ensure minimum 1fps
|
||||
let last_ts = *self.last_frame_ts.read();
|
||||
let should_force_send = if let Some(ts) = last_ts {
|
||||
ts.elapsed() >= Duration::from_secs(1)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if !should_force_send {
|
||||
// Drop this duplicate frame
|
||||
self.dropped_same_frames.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
// If more than 1 second since last frame, force send even if identical
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Frame is different or limit reached or forced by 1fps guarantee, update
|
||||
self.dropped_same_frames.store(0, Ordering::Relaxed);
|
||||
|
||||
self.sequence.fetch_add(1, Ordering::Relaxed);
|
||||
self.online.store(true, Ordering::SeqCst);
|
||||
*self.last_frame_ts.write() = Some(Instant::now());
|
||||
self.current_frame.store(Arc::new(Some(frame)));
|
||||
|
||||
// Notify waiting clients
|
||||
let _ = self.frame_notify.send(());
|
||||
}
|
||||
|
||||
/// Encode a non-JPEG frame to JPEG
|
||||
fn encode_to_jpeg(&self, frame: &VideoFrame) -> Result<VideoFrame, String> {
|
||||
let resolution = frame.resolution;
|
||||
let sequence = self.sequence.load(Ordering::Relaxed);
|
||||
|
||||
// Get or create encoder
|
||||
let mut encoder_guard = self.jpeg_encoder.lock();
|
||||
let encoder = encoder_guard.get_or_insert_with(|| {
|
||||
let config = EncoderConfig::jpeg(resolution, 85);
|
||||
match JpegEncoder::new(config) {
|
||||
Ok(enc) => {
|
||||
debug!("Created JPEG encoder for MJPEG stream: {}x{}", resolution.width, resolution.height);
|
||||
enc
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create JPEG encoder: {}, using default", e);
|
||||
// Try with default config
|
||||
JpegEncoder::new(EncoderConfig::jpeg(resolution, 85))
|
||||
.expect("Failed to create default JPEG encoder")
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Check if resolution changed
|
||||
if encoder.config().resolution != resolution {
|
||||
debug!("Resolution changed, recreating JPEG encoder: {}x{}", resolution.width, resolution.height);
|
||||
let config = EncoderConfig::jpeg(resolution, 85);
|
||||
*encoder = JpegEncoder::new(config).map_err(|e| format!("Failed to create encoder: {}", e))?;
|
||||
}
|
||||
|
||||
// Encode based on input format
|
||||
let encoded = match frame.format {
|
||||
PixelFormat::Yuyv => {
|
||||
encoder.encode_yuyv(frame.data(), sequence)
|
||||
.map_err(|e| format!("YUYV encode failed: {}", e))?
|
||||
}
|
||||
PixelFormat::Nv12 => {
|
||||
encoder.encode_nv12(frame.data(), sequence)
|
||||
.map_err(|e| format!("NV12 encode failed: {}", e))?
|
||||
}
|
||||
PixelFormat::Rgb24 => {
|
||||
encoder.encode_rgb(frame.data(), sequence)
|
||||
.map_err(|e| format!("RGB encode failed: {}", e))?
|
||||
}
|
||||
PixelFormat::Bgr24 => {
|
||||
encoder.encode_bgr(frame.data(), sequence)
|
||||
.map_err(|e| format!("BGR encode failed: {}", e))?
|
||||
}
|
||||
_ => {
|
||||
return Err(format!("Unsupported format for JPEG encoding: {}", frame.format));
|
||||
}
|
||||
};
|
||||
|
||||
// Create new VideoFrame with JPEG data
|
||||
Ok(VideoFrame::from_vec(
|
||||
encoded.data.to_vec(),
|
||||
resolution,
|
||||
PixelFormat::Mjpeg,
|
||||
0, // stride not relevant for JPEG
|
||||
sequence,
|
||||
))
|
||||
}
|
||||
|
||||
/// Set stream offline
|
||||
pub fn set_offline(&self) {
|
||||
self.online.store(false, Ordering::SeqCst);
|
||||
let _ = self.frame_notify.send(());
|
||||
}
|
||||
|
||||
/// Set stream online (called when streaming starts)
|
||||
pub fn set_online(&self) {
|
||||
self.online.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if stream is online
|
||||
pub fn is_online(&self) -> bool {
|
||||
self.online.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Get current client count
|
||||
pub fn client_count(&self) -> u64 {
|
||||
self.clients.read().len() as u64
|
||||
}
|
||||
|
||||
/// Register a new client
|
||||
pub fn register_client(&self, client_id: ClientId) {
|
||||
let session = ClientSession::new(client_id.clone());
|
||||
self.clients.write().insert(client_id.clone(), session);
|
||||
info!("Client {} connected (total: {})", client_id, self.client_count());
|
||||
}
|
||||
|
||||
/// Unregister a client
|
||||
pub fn unregister_client(&self, client_id: &str) {
|
||||
if let Some(session) = self.clients.write().remove(client_id) {
|
||||
let duration = session.connected_duration();
|
||||
let duration_secs = duration.as_secs_f32();
|
||||
let avg_fps = if duration_secs > 0.1 {
|
||||
session.frames_sent as f32 / duration_secs
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
info!(
|
||||
"Client {} disconnected after {:.1}s ({} frames, {:.1} avg FPS)",
|
||||
client_id, duration_secs, session.frames_sent, avg_fps
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record frame sent to a specific client
|
||||
pub fn record_frame_sent(&self, client_id: &str) {
|
||||
if let Some(session) = self.clients.write().get_mut(client_id) {
|
||||
session.last_activity = Instant::now();
|
||||
session.frames_sent += 1;
|
||||
session.fps_calculator.record_frame();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get per-client statistics
|
||||
pub fn get_clients_stat(&self) -> HashMap<String, crate::events::types::ClientStats> {
|
||||
self.clients
|
||||
.read()
|
||||
.iter()
|
||||
.map(|(id, session)| {
|
||||
(
|
||||
id.clone(),
|
||||
crate::events::types::ClientStats {
|
||||
id: id.clone(),
|
||||
fps: session.fps_calculator.current_fps(),
|
||||
connected_secs: session.connected_duration().as_secs(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get auto-pause configuration
|
||||
pub fn auto_pause_config(&self) -> AutoPauseConfig {
|
||||
self.auto_pause_config.read().clone()
|
||||
}
|
||||
|
||||
/// Update auto-pause configuration
|
||||
pub fn set_auto_pause_config(&self, config: AutoPauseConfig) {
|
||||
let config_clone = config.clone();
|
||||
*self.auto_pause_config.write() = config;
|
||||
info!(
|
||||
"Auto-pause config updated: enabled={}, delay={}s, timeout={}s",
|
||||
config_clone.enabled, config_clone.shutdown_delay_secs, config_clone.client_timeout_secs
|
||||
);
|
||||
}
|
||||
|
||||
/// Get current frame (if any)
|
||||
pub fn current_frame(&self) -> Option<VideoFrame> {
|
||||
(**self.current_frame.load()).clone()
|
||||
}
|
||||
|
||||
/// Subscribe to frame updates
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<()> {
|
||||
self.frame_notify.subscribe()
|
||||
}
|
||||
|
||||
/// Disconnect all clients (used during config changes)
|
||||
/// This clears the client list and sets the stream offline,
|
||||
/// which will cause all active MJPEG streams to terminate.
|
||||
pub fn disconnect_all_clients(&self) {
|
||||
let count = {
|
||||
let mut clients = self.clients.write();
|
||||
let count = clients.len();
|
||||
clients.clear();
|
||||
count
|
||||
};
|
||||
if count > 0 {
|
||||
info!("Disconnected all {} MJPEG clients for config change", count);
|
||||
}
|
||||
// Set offline to signal all streaming tasks to stop
|
||||
self.set_offline();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MjpegStreamHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard for client lifecycle management
|
||||
/// Ensures cleanup even on panic or abrupt disconnection
|
||||
pub struct ClientGuard {
|
||||
client_id: ClientId,
|
||||
handler: Arc<MjpegStreamHandler>,
|
||||
}
|
||||
|
||||
impl ClientGuard {
|
||||
/// Create a new client guard
|
||||
pub fn new(client_id: ClientId, handler: Arc<MjpegStreamHandler>) -> Self {
|
||||
handler.register_client(client_id.clone());
|
||||
Self {
|
||||
client_id,
|
||||
handler,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get client ID
|
||||
pub fn id(&self) -> &ClientId {
|
||||
&self.client_id
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ClientGuard {
|
||||
fn drop(&mut self) {
|
||||
self.handler.unregister_client(&self.client_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl MjpegStreamHandler {
|
||||
/// Start stale client cleanup task
|
||||
/// Should be called once when handler is created
|
||||
pub fn start_cleanup_task(self: Arc<Self>) {
|
||||
let handler = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(5));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let timeout_secs = handler.auto_pause_config().client_timeout_secs;
|
||||
let timeout = Duration::from_secs(timeout_secs);
|
||||
let now = Instant::now();
|
||||
let mut stale = Vec::new();
|
||||
|
||||
// Find stale clients
|
||||
{
|
||||
let clients = handler.clients.read();
|
||||
for (id, session) in clients.iter() {
|
||||
if now.duration_since(session.last_activity) > timeout {
|
||||
stale.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove stale clients
|
||||
if !stale.is_empty() {
|
||||
let mut clients = handler.clients.write();
|
||||
for id in stale {
|
||||
if let Some(session) = clients.remove(&id) {
|
||||
warn!(
|
||||
"Removed stale client {} (inactive for {:.1}s)",
|
||||
id,
|
||||
now.duration_since(session.last_activity).as_secs_f32()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare two frames for equality (hash-based, ustreamer-style)
|
||||
/// Returns true if frames are identical in geometry and content
|
||||
fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
|
||||
// Quick checks first (geometry)
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if a.resolution.width != b.resolution.width || a.resolution.height != b.resolution.height {
|
||||
return false;
|
||||
}
|
||||
|
||||
if a.format != b.format {
|
||||
return false;
|
||||
}
|
||||
|
||||
if a.stride != b.stride {
|
||||
return false;
|
||||
}
|
||||
|
||||
if a.online != b.online {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare hashes instead of full binary data
|
||||
// Hash is computed once and cached in OnceLock for efficiency
|
||||
// This is much faster than binary comparison for large frames (1080p MJPEG)
|
||||
a.get_hash() == b.get_hash()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use bytes::Bytes;
|
||||
use crate::video::{format::Resolution, PixelFormat};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_handler() {
|
||||
let handler = MjpegStreamHandler::new();
|
||||
assert!(!handler.is_online());
|
||||
assert_eq!(handler.client_count(), 0);
|
||||
|
||||
// Create a frame
|
||||
let _frame = VideoFrame::new(
|
||||
Bytes::from(vec![0xFF, 0xD8, 0x00, 0x00, 0xFF, 0xD9]),
|
||||
Resolution::VGA,
|
||||
PixelFormat::Mjpeg,
|
||||
0,
|
||||
1,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fps_calculator() {
|
||||
let mut calc = FpsCalculator::new();
|
||||
|
||||
// Initially empty
|
||||
assert_eq!(calc.current_fps(), 0);
|
||||
|
||||
// Record some frames
|
||||
calc.record_frame();
|
||||
calc.record_frame();
|
||||
calc.record_frame();
|
||||
|
||||
// Should have 3 frames in window
|
||||
assert!(calc.frame_times.len() == 3);
|
||||
}
|
||||
}
|
||||
487
src/stream/mjpeg_streamer.rs
Normal file
487
src/stream/mjpeg_streamer.rs
Normal file
@@ -0,0 +1,487 @@
|
||||
//! MJPEG Streamer - High-level MJPEG/HTTP streaming manager
|
||||
//!
|
||||
//! This module provides a unified interface for MJPEG streaming mode,
|
||||
//! integrating video capture, MJPEG distribution, and WebSocket HID.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! MjpegStreamer
|
||||
//! |
|
||||
//! +-- VideoCapturer (V4L2 video capture)
|
||||
//! +-- MjpegStreamHandler (HTTP multipart video)
|
||||
//! +-- WsHidHandler (WebSocket HID)
|
||||
//! ```
|
||||
//!
|
||||
//! Note: Audio WebSocket is handled separately by audio_ws.rs (/api/ws/audio)
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::info;
|
||||
|
||||
use crate::audio::AudioController;
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
use crate::hid::HidController;
|
||||
use crate::video::capture::{CaptureConfig, VideoCapturer};
|
||||
use crate::video::device::{enumerate_devices, find_best_device, VideoDeviceInfo};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::video::frame::VideoFrame;
|
||||
|
||||
use super::mjpeg::MjpegStreamHandler;
|
||||
use super::ws_hid::WsHidHandler;
|
||||
|
||||
/// MJPEG streamer configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MjpegStreamerConfig {
|
||||
/// Device path (None = auto-detect)
|
||||
pub device_path: Option<PathBuf>,
|
||||
/// Desired resolution
|
||||
pub resolution: Resolution,
|
||||
/// Desired format
|
||||
pub format: PixelFormat,
|
||||
/// Desired FPS
|
||||
pub fps: u32,
|
||||
/// JPEG quality (1-100)
|
||||
pub jpeg_quality: u8,
|
||||
}
|
||||
|
||||
impl Default for MjpegStreamerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_path: None,
|
||||
resolution: Resolution::HD1080,
|
||||
format: PixelFormat::Mjpeg,
|
||||
fps: 30,
|
||||
jpeg_quality: 80,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MJPEG streamer state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MjpegStreamerState {
|
||||
/// Not initialized
|
||||
Uninitialized,
|
||||
/// Ready but not streaming
|
||||
Ready,
|
||||
/// Actively streaming
|
||||
Streaming,
|
||||
/// No video signal
|
||||
NoSignal,
|
||||
/// Error occurred
|
||||
Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MjpegStreamerState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MjpegStreamerState::Uninitialized => write!(f, "uninitialized"),
|
||||
MjpegStreamerState::Ready => write!(f, "ready"),
|
||||
MjpegStreamerState::Streaming => write!(f, "streaming"),
|
||||
MjpegStreamerState::NoSignal => write!(f, "no_signal"),
|
||||
MjpegStreamerState::Error => write!(f, "error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MJPEG streamer statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MjpegStreamerStats {
|
||||
/// Current state
|
||||
pub state: String,
|
||||
/// Current device path
|
||||
pub device: Option<String>,
|
||||
/// Video resolution
|
||||
pub resolution: Option<(u32, u32)>,
|
||||
/// Video format
|
||||
pub format: Option<String>,
|
||||
/// Current FPS
|
||||
pub fps: u32,
|
||||
/// MJPEG client count
|
||||
pub mjpeg_clients: u64,
|
||||
/// WebSocket HID client count
|
||||
pub ws_hid_clients: usize,
|
||||
/// Total frames captured
|
||||
pub frames_captured: u64,
|
||||
}
|
||||
|
||||
/// MJPEG Streamer
|
||||
///
|
||||
/// High-level manager for MJPEG/HTTP streaming mode.
|
||||
/// Integrates video capture, MJPEG distribution, and WebSocket HID.
|
||||
pub struct MjpegStreamer {
|
||||
// === Video ===
|
||||
config: RwLock<MjpegStreamerConfig>,
|
||||
capturer: RwLock<Option<Arc<VideoCapturer>>>,
|
||||
mjpeg_handler: Arc<MjpegStreamHandler>,
|
||||
current_device: RwLock<Option<VideoDeviceInfo>>,
|
||||
state: RwLock<MjpegStreamerState>,
|
||||
|
||||
// === Audio (controller reference only, WS handled by audio_ws.rs) ===
|
||||
audio_controller: RwLock<Option<Arc<AudioController>>>,
|
||||
audio_enabled: AtomicBool,
|
||||
|
||||
// === HID ===
|
||||
ws_hid_handler: Arc<WsHidHandler>,
|
||||
hid_controller: RwLock<Option<Arc<HidController>>>,
|
||||
|
||||
// === Control ===
|
||||
start_lock: tokio::sync::Mutex<()>,
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
config_changing: AtomicBool,
|
||||
}
|
||||
|
||||
impl MjpegStreamer {
|
||||
/// Create a new MJPEG streamer
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
config: RwLock::new(MjpegStreamerConfig::default()),
|
||||
capturer: RwLock::new(None),
|
||||
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
|
||||
current_device: RwLock::new(None),
|
||||
state: RwLock::new(MjpegStreamerState::Uninitialized),
|
||||
audio_controller: RwLock::new(None),
|
||||
audio_enabled: AtomicBool::new(false),
|
||||
ws_hid_handler: WsHidHandler::new(),
|
||||
hid_controller: RwLock::new(None),
|
||||
start_lock: tokio::sync::Mutex::new(()),
|
||||
events: RwLock::new(None),
|
||||
config_changing: AtomicBool::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with specific config
|
||||
pub fn with_config(config: MjpegStreamerConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
config: RwLock::new(config),
|
||||
capturer: RwLock::new(None),
|
||||
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
|
||||
current_device: RwLock::new(None),
|
||||
state: RwLock::new(MjpegStreamerState::Uninitialized),
|
||||
audio_controller: RwLock::new(None),
|
||||
audio_enabled: AtomicBool::new(false),
|
||||
ws_hid_handler: WsHidHandler::new(),
|
||||
hid_controller: RwLock::new(None),
|
||||
start_lock: tokio::sync::Mutex::new(()),
|
||||
events: RwLock::new(None),
|
||||
config_changing: AtomicBool::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Configuration and Setup
|
||||
// ========================================================================
|
||||
|
||||
/// Set event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Set audio controller (for reference, WebSocket handled by audio_ws.rs)
|
||||
pub async fn set_audio_controller(&self, audio: Arc<AudioController>) {
|
||||
*self.audio_controller.write().await = Some(audio);
|
||||
info!("MjpegStreamer: Audio controller set");
|
||||
}
|
||||
|
||||
/// Set HID controller
|
||||
pub async fn set_hid_controller(&self, hid: Arc<HidController>) {
|
||||
*self.hid_controller.write().await = Some(hid.clone());
|
||||
self.ws_hid_handler.set_hid_controller(hid);
|
||||
info!("MjpegStreamer: HID controller set");
|
||||
}
|
||||
|
||||
/// Enable or disable audio
|
||||
pub fn set_audio_enabled(&self, enabled: bool) {
|
||||
self.audio_enabled.store(enabled, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if audio is enabled
|
||||
pub fn is_audio_enabled(&self) -> bool {
|
||||
self.audio_enabled.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// State and Status
|
||||
// ========================================================================
|
||||
|
||||
/// Get current state
|
||||
pub async fn state(&self) -> MjpegStreamerState {
|
||||
*self.state.read().await
|
||||
}
|
||||
|
||||
/// Check if config is currently being changed
|
||||
pub fn is_config_changing(&self) -> bool {
|
||||
self.config_changing.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Get current device info
|
||||
pub async fn current_device(&self) -> Option<VideoDeviceInfo> {
|
||||
self.current_device.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get statistics
|
||||
pub async fn stats(&self) -> MjpegStreamerStats {
|
||||
let state = *self.state.read().await;
|
||||
let device = self.current_device.read().await;
|
||||
let config = self.config.read().await;
|
||||
|
||||
let (resolution, format, frames_captured) = if let Some(ref cap) = *self.capturer.read().await {
|
||||
let stats = cap.stats().await;
|
||||
(
|
||||
Some((config.resolution.width, config.resolution.height)),
|
||||
Some(config.format.to_string()),
|
||||
stats.frames_captured,
|
||||
)
|
||||
} else {
|
||||
(None, None, 0)
|
||||
};
|
||||
|
||||
MjpegStreamerStats {
|
||||
state: state.to_string(),
|
||||
device: device.as_ref().map(|d| d.path.display().to_string()),
|
||||
resolution,
|
||||
format,
|
||||
fps: config.fps,
|
||||
mjpeg_clients: self.mjpeg_handler.client_count(),
|
||||
ws_hid_clients: self.ws_hid_handler.client_count(),
|
||||
frames_captured,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Handler Access
|
||||
// ========================================================================
|
||||
|
||||
/// Get MJPEG handler for HTTP streaming
|
||||
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
|
||||
self.mjpeg_handler.clone()
|
||||
}
|
||||
|
||||
/// Get WebSocket HID handler
|
||||
pub fn ws_hid_handler(&self) -> Arc<WsHidHandler> {
|
||||
self.ws_hid_handler.clone()
|
||||
}
|
||||
|
||||
/// Get frame sender for WebRTC integration
|
||||
pub async fn frame_sender(&self) -> Option<broadcast::Sender<VideoFrame>> {
|
||||
if let Some(ref cap) = *self.capturer.read().await {
|
||||
Some(cap.frame_sender())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Initialization
|
||||
// ========================================================================
|
||||
|
||||
/// Initialize with auto-detected device
|
||||
pub async fn init_auto(self: &Arc<Self>) -> Result<()> {
|
||||
let best = find_best_device()?;
|
||||
self.init_with_device(best).await
|
||||
}
|
||||
|
||||
/// Initialize with specific device
|
||||
pub async fn init_with_device(self: &Arc<Self>, device: VideoDeviceInfo) -> Result<()> {
|
||||
info!("MjpegStreamer: Initializing with device: {}", device.path.display());
|
||||
|
||||
let config = self.config.read().await.clone();
|
||||
|
||||
// Create capture config
|
||||
let capture_config = CaptureConfig {
|
||||
device_path: device.path.clone(),
|
||||
resolution: config.resolution,
|
||||
format: config.format,
|
||||
fps: config.fps,
|
||||
buffer_count: 4,
|
||||
timeout: std::time::Duration::from_secs(5),
|
||||
jpeg_quality: config.jpeg_quality,
|
||||
};
|
||||
|
||||
// Create capturer
|
||||
let capturer = Arc::new(VideoCapturer::new(capture_config));
|
||||
|
||||
// Store device and capturer
|
||||
*self.current_device.write().await = Some(device);
|
||||
*self.capturer.write().await = Some(capturer);
|
||||
*self.state.write().await = MjpegStreamerState::Ready;
|
||||
|
||||
self.publish_state_change().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Streaming Control
|
||||
// ========================================================================
|
||||
|
||||
/// Start streaming
|
||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||
let _lock = self.start_lock.lock().await;
|
||||
|
||||
if self.config_changing.load(Ordering::SeqCst) {
|
||||
return Err(AppError::VideoError("Config change in progress".to_string()));
|
||||
}
|
||||
|
||||
let state = *self.state.read().await;
|
||||
if state == MjpegStreamerState::Streaming {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get capturer
|
||||
let capturer = self.capturer.read().await.clone();
|
||||
let capturer = capturer.ok_or_else(|| AppError::VideoError("Not initialized".to_string()))?;
|
||||
|
||||
// Start capture
|
||||
capturer.start().await?;
|
||||
|
||||
// Start frame forwarding task
|
||||
let handler = self.mjpeg_handler.clone();
|
||||
let mut frame_rx = capturer.frame_sender().subscribe();
|
||||
tokio::spawn(async move {
|
||||
while let Ok(frame) = frame_rx.recv().await {
|
||||
handler.update_frame(frame);
|
||||
}
|
||||
});
|
||||
|
||||
// Note: Audio WebSocket is handled separately by audio_ws.rs (/api/ws/audio)
|
||||
|
||||
*self.state.write().await = MjpegStreamerState::Streaming;
|
||||
self.mjpeg_handler.set_online();
|
||||
|
||||
self.publish_state_change().await;
|
||||
info!("MjpegStreamer: Streaming started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop streaming
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
let state = *self.state.read().await;
|
||||
if state != MjpegStreamerState::Streaming {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Stop capturer
|
||||
if let Some(ref cap) = *self.capturer.read().await {
|
||||
let _ = cap.stop().await;
|
||||
}
|
||||
|
||||
// Set offline
|
||||
self.mjpeg_handler.set_offline();
|
||||
*self.state.write().await = MjpegStreamerState::Ready;
|
||||
|
||||
self.publish_state_change().await;
|
||||
info!("MjpegStreamer: Streaming stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if streaming
|
||||
pub async fn is_streaming(&self) -> bool {
|
||||
*self.state.read().await == MjpegStreamerState::Streaming
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Configuration Updates
|
||||
// ========================================================================
|
||||
|
||||
/// Apply video configuration
|
||||
///
|
||||
/// This stops the current stream, reconfigures the capturer, and restarts.
|
||||
pub async fn apply_config(self: &Arc<Self>, config: MjpegStreamerConfig) -> Result<()> {
|
||||
info!("MjpegStreamer: Applying config: {:?}", config);
|
||||
|
||||
self.config_changing.store(true, Ordering::SeqCst);
|
||||
|
||||
// Stop current stream
|
||||
self.stop().await?;
|
||||
|
||||
// Disconnect all MJPEG clients
|
||||
self.mjpeg_handler.disconnect_all_clients();
|
||||
|
||||
// Release capturer
|
||||
*self.capturer.write().await = None;
|
||||
|
||||
// Update config
|
||||
*self.config.write().await = config.clone();
|
||||
|
||||
// Re-initialize if device path is set
|
||||
if let Some(ref path) = config.device_path {
|
||||
let devices = enumerate_devices()?;
|
||||
let device = devices
|
||||
.into_iter()
|
||||
.find(|d| d.path == *path)
|
||||
.ok_or_else(|| AppError::VideoError(format!("Device not found: {}", path.display())))?;
|
||||
|
||||
self.init_with_device(device).await?;
|
||||
}
|
||||
|
||||
self.config_changing.store(false, Ordering::SeqCst);
|
||||
self.publish_state_change().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Internal
|
||||
// ========================================================================
|
||||
|
||||
/// Publish state change event
|
||||
async fn publish_state_change(&self) {
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
let state = *self.state.read().await;
|
||||
let device = self.current_device.read().await;
|
||||
|
||||
events.publish(SystemEvent::StreamStateChanged {
|
||||
state: state.to_string(),
|
||||
device: device.as_ref().map(|d| d.path.display().to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MjpegStreamer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config: RwLock::new(MjpegStreamerConfig::default()),
|
||||
capturer: RwLock::new(None),
|
||||
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
|
||||
current_device: RwLock::new(None),
|
||||
state: RwLock::new(MjpegStreamerState::Uninitialized),
|
||||
audio_controller: RwLock::new(None),
|
||||
audio_enabled: AtomicBool::new(false),
|
||||
ws_hid_handler: WsHidHandler::new(),
|
||||
hid_controller: RwLock::new(None),
|
||||
start_lock: tokio::sync::Mutex::new(()),
|
||||
events: RwLock::new(None),
|
||||
config_changing: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mjpeg_streamer_creation() {
|
||||
let streamer = MjpegStreamer::new();
|
||||
assert!(!streamer.is_config_changing());
|
||||
assert!(!streamer.is_audio_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mjpeg_streamer_config_default() {
|
||||
let config = MjpegStreamerConfig::default();
|
||||
assert_eq!(config.resolution, Resolution::HD1080);
|
||||
assert_eq!(config.format, PixelFormat::Mjpeg);
|
||||
assert_eq!(config.fps, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mjpeg_streamer_state_display() {
|
||||
assert_eq!(MjpegStreamerState::Streaming.to_string(), "streaming");
|
||||
assert_eq!(MjpegStreamerState::Ready.to_string(), "ready");
|
||||
}
|
||||
}
|
||||
17
src/stream/mod.rs
Normal file
17
src/stream/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
//! Video streaming module
|
||||
//!
|
||||
//! Provides MJPEG streaming and WebSocket handlers for MJPEG mode.
|
||||
//!
|
||||
//! # Components
|
||||
//!
|
||||
//! - `MjpegStreamer` - High-level MJPEG streaming manager
|
||||
//! - `MjpegStreamHandler` - HTTP multipart MJPEG video streaming
|
||||
//! - `WsHidHandler` - WebSocket HID input handler
|
||||
|
||||
pub mod mjpeg;
|
||||
pub mod mjpeg_streamer;
|
||||
pub mod ws_hid;
|
||||
|
||||
pub use mjpeg::{ClientGuard, MjpegStreamHandler};
|
||||
pub use mjpeg_streamer::{MjpegStreamer, MjpegStreamerConfig, MjpegStreamerState, MjpegStreamerStats};
|
||||
pub use ws_hid::WsHidHandler;
|
||||
280
src/stream/ws_hid.rs
Normal file
280
src/stream/ws_hid.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
//! WebSocket HID Handler for MJPEG mode
|
||||
//!
|
||||
//! This module provides a standalone WebSocket HID handler that can be used
|
||||
//! independently of the application state. It manages multiple WebSocket
|
||||
//! connections and forwards HID events to the HID controller.
|
||||
//!
|
||||
//! # Protocol
|
||||
//!
|
||||
//! Only binary protocol is supported for optimal performance.
|
||||
//! See `crate::hid::datachannel` for message format details.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! WsHidHandler
|
||||
//! |
|
||||
//! +-- clients: HashMap<ClientId, WsHidClient>
|
||||
//! +-- hid_controller: Arc<HidController>
|
||||
//! |
|
||||
//! +-- add_client() -> spawns client handler task
|
||||
//! +-- remove_client()
|
||||
//! ```
|
||||
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::hid::datachannel::{parse_hid_message, HidChannelEvent};
|
||||
use crate::hid::HidController;
|
||||
|
||||
/// Client ID type
|
||||
pub type ClientId = String;
|
||||
|
||||
/// WebSocket HID client information
|
||||
#[derive(Debug)]
|
||||
pub struct WsHidClient {
|
||||
/// Client ID
|
||||
pub id: ClientId,
|
||||
/// Connection timestamp
|
||||
pub connected_at: Instant,
|
||||
/// Events processed
|
||||
pub events_processed: AtomicU64,
|
||||
/// Shutdown signal sender
|
||||
shutdown_tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl WsHidClient {
|
||||
/// Get events processed count
|
||||
pub fn events_count(&self) -> u64 {
|
||||
self.events_processed.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get connection duration in seconds
|
||||
pub fn connected_secs(&self) -> u64 {
|
||||
self.connected_at.elapsed().as_secs()
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket HID Handler
|
||||
///
|
||||
/// Manages WebSocket connections for HID input in MJPEG mode.
|
||||
/// Only binary protocol is supported for optimal performance.
|
||||
pub struct WsHidHandler {
|
||||
/// HID controller reference
|
||||
hid_controller: RwLock<Option<Arc<HidController>>>,
|
||||
/// Active clients
|
||||
clients: RwLock<HashMap<ClientId, Arc<WsHidClient>>>,
|
||||
/// Running state
|
||||
running: AtomicBool,
|
||||
/// Total events processed
|
||||
total_events: AtomicU64,
|
||||
}
|
||||
|
||||
impl WsHidHandler {
|
||||
/// Create a new WebSocket HID handler
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
hid_controller: RwLock::new(None),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
running: AtomicBool::new(true),
|
||||
total_events: AtomicU64::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Set HID controller
|
||||
pub fn set_hid_controller(&self, hid: Arc<HidController>) {
|
||||
*self.hid_controller.write() = Some(hid);
|
||||
info!("WsHidHandler: HID controller set");
|
||||
}
|
||||
|
||||
/// Get HID controller
|
||||
pub fn hid_controller(&self) -> Option<Arc<HidController>> {
|
||||
self.hid_controller.read().clone()
|
||||
}
|
||||
|
||||
/// Check if HID controller is available
|
||||
pub fn is_hid_available(&self) -> bool {
|
||||
self.hid_controller.read().is_some()
|
||||
}
|
||||
|
||||
/// Get client count
|
||||
pub fn client_count(&self) -> usize {
|
||||
self.clients.read().len()
|
||||
}
|
||||
|
||||
/// Check if running
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.running.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Stop the handler
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::SeqCst);
|
||||
// Signal all clients to disconnect
|
||||
let clients = self.clients.read();
|
||||
for client in clients.values() {
|
||||
let _ = client.shutdown_tx.try_send(());
|
||||
}
|
||||
}
|
||||
|
||||
/// Get total events processed
|
||||
pub fn total_events(&self) -> u64 {
|
||||
self.total_events.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Add a new WebSocket client
|
||||
///
|
||||
/// This spawns a background task to handle the WebSocket connection.
|
||||
pub async fn add_client(self: &Arc<Self>, client_id: ClientId, socket: WebSocket) {
|
||||
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
|
||||
|
||||
let client = Arc::new(WsHidClient {
|
||||
id: client_id.clone(),
|
||||
connected_at: Instant::now(),
|
||||
events_processed: AtomicU64::new(0),
|
||||
shutdown_tx,
|
||||
});
|
||||
|
||||
self.clients.write().insert(client_id.clone(), client.clone());
|
||||
info!(
|
||||
"WsHidHandler: Client {} connected (total: {})",
|
||||
client_id,
|
||||
self.client_count()
|
||||
);
|
||||
|
||||
// Spawn handler task
|
||||
let handler = self.clone();
|
||||
tokio::spawn(async move {
|
||||
handler
|
||||
.handle_client(client_id.clone(), socket, client, shutdown_rx)
|
||||
.await;
|
||||
handler.remove_client(&client_id);
|
||||
});
|
||||
}
|
||||
|
||||
/// Remove a client
|
||||
pub fn remove_client(&self, client_id: &str) {
|
||||
if let Some(client) = self.clients.write().remove(client_id) {
|
||||
info!(
|
||||
"WsHidHandler: Client {} disconnected after {}s ({} events)",
|
||||
client_id,
|
||||
client.connected_secs(),
|
||||
client.events_count()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a WebSocket client connection
|
||||
async fn handle_client(
|
||||
&self,
|
||||
client_id: ClientId,
|
||||
socket: WebSocket,
|
||||
client: Arc<WsHidClient>,
|
||||
mut shutdown_rx: mpsc::Receiver<()>,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Send initial status as binary: 0x00 = ok, 0x01 = error
|
||||
let status_byte = if self.is_hid_available() { 0x00u8 } else { 0x01u8 };
|
||||
let _ = sender.send(Message::Binary(vec![status_byte])).await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = shutdown_rx.recv() => {
|
||||
debug!("WsHidHandler: Client {} received shutdown signal", client_id);
|
||||
break;
|
||||
}
|
||||
|
||||
msg = receiver.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
if let Err(e) = self.handle_binary_message(&data, &client).await {
|
||||
warn!("WsHidHandler: Failed to handle binary message: {}", e);
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = sender.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
debug!("WsHidHandler: Client {} closed connection", client_id);
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("WsHidHandler: WebSocket error for client {}: {}", client_id, e);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
debug!("WsHidHandler: Client {} stream ended", client_id);
|
||||
break;
|
||||
}
|
||||
// Ignore text messages - binary protocol only
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("WsHidHandler: Ignoring text message from client {} (binary protocol only)", client_id);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle binary HID message
|
||||
async fn handle_binary_message(&self, data: &[u8], client: &WsHidClient) -> Result<(), String> {
|
||||
let hid = self
|
||||
.hid_controller
|
||||
.read()
|
||||
.clone()
|
||||
.ok_or("HID controller not available")?;
|
||||
|
||||
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
|
||||
|
||||
match event {
|
||||
HidChannelEvent::Keyboard(kb_event) => {
|
||||
hid.send_keyboard(kb_event)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
}
|
||||
HidChannelEvent::Mouse(ms_event) => {
|
||||
hid.send_mouse(ms_event).await.map_err(|e| e.to_string())?;
|
||||
}
|
||||
}
|
||||
|
||||
client.events_processed.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_events.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WsHidHandler {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hid_controller: RwLock::new(None),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
running: AtomicBool::new(true),
|
||||
total_events: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ws_hid_handler_creation() {
|
||||
let handler = WsHidHandler::new();
|
||||
assert!(handler.is_running());
|
||||
assert_eq!(handler.client_count(), 0);
|
||||
assert!(!handler.is_hid_available());
|
||||
}
|
||||
}
|
||||
7
src/utils/mod.rs
Normal file
7
src/utils/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
//! Utility modules for One-KVM
|
||||
//!
|
||||
//! This module contains common utilities used across the codebase.
|
||||
|
||||
pub mod throttle;
|
||||
|
||||
pub use throttle::LogThrottler;
|
||||
247
src/utils/throttle.rs
Normal file
247
src/utils/throttle.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
//! Log throttling utility
|
||||
//!
|
||||
//! Provides a mechanism to limit how often the same log message is recorded,
|
||||
//! preventing log flooding when errors occur repeatedly.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Log throttler that limits how often the same message is logged
|
||||
///
|
||||
/// This is useful for preventing log flooding when errors occur repeatedly,
|
||||
/// such as when a device is disconnected and reconnection attempts fail.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use one_kvm::utils::LogThrottler;
|
||||
///
|
||||
/// let throttler = LogThrottler::new(Duration::from_secs(5));
|
||||
///
|
||||
/// // First call returns true
|
||||
/// assert!(throttler.should_log("device_error"));
|
||||
///
|
||||
/// // Subsequent calls within 5 seconds return false
|
||||
/// assert!(!throttler.should_log("device_error"));
|
||||
/// ```
|
||||
pub struct LogThrottler {
|
||||
/// Map of message key to last log time
|
||||
last_logged: RwLock<HashMap<String, Instant>>,
|
||||
/// Throttle interval
|
||||
interval: Duration,
|
||||
}
|
||||
|
||||
impl LogThrottler {
|
||||
/// Create a new log throttler with the specified interval
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `interval` - The minimum time between log messages for the same key
|
||||
pub fn new(interval: Duration) -> Self {
|
||||
Self {
|
||||
last_logged: RwLock::new(HashMap::new()),
|
||||
interval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new log throttler with interval specified in seconds
|
||||
pub fn with_secs(secs: u64) -> Self {
|
||||
Self::new(Duration::from_secs(secs))
|
||||
}
|
||||
|
||||
/// Check if a message should be logged (not throttled)
|
||||
///
|
||||
/// Returns `true` if the message should be logged, `false` if it should be throttled.
|
||||
/// If `true` is returned, the internal timestamp is updated.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - A unique identifier for the message type
|
||||
pub fn should_log(&self, key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
// First check with read lock (fast path)
|
||||
{
|
||||
let map = self.last_logged.read().unwrap();
|
||||
if let Some(last) = map.get(key) {
|
||||
if now.duration_since(*last) < self.interval {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update with write lock
|
||||
let mut map = self.last_logged.write().unwrap();
|
||||
// Double-check after acquiring write lock
|
||||
if let Some(last) = map.get(key) {
|
||||
if now.duration_since(*last) < self.interval {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
map.insert(key.to_string(), now);
|
||||
true
|
||||
}
|
||||
|
||||
/// Clear throttle state for a specific key
|
||||
///
|
||||
/// This should be called when an error condition recovers,
|
||||
/// so the next error will be logged immediately.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - The key to clear
|
||||
pub fn clear(&self, key: &str) {
|
||||
self.last_logged.write().unwrap().remove(key);
|
||||
}
|
||||
|
||||
/// Clear all throttle state
|
||||
pub fn clear_all(&self) {
|
||||
self.last_logged.write().unwrap().clear();
|
||||
}
|
||||
|
||||
/// Get the number of tracked keys
|
||||
pub fn len(&self) -> usize {
|
||||
self.last_logged.read().unwrap().len()
|
||||
}
|
||||
|
||||
/// Check if the throttler is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.last_logged.read().unwrap().is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LogThrottler {
|
||||
/// Create a default log throttler with 5 second interval
|
||||
fn default() -> Self {
|
||||
Self::with_secs(5)
|
||||
}
|
||||
}
|
||||
|
||||
/// Macro for throttled warning logging
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use one_kvm::utils::LogThrottler;
|
||||
/// use one_kvm::warn_throttled;
|
||||
///
|
||||
/// let throttler = LogThrottler::default();
|
||||
/// warn_throttled!(throttler, "my_error", "Error occurred: {}", "details");
|
||||
/// ```
|
||||
#[macro_export]
|
||||
macro_rules! warn_throttled {
|
||||
($throttler:expr, $key:expr, $($arg:tt)*) => {
|
||||
if $throttler.should_log($key) {
|
||||
tracing::warn!($($arg)*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for throttled error logging
|
||||
#[macro_export]
|
||||
macro_rules! error_throttled {
|
||||
($throttler:expr, $key:expr, $($arg:tt)*) => {
|
||||
if $throttler.should_log($key) {
|
||||
tracing::error!($($arg)*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for throttled info logging
|
||||
#[macro_export]
|
||||
macro_rules! info_throttled {
|
||||
($throttler:expr, $key:expr, $($arg:tt)*) => {
|
||||
if $throttler.should_log($key) {
|
||||
tracing::info!($($arg)*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn test_should_log_first_call() {
|
||||
let throttler = LogThrottler::with_secs(1);
|
||||
assert!(throttler.should_log("test_key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_throttling() {
|
||||
let throttler = LogThrottler::new(Duration::from_millis(100));
|
||||
|
||||
// First call should succeed
|
||||
assert!(throttler.should_log("test_key"));
|
||||
|
||||
// Immediate second call should be throttled
|
||||
assert!(!throttler.should_log("test_key"));
|
||||
|
||||
// Wait for throttle to expire
|
||||
thread::sleep(Duration::from_millis(150));
|
||||
|
||||
// Should succeed again
|
||||
assert!(throttler.should_log("test_key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keys() {
|
||||
let throttler = LogThrottler::with_secs(10);
|
||||
|
||||
// Different keys should be independent
|
||||
assert!(throttler.should_log("key1"));
|
||||
assert!(throttler.should_log("key2"));
|
||||
assert!(!throttler.should_log("key1"));
|
||||
assert!(!throttler.should_log("key2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let throttler = LogThrottler::with_secs(10);
|
||||
|
||||
assert!(throttler.should_log("test_key"));
|
||||
assert!(!throttler.should_log("test_key"));
|
||||
|
||||
// Clear the key
|
||||
throttler.clear("test_key");
|
||||
|
||||
// Should be able to log again
|
||||
assert!(throttler.should_log("test_key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_all() {
|
||||
let throttler = LogThrottler::with_secs(10);
|
||||
|
||||
assert!(throttler.should_log("key1"));
|
||||
assert!(throttler.should_log("key2"));
|
||||
|
||||
throttler.clear_all();
|
||||
|
||||
assert!(throttler.should_log("key1"));
|
||||
assert!(throttler.should_log("key2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default() {
|
||||
let throttler = LogThrottler::default();
|
||||
assert!(throttler.should_log("test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_len_and_is_empty() {
|
||||
let throttler = LogThrottler::with_secs(10);
|
||||
|
||||
assert!(throttler.is_empty());
|
||||
assert_eq!(throttler.len(), 0);
|
||||
|
||||
throttler.should_log("key1");
|
||||
assert!(!throttler.is_empty());
|
||||
assert_eq!(throttler.len(), 1);
|
||||
|
||||
throttler.should_log("key2");
|
||||
assert_eq!(throttler.len(), 2);
|
||||
}
|
||||
}
|
||||
693
src/video/capture.rs
Normal file
693
src/video/capture.rs
Normal file
@@ -0,0 +1,693 @@
|
||||
//! V4L2 video capture implementation
|
||||
//!
|
||||
//! Provides async video capture using memory-mapped buffers.
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{broadcast, watch, Mutex};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use v4l::buffer::Type as BufferType;
|
||||
use v4l::io::traits::CaptureStream;
|
||||
use v4l::prelude::*;
|
||||
use v4l::video::capture::Parameters;
|
||||
use v4l::video::Capture;
|
||||
use v4l::Format;
|
||||
|
||||
use super::format::{PixelFormat, Resolution};
|
||||
use super::frame::VideoFrame;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Default number of capture buffers (reduced from 4 to 2 for lower latency)
|
||||
const DEFAULT_BUFFER_COUNT: u32 = 2;
|
||||
/// Default capture timeout in seconds
|
||||
const DEFAULT_TIMEOUT: u64 = 2;
|
||||
/// Minimum valid frame size (bytes)
|
||||
const MIN_FRAME_SIZE: usize = 128;
|
||||
|
||||
/// Video capturer configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CaptureConfig {
|
||||
/// Device path
|
||||
pub device_path: PathBuf,
|
||||
/// Desired resolution
|
||||
pub resolution: Resolution,
|
||||
/// Desired pixel format
|
||||
pub format: PixelFormat,
|
||||
/// Desired frame rate (0 = max available)
|
||||
pub fps: u32,
|
||||
/// Number of capture buffers
|
||||
pub buffer_count: u32,
|
||||
/// Capture timeout
|
||||
pub timeout: Duration,
|
||||
/// JPEG quality (1-100, for MJPEG sources with hardware quality control)
|
||||
pub jpeg_quality: u8,
|
||||
}
|
||||
|
||||
impl Default for CaptureConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_path: PathBuf::from("/dev/video0"),
|
||||
resolution: Resolution::HD1080,
|
||||
format: PixelFormat::Mjpeg,
|
||||
fps: 30,
|
||||
buffer_count: DEFAULT_BUFFER_COUNT,
|
||||
timeout: Duration::from_secs(DEFAULT_TIMEOUT),
|
||||
jpeg_quality: 80,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CaptureConfig {
|
||||
/// Create config for a specific device
|
||||
pub fn for_device(path: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
device_path: path.as_ref().to_path_buf(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set resolution
|
||||
pub fn with_resolution(mut self, width: u32, height: u32) -> Self {
|
||||
self.resolution = Resolution::new(width, height);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set format
|
||||
pub fn with_format(mut self, format: PixelFormat) -> Self {
|
||||
self.format = format;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set frame rate
|
||||
pub fn with_fps(mut self, fps: u32) -> Self {
|
||||
self.fps = fps;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CaptureStats {
|
||||
/// Total frames captured
|
||||
pub frames_captured: u64,
|
||||
/// Frames dropped (invalid/too small)
|
||||
pub frames_dropped: u64,
|
||||
/// Current FPS (calculated)
|
||||
pub current_fps: f32,
|
||||
/// Average frame size in bytes
|
||||
pub avg_frame_size: usize,
|
||||
/// Capture errors
|
||||
pub errors: u64,
|
||||
/// Last frame timestamp
|
||||
pub last_frame_ts: Option<Instant>,
|
||||
/// Whether signal is present
|
||||
pub signal_present: bool,
|
||||
}
|
||||
|
||||
/// Video capturer state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CaptureState {
|
||||
/// Not started
|
||||
Stopped,
|
||||
/// Starting (initializing device)
|
||||
Starting,
|
||||
/// Running and capturing
|
||||
Running,
|
||||
/// No signal from source
|
||||
NoSignal,
|
||||
/// Error occurred
|
||||
Error,
|
||||
/// Device was lost (disconnected)
|
||||
DeviceLost,
|
||||
}
|
||||
|
||||
/// Async video capturer
|
||||
pub struct VideoCapturer {
|
||||
config: CaptureConfig,
|
||||
state: Arc<watch::Sender<CaptureState>>,
|
||||
state_rx: watch::Receiver<CaptureState>,
|
||||
stats: Arc<Mutex<CaptureStats>>,
|
||||
frame_tx: broadcast::Sender<VideoFrame>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
sequence: Arc<AtomicU64>,
|
||||
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
/// Last error that occurred (device path, reason)
|
||||
last_error: Arc<parking_lot::RwLock<Option<(String, String)>>>,
|
||||
}
|
||||
|
||||
impl VideoCapturer {
|
||||
/// Create a new video capturer
|
||||
pub fn new(config: CaptureConfig) -> Self {
|
||||
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
|
||||
let (frame_tx, _) = broadcast::channel(16); // Buffer up to 16 frames
|
||||
|
||||
Self {
|
||||
config,
|
||||
state: Arc::new(state_tx),
|
||||
state_rx,
|
||||
stats: Arc::new(Mutex::new(CaptureStats::default())),
|
||||
frame_tx,
|
||||
stop_flag: Arc::new(AtomicBool::new(false)),
|
||||
sequence: Arc::new(AtomicU64::new(0)),
|
||||
capture_handle: Mutex::new(None),
|
||||
last_error: Arc::new(parking_lot::RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current capture state
|
||||
pub fn state(&self) -> CaptureState {
|
||||
*self.state_rx.borrow()
|
||||
}
|
||||
|
||||
/// Subscribe to state changes
|
||||
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
|
||||
self.state_rx.clone()
|
||||
}
|
||||
|
||||
/// Get last error (device path, reason)
|
||||
pub fn last_error(&self) -> Option<(String, String)> {
|
||||
self.last_error.read().clone()
|
||||
}
|
||||
|
||||
/// Clear last error
|
||||
pub fn clear_error(&self) {
|
||||
*self.last_error.write() = None;
|
||||
}
|
||||
|
||||
/// Subscribe to frames
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<VideoFrame> {
|
||||
self.frame_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get frame sender (for sharing with other components like WebRTC)
|
||||
pub fn frame_sender(&self) -> broadcast::Sender<VideoFrame> {
|
||||
self.frame_tx.clone()
|
||||
}
|
||||
|
||||
/// Get capture statistics
|
||||
pub async fn stats(&self) -> CaptureStats {
|
||||
self.stats.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Get config
|
||||
pub fn config(&self) -> &CaptureConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Start capturing in background
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
let current_state = self.state();
|
||||
// Already running or starting - nothing to do
|
||||
if current_state == CaptureState::Running || current_state == CaptureState::Starting {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Starting capture on {:?} at {}x{} {}",
|
||||
self.config.device_path,
|
||||
self.config.resolution.width,
|
||||
self.config.resolution.height,
|
||||
self.config.format
|
||||
);
|
||||
|
||||
// Set Starting state immediately to prevent concurrent start attempts
|
||||
let _ = self.state.send(CaptureState::Starting);
|
||||
|
||||
// Clear any previous error
|
||||
*self.last_error.write() = None;
|
||||
|
||||
self.stop_flag.store(false, Ordering::SeqCst);
|
||||
|
||||
let config = self.config.clone();
|
||||
let state = self.state.clone();
|
||||
let stats = self.stats.clone();
|
||||
let frame_tx = self.frame_tx.clone();
|
||||
let stop_flag = self.stop_flag.clone();
|
||||
let sequence = self.sequence.clone();
|
||||
let last_error = self.last_error.clone();
|
||||
|
||||
let handle = tokio::task::spawn_blocking(move || {
|
||||
capture_loop(config, state, stats, frame_tx, stop_flag, sequence, last_error);
|
||||
});
|
||||
|
||||
*self.capture_handle.lock().await = Some(handle);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop capturing
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
info!("Stopping capture");
|
||||
self.stop_flag.store(true, Ordering::SeqCst);
|
||||
|
||||
if let Some(handle) = self.capture_handle.lock().await.take() {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
||||
let _ = self.state.send(CaptureState::Stopped);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if capturing
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.state() == CaptureState::Running
|
||||
}
|
||||
|
||||
/// Get the latest frame (if any receivers would get it)
|
||||
pub fn latest_frame(&self) -> Option<VideoFrame> {
|
||||
// This is a bit tricky with broadcast - we'd need to track internally
|
||||
// For now, callers should use subscribe()
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Main capture loop (runs in blocking thread)
|
||||
fn capture_loop(
|
||||
config: CaptureConfig,
|
||||
state: Arc<watch::Sender<CaptureState>>,
|
||||
stats: Arc<Mutex<CaptureStats>>,
|
||||
frame_tx: broadcast::Sender<VideoFrame>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
sequence: Arc<AtomicU64>,
|
||||
error_holder: Arc<parking_lot::RwLock<Option<(String, String)>>>,
|
||||
) {
|
||||
let result = run_capture(
|
||||
&config,
|
||||
&state,
|
||||
&stats,
|
||||
&frame_tx,
|
||||
&stop_flag,
|
||||
&sequence,
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let _ = state.send(CaptureState::Stopped);
|
||||
}
|
||||
Err(AppError::VideoDeviceLost { device, reason }) => {
|
||||
error!("Video device lost: {} - {}", device, reason);
|
||||
// Store the error for recovery handling
|
||||
*error_holder.write() = Some((device, reason));
|
||||
let _ = state.send(CaptureState::DeviceLost);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Capture error: {}", e);
|
||||
let _ = state.send(CaptureState::Error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_capture(
|
||||
config: &CaptureConfig,
|
||||
state: &watch::Sender<CaptureState>,
|
||||
stats: &Arc<Mutex<CaptureStats>>,
|
||||
frame_tx: &broadcast::Sender<VideoFrame>,
|
||||
stop_flag: &AtomicBool,
|
||||
sequence: &AtomicU64,
|
||||
) -> Result<()> {
|
||||
// Retry logic for device busy errors
|
||||
const MAX_RETRIES: u32 = 5;
|
||||
const RETRY_DELAY_MS: u64 = 200;
|
||||
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
if stop_flag.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Open device
|
||||
let device = match Device::with_path(&config.device_path) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("busy") || err_str.contains("resource") {
|
||||
warn!(
|
||||
"Device busy on attempt {}/{}, retrying in {}ms...",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
RETRY_DELAY_MS
|
||||
);
|
||||
std::thread::sleep(Duration::from_millis(RETRY_DELAY_MS));
|
||||
last_error = Some(AppError::VideoError(format!(
|
||||
"Failed to open device {:?}: {}",
|
||||
config.device_path, e
|
||||
)));
|
||||
continue;
|
||||
}
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Failed to open device {:?}: {}",
|
||||
config.device_path, e
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Set format
|
||||
let format = Format::new(
|
||||
config.resolution.width,
|
||||
config.resolution.height,
|
||||
config.format.to_fourcc(),
|
||||
);
|
||||
|
||||
let actual_format = match device.set_format(&format) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("busy") || err_str.contains("resource") {
|
||||
warn!(
|
||||
"Device busy on set_format attempt {}/{}, retrying in {}ms...",
|
||||
attempt + 1,
|
||||
MAX_RETRIES,
|
||||
RETRY_DELAY_MS
|
||||
);
|
||||
std::thread::sleep(Duration::from_millis(RETRY_DELAY_MS));
|
||||
last_error = Some(AppError::VideoError(format!("Failed to set format: {}", e)));
|
||||
continue;
|
||||
}
|
||||
return Err(AppError::VideoError(format!("Failed to set format: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
// Device opened and format set successfully - proceed with capture
|
||||
return run_capture_inner(
|
||||
config,
|
||||
state,
|
||||
stats,
|
||||
frame_tx,
|
||||
stop_flag,
|
||||
sequence,
|
||||
device,
|
||||
actual_format,
|
||||
);
|
||||
}
|
||||
|
||||
// All retries exhausted
|
||||
Err(last_error.unwrap_or_else(|| {
|
||||
AppError::VideoError("Failed to open device after all retries".to_string())
|
||||
}))
|
||||
}
|
||||
|
||||
/// Inner capture function after device is successfully opened
|
||||
fn run_capture_inner(
|
||||
config: &CaptureConfig,
|
||||
state: &watch::Sender<CaptureState>,
|
||||
stats: &Arc<Mutex<CaptureStats>>,
|
||||
frame_tx: &broadcast::Sender<VideoFrame>,
|
||||
stop_flag: &AtomicBool,
|
||||
sequence: &AtomicU64,
|
||||
device: Device,
|
||||
actual_format: Format,
|
||||
) -> Result<()> {
|
||||
info!(
|
||||
"Capture format: {}x{} {:?} stride={}",
|
||||
actual_format.width, actual_format.height, actual_format.fourcc, actual_format.stride
|
||||
);
|
||||
|
||||
let resolution = Resolution::new(actual_format.width, actual_format.height);
|
||||
let pixel_format = PixelFormat::from_fourcc(actual_format.fourcc).unwrap_or(config.format);
|
||||
|
||||
// Try to set hardware FPS (V4L2 VIDIOC_S_PARM)
|
||||
if config.fps > 0 {
|
||||
match device.set_params(&Parameters::with_fps(config.fps)) {
|
||||
Ok(actual_params) => {
|
||||
// Extract actual FPS from returned interval (numerator/denominator)
|
||||
let actual_hw_fps = if actual_params.interval.numerator > 0 {
|
||||
actual_params.interval.denominator / actual_params.interval.numerator
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if actual_hw_fps == config.fps {
|
||||
info!("Hardware FPS set successfully: {} fps", actual_hw_fps);
|
||||
} else if actual_hw_fps > 0 {
|
||||
info!(
|
||||
"Hardware FPS coerced: requested {} fps, got {} fps",
|
||||
config.fps, actual_hw_fps
|
||||
);
|
||||
} else {
|
||||
warn!("Hardware FPS setting returned invalid interval");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to set hardware FPS: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create stream with mmap buffers
|
||||
let mut stream =
|
||||
MmapStream::with_buffers(&device, BufferType::VideoCapture, config.buffer_count)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to create stream: {}", e)))?;
|
||||
|
||||
let _ = state.send(CaptureState::Running);
|
||||
info!("Capture started");
|
||||
|
||||
// FPS calculation variables
|
||||
let mut fps_frame_count = 0u64;
|
||||
let mut fps_window_start = Instant::now();
|
||||
let fps_window_duration = Duration::from_secs(1);
|
||||
|
||||
// Main capture loop
|
||||
while !stop_flag.load(Ordering::Relaxed) {
|
||||
// Try to capture a frame
|
||||
let (buf, meta) = match stream.next() {
|
||||
Ok(frame_data) => frame_data,
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::TimedOut {
|
||||
warn!("Capture timeout - no signal?");
|
||||
let _ = state.send(CaptureState::NoSignal);
|
||||
|
||||
// Update stats
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.signal_present = false;
|
||||
}
|
||||
|
||||
// Wait a bit before retrying
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for device loss errors
|
||||
let is_device_lost = match e.raw_os_error() {
|
||||
Some(6) => true, // ENXIO - No such device or address
|
||||
Some(19) => true, // ENODEV - No such device
|
||||
Some(5) => true, // EIO - I/O error (device removed)
|
||||
Some(32) => true, // EPIPE - Broken pipe
|
||||
Some(108) => true, // ESHUTDOWN - Transport endpoint shutdown
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if is_device_lost {
|
||||
let device_path = config.device_path.display().to_string();
|
||||
error!("Video device lost: {} - {}", device_path, e);
|
||||
return Err(AppError::VideoDeviceLost {
|
||||
device: device_path,
|
||||
reason: e.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
error!("Capture error: {}", e);
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.errors += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Use actual bytes used, not buffer size
|
||||
let frame_size = meta.bytesused as usize;
|
||||
|
||||
// Validate frame
|
||||
if frame_size < MIN_FRAME_SIZE {
|
||||
debug!("Dropping small frame: {} bytes (bytesused={})", frame_size, meta.bytesused);
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.frames_dropped += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// For JPEG formats, validate header
|
||||
if pixel_format.is_compressed() && !is_valid_jpeg(&buf[..frame_size]) {
|
||||
debug!("Dropping invalid JPEG frame (size={})", frame_size);
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.frames_dropped += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create frame with actual data size
|
||||
let seq = sequence.fetch_add(1, Ordering::Relaxed);
|
||||
let frame = VideoFrame::new(
|
||||
Bytes::copy_from_slice(&buf[..frame_size]),
|
||||
resolution,
|
||||
pixel_format,
|
||||
actual_format.stride,
|
||||
seq,
|
||||
);
|
||||
|
||||
// Update state if was no signal
|
||||
if *state.borrow() == CaptureState::NoSignal {
|
||||
let _ = state.send(CaptureState::Running);
|
||||
}
|
||||
|
||||
// Send frame to subscribers
|
||||
let receiver_count = frame_tx.receiver_count();
|
||||
if receiver_count > 0 {
|
||||
if let Err(e) = frame_tx.send(frame) {
|
||||
debug!("No active receivers for frame: {}", e);
|
||||
}
|
||||
} else if seq % 60 == 0 {
|
||||
// Log every 60 frames (about 1 second at 60fps) when no receivers
|
||||
debug!("No receivers for video frames (receiver_count=0)");
|
||||
}
|
||||
|
||||
// Update stats
|
||||
if let Ok(mut s) = stats.try_lock() {
|
||||
s.frames_captured += 1;
|
||||
s.signal_present = true;
|
||||
s.last_frame_ts = Some(Instant::now());
|
||||
|
||||
// Update FPS calculation
|
||||
fps_frame_count += 1;
|
||||
let elapsed = fps_window_start.elapsed();
|
||||
|
||||
if elapsed >= fps_window_duration {
|
||||
// Calculate FPS from the completed window
|
||||
s.current_fps = (fps_frame_count as f32 / elapsed.as_secs_f32()).max(0.0);
|
||||
// Reset for next window
|
||||
fps_frame_count = 0;
|
||||
fps_window_start = Instant::now();
|
||||
} else if elapsed.as_millis() > 100 && fps_frame_count > 0 {
|
||||
// Provide partial estimate if we have at least 100ms of data
|
||||
s.current_fps = (fps_frame_count as f32 / elapsed.as_secs_f32()).max(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Capture stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate JPEG frame data
|
||||
fn is_valid_jpeg(data: &[u8]) -> bool {
|
||||
if data.len() < 125 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check start marker (0xFFD8)
|
||||
let start_marker = ((data[0] as u16) << 8) | data[1] as u16;
|
||||
if start_marker != 0xFFD8 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check end marker
|
||||
let end = data.len();
|
||||
let end_marker = ((data[end - 2] as u16) << 8) | data[end - 1] as u16;
|
||||
|
||||
// Valid end markers: 0xFFD9, 0xD900, 0x0000 (padded)
|
||||
matches!(end_marker, 0xFFD9 | 0xD900 | 0x0000)
|
||||
}
|
||||
|
||||
/// Frame grabber for one-shot capture
|
||||
pub struct FrameGrabber {
|
||||
device_path: PathBuf,
|
||||
}
|
||||
|
||||
impl FrameGrabber {
|
||||
/// Create a new frame grabber
|
||||
pub fn new(device_path: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
device_path: device_path.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Capture a single frame
|
||||
pub async fn grab(
|
||||
&self,
|
||||
resolution: Resolution,
|
||||
format: PixelFormat,
|
||||
) -> Result<VideoFrame> {
|
||||
let device_path = self.device_path.clone();
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
grab_single_frame(&device_path, resolution, format)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AppError::VideoError(format!("Grab task failed: {}", e)))?
|
||||
}
|
||||
}
|
||||
|
||||
fn grab_single_frame(
|
||||
device_path: &Path,
|
||||
resolution: Resolution,
|
||||
format: PixelFormat,
|
||||
) -> Result<VideoFrame> {
|
||||
let device = Device::with_path(device_path).map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to open device: {}", e))
|
||||
})?;
|
||||
|
||||
let fmt = Format::new(resolution.width, resolution.height, format.to_fourcc());
|
||||
let actual = device.set_format(&fmt).map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to set format: {}", e))
|
||||
})?;
|
||||
|
||||
let mut stream = MmapStream::with_buffers(&device, BufferType::VideoCapture, 2)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to create stream: {}", e)))?;
|
||||
|
||||
// Try to get a valid frame (skip first few which might be bad)
|
||||
for attempt in 0..5 {
|
||||
match stream.next() {
|
||||
Ok((buf, _meta)) => {
|
||||
if buf.len() >= MIN_FRAME_SIZE {
|
||||
let actual_format =
|
||||
PixelFormat::from_fourcc(actual.fourcc).unwrap_or(format);
|
||||
|
||||
return Ok(VideoFrame::new(
|
||||
Bytes::copy_from_slice(buf),
|
||||
Resolution::new(actual.width, actual.height),
|
||||
actual_format,
|
||||
actual.stride,
|
||||
0,
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt == 4 {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Failed to grab frame: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(AppError::VideoError("Failed to capture valid frame".to_string()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_jpeg() {
|
||||
// Valid JPEG header and footer
|
||||
let mut data = vec![0xFF, 0xD8]; // SOI
|
||||
data.extend(vec![0u8; 200]); // Content
|
||||
data.extend([0xFF, 0xD9]); // EOI
|
||||
|
||||
assert!(is_valid_jpeg(&data));
|
||||
|
||||
// Invalid - too small
|
||||
assert!(!is_valid_jpeg(&[0xFF, 0xD8, 0xFF, 0xD9]));
|
||||
|
||||
// Invalid - wrong header
|
||||
let mut bad = vec![0x00, 0x00];
|
||||
bad.extend(vec![0u8; 200]);
|
||||
assert!(!is_valid_jpeg(&bad));
|
||||
}
|
||||
}
|
||||
640
src/video/convert.rs
Normal file
640
src/video/convert.rs
Normal file
@@ -0,0 +1,640 @@
|
||||
//! Pixel format conversion utilities
|
||||
//!
|
||||
//! This module provides SIMD-accelerated color space conversion using libyuv.
|
||||
//! Primary use case: YUYV (from V4L2 capture) → YUV420P/NV12 (for H264 encoding)
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
/// YUV420P buffer with separate Y, U, V planes
|
||||
pub struct Yuv420pBuffer {
|
||||
/// Raw buffer containing all planes
|
||||
data: Vec<u8>,
|
||||
/// Width of the frame
|
||||
width: u32,
|
||||
/// Height of the frame
|
||||
height: u32,
|
||||
/// Y plane offset (always 0)
|
||||
y_offset: usize,
|
||||
/// U plane offset
|
||||
u_offset: usize,
|
||||
/// V plane offset
|
||||
v_offset: usize,
|
||||
}
|
||||
|
||||
impl Yuv420pBuffer {
|
||||
/// Create a new YUV420P buffer for the given resolution
|
||||
pub fn new(resolution: Resolution) -> Self {
|
||||
let width = resolution.width;
|
||||
let height = resolution.height;
|
||||
|
||||
// YUV420P: Y = width*height, U = width*height/4, V = width*height/4
|
||||
let y_size = (width * height) as usize;
|
||||
let uv_size = y_size / 4;
|
||||
let total_size = y_size + uv_size * 2;
|
||||
|
||||
Self {
|
||||
data: vec![0u8; total_size],
|
||||
width,
|
||||
height,
|
||||
y_offset: 0,
|
||||
u_offset: y_size,
|
||||
v_offset: y_size + uv_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the raw buffer as bytes
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Get the raw buffer as mutable bytes
|
||||
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
/// Get Y plane
|
||||
pub fn y_plane(&self) -> &[u8] {
|
||||
&self.data[self.y_offset..self.u_offset]
|
||||
}
|
||||
|
||||
/// Get Y plane mutable
|
||||
pub fn y_plane_mut(&mut self) -> &mut [u8] {
|
||||
let u_offset = self.u_offset;
|
||||
&mut self.data[self.y_offset..u_offset]
|
||||
}
|
||||
|
||||
/// Get U plane
|
||||
pub fn u_plane(&self) -> &[u8] {
|
||||
&self.data[self.u_offset..self.v_offset]
|
||||
}
|
||||
|
||||
/// Get U plane mutable
|
||||
pub fn u_plane_mut(&mut self) -> &mut [u8] {
|
||||
let v_offset = self.v_offset;
|
||||
let u_offset = self.u_offset;
|
||||
&mut self.data[u_offset..v_offset]
|
||||
}
|
||||
|
||||
/// Get V plane
|
||||
pub fn v_plane(&self) -> &[u8] {
|
||||
&self.data[self.v_offset..]
|
||||
}
|
||||
|
||||
/// Get V plane mutable
|
||||
pub fn v_plane_mut(&mut self) -> &mut [u8] {
|
||||
let v_offset = self.v_offset;
|
||||
&mut self.data[v_offset..]
|
||||
}
|
||||
|
||||
/// Get buffer length
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Check if buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Get resolution
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
Resolution::new(self.width, self.height)
|
||||
}
|
||||
}
|
||||
|
||||
/// NV12 buffer with Y plane and interleaved UV plane
|
||||
pub struct Nv12Buffer {
|
||||
/// Raw buffer containing Y plane followed by interleaved UV plane
|
||||
data: Vec<u8>,
|
||||
/// Width of the frame
|
||||
width: u32,
|
||||
/// Height of the frame
|
||||
height: u32,
|
||||
}
|
||||
|
||||
impl Nv12Buffer {
|
||||
/// Create a new NV12 buffer for the given resolution
|
||||
pub fn new(resolution: Resolution) -> Self {
|
||||
let width = resolution.width;
|
||||
let height = resolution.height;
|
||||
// NV12: Y = width*height, UV = width*height/2 (interleaved)
|
||||
let y_size = (width * height) as usize;
|
||||
let uv_size = y_size / 2;
|
||||
let total_size = y_size + uv_size;
|
||||
|
||||
Self {
|
||||
data: vec![0u8; total_size],
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the raw buffer as bytes
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Get the raw buffer as mutable bytes
|
||||
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
/// Get Y plane
|
||||
pub fn y_plane(&self) -> &[u8] {
|
||||
let y_size = (self.width * self.height) as usize;
|
||||
&self.data[..y_size]
|
||||
}
|
||||
|
||||
/// Get Y plane mutable
|
||||
pub fn y_plane_mut(&mut self) -> &mut [u8] {
|
||||
let y_size = (self.width * self.height) as usize;
|
||||
&mut self.data[..y_size]
|
||||
}
|
||||
|
||||
/// Get UV plane (interleaved)
|
||||
pub fn uv_plane(&self) -> &[u8] {
|
||||
let y_size = (self.width * self.height) as usize;
|
||||
&self.data[y_size..]
|
||||
}
|
||||
|
||||
/// Get UV plane mutable
|
||||
pub fn uv_plane_mut(&mut self) -> &mut [u8] {
|
||||
let y_size = (self.width * self.height) as usize;
|
||||
&mut self.data[y_size..]
|
||||
}
|
||||
|
||||
/// Get buffer length
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Check if buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Get resolution
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
Resolution::new(self.width, self.height)
|
||||
}
|
||||
}
|
||||
|
||||
/// Pixel format converter using libyuv (SIMD accelerated)
|
||||
pub struct PixelConverter {
|
||||
/// Source format
|
||||
src_format: PixelFormat,
|
||||
/// Destination format
|
||||
dst_format: PixelFormat,
|
||||
/// Frame resolution
|
||||
resolution: Resolution,
|
||||
/// Output buffer (reused across conversions)
|
||||
output_buffer: Yuv420pBuffer,
|
||||
}
|
||||
|
||||
impl PixelConverter {
|
||||
/// Create a new converter for YUYV → YUV420P
|
||||
pub fn yuyv_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Yuyv,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for UYVY → YUV420P
|
||||
pub fn uyvy_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Uyvy,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for YVYU → YUV420P
|
||||
pub fn yvyu_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Yvyu,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for NV12 → YUV420P
|
||||
pub fn nv12_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Nv12,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for YVU420 → YUV420P (swap U and V planes)
|
||||
pub fn yvu420_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Yvu420,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for RGB24 → YUV420P
|
||||
pub fn rgb24_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Rgb24,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for BGR24 → YUV420P
|
||||
pub fn bgr24_to_yuv420p(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Bgr24,
|
||||
dst_format: PixelFormat::Yuv420,
|
||||
resolution,
|
||||
output_buffer: Yuv420pBuffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a frame and return reference to the output buffer
|
||||
pub fn convert(&mut self, input: &[u8]) -> Result<&[u8]> {
|
||||
let width = self.resolution.width as i32;
|
||||
let height = self.resolution.height as i32;
|
||||
let expected_size = self.output_buffer.len();
|
||||
|
||||
match (self.src_format, self.dst_format) {
|
||||
(PixelFormat::Yuyv, PixelFormat::Yuv420) => {
|
||||
libyuv::yuy2_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
|
||||
}
|
||||
(PixelFormat::Uyvy, PixelFormat::Yuv420) => {
|
||||
libyuv::uyvy_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
|
||||
}
|
||||
(PixelFormat::Nv12, PixelFormat::Yuv420) => {
|
||||
libyuv::nv12_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
|
||||
}
|
||||
(PixelFormat::Rgb24, PixelFormat::Yuv420) => {
|
||||
libyuv::rgb24_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
|
||||
}
|
||||
(PixelFormat::Bgr24, PixelFormat::Yuv420) => {
|
||||
libyuv::bgr24_to_i420(input, self.output_buffer.as_bytes_mut(), width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv conversion failed: {}", e)))?;
|
||||
}
|
||||
(PixelFormat::Yvyu, PixelFormat::Yuv420) => {
|
||||
// YVYU is not directly supported by libyuv, use software conversion
|
||||
self.convert_yvyu_to_yuv420p_sw(input)?;
|
||||
}
|
||||
(PixelFormat::Yvu420, PixelFormat::Yuv420) => {
|
||||
// YVU420 just swaps U and V planes
|
||||
self.convert_yvu420_to_yuv420p_sw(input)?;
|
||||
}
|
||||
(PixelFormat::Yuv420, PixelFormat::Yuv420) => {
|
||||
// No conversion needed, just copy
|
||||
if input.len() < expected_size {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Input buffer too small: {} < {}",
|
||||
input.len(),
|
||||
expected_size
|
||||
)));
|
||||
}
|
||||
self.output_buffer.as_bytes_mut().copy_from_slice(&input[..expected_size]);
|
||||
}
|
||||
_ => {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Unsupported conversion: {} → {}",
|
||||
self.src_format, self.dst_format
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(self.output_buffer.as_bytes())
|
||||
}
|
||||
|
||||
/// Get output buffer length
|
||||
pub fn output_len(&self) -> usize {
|
||||
self.output_buffer.len()
|
||||
}
|
||||
|
||||
/// Get resolution
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
self.resolution
|
||||
}
|
||||
|
||||
/// Software conversion for YVYU (not supported by libyuv)
|
||||
fn convert_yvyu_to_yuv420p_sw(&mut self, yvyu: &[u8]) -> Result<()> {
|
||||
let width = self.resolution.width as usize;
|
||||
let height = self.resolution.height as usize;
|
||||
let y_size = width * height;
|
||||
let uv_size = y_size / 4;
|
||||
let half_width = width / 2;
|
||||
|
||||
let data = self.output_buffer.as_bytes_mut();
|
||||
let (y_plane, uv_planes) = data.split_at_mut(y_size);
|
||||
let (u_plane, v_plane) = uv_planes.split_at_mut(uv_size);
|
||||
|
||||
for row in (0..height).step_by(2) {
|
||||
let yvyu_row0_offset = row * width * 2;
|
||||
let yvyu_row1_offset = (row + 1) * width * 2;
|
||||
let y_row0_offset = row * width;
|
||||
let y_row1_offset = (row + 1) * width;
|
||||
let uv_row_offset = (row / 2) * half_width;
|
||||
|
||||
for col in (0..width).step_by(2) {
|
||||
let yvyu_offset0 = yvyu_row0_offset + col * 2;
|
||||
let yvyu_offset1 = yvyu_row1_offset + col * 2;
|
||||
|
||||
// YVYU: Y0, V0, Y1, U0
|
||||
let y0_0 = yvyu[yvyu_offset0];
|
||||
let v0 = yvyu[yvyu_offset0 + 1];
|
||||
let y0_1 = yvyu[yvyu_offset0 + 2];
|
||||
let u0 = yvyu[yvyu_offset0 + 3];
|
||||
|
||||
let y1_0 = yvyu[yvyu_offset1];
|
||||
let v1 = yvyu[yvyu_offset1 + 1];
|
||||
let y1_1 = yvyu[yvyu_offset1 + 2];
|
||||
let u1 = yvyu[yvyu_offset1 + 3];
|
||||
|
||||
y_plane[y_row0_offset + col] = y0_0;
|
||||
y_plane[y_row0_offset + col + 1] = y0_1;
|
||||
y_plane[y_row1_offset + col] = y1_0;
|
||||
y_plane[y_row1_offset + col + 1] = y1_1;
|
||||
|
||||
let uv_idx = uv_row_offset + col / 2;
|
||||
u_plane[uv_idx] = ((u0 as u16 + u1 as u16) / 2) as u8;
|
||||
v_plane[uv_idx] = ((v0 as u16 + v1 as u16) / 2) as u8;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Software conversion for YVU420 (just swap U and V)
|
||||
fn convert_yvu420_to_yuv420p_sw(&mut self, yvu420: &[u8]) -> Result<()> {
|
||||
let width = self.resolution.width as usize;
|
||||
let height = self.resolution.height as usize;
|
||||
let y_size = width * height;
|
||||
let uv_size = y_size / 4;
|
||||
|
||||
let data = self.output_buffer.as_bytes_mut();
|
||||
let (y_plane, uv_planes) = data.split_at_mut(y_size);
|
||||
let (u_plane, v_plane) = uv_planes.split_at_mut(uv_size);
|
||||
|
||||
// Copy Y plane directly
|
||||
y_plane.copy_from_slice(&yvu420[..y_size]);
|
||||
|
||||
// In YVU420, V comes before U
|
||||
let v_src = &yvu420[y_size..y_size + uv_size];
|
||||
let u_src = &yvu420[y_size + uv_size..];
|
||||
|
||||
// Swap U and V
|
||||
u_plane.copy_from_slice(u_src);
|
||||
v_plane.copy_from_slice(v_src);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate YUV420P buffer size for a given resolution
|
||||
pub fn yuv420p_buffer_size(resolution: Resolution) -> usize {
|
||||
let pixels = (resolution.width * resolution.height) as usize;
|
||||
pixels + pixels / 2
|
||||
}
|
||||
|
||||
/// Calculate YUYV buffer size for a given resolution
|
||||
pub fn yuyv_buffer_size(resolution: Resolution) -> usize {
|
||||
(resolution.width * resolution.height * 2) as usize
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MJPEG Decoder - Decodes JPEG to YUV420P using libyuv
|
||||
// ============================================================================
|
||||
|
||||
/// MJPEG/JPEG decoder that outputs YUV420P using libyuv
|
||||
pub struct MjpegDecoder {
|
||||
/// Resolution hint (can be updated from decoded frame)
|
||||
resolution: Resolution,
|
||||
/// YUV420P output buffer
|
||||
yuv_buffer: Yuv420pBuffer,
|
||||
}
|
||||
|
||||
impl MjpegDecoder {
|
||||
/// Create a new MJPEG decoder with expected resolution
|
||||
pub fn new(resolution: Resolution) -> Result<Self> {
|
||||
Ok(Self {
|
||||
resolution,
|
||||
yuv_buffer: Yuv420pBuffer::new(resolution),
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode MJPEG/JPEG data to YUV420P using libyuv
|
||||
pub fn decode(&mut self, jpeg_data: &[u8]) -> Result<&[u8]> {
|
||||
// Get MJPEG dimensions
|
||||
let (width, height) = libyuv::mjpeg_size(jpeg_data)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to get MJPEG size: {}", e)))?;
|
||||
|
||||
// Check if resolution changed
|
||||
if width != self.resolution.width as i32 || height != self.resolution.height as i32 {
|
||||
tracing::debug!(
|
||||
"MJPEG resolution changed: {}x{} -> {}x{}",
|
||||
self.resolution.width,
|
||||
self.resolution.height,
|
||||
width,
|
||||
height
|
||||
);
|
||||
self.resolution = Resolution::new(width as u32, height as u32);
|
||||
self.yuv_buffer = Yuv420pBuffer::new(self.resolution);
|
||||
}
|
||||
|
||||
// Decode MJPEG directly to I420 using libyuv
|
||||
libyuv::mjpeg_to_i420(jpeg_data, self.yuv_buffer.as_bytes_mut(), width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("MJPEG decode failed: {}", e)))?;
|
||||
|
||||
Ok(self.yuv_buffer.as_bytes())
|
||||
}
|
||||
|
||||
/// Get current resolution
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
self.resolution
|
||||
}
|
||||
|
||||
/// Get YUV420P buffer size
|
||||
pub fn yuv_buffer_size(&self) -> usize {
|
||||
self.yuv_buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// NV12 Converter for VAAPI encoder (using libyuv)
|
||||
// ============================================================================
|
||||
|
||||
/// Pixel format converter that outputs NV12 (for VAAPI encoders)
|
||||
pub struct Nv12Converter {
|
||||
/// Source format
|
||||
src_format: PixelFormat,
|
||||
/// Frame resolution
|
||||
resolution: Resolution,
|
||||
/// Output buffer (reused across conversions)
|
||||
output_buffer: Nv12Buffer,
|
||||
}
|
||||
|
||||
impl Nv12Converter {
|
||||
/// Create a new converter for BGR24 → NV12
|
||||
pub fn bgr24_to_nv12(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Bgr24,
|
||||
resolution,
|
||||
output_buffer: Nv12Buffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for RGB24 → NV12
|
||||
pub fn rgb24_to_nv12(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Rgb24,
|
||||
resolution,
|
||||
output_buffer: Nv12Buffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new converter for YUYV → NV12
|
||||
pub fn yuyv_to_nv12(resolution: Resolution) -> Self {
|
||||
Self {
|
||||
src_format: PixelFormat::Yuyv,
|
||||
resolution,
|
||||
output_buffer: Nv12Buffer::new(resolution),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a frame and return reference to the output buffer
|
||||
pub fn convert(&mut self, input: &[u8]) -> Result<&[u8]> {
|
||||
let width = self.resolution.width as i32;
|
||||
let height = self.resolution.height as i32;
|
||||
let dst = self.output_buffer.as_bytes_mut();
|
||||
|
||||
let result = match self.src_format {
|
||||
PixelFormat::Bgr24 => libyuv::bgr24_to_nv12(input, dst, width, height),
|
||||
PixelFormat::Rgb24 => libyuv::rgb24_to_nv12(input, dst, width, height),
|
||||
PixelFormat::Yuyv => libyuv::yuy2_to_nv12(input, dst, width, height),
|
||||
_ => {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Unsupported conversion to NV12: {}",
|
||||
self.src_format
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
result.map_err(|e| AppError::VideoError(format!("libyuv NV12 conversion failed: {}", e)))?;
|
||||
Ok(self.output_buffer.as_bytes())
|
||||
}
|
||||
|
||||
/// Get output buffer length
|
||||
pub fn output_len(&self) -> usize {
|
||||
self.output_buffer.len()
|
||||
}
|
||||
|
||||
/// Get resolution
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
self.resolution
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Standalone conversion functions (using libyuv)
|
||||
// ============================================================================
|
||||
|
||||
/// Convert BGR24 to NV12 using libyuv
|
||||
pub fn bgr_to_nv12(bgr: &[u8], nv12: &mut [u8], width: usize, height: usize) {
|
||||
if let Err(e) = libyuv::bgr24_to_nv12(bgr, nv12, width as i32, height as i32) {
|
||||
tracing::error!("libyuv BGR24→NV12 conversion failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert RGB24 to NV12 using libyuv
|
||||
pub fn rgb_to_nv12(rgb: &[u8], nv12: &mut [u8], width: usize, height: usize) {
|
||||
if let Err(e) = libyuv::rgb24_to_nv12(rgb, nv12, width as i32, height as i32) {
|
||||
tracing::error!("libyuv RGB24→NV12 conversion failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert YUYV to NV12 using libyuv
|
||||
pub fn yuyv_to_nv12(yuyv: &[u8], nv12: &mut [u8], width: usize, height: usize) {
|
||||
if let Err(e) = libyuv::yuy2_to_nv12(yuyv, nv12, width as i32, height as i32) {
|
||||
tracing::error!("libyuv YUYV→NV12 conversion failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Extended PixelConverter for MJPEG support
|
||||
// ============================================================================
|
||||
|
||||
/// MJPEG to YUV420P converter (wraps MjpegDecoder)
|
||||
pub struct MjpegToYuv420Converter {
|
||||
decoder: MjpegDecoder,
|
||||
}
|
||||
|
||||
impl MjpegToYuv420Converter {
|
||||
/// Create a new MJPEG to YUV420P converter
|
||||
pub fn new(resolution: Resolution) -> Result<Self> {
|
||||
Ok(Self {
|
||||
decoder: MjpegDecoder::new(resolution)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert MJPEG data to YUV420P
|
||||
pub fn convert(&mut self, mjpeg_data: &[u8]) -> Result<&[u8]> {
|
||||
self.decoder.decode(mjpeg_data)
|
||||
}
|
||||
|
||||
/// Get current resolution
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
self.decoder.resolution()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_yuv420p_buffer_creation() {
|
||||
let buffer = Yuv420pBuffer::new(Resolution::HD720);
|
||||
assert_eq!(buffer.len(), 1280 * 720 * 3 / 2);
|
||||
assert_eq!(buffer.y_plane().len(), 1280 * 720);
|
||||
assert_eq!(buffer.u_plane().len(), 1280 * 720 / 4);
|
||||
assert_eq!(buffer.v_plane().len(), 1280 * 720 / 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nv12_buffer_creation() {
|
||||
let buffer = Nv12Buffer::new(Resolution::HD720);
|
||||
assert_eq!(buffer.len(), 1280 * 720 * 3 / 2);
|
||||
assert_eq!(buffer.y_plane().len(), 1280 * 720);
|
||||
assert_eq!(buffer.uv_plane().len(), 1280 * 720 / 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yuyv_to_yuv420p_conversion() {
|
||||
let resolution = Resolution::new(4, 4);
|
||||
let mut converter = PixelConverter::yuyv_to_yuv420p(resolution);
|
||||
|
||||
// Create YUYV data (4x4 = 32 bytes)
|
||||
let yuyv = vec![
|
||||
16, 128, 17, 129, 18, 130, 19, 131,
|
||||
20, 132, 21, 133, 22, 134, 23, 135,
|
||||
24, 136, 25, 137, 26, 138, 27, 139,
|
||||
28, 140, 29, 141, 30, 142, 31, 143,
|
||||
];
|
||||
|
||||
let result = converter.convert(&yuyv).unwrap();
|
||||
assert_eq!(result.len(), 24); // 4*4 + 2*2 + 2*2 = 24 bytes
|
||||
}
|
||||
}
|
||||
645
src/video/decoder/mjpeg.rs
Normal file
645
src/video/decoder/mjpeg.rs
Normal file
@@ -0,0 +1,645 @@
|
||||
//! MJPEG VAAPI hardware decoder
|
||||
//!
|
||||
//! Uses hwcodec's FFmpeg VAAPI backend to decode MJPEG to NV12.
|
||||
//! This provides hardware-accelerated JPEG decoding with direct NV12 output,
|
||||
//! which is the optimal format for VAAPI H264 encoding.
|
||||
|
||||
use std::sync::Once;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use hwcodec::ffmpeg::AVHWDeviceType;
|
||||
use hwcodec::ffmpeg::AVPixelFormat;
|
||||
use hwcodec::ffmpeg_ram::decode::{DecodeContext, DecodeFrame, Decoder};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::Resolution;
|
||||
|
||||
// libyuv for SIMD-accelerated YUV conversion
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize hwcodec logging (only once)
|
||||
fn init_hwcodec_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
debug!("hwcodec MJPEG decoder logging initialized");
|
||||
});
|
||||
}
|
||||
|
||||
/// MJPEG VAAPI decoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MjpegVaapiDecoderConfig {
|
||||
/// Expected resolution (can be updated from decoded frame)
|
||||
pub resolution: Resolution,
|
||||
/// Use hardware acceleration (VAAPI)
|
||||
pub use_hwaccel: bool,
|
||||
}
|
||||
|
||||
impl Default for MjpegVaapiDecoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
resolution: Resolution::HD1080,
|
||||
use_hwaccel: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decoded frame data in NV12 format
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecodedNv12Frame {
|
||||
/// Y plane data
|
||||
pub y_plane: Vec<u8>,
|
||||
/// UV interleaved plane data
|
||||
pub uv_plane: Vec<u8>,
|
||||
/// Y plane linesize (stride)
|
||||
pub y_linesize: i32,
|
||||
/// UV plane linesize (stride)
|
||||
pub uv_linesize: i32,
|
||||
/// Frame width
|
||||
pub width: i32,
|
||||
/// Frame height
|
||||
pub height: i32,
|
||||
}
|
||||
|
||||
/// Decoded frame data in YUV420P (I420) format
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecodedYuv420pFrame {
|
||||
/// Y plane data
|
||||
pub y_plane: Vec<u8>,
|
||||
/// U plane data
|
||||
pub u_plane: Vec<u8>,
|
||||
/// V plane data
|
||||
pub v_plane: Vec<u8>,
|
||||
/// Y plane linesize (stride)
|
||||
pub y_linesize: i32,
|
||||
/// U plane linesize (stride)
|
||||
pub u_linesize: i32,
|
||||
/// V plane linesize (stride)
|
||||
pub v_linesize: i32,
|
||||
/// Frame width
|
||||
pub width: i32,
|
||||
/// Frame height
|
||||
pub height: i32,
|
||||
}
|
||||
|
||||
impl DecodedYuv420pFrame {
|
||||
/// Get packed YUV420P data (Y plane followed by U and V planes, with stride removed)
|
||||
pub fn to_packed_yuv420p(&self) -> Vec<u8> {
|
||||
let width = self.width as usize;
|
||||
let height = self.height as usize;
|
||||
let y_size = width * height;
|
||||
let uv_size = width * height / 4;
|
||||
|
||||
let mut packed = Vec::with_capacity(y_size + uv_size * 2);
|
||||
|
||||
// Copy Y plane, removing stride padding if any
|
||||
if self.y_linesize as usize == width {
|
||||
packed.extend_from_slice(&self.y_plane[..y_size]);
|
||||
} else {
|
||||
for row in 0..height {
|
||||
let src_offset = row * self.y_linesize as usize;
|
||||
packed.extend_from_slice(&self.y_plane[src_offset..src_offset + width]);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy U plane
|
||||
let uv_width = width / 2;
|
||||
let uv_height = height / 2;
|
||||
if self.u_linesize as usize == uv_width {
|
||||
packed.extend_from_slice(&self.u_plane[..uv_size]);
|
||||
} else {
|
||||
for row in 0..uv_height {
|
||||
let src_offset = row * self.u_linesize as usize;
|
||||
packed.extend_from_slice(&self.u_plane[src_offset..src_offset + uv_width]);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy V plane
|
||||
if self.v_linesize as usize == uv_width {
|
||||
packed.extend_from_slice(&self.v_plane[..uv_size]);
|
||||
} else {
|
||||
for row in 0..uv_height {
|
||||
let src_offset = row * self.v_linesize as usize;
|
||||
packed.extend_from_slice(&self.v_plane[src_offset..src_offset + uv_width]);
|
||||
}
|
||||
}
|
||||
|
||||
packed
|
||||
}
|
||||
|
||||
/// Copy packed YUV420P data to external buffer (zero allocation)
|
||||
/// Returns the number of bytes written, or None if buffer too small
|
||||
pub fn copy_to_packed_yuv420p(&self, dst: &mut [u8]) -> Option<usize> {
|
||||
let width = self.width as usize;
|
||||
let height = self.height as usize;
|
||||
let y_size = width * height;
|
||||
let uv_size = width * height / 4;
|
||||
let total_size = y_size + uv_size * 2;
|
||||
|
||||
if dst.len() < total_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Copy Y plane
|
||||
if self.y_linesize as usize == width {
|
||||
dst[..y_size].copy_from_slice(&self.y_plane[..y_size]);
|
||||
} else {
|
||||
for row in 0..height {
|
||||
let src_offset = row * self.y_linesize as usize;
|
||||
let dst_offset = row * width;
|
||||
dst[dst_offset..dst_offset + width]
|
||||
.copy_from_slice(&self.y_plane[src_offset..src_offset + width]);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy U plane
|
||||
let uv_width = width / 2;
|
||||
let uv_height = height / 2;
|
||||
if self.u_linesize as usize == uv_width {
|
||||
dst[y_size..y_size + uv_size].copy_from_slice(&self.u_plane[..uv_size]);
|
||||
} else {
|
||||
for row in 0..uv_height {
|
||||
let src_offset = row * self.u_linesize as usize;
|
||||
let dst_offset = y_size + row * uv_width;
|
||||
dst[dst_offset..dst_offset + uv_width]
|
||||
.copy_from_slice(&self.u_plane[src_offset..src_offset + uv_width]);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy V plane
|
||||
let v_offset = y_size + uv_size;
|
||||
if self.v_linesize as usize == uv_width {
|
||||
dst[v_offset..v_offset + uv_size].copy_from_slice(&self.v_plane[..uv_size]);
|
||||
} else {
|
||||
for row in 0..uv_height {
|
||||
let src_offset = row * self.v_linesize as usize;
|
||||
let dst_offset = v_offset + row * uv_width;
|
||||
dst[dst_offset..dst_offset + uv_width]
|
||||
.copy_from_slice(&self.v_plane[src_offset..src_offset + uv_width]);
|
||||
}
|
||||
}
|
||||
|
||||
Some(total_size)
|
||||
}
|
||||
}
|
||||
|
||||
impl DecodedNv12Frame {
|
||||
/// Get packed NV12 data (Y plane followed by UV plane, with stride removed)
|
||||
pub fn to_packed_nv12(&self) -> Vec<u8> {
|
||||
let width = self.width as usize;
|
||||
let height = self.height as usize;
|
||||
let y_size = width * height;
|
||||
let uv_size = width * height / 2;
|
||||
|
||||
let mut packed = Vec::with_capacity(y_size + uv_size);
|
||||
|
||||
// Copy Y plane, removing stride padding if any
|
||||
if self.y_linesize as usize == width {
|
||||
// No padding, direct copy
|
||||
packed.extend_from_slice(&self.y_plane[..y_size]);
|
||||
} else {
|
||||
// Has padding, copy row by row
|
||||
for row in 0..height {
|
||||
let src_offset = row * self.y_linesize as usize;
|
||||
packed.extend_from_slice(&self.y_plane[src_offset..src_offset + width]);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy UV plane, removing stride padding if any
|
||||
let uv_height = height / 2;
|
||||
if self.uv_linesize as usize == width {
|
||||
// No padding, direct copy
|
||||
packed.extend_from_slice(&self.uv_plane[..uv_size]);
|
||||
} else {
|
||||
// Has padding, copy row by row
|
||||
for row in 0..uv_height {
|
||||
let src_offset = row * self.uv_linesize as usize;
|
||||
packed.extend_from_slice(&self.uv_plane[src_offset..src_offset + width]);
|
||||
}
|
||||
}
|
||||
|
||||
packed
|
||||
}
|
||||
|
||||
/// Copy packed NV12 data to external buffer (zero allocation)
|
||||
/// Returns the number of bytes written, or None if buffer too small
|
||||
pub fn copy_to_packed_nv12(&self, dst: &mut [u8]) -> Option<usize> {
|
||||
let width = self.width as usize;
|
||||
let height = self.height as usize;
|
||||
let y_size = width * height;
|
||||
let uv_size = width * height / 2;
|
||||
let total_size = y_size + uv_size;
|
||||
|
||||
if dst.len() < total_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Copy Y plane, removing stride padding if any
|
||||
if self.y_linesize as usize == width {
|
||||
// No padding, direct copy
|
||||
dst[..y_size].copy_from_slice(&self.y_plane[..y_size]);
|
||||
} else {
|
||||
// Has padding, copy row by row
|
||||
for row in 0..height {
|
||||
let src_offset = row * self.y_linesize as usize;
|
||||
let dst_offset = row * width;
|
||||
dst[dst_offset..dst_offset + width]
|
||||
.copy_from_slice(&self.y_plane[src_offset..src_offset + width]);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy UV plane, removing stride padding if any
|
||||
let uv_height = height / 2;
|
||||
if self.uv_linesize as usize == width {
|
||||
// No padding, direct copy
|
||||
dst[y_size..total_size].copy_from_slice(&self.uv_plane[..uv_size]);
|
||||
} else {
|
||||
// Has padding, copy row by row
|
||||
for row in 0..uv_height {
|
||||
let src_offset = row * self.uv_linesize as usize;
|
||||
let dst_offset = y_size + row * width;
|
||||
dst[dst_offset..dst_offset + width]
|
||||
.copy_from_slice(&self.uv_plane[src_offset..src_offset + width]);
|
||||
}
|
||||
}
|
||||
|
||||
Some(total_size)
|
||||
}
|
||||
}
|
||||
|
||||
/// MJPEG VAAPI hardware decoder
|
||||
///
|
||||
/// Decodes MJPEG frames to NV12 format using VAAPI hardware acceleration.
|
||||
/// This is optimal for feeding into VAAPI H264 encoder.
|
||||
pub struct MjpegVaapiDecoder {
|
||||
/// hwcodec decoder instance
|
||||
decoder: Decoder,
|
||||
/// Configuration
|
||||
config: MjpegVaapiDecoderConfig,
|
||||
/// Frame counter
|
||||
frame_count: u64,
|
||||
/// Whether hardware acceleration is active
|
||||
hwaccel_active: bool,
|
||||
}
|
||||
|
||||
impl MjpegVaapiDecoder {
|
||||
/// Create a new MJPEG decoder
|
||||
/// Note: VAAPI does not support MJPEG decoding on most hardware,
|
||||
/// so we use software decoding and convert to NV12 for VAAPI encoding.
|
||||
pub fn new(config: MjpegVaapiDecoderConfig) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
// VAAPI doesn't support MJPEG decoding, always use software decoder
|
||||
// The output will be converted to NV12 for VAAPI H264 encoding
|
||||
let device_type = AVHWDeviceType::AV_HWDEVICE_TYPE_NONE;
|
||||
|
||||
info!(
|
||||
"Creating MJPEG decoder with software decoding (VAAPI doesn't support MJPEG decode)"
|
||||
);
|
||||
|
||||
let ctx = DecodeContext {
|
||||
name: "mjpeg".to_string(),
|
||||
device_type,
|
||||
thread_count: 4, // Use multiple threads for software decoding
|
||||
};
|
||||
|
||||
let decoder = Decoder::new(ctx).map_err(|_| {
|
||||
AppError::VideoError("Failed to create MJPEG software decoder".to_string())
|
||||
})?;
|
||||
|
||||
// hwaccel is not actually active for MJPEG decoding
|
||||
let hwaccel_active = false;
|
||||
|
||||
info!(
|
||||
"MJPEG decoder created successfully (software decode, will convert to NV12)"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
decoder,
|
||||
config,
|
||||
frame_count: 0,
|
||||
hwaccel_active,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default config (VAAPI enabled)
|
||||
pub fn with_vaapi(resolution: Resolution) -> Result<Self> {
|
||||
Self::new(MjpegVaapiDecoderConfig {
|
||||
resolution,
|
||||
use_hwaccel: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with software decoding (fallback)
|
||||
pub fn with_software(resolution: Resolution) -> Result<Self> {
|
||||
Self::new(MjpegVaapiDecoderConfig {
|
||||
resolution,
|
||||
use_hwaccel: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if hardware acceleration is active
|
||||
pub fn is_hwaccel_active(&self) -> bool {
|
||||
self.hwaccel_active
|
||||
}
|
||||
|
||||
/// Decode MJPEG frame to NV12
|
||||
///
|
||||
/// Returns the decoded frame in NV12 format, or an error if decoding fails.
|
||||
pub fn decode(&mut self, jpeg_data: &[u8]) -> Result<DecodedNv12Frame> {
|
||||
if jpeg_data.len() < 2 {
|
||||
return Err(AppError::VideoError("JPEG data too small".to_string()));
|
||||
}
|
||||
|
||||
// Verify JPEG signature (FFD8)
|
||||
if jpeg_data[0] != 0xFF || jpeg_data[1] != 0xD8 {
|
||||
return Err(AppError::VideoError("Invalid JPEG signature".to_string()));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
let frames = self.decoder.decode(jpeg_data).map_err(|e| {
|
||||
AppError::VideoError(format!("MJPEG decode failed: error code {}", e))
|
||||
})?;
|
||||
|
||||
if frames.is_empty() {
|
||||
return Err(AppError::VideoError("Decoder returned no frames".to_string()));
|
||||
}
|
||||
|
||||
let frame = &frames[0];
|
||||
|
||||
// Handle different output formats
|
||||
// VAAPI MJPEG decoder may output NV12, YUV420P, or YUVJ420P (JPEG full-range)
|
||||
if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_NV12
|
||||
|| frame.pixfmt == AVPixelFormat::AV_PIX_FMT_NV21
|
||||
{
|
||||
// NV12/NV21 format: Y plane + UV interleaved plane
|
||||
if frame.data.len() < 2 {
|
||||
return Err(AppError::VideoError("Invalid NV12 frame data".to_string()));
|
||||
}
|
||||
|
||||
return Ok(DecodedNv12Frame {
|
||||
y_plane: frame.data[0].clone(),
|
||||
uv_plane: frame.data[1].clone(),
|
||||
y_linesize: frame.linesize[0],
|
||||
uv_linesize: frame.linesize[1],
|
||||
width: frame.width,
|
||||
height: frame.height,
|
||||
});
|
||||
}
|
||||
|
||||
// YUV420P or YUVJ420P (JPEG full-range) - need to convert to NV12
|
||||
if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUV420P
|
||||
|| frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUVJ420P
|
||||
{
|
||||
return Self::convert_yuv420p_to_nv12_static(frame);
|
||||
}
|
||||
|
||||
// YUV422P or YUVJ422P (JPEG full-range 4:2:2) - need to convert to NV12
|
||||
if frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUV422P
|
||||
|| frame.pixfmt == AVPixelFormat::AV_PIX_FMT_YUVJ422P
|
||||
{
|
||||
return Self::convert_yuv422p_to_nv12_static(frame);
|
||||
}
|
||||
|
||||
Err(AppError::VideoError(format!(
|
||||
"Unexpected output format: {:?} (expected NV12, YUV420P, YUV422P, or YUVJ variants)",
|
||||
frame.pixfmt
|
||||
)))
|
||||
}
|
||||
|
||||
/// Convert YUV420P frame to NV12 format using libyuv (SIMD accelerated)
|
||||
fn convert_yuv420p_to_nv12_static(frame: &DecodeFrame) -> Result<DecodedNv12Frame> {
|
||||
if frame.data.len() < 3 {
|
||||
return Err(AppError::VideoError("Invalid YUV420P frame data".to_string()));
|
||||
}
|
||||
|
||||
let width = frame.width as i32;
|
||||
let height = frame.height as i32;
|
||||
let y_linesize = frame.linesize[0];
|
||||
let u_linesize = frame.linesize[1];
|
||||
let v_linesize = frame.linesize[2];
|
||||
|
||||
// Allocate packed NV12 output buffer
|
||||
let nv12_size = (width * height * 3 / 2) as usize;
|
||||
let mut nv12_data = vec![0u8; nv12_size];
|
||||
|
||||
// Use libyuv for SIMD-accelerated I420 → NV12 conversion
|
||||
libyuv::i420_to_nv12_planar(
|
||||
&frame.data[0], y_linesize,
|
||||
&frame.data[1], u_linesize,
|
||||
&frame.data[2], v_linesize,
|
||||
&mut nv12_data,
|
||||
width, height,
|
||||
).map_err(|e| AppError::VideoError(format!("libyuv I420→NV12 failed: {}", e)))?;
|
||||
|
||||
// Split into Y and UV planes for DecodedNv12Frame
|
||||
let y_size = (width * height) as usize;
|
||||
let y_plane = nv12_data[..y_size].to_vec();
|
||||
let uv_plane = nv12_data[y_size..].to_vec();
|
||||
|
||||
Ok(DecodedNv12Frame {
|
||||
y_plane,
|
||||
uv_plane,
|
||||
y_linesize: width, // Output is packed, no padding
|
||||
uv_linesize: width,
|
||||
width: frame.width,
|
||||
height: frame.height,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert YUV422P frame to NV12 format using libyuv (SIMD accelerated)
|
||||
/// Pipeline: I422 (YUV422P) → I420 → NV12
|
||||
fn convert_yuv422p_to_nv12_static(frame: &DecodeFrame) -> Result<DecodedNv12Frame> {
|
||||
if frame.data.len() < 3 {
|
||||
return Err(AppError::VideoError("Invalid YUV422P frame data".to_string()));
|
||||
}
|
||||
|
||||
let width = frame.width as i32;
|
||||
let height = frame.height as i32;
|
||||
let y_linesize = frame.linesize[0];
|
||||
let u_linesize = frame.linesize[1];
|
||||
let v_linesize = frame.linesize[2];
|
||||
|
||||
// Step 1: I422 → I420 (vertical chroma downsampling via SIMD)
|
||||
let i420_size = (width * height * 3 / 2) as usize;
|
||||
let mut i420_data = vec![0u8; i420_size];
|
||||
|
||||
libyuv::i422_to_i420_planar(
|
||||
&frame.data[0], y_linesize,
|
||||
&frame.data[1], u_linesize,
|
||||
&frame.data[2], v_linesize,
|
||||
&mut i420_data,
|
||||
width, height,
|
||||
).map_err(|e| AppError::VideoError(format!("libyuv I422→I420 failed: {}", e)))?;
|
||||
|
||||
// Step 2: I420 → NV12 (UV interleaving via SIMD)
|
||||
let nv12_size = (width * height * 3 / 2) as usize;
|
||||
let mut nv12_data = vec![0u8; nv12_size];
|
||||
|
||||
libyuv::i420_to_nv12(&i420_data, &mut nv12_data, width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv I420→NV12 failed: {}", e)))?;
|
||||
|
||||
// Split into Y and UV planes for DecodedNv12Frame
|
||||
let y_size = (width * height) as usize;
|
||||
let y_plane = nv12_data[..y_size].to_vec();
|
||||
let uv_plane = nv12_data[y_size..].to_vec();
|
||||
|
||||
Ok(DecodedNv12Frame {
|
||||
y_plane,
|
||||
uv_plane,
|
||||
y_linesize: width,
|
||||
uv_linesize: width,
|
||||
width: frame.width,
|
||||
height: frame.height,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get frame count
|
||||
pub fn frame_count(&self) -> u64 {
|
||||
self.frame_count
|
||||
}
|
||||
|
||||
/// Get current resolution from config
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
self.config.resolution
|
||||
}
|
||||
}
|
||||
|
||||
/// Libyuv-based MJPEG decoder for direct YUV420P output
|
||||
///
|
||||
/// This decoder is optimized for software encoders (libvpx, libx265) that need YUV420P input.
|
||||
/// It uses libyuv's MJPGToI420 to decode directly to I420/YUV420P format.
|
||||
pub struct MjpegTurboDecoder {
|
||||
/// Frame counter
|
||||
frame_count: u64,
|
||||
}
|
||||
|
||||
impl MjpegTurboDecoder {
|
||||
/// Create a new libyuv-based MJPEG decoder
|
||||
pub fn new(resolution: Resolution) -> Result<Self> {
|
||||
info!(
|
||||
"Created libyuv MJPEG decoder for {}x{} (direct YUV420P output)",
|
||||
resolution.width, resolution.height
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
frame_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode MJPEG frame directly to YUV420P using libyuv
|
||||
///
|
||||
/// This is the optimal path for software encoders that need YUV420P input.
|
||||
/// libyuv handles all JPEG subsampling formats internally.
|
||||
pub fn decode_to_yuv420p(&mut self, jpeg_data: &[u8]) -> Result<DecodedYuv420pFrame> {
|
||||
if jpeg_data.len() < 2 || jpeg_data[0] != 0xFF || jpeg_data[1] != 0xD8 {
|
||||
return Err(AppError::VideoError("Invalid JPEG data".to_string()));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
// Get JPEG dimensions
|
||||
let (width, height) = libyuv::mjpeg_size(jpeg_data)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to read MJPEG size: {}", e)))?;
|
||||
|
||||
let y_size = (width * height) as usize;
|
||||
let uv_size = y_size / 4;
|
||||
let yuv420_size = y_size + uv_size * 2;
|
||||
|
||||
let mut yuv_data = vec![0u8; yuv420_size];
|
||||
|
||||
libyuv::mjpeg_to_i420(jpeg_data, &mut yuv_data, width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv MJPEG→I420 failed: {}", e)))?;
|
||||
|
||||
Ok(DecodedYuv420pFrame {
|
||||
y_plane: yuv_data[..y_size].to_vec(),
|
||||
u_plane: yuv_data[y_size..y_size + uv_size].to_vec(),
|
||||
v_plane: yuv_data[y_size + uv_size..].to_vec(),
|
||||
y_linesize: width,
|
||||
u_linesize: width / 2,
|
||||
v_linesize: width / 2,
|
||||
width,
|
||||
height,
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode directly to packed YUV420P buffer using libyuv
|
||||
///
|
||||
/// This uses libyuv's MJPGToI420 which handles all JPEG subsampling formats
|
||||
/// and converts to I420 directly.
|
||||
pub fn decode_to_yuv420p_buffer(&mut self, jpeg_data: &[u8], dst: &mut [u8]) -> Result<usize> {
|
||||
if jpeg_data.len() < 2 || jpeg_data[0] != 0xFF || jpeg_data[1] != 0xD8 {
|
||||
return Err(AppError::VideoError("Invalid JPEG data".to_string()));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
// Get JPEG dimensions from libyuv
|
||||
let (width, height) = libyuv::mjpeg_size(jpeg_data)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to read MJPEG size: {}", e)))?;
|
||||
|
||||
let yuv420_size = (width * height * 3 / 2) as usize;
|
||||
|
||||
if dst.len() < yuv420_size {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Buffer too small: {} < {}", dst.len(), yuv420_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Decode MJPEG directly to I420 using libyuv
|
||||
// libyuv handles all JPEG subsampling formats (4:2:0, 4:2:2, 4:4:4) internally
|
||||
libyuv::mjpeg_to_i420(jpeg_data, &mut dst[..yuv420_size], width, height)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv MJPEG→I420 failed: {}", e)))?;
|
||||
|
||||
Ok(yuv420_size)
|
||||
}
|
||||
|
||||
/// Get frame count
|
||||
pub fn frame_count(&self) -> u64 {
|
||||
self.frame_count
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if MJPEG VAAPI decoder is available
|
||||
pub fn is_mjpeg_vaapi_available() -> bool {
|
||||
let ctx = DecodeContext {
|
||||
name: "mjpeg".to_string(),
|
||||
device_type: AVHWDeviceType::AV_HWDEVICE_TYPE_VAAPI,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
match Decoder::new(ctx) {
|
||||
Ok(_) => {
|
||||
info!("MJPEG VAAPI decoder is available");
|
||||
true
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("MJPEG VAAPI decoder is not available");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mjpeg_vaapi_availability() {
|
||||
let available = is_mjpeg_vaapi_available();
|
||||
println!("MJPEG VAAPI available: {}", available);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decoder_creation() {
|
||||
let config = MjpegVaapiDecoderConfig::default();
|
||||
match MjpegVaapiDecoder::new(config) {
|
||||
Ok(decoder) => {
|
||||
println!("Decoder created, hwaccel: {}", decoder.is_hwaccel_active());
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Failed to create decoder: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
11
src/video/decoder/mod.rs
Normal file
11
src/video/decoder/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! Video decoder implementations
|
||||
//!
|
||||
//! This module provides video decoding capabilities including:
|
||||
//! - MJPEG VAAPI hardware decoding (outputs NV12)
|
||||
//! - MJPEG turbojpeg decoding (outputs YUV420P directly)
|
||||
|
||||
pub mod mjpeg;
|
||||
|
||||
pub use mjpeg::{
|
||||
DecodedYuv420pFrame, MjpegTurboDecoder, MjpegVaapiDecoder, MjpegVaapiDecoderConfig,
|
||||
};
|
||||
459
src/video/device.rs
Normal file
459
src/video/device.rs
Normal file
@@ -0,0 +1,459 @@
|
||||
//! V4L2 device enumeration and capability query
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::{debug, info, warn};
|
||||
use v4l::capability::Flags;
|
||||
use v4l::prelude::*;
|
||||
use v4l::video::Capture;
|
||||
use v4l::Format;
|
||||
use v4l::FourCC;
|
||||
|
||||
use super::format::{PixelFormat, Resolution};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Information about a video device
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VideoDeviceInfo {
|
||||
/// Device path (e.g., /dev/video0)
|
||||
pub path: PathBuf,
|
||||
/// Device name from driver
|
||||
pub name: String,
|
||||
/// Driver name
|
||||
pub driver: String,
|
||||
/// Bus info
|
||||
pub bus_info: String,
|
||||
/// Card name
|
||||
pub card: String,
|
||||
/// Supported pixel formats
|
||||
pub formats: Vec<FormatInfo>,
|
||||
/// Device capabilities
|
||||
pub capabilities: DeviceCapabilities,
|
||||
/// Whether this is likely an HDMI capture card
|
||||
pub is_capture_card: bool,
|
||||
/// Priority score for device selection (higher is better)
|
||||
pub priority: u32,
|
||||
}
|
||||
|
||||
/// Information about a supported format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FormatInfo {
|
||||
/// Pixel format
|
||||
pub format: PixelFormat,
|
||||
/// Supported resolutions
|
||||
pub resolutions: Vec<ResolutionInfo>,
|
||||
/// Description from driver
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// Information about a supported resolution and frame rates
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResolutionInfo {
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub fps: Vec<u32>,
|
||||
}
|
||||
|
||||
impl ResolutionInfo {
|
||||
pub fn new(width: u32, height: u32, fps: Vec<u32>) -> Self {
|
||||
Self { width, height, fps }
|
||||
}
|
||||
|
||||
pub fn resolution(&self) -> Resolution {
|
||||
Resolution::new(self.width, self.height)
|
||||
}
|
||||
}
|
||||
|
||||
/// Device capabilities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct DeviceCapabilities {
|
||||
pub video_capture: bool,
|
||||
pub video_capture_mplane: bool,
|
||||
pub video_output: bool,
|
||||
pub streaming: bool,
|
||||
pub read_write: bool,
|
||||
}
|
||||
|
||||
/// Wrapper around a V4L2 video device
|
||||
pub struct VideoDevice {
|
||||
pub path: PathBuf,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl VideoDevice {
|
||||
/// Open a video device by path
|
||||
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref().to_path_buf();
|
||||
debug!("Opening video device: {:?}", path);
|
||||
|
||||
let device = Device::with_path(&path).map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to open device {:?}: {}", path, e))
|
||||
})?;
|
||||
|
||||
Ok(Self { path, device })
|
||||
}
|
||||
|
||||
/// Get device capabilities
|
||||
pub fn capabilities(&self) -> Result<DeviceCapabilities> {
|
||||
let caps = self.device.query_caps().map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to query capabilities: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(DeviceCapabilities {
|
||||
video_capture: caps.capabilities.contains(Flags::VIDEO_CAPTURE),
|
||||
video_capture_mplane: caps.capabilities.contains(Flags::VIDEO_CAPTURE_MPLANE),
|
||||
video_output: caps.capabilities.contains(Flags::VIDEO_OUTPUT),
|
||||
streaming: caps.capabilities.contains(Flags::STREAMING),
|
||||
read_write: caps.capabilities.contains(Flags::READ_WRITE),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get detailed device information
|
||||
pub fn info(&self) -> Result<VideoDeviceInfo> {
|
||||
let caps = self.device.query_caps().map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to query capabilities: {}", e))
|
||||
})?;
|
||||
|
||||
let capabilities = DeviceCapabilities {
|
||||
video_capture: caps.capabilities.contains(Flags::VIDEO_CAPTURE),
|
||||
video_capture_mplane: caps.capabilities.contains(Flags::VIDEO_CAPTURE_MPLANE),
|
||||
video_output: caps.capabilities.contains(Flags::VIDEO_OUTPUT),
|
||||
streaming: caps.capabilities.contains(Flags::STREAMING),
|
||||
read_write: caps.capabilities.contains(Flags::READ_WRITE),
|
||||
};
|
||||
|
||||
let formats = self.enumerate_formats()?;
|
||||
|
||||
// Determine if this is likely an HDMI capture card
|
||||
let is_capture_card = Self::detect_capture_card(&caps.card, &caps.driver, &formats);
|
||||
|
||||
// Calculate priority score
|
||||
let priority = Self::calculate_priority(&caps.card, &caps.driver, &formats, is_capture_card);
|
||||
|
||||
Ok(VideoDeviceInfo {
|
||||
path: self.path.clone(),
|
||||
name: caps.card.clone(),
|
||||
driver: caps.driver.clone(),
|
||||
bus_info: caps.bus.clone(),
|
||||
card: caps.card,
|
||||
formats,
|
||||
capabilities,
|
||||
is_capture_card,
|
||||
priority,
|
||||
})
|
||||
}
|
||||
|
||||
/// Enumerate supported formats
|
||||
pub fn enumerate_formats(&self) -> Result<Vec<FormatInfo>> {
|
||||
let mut formats = Vec::new();
|
||||
|
||||
// Get supported formats
|
||||
let format_descs = self.device.enum_formats().map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to enumerate formats: {}", e))
|
||||
})?;
|
||||
|
||||
for desc in format_descs {
|
||||
// Try to convert FourCC to our PixelFormat
|
||||
if let Some(format) = PixelFormat::from_fourcc(desc.fourcc) {
|
||||
let resolutions = self.enumerate_resolutions(desc.fourcc)?;
|
||||
|
||||
formats.push(FormatInfo {
|
||||
format,
|
||||
resolutions,
|
||||
description: desc.description.clone(),
|
||||
});
|
||||
} else {
|
||||
debug!(
|
||||
"Skipping unsupported format: {:?} ({})",
|
||||
desc.fourcc, desc.description
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by format priority (MJPEG first)
|
||||
formats.sort_by(|a, b| b.format.priority().cmp(&a.format.priority()));
|
||||
|
||||
Ok(formats)
|
||||
}
|
||||
|
||||
/// Enumerate resolutions for a specific format
|
||||
fn enumerate_resolutions(&self, fourcc: FourCC) -> Result<Vec<ResolutionInfo>> {
|
||||
let mut resolutions = Vec::new();
|
||||
|
||||
// Try to enumerate frame sizes
|
||||
match self.device.enum_framesizes(fourcc) {
|
||||
Ok(sizes) => {
|
||||
for size in sizes {
|
||||
match size.size {
|
||||
v4l::framesize::FrameSizeEnum::Discrete(d) => {
|
||||
let fps = self.enumerate_fps(fourcc, d.width, d.height).unwrap_or_default();
|
||||
resolutions.push(ResolutionInfo::new(d.width, d.height, fps));
|
||||
}
|
||||
v4l::framesize::FrameSizeEnum::Stepwise(s) => {
|
||||
// For stepwise, add some common resolutions
|
||||
for res in [
|
||||
Resolution::VGA,
|
||||
Resolution::HD720,
|
||||
Resolution::HD1080,
|
||||
Resolution::UHD4K,
|
||||
] {
|
||||
if res.width >= s.min_width
|
||||
&& res.width <= s.max_width
|
||||
&& res.height >= s.min_height
|
||||
&& res.height <= s.max_height
|
||||
{
|
||||
let fps = self.enumerate_fps(fourcc, res.width, res.height).unwrap_or_default();
|
||||
resolutions.push(ResolutionInfo::new(res.width, res.height, fps));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to enumerate frame sizes for {:?}: {}", fourcc, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by resolution (largest first)
|
||||
resolutions.sort_by(|a, b| (b.width * b.height).cmp(&(a.width * a.height)));
|
||||
resolutions.dedup_by(|a, b| a.width == b.width && a.height == b.height);
|
||||
|
||||
Ok(resolutions)
|
||||
}
|
||||
|
||||
/// Enumerate FPS for a specific resolution
|
||||
fn enumerate_fps(&self, fourcc: FourCC, width: u32, height: u32) -> Result<Vec<u32>> {
|
||||
let mut fps_list = Vec::new();
|
||||
|
||||
match self.device.enum_frameintervals(fourcc, width, height) {
|
||||
Ok(intervals) => {
|
||||
for interval in intervals {
|
||||
match interval.interval {
|
||||
v4l::frameinterval::FrameIntervalEnum::Discrete(fraction) => {
|
||||
if fraction.numerator > 0 {
|
||||
let fps = fraction.denominator / fraction.numerator;
|
||||
fps_list.push(fps);
|
||||
}
|
||||
}
|
||||
v4l::frameinterval::FrameIntervalEnum::Stepwise(step) => {
|
||||
// Just pick max/min/step
|
||||
if step.max.numerator > 0 {
|
||||
let min_fps = step.max.denominator / step.max.numerator;
|
||||
let max_fps = step.min.denominator / step.min.numerator;
|
||||
fps_list.push(min_fps);
|
||||
if max_fps != min_fps {
|
||||
fps_list.push(max_fps);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// If enumeration fails, assume 30fps
|
||||
fps_list.push(30);
|
||||
}
|
||||
}
|
||||
|
||||
fps_list.sort_by(|a, b| b.cmp(a));
|
||||
fps_list.dedup();
|
||||
Ok(fps_list)
|
||||
}
|
||||
|
||||
/// Get current format
|
||||
pub fn get_format(&self) -> Result<Format> {
|
||||
self.device.format().map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to get format: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Set capture format
|
||||
pub fn set_format(&self, width: u32, height: u32, format: PixelFormat) -> Result<Format> {
|
||||
let fmt = Format::new(width, height, format.to_fourcc());
|
||||
|
||||
// Request the format
|
||||
let actual = self.device.set_format(&fmt).map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to set format: {}", e))
|
||||
})?;
|
||||
|
||||
if actual.width != width || actual.height != height {
|
||||
warn!(
|
||||
"Requested {}x{}, got {}x{}",
|
||||
width, height, actual.width, actual.height
|
||||
);
|
||||
}
|
||||
|
||||
Ok(actual)
|
||||
}
|
||||
|
||||
/// Detect if device is likely an HDMI capture card
|
||||
fn detect_capture_card(card: &str, driver: &str, formats: &[FormatInfo]) -> bool {
|
||||
let card_lower = card.to_lowercase();
|
||||
let driver_lower = driver.to_lowercase();
|
||||
|
||||
// Known capture card patterns
|
||||
let capture_patterns = [
|
||||
"hdmi",
|
||||
"capture",
|
||||
"grabber",
|
||||
"usb3",
|
||||
"ms2109",
|
||||
"ms2130",
|
||||
"macrosilicon",
|
||||
"tc358743",
|
||||
"uvc",
|
||||
];
|
||||
|
||||
// Check card/driver names
|
||||
for pattern in capture_patterns {
|
||||
if card_lower.contains(pattern) || driver_lower.contains(pattern) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Capture cards usually support MJPEG and high resolutions
|
||||
let has_mjpeg = formats.iter().any(|f| f.format == PixelFormat::Mjpeg);
|
||||
let has_1080p = formats.iter().any(|f| {
|
||||
f.resolutions
|
||||
.iter()
|
||||
.any(|r| r.width >= 1920 && r.height >= 1080)
|
||||
});
|
||||
|
||||
has_mjpeg && has_1080p
|
||||
}
|
||||
|
||||
/// Calculate device priority for selection
|
||||
fn calculate_priority(
|
||||
_card: &str,
|
||||
driver: &str,
|
||||
formats: &[FormatInfo],
|
||||
is_capture_card: bool,
|
||||
) -> u32 {
|
||||
let mut priority = 0u32;
|
||||
|
||||
// Capture cards get highest priority
|
||||
if is_capture_card {
|
||||
priority += 1000;
|
||||
}
|
||||
|
||||
// MJPEG support is valuable
|
||||
if formats.iter().any(|f| f.format == PixelFormat::Mjpeg) {
|
||||
priority += 100;
|
||||
}
|
||||
|
||||
// High resolution support
|
||||
let max_resolution = formats
|
||||
.iter()
|
||||
.flat_map(|f| &f.resolutions)
|
||||
.map(|r| r.width * r.height)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
priority += (max_resolution / 100000) as u32;
|
||||
|
||||
// Known good drivers get bonus
|
||||
let good_drivers = ["uvcvideo", "tc358743"];
|
||||
if good_drivers.iter().any(|d| driver.contains(d)) {
|
||||
priority += 50;
|
||||
}
|
||||
|
||||
priority
|
||||
}
|
||||
|
||||
/// Get the inner device reference (for advanced usage)
|
||||
pub fn inner(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerate all video capture devices
|
||||
pub fn enumerate_devices() -> Result<Vec<VideoDeviceInfo>> {
|
||||
info!("Enumerating video devices...");
|
||||
|
||||
let mut devices = Vec::new();
|
||||
|
||||
// Scan /dev/video* devices
|
||||
for entry in std::fs::read_dir("/dev").map_err(|e| {
|
||||
AppError::VideoError(format!("Failed to read /dev: {}", e))
|
||||
})? {
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let path = entry.path();
|
||||
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
|
||||
|
||||
if !name.starts_with("video") {
|
||||
continue;
|
||||
}
|
||||
|
||||
debug!("Found video device: {:?}", path);
|
||||
|
||||
// Try to open and query the device
|
||||
match VideoDevice::open(&path) {
|
||||
Ok(device) => {
|
||||
match device.info() {
|
||||
Ok(info) => {
|
||||
// Only include devices with video capture capability
|
||||
if info.capabilities.video_capture || info.capabilities.video_capture_mplane
|
||||
{
|
||||
info!(
|
||||
"Found capture device: {} ({}) - {} formats",
|
||||
info.name,
|
||||
info.driver,
|
||||
info.formats.len()
|
||||
);
|
||||
devices.push(info);
|
||||
} else {
|
||||
debug!("Skipping non-capture device: {:?}", path);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get info for {:?}: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to open {:?}: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by priority (highest first)
|
||||
devices.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
|
||||
info!("Found {} video capture devices", devices.len());
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
/// Find the best video device for KVM use
|
||||
pub fn find_best_device() -> Result<VideoDeviceInfo> {
|
||||
let devices = enumerate_devices()?;
|
||||
|
||||
devices.into_iter().next().ok_or_else(|| {
|
||||
AppError::VideoError("No video capture devices found".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pixel_format_conversion() {
|
||||
let format = PixelFormat::Mjpeg;
|
||||
let fourcc = format.to_fourcc();
|
||||
let back = PixelFormat::from_fourcc(fourcc);
|
||||
assert_eq!(back, Some(format));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolution() {
|
||||
let res = Resolution::HD1080;
|
||||
assert_eq!(res.width, 1920);
|
||||
assert_eq!(res.height, 1080);
|
||||
assert!(res.is_valid());
|
||||
}
|
||||
}
|
||||
370
src/video/encoder/codec.rs
Normal file
370
src/video/encoder/codec.rs
Normal file
@@ -0,0 +1,370 @@
|
||||
//! WebRTC Video Codec abstraction layer
|
||||
//!
|
||||
//! This module provides a unified interface for video codecs used in WebRTC streaming.
|
||||
//! It supports multiple codec types (H264, VP8, VP9, H265) with a common API.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! VideoCodec (trait)
|
||||
//! |
|
||||
//! +-- H264Codec (current implementation)
|
||||
//! +-- VP8Codec (reserved)
|
||||
//! +-- VP9Codec (reserved)
|
||||
//! +-- H265Codec (reserved)
|
||||
//! ```
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::video::format::Resolution;
|
||||
|
||||
/// Supported video codec types for WebRTC
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum VideoCodecType {
|
||||
/// H.264/AVC - widely supported, good compression
|
||||
H264,
|
||||
/// VP8 - royalty-free, good browser support
|
||||
VP8,
|
||||
/// VP9 - better compression than VP8
|
||||
VP9,
|
||||
/// H.265/HEVC - best compression, limited browser support
|
||||
H265,
|
||||
}
|
||||
|
||||
impl VideoCodecType {
|
||||
/// Get the codec name for SDP
|
||||
pub fn sdp_name(&self) -> &'static str {
|
||||
match self {
|
||||
VideoCodecType::H264 => "H264",
|
||||
VideoCodecType::VP8 => "VP8",
|
||||
VideoCodecType::VP9 => "VP9",
|
||||
VideoCodecType::H265 => "H265",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default RTP payload type
|
||||
pub fn default_payload_type(&self) -> u8 {
|
||||
match self {
|
||||
VideoCodecType::H264 => 96,
|
||||
VideoCodecType::VP8 => 97,
|
||||
VideoCodecType::VP9 => 98,
|
||||
VideoCodecType::H265 => 99,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the RTP clock rate (always 90000 for video)
|
||||
pub fn clock_rate(&self) -> u32 {
|
||||
90000
|
||||
}
|
||||
|
||||
/// Get the MIME type
|
||||
pub fn mime_type(&self) -> &'static str {
|
||||
match self {
|
||||
VideoCodecType::H264 => "video/H264",
|
||||
VideoCodecType::VP8 => "video/VP8",
|
||||
VideoCodecType::VP9 => "video/VP9",
|
||||
VideoCodecType::H265 => "video/H265",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VideoCodecType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.sdp_name())
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoded video frame for WebRTC transmission
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodecFrame {
|
||||
/// Encoded data (Annex B format for H264/H265, raw for VP8/VP9)
|
||||
pub data: Bytes,
|
||||
/// Presentation timestamp in milliseconds
|
||||
pub pts_ms: i64,
|
||||
/// Whether this is a keyframe (IDR for H264, key frame for VP8/VP9)
|
||||
pub is_keyframe: bool,
|
||||
/// Codec type
|
||||
pub codec: VideoCodecType,
|
||||
/// Frame sequence number
|
||||
pub sequence: u64,
|
||||
/// Frame duration
|
||||
pub duration: Duration,
|
||||
}
|
||||
|
||||
impl CodecFrame {
|
||||
/// Create a new H264 frame
|
||||
pub fn h264(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
|
||||
Self {
|
||||
data,
|
||||
pts_ms,
|
||||
is_keyframe,
|
||||
codec: VideoCodecType::H264,
|
||||
sequence,
|
||||
duration: Duration::from_millis(1000 / fps as u64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new VP8 frame
|
||||
pub fn vp8(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
|
||||
Self {
|
||||
data,
|
||||
pts_ms,
|
||||
is_keyframe,
|
||||
codec: VideoCodecType::VP8,
|
||||
sequence,
|
||||
duration: Duration::from_millis(1000 / fps as u64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new VP9 frame
|
||||
pub fn vp9(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
|
||||
Self {
|
||||
data,
|
||||
pts_ms,
|
||||
is_keyframe,
|
||||
codec: VideoCodecType::VP9,
|
||||
sequence,
|
||||
duration: Duration::from_millis(1000 / fps as u64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new H265 frame
|
||||
pub fn h265(data: Bytes, pts_ms: i64, is_keyframe: bool, sequence: u64, fps: u32) -> Self {
|
||||
Self {
|
||||
data,
|
||||
pts_ms,
|
||||
is_keyframe,
|
||||
codec: VideoCodecType::H265,
|
||||
sequence,
|
||||
duration: Duration::from_millis(1000 / fps as u64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get frame size in bytes
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Check if frame is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Video codec configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VideoCodecConfig {
|
||||
/// Codec type
|
||||
pub codec: VideoCodecType,
|
||||
/// Target resolution
|
||||
pub resolution: Resolution,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// Target FPS
|
||||
pub fps: u32,
|
||||
/// GOP size (keyframe interval in frames)
|
||||
pub gop_size: u32,
|
||||
/// Profile (codec-specific)
|
||||
pub profile: Option<String>,
|
||||
/// Level (codec-specific)
|
||||
pub level: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for VideoCodecConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
codec: VideoCodecType::H264,
|
||||
resolution: Resolution::HD720,
|
||||
bitrate_kbps: 8000,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
profile: None,
|
||||
level: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VideoCodecConfig {
|
||||
/// Create H264 config with common settings
|
||||
pub fn h264(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
|
||||
Self {
|
||||
codec: VideoCodecType::H264,
|
||||
resolution,
|
||||
bitrate_kbps,
|
||||
fps,
|
||||
gop_size: fps, // 1 second GOP
|
||||
profile: Some("baseline".to_string()),
|
||||
level: Some("3.1".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create VP8 config
|
||||
pub fn vp8(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
|
||||
Self {
|
||||
codec: VideoCodecType::VP8,
|
||||
resolution,
|
||||
bitrate_kbps,
|
||||
fps,
|
||||
gop_size: fps,
|
||||
profile: None,
|
||||
level: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create VP9 config
|
||||
pub fn vp9(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
|
||||
Self {
|
||||
codec: VideoCodecType::VP9,
|
||||
resolution,
|
||||
bitrate_kbps,
|
||||
fps,
|
||||
gop_size: fps,
|
||||
profile: None,
|
||||
level: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create H265 config
|
||||
pub fn h265(resolution: Resolution, bitrate_kbps: u32, fps: u32) -> Self {
|
||||
Self {
|
||||
codec: VideoCodecType::H265,
|
||||
resolution,
|
||||
bitrate_kbps,
|
||||
fps,
|
||||
gop_size: fps,
|
||||
profile: Some("main".to_string()),
|
||||
level: Some("4.0".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WebRTC video codec trait
|
||||
///
|
||||
/// This trait defines the interface for video codecs used in WebRTC streaming.
|
||||
/// Implementations should handle format conversion internally if needed.
|
||||
pub trait VideoCodec: Send {
|
||||
/// Get codec type
|
||||
fn codec_type(&self) -> VideoCodecType;
|
||||
|
||||
/// Get codec name for display
|
||||
fn codec_name(&self) -> &'static str;
|
||||
|
||||
/// Get RTP payload type
|
||||
fn payload_type(&self) -> u8 {
|
||||
self.codec_type().default_payload_type()
|
||||
}
|
||||
|
||||
/// Get SDP fmtp parameters (codec-specific)
|
||||
///
|
||||
/// For H264: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f"
|
||||
/// For VP8/VP9: None or empty
|
||||
fn sdp_fmtp(&self) -> Option<String>;
|
||||
|
||||
/// Encode a raw frame (NV12 format expected)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `frame` - Raw frame data in NV12 format
|
||||
/// * `pts_ms` - Presentation timestamp in milliseconds
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(Some(frame))` - Encoded frame
|
||||
/// * `Ok(None)` - Encoder is buffering (no output yet)
|
||||
/// * `Err(e)` - Encoding error
|
||||
fn encode(&mut self, frame: &[u8], pts_ms: i64) -> Result<Option<CodecFrame>>;
|
||||
|
||||
/// Set target bitrate dynamically
|
||||
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()>;
|
||||
|
||||
/// Request a keyframe on next encode
|
||||
fn request_keyframe(&mut self);
|
||||
|
||||
/// Get current configuration
|
||||
fn config(&self) -> &VideoCodecConfig;
|
||||
|
||||
/// Flush any pending frames
|
||||
fn flush(&mut self) -> Result<Vec<CodecFrame>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
/// Reset encoder state
|
||||
fn reset(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Video codec factory trait
|
||||
///
|
||||
/// Used to create codec instances and query available codecs.
|
||||
pub trait VideoCodecFactory: Send + Sync {
|
||||
/// Create a codec with the given configuration
|
||||
fn create(&self, config: VideoCodecConfig) -> Result<Box<dyn VideoCodec>>;
|
||||
|
||||
/// Get supported codec types
|
||||
fn supported_codecs(&self) -> Vec<VideoCodecType>;
|
||||
|
||||
/// Check if a specific codec is available
|
||||
fn is_codec_available(&self, codec: VideoCodecType) -> bool {
|
||||
self.supported_codecs().contains(&codec)
|
||||
}
|
||||
|
||||
/// Get the best available codec (based on priority)
|
||||
fn best_codec(&self) -> Option<VideoCodecType> {
|
||||
// Priority: H264 > VP8 > VP9 > H265
|
||||
let supported = self.supported_codecs();
|
||||
if supported.contains(&VideoCodecType::H264) {
|
||||
Some(VideoCodecType::H264)
|
||||
} else if supported.contains(&VideoCodecType::VP8) {
|
||||
Some(VideoCodecType::VP8)
|
||||
} else if supported.contains(&VideoCodecType::VP9) {
|
||||
Some(VideoCodecType::VP9)
|
||||
} else if supported.contains(&VideoCodecType::H265) {
|
||||
Some(VideoCodecType::H265)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_codec_type_properties() {
|
||||
assert_eq!(VideoCodecType::H264.sdp_name(), "H264");
|
||||
assert_eq!(VideoCodecType::H264.default_payload_type(), 96);
|
||||
assert_eq!(VideoCodecType::H264.clock_rate(), 90000);
|
||||
assert_eq!(VideoCodecType::H264.mime_type(), "video/H264");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codec_frame_creation() {
|
||||
let data = Bytes::from(vec![0x00, 0x00, 0x00, 0x01, 0x65]);
|
||||
let frame = CodecFrame::h264(data.clone(), 1000, true, 1, 30);
|
||||
|
||||
assert_eq!(frame.codec, VideoCodecType::H264);
|
||||
assert!(frame.is_keyframe);
|
||||
assert_eq!(frame.pts_ms, 1000);
|
||||
assert_eq!(frame.sequence, 1);
|
||||
assert_eq!(frame.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codec_config_default() {
|
||||
let config = VideoCodecConfig::default();
|
||||
assert_eq!(config.codec, VideoCodecType::H264);
|
||||
assert_eq!(config.bitrate_kbps, 2000);
|
||||
assert_eq!(config.fps, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codec_config_h264() {
|
||||
let config = VideoCodecConfig::h264(Resolution::HD1080, 4000, 60);
|
||||
assert_eq!(config.codec, VideoCodecType::H264);
|
||||
assert_eq!(config.bitrate_kbps, 4000);
|
||||
assert_eq!(config.fps, 60);
|
||||
assert_eq!(config.gop_size, 60);
|
||||
}
|
||||
}
|
||||
532
src/video/encoder/h264.rs
Normal file
532
src/video/encoder/h264.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
//! H.264 encoder using hwcodec (rustdesk's FFmpeg wrapper)
|
||||
//!
|
||||
//! Supports multiple encoder backends via FFmpeg:
|
||||
//! - VAAPI (Intel/AMD/NVIDIA on Linux)
|
||||
//! - NVENC (NVIDIA)
|
||||
//! - AMF (AMD)
|
||||
//! - Software (libx264)
|
||||
//!
|
||||
//! The encoder is selected automatically based on availability.
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::sync::Once;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use hwcodec::common::{Quality, RateControl};
|
||||
use hwcodec::ffmpeg::AVPixelFormat;
|
||||
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
|
||||
use hwcodec::ffmpeg_ram::CodecInfo;
|
||||
|
||||
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize hwcodec logging (only once)
|
||||
fn init_hwcodec_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
// hwcodec uses the `log` crate, which will work with our tracing subscriber
|
||||
debug!("hwcodec logging initialized");
|
||||
});
|
||||
}
|
||||
|
||||
/// H.264 encoder type (detected from hwcodec)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum H264EncoderType {
|
||||
/// NVIDIA NVENC
|
||||
Nvenc,
|
||||
/// Intel Quick Sync (QSV)
|
||||
Qsv,
|
||||
/// AMD AMF
|
||||
Amf,
|
||||
/// VAAPI (Linux generic)
|
||||
Vaapi,
|
||||
/// RKMPP (Rockchip) - requires hwcodec extension
|
||||
Rkmpp,
|
||||
/// V4L2 M2M (ARM generic) - requires hwcodec extension
|
||||
V4l2M2m,
|
||||
/// Software encoding (libx264/openh264)
|
||||
Software,
|
||||
/// No encoder available
|
||||
None,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for H264EncoderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
H264EncoderType::Nvenc => write!(f, "NVENC"),
|
||||
H264EncoderType::Qsv => write!(f, "QSV"),
|
||||
H264EncoderType::Amf => write!(f, "AMF"),
|
||||
H264EncoderType::Vaapi => write!(f, "VAAPI"),
|
||||
H264EncoderType::Rkmpp => write!(f, "RKMPP"),
|
||||
H264EncoderType::V4l2M2m => write!(f, "V4L2 M2M"),
|
||||
H264EncoderType::Software => write!(f, "Software"),
|
||||
H264EncoderType::None => write!(f, "None"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for H264EncoderType {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
/// Map codec name to encoder type
|
||||
fn codec_name_to_type(name: &str) -> H264EncoderType {
|
||||
if name.contains("nvenc") {
|
||||
H264EncoderType::Nvenc
|
||||
} else if name.contains("qsv") {
|
||||
H264EncoderType::Qsv
|
||||
} else if name.contains("amf") {
|
||||
H264EncoderType::Amf
|
||||
} else if name.contains("vaapi") {
|
||||
H264EncoderType::Vaapi
|
||||
} else if name.contains("rkmpp") {
|
||||
H264EncoderType::Rkmpp
|
||||
} else if name.contains("v4l2m2m") {
|
||||
H264EncoderType::V4l2M2m
|
||||
} else {
|
||||
H264EncoderType::Software
|
||||
}
|
||||
}
|
||||
|
||||
/// Input pixel format for H264 encoder
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum H264InputFormat {
|
||||
/// YUV420P (I420) - planar Y, U, V
|
||||
Yuv420p,
|
||||
/// NV12 - Y plane + interleaved UV plane (optimal for VAAPI)
|
||||
Nv12,
|
||||
}
|
||||
|
||||
impl Default for H264InputFormat {
|
||||
fn default() -> Self {
|
||||
Self::Nv12 // Default to NV12 for VAAPI compatibility
|
||||
}
|
||||
}
|
||||
|
||||
/// H.264 encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct H264Config {
|
||||
/// Base encoder config
|
||||
pub base: EncoderConfig,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// GOP size (keyframe interval)
|
||||
pub gop_size: u32,
|
||||
/// Frame rate
|
||||
pub fps: u32,
|
||||
/// Input pixel format
|
||||
pub input_format: H264InputFormat,
|
||||
}
|
||||
|
||||
impl Default for H264Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::default(),
|
||||
bitrate_kbps: 8000,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: H264InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl H264Config {
|
||||
/// Create config for low latency streaming with NV12 input (optimal for VAAPI)
|
||||
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::h264(resolution, bitrate_kbps),
|
||||
bitrate_kbps,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: H264InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for low latency streaming with YUV420P input
|
||||
pub fn low_latency_yuv420p(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::h264(resolution, bitrate_kbps),
|
||||
bitrate_kbps,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: H264InputFormat::Yuv420p,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for quality streaming
|
||||
pub fn quality(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::h264(resolution, bitrate_kbps),
|
||||
bitrate_kbps,
|
||||
gop_size: 60,
|
||||
fps: 30,
|
||||
input_format: H264InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set input format
|
||||
pub fn with_input_format(mut self, format: H264InputFormat) -> Self {
|
||||
self.input_format = format;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Get available H264 encoders from hwcodec
|
||||
pub fn get_available_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: String::new(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt: AVPixelFormat::AV_PIX_FMT_YUV420P,
|
||||
align: 1,
|
||||
fps: 30,
|
||||
gop: 30,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (ultrafast)
|
||||
kbs: 2000,
|
||||
q: 23,
|
||||
thread_count: 4,
|
||||
};
|
||||
|
||||
HwEncoder::available_encoders(ctx, None)
|
||||
}
|
||||
|
||||
/// Detect best available H.264 encoder
|
||||
pub fn detect_best_encoder(width: u32, height: u32) -> (H264EncoderType, Option<String>) {
|
||||
let encoders = get_available_encoders(width, height);
|
||||
|
||||
if encoders.is_empty() {
|
||||
warn!("No H.264 encoders available from hwcodec");
|
||||
return (H264EncoderType::None, None);
|
||||
}
|
||||
|
||||
// Find H264 encoder (not H265)
|
||||
for codec in &encoders {
|
||||
if codec.format == hwcodec::common::DataFormat::H264 {
|
||||
let encoder_type = codec_name_to_type(&codec.name);
|
||||
info!("Best H.264 encoder: {} ({})", codec.name, encoder_type);
|
||||
return (encoder_type, Some(codec.name.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
(H264EncoderType::None, None)
|
||||
}
|
||||
|
||||
/// Encoded frame from hwcodec (cloned for ownership)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HwEncodeFrame {
|
||||
pub data: Vec<u8>,
|
||||
pub pts: i64,
|
||||
pub key: i32,
|
||||
}
|
||||
|
||||
/// H.264 encoder using hwcodec
|
||||
pub struct H264Encoder {
|
||||
/// hwcodec encoder instance
|
||||
inner: HwEncoder,
|
||||
/// Encoder configuration
|
||||
config: H264Config,
|
||||
/// Detected encoder type
|
||||
encoder_type: H264EncoderType,
|
||||
/// Codec name
|
||||
codec_name: String,
|
||||
/// Frame counter
|
||||
frame_count: u64,
|
||||
/// YUV420P buffer for input (reserved for future use)
|
||||
#[allow(dead_code)]
|
||||
yuv_buffer: Vec<u8>,
|
||||
/// Required YUV buffer length from hwcodec
|
||||
yuv_length: i32,
|
||||
}
|
||||
|
||||
impl H264Encoder {
|
||||
/// Create a new H.264 encoder with automatic codec detection
|
||||
pub fn new(config: H264Config) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
// Detect best encoder
|
||||
let (_encoder_type, codec_name) = detect_best_encoder(width, height);
|
||||
|
||||
let codec_name = codec_name.ok_or_else(|| {
|
||||
AppError::VideoError("No H.264 encoder available".to_string())
|
||||
})?;
|
||||
|
||||
Self::with_codec(config, &codec_name)
|
||||
}
|
||||
|
||||
/// Create encoder with specific codec name
|
||||
pub fn with_codec(config: H264Config, codec_name: &str) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
// Select pixel format based on config
|
||||
let pixfmt = match config.input_format {
|
||||
H264InputFormat::Nv12 => AVPixelFormat::AV_PIX_FMT_NV12,
|
||||
H264InputFormat::Yuv420p => AVPixelFormat::AV_PIX_FMT_YUV420P,
|
||||
};
|
||||
|
||||
info!(
|
||||
"Creating H.264 encoder: {} at {}x{} @ {} kbps (input: {:?})",
|
||||
codec_name, width, height, config.bitrate_kbps, config.input_format
|
||||
);
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: codec_name.to_string(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt,
|
||||
align: 1,
|
||||
fps: config.fps as i32,
|
||||
gop: config.gop_size as i32,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Low, // Use low quality preset for fastest encoding (lowest latency)
|
||||
kbs: config.bitrate_kbps as i32,
|
||||
q: 23,
|
||||
thread_count: 4, // Use 4 threads for better performance
|
||||
};
|
||||
|
||||
let inner = HwEncoder::new(ctx).map_err(|_| {
|
||||
AppError::VideoError(format!("Failed to create encoder: {}", codec_name))
|
||||
})?;
|
||||
|
||||
let yuv_length = inner.length;
|
||||
let encoder_type = codec_name_to_type(codec_name);
|
||||
|
||||
info!(
|
||||
"H.264 encoder created: {} (type: {}, buffer_length: {}, input_format: {:?})",
|
||||
codec_name, encoder_type, yuv_length, config.input_format
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
config,
|
||||
encoder_type,
|
||||
codec_name: codec_name.to_string(),
|
||||
frame_count: 0,
|
||||
yuv_buffer: vec![0u8; yuv_length as usize],
|
||||
yuv_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with auto-detected encoder
|
||||
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
|
||||
let config = H264Config::low_latency(resolution, bitrate_kbps);
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Get encoder type
|
||||
pub fn encoder_type(&self) -> &H264EncoderType {
|
||||
&self.encoder_type
|
||||
}
|
||||
|
||||
/// Get codec name
|
||||
pub fn codec_name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically
|
||||
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
|
||||
AppError::VideoError("Failed to set bitrate".to_string())
|
||||
})?;
|
||||
self.config.bitrate_kbps = bitrate_kbps;
|
||||
debug!("Bitrate updated to {} kbps", bitrate_kbps);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Request next frame to be a keyframe (IDR)
|
||||
pub fn request_keyframe(&mut self) {
|
||||
self.inner.request_keyframe();
|
||||
debug!("H264 keyframe requested");
|
||||
}
|
||||
|
||||
/// Encode raw frame data (YUV420P or NV12 depending on config)
|
||||
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
if data.len() < self.yuv_length as usize {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Frame data too small: {} < {}",
|
||||
data.len(),
|
||||
self.yuv_length
|
||||
)));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
match self.inner.encode(data, pts_ms) {
|
||||
Ok(frames) => {
|
||||
// Copy frame data to owned HwEncodeFrame
|
||||
let owned_frames: Vec<HwEncodeFrame> = frames
|
||||
.iter()
|
||||
.map(|f| HwEncodeFrame {
|
||||
data: f.data.clone(),
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect();
|
||||
Ok(owned_frames)
|
||||
}
|
||||
Err(e) => {
|
||||
// For the first ~30 frames, x264 may fail due to initialization
|
||||
// Log as warning instead of error to avoid alarming users
|
||||
if self.frame_count <= 30 {
|
||||
warn!(
|
||||
"Encode failed during initialization (frame {}): {} - this is normal for x264",
|
||||
self.frame_count, e
|
||||
);
|
||||
} else {
|
||||
error!("Encode failed: {}", e);
|
||||
}
|
||||
Err(AppError::VideoError(format!("Encode failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode YUV420P data (legacy method, use encode_raw for new code)
|
||||
pub fn encode_yuv420p(&mut self, yuv_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
self.encode_raw(yuv_data, pts_ms)
|
||||
}
|
||||
|
||||
/// Encode NV12 data
|
||||
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
self.encode_raw(nv12_data, pts_ms)
|
||||
}
|
||||
|
||||
/// Get input format
|
||||
pub fn input_format(&self) -> H264InputFormat {
|
||||
self.config.input_format
|
||||
}
|
||||
|
||||
/// Get YUV buffer info (linesize, offset, length)
|
||||
pub fn yuv_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
|
||||
(
|
||||
self.inner.linesize.clone(),
|
||||
self.inner.offset.clone(),
|
||||
self.inner.length,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: H264Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
|
||||
// that are not Send by default. However, we ensure that H264Encoder is only used from
|
||||
// a single task/thread at a time (encoding is sequential), so this is safe.
|
||||
// The raw pointers are internal FFmpeg context that doesn't escape the encoder.
|
||||
unsafe impl Send for H264Encoder {}
|
||||
|
||||
impl Encoder for H264Encoder {
|
||||
fn name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
fn output_format(&self) -> EncodedFormat {
|
||||
EncodedFormat::H264
|
||||
}
|
||||
|
||||
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
// Assume input is YUV420P
|
||||
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
|
||||
|
||||
let frames = self.encode_yuv420p(data, pts_ms)?;
|
||||
|
||||
if frames.is_empty() {
|
||||
// Encoder needs more frames (shouldn't happen with our config)
|
||||
warn!("Encoder returned no frames");
|
||||
return Err(AppError::VideoError("Encoder returned no frames".to_string()));
|
||||
}
|
||||
|
||||
// Take the first frame
|
||||
let frame = &frames[0];
|
||||
let key_frame = frame.key == 1;
|
||||
|
||||
Ok(EncodedFrame::h264(
|
||||
Bytes::from(frame.data.clone()),
|
||||
self.config.base.resolution,
|
||||
key_frame,
|
||||
sequence,
|
||||
frame.pts as u64,
|
||||
frame.pts as u64,
|
||||
))
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
|
||||
// hwcodec doesn't have explicit flush, return empty
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<()> {
|
||||
self.frame_count = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn config(&self) -> &EncoderConfig {
|
||||
&self.config.base
|
||||
}
|
||||
|
||||
fn supports_format(&self, format: PixelFormat) -> bool {
|
||||
// Check if the format matches our configured input format
|
||||
match self.config.input_format {
|
||||
H264InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
|
||||
H264InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoder statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EncoderStats {
|
||||
/// Total frames encoded
|
||||
pub frames_encoded: u64,
|
||||
/// Total bytes output
|
||||
pub bytes_output: u64,
|
||||
/// Current encoding FPS
|
||||
pub fps: f32,
|
||||
/// Average encoding time per frame (ms)
|
||||
pub avg_encode_time_ms: f32,
|
||||
/// Keyframes encoded
|
||||
pub keyframes: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_encoder() {
|
||||
let (encoder_type, codec_name) = detect_best_encoder(1280, 720);
|
||||
println!("Detected encoder: {:?} ({:?})", encoder_type, codec_name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_encoders() {
|
||||
let encoders = get_available_encoders(1280, 720);
|
||||
println!("Available encoders:");
|
||||
for enc in &encoders {
|
||||
println!(" - {} ({:?})", enc.name, enc.format);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_encoder() {
|
||||
let config = H264Config::low_latency(Resolution::HD720, 2000);
|
||||
match H264Encoder::new(config) {
|
||||
Ok(encoder) => {
|
||||
println!("Created encoder: {} ({})", encoder.codec_name(), encoder.encoder_type());
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Failed to create encoder: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
577
src/video/encoder/h265.rs
Normal file
577
src/video/encoder/h265.rs
Normal file
@@ -0,0 +1,577 @@
|
||||
//! H.265/HEVC encoder using hwcodec (FFmpeg wrapper)
|
||||
//!
|
||||
//! Supports both hardware and software encoding:
|
||||
//! - Hardware: VAAPI, NVENC, QSV, AMF, RKMPP, V4L2 M2M
|
||||
//! - Software: libx265 (CPU-based, high CPU usage)
|
||||
//!
|
||||
//! Hardware encoding is preferred when available for better performance.
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::sync::Once;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use hwcodec::common::{DataFormat, Quality, RateControl};
|
||||
use hwcodec::ffmpeg::AVPixelFormat;
|
||||
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
|
||||
use hwcodec::ffmpeg_ram::CodecInfo;
|
||||
|
||||
use super::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
|
||||
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize hwcodec logging (only once)
|
||||
fn init_hwcodec_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
debug!("hwcodec logging initialized for H265");
|
||||
});
|
||||
}
|
||||
|
||||
/// H.265 encoder type (detected from hwcodec)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum H265EncoderType {
|
||||
/// NVIDIA NVENC
|
||||
Nvenc,
|
||||
/// Intel Quick Sync (QSV)
|
||||
Qsv,
|
||||
/// AMD AMF
|
||||
Amf,
|
||||
/// VAAPI (Linux generic)
|
||||
Vaapi,
|
||||
/// RKMPP (Rockchip)
|
||||
Rkmpp,
|
||||
/// V4L2 M2M (ARM generic)
|
||||
V4l2M2m,
|
||||
/// Software encoder (libx265)
|
||||
Software,
|
||||
/// No encoder available
|
||||
None,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for H265EncoderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
H265EncoderType::Nvenc => write!(f, "NVENC"),
|
||||
H265EncoderType::Qsv => write!(f, "QSV"),
|
||||
H265EncoderType::Amf => write!(f, "AMF"),
|
||||
H265EncoderType::Vaapi => write!(f, "VAAPI"),
|
||||
H265EncoderType::Rkmpp => write!(f, "RKMPP"),
|
||||
H265EncoderType::V4l2M2m => write!(f, "V4L2 M2M"),
|
||||
H265EncoderType::Software => write!(f, "Software"),
|
||||
H265EncoderType::None => write!(f, "None"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for H265EncoderType {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncoderBackend> for H265EncoderType {
|
||||
fn from(backend: EncoderBackend) -> Self {
|
||||
match backend {
|
||||
EncoderBackend::Nvenc => H265EncoderType::Nvenc,
|
||||
EncoderBackend::Qsv => H265EncoderType::Qsv,
|
||||
EncoderBackend::Amf => H265EncoderType::Amf,
|
||||
EncoderBackend::Vaapi => H265EncoderType::Vaapi,
|
||||
EncoderBackend::Rkmpp => H265EncoderType::Rkmpp,
|
||||
EncoderBackend::V4l2m2m => H265EncoderType::V4l2M2m,
|
||||
EncoderBackend::Software => H265EncoderType::Software,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Input pixel format for H265 encoder
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum H265InputFormat {
|
||||
/// YUV420P (I420) - planar Y, U, V
|
||||
Yuv420p,
|
||||
/// NV12 - Y plane + interleaved UV plane (optimal for hardware encoders)
|
||||
Nv12,
|
||||
}
|
||||
|
||||
impl Default for H265InputFormat {
|
||||
fn default() -> Self {
|
||||
Self::Nv12 // Default to NV12 for hardware encoder compatibility
|
||||
}
|
||||
}
|
||||
|
||||
/// H.265 encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct H265Config {
|
||||
/// Base encoder config
|
||||
pub base: EncoderConfig,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// GOP size (keyframe interval)
|
||||
pub gop_size: u32,
|
||||
/// Frame rate
|
||||
pub fps: u32,
|
||||
/// Input pixel format
|
||||
pub input_format: H265InputFormat,
|
||||
}
|
||||
|
||||
impl Default for H265Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::default(),
|
||||
bitrate_kbps: 8000,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: H265InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl H265Config {
|
||||
/// Create config for low latency streaming with NV12 input
|
||||
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig {
|
||||
resolution,
|
||||
input_format: PixelFormat::Nv12,
|
||||
quality: bitrate_kbps,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
},
|
||||
bitrate_kbps,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: H265InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for quality streaming
|
||||
pub fn quality(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig {
|
||||
resolution,
|
||||
input_format: PixelFormat::Nv12,
|
||||
quality: bitrate_kbps,
|
||||
fps: 30,
|
||||
gop_size: 60,
|
||||
},
|
||||
bitrate_kbps,
|
||||
gop_size: 60,
|
||||
fps: 30,
|
||||
input_format: H265InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set input format
|
||||
pub fn with_input_format(mut self, format: H265InputFormat) -> Self {
|
||||
self.input_format = format;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Get available H265 hardware encoders from hwcodec
|
||||
pub fn get_available_h265_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: String::new(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
|
||||
align: 1,
|
||||
fps: 30,
|
||||
gop: 30,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: 2000,
|
||||
q: 23,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
let all_encoders = HwEncoder::available_encoders(ctx, None);
|
||||
|
||||
// Include both hardware and software H265 encoders
|
||||
all_encoders
|
||||
.into_iter()
|
||||
.filter(|e| e.format == DataFormat::H265)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Detect best available H.265 encoder (hardware preferred, software fallback)
|
||||
pub fn detect_best_h265_encoder(width: u32, height: u32) -> (H265EncoderType, Option<String>) {
|
||||
let encoders = get_available_h265_encoders(width, height);
|
||||
|
||||
if encoders.is_empty() {
|
||||
warn!("No H.265 encoders available");
|
||||
return (H265EncoderType::None, None);
|
||||
}
|
||||
|
||||
// Prefer hardware encoders over software (libx265)
|
||||
// Hardware priority: NVENC > QSV > AMF > VAAPI > RKMPP > V4L2 M2M > Software
|
||||
let codec = encoders
|
||||
.iter()
|
||||
.find(|e| !e.name.contains("libx265"))
|
||||
.or_else(|| encoders.first())
|
||||
.unwrap();
|
||||
|
||||
let encoder_type = if codec.name.contains("nvenc") {
|
||||
H265EncoderType::Nvenc
|
||||
} else if codec.name.contains("qsv") {
|
||||
H265EncoderType::Qsv
|
||||
} else if codec.name.contains("amf") {
|
||||
H265EncoderType::Amf
|
||||
} else if codec.name.contains("vaapi") {
|
||||
H265EncoderType::Vaapi
|
||||
} else if codec.name.contains("rkmpp") {
|
||||
H265EncoderType::Rkmpp
|
||||
} else if codec.name.contains("v4l2m2m") {
|
||||
H265EncoderType::V4l2M2m
|
||||
} else if codec.name.contains("libx265") {
|
||||
H265EncoderType::Software
|
||||
} else {
|
||||
H265EncoderType::Software // Default to software for unknown
|
||||
};
|
||||
|
||||
info!(
|
||||
"Selected H.265 encoder: {} ({})",
|
||||
codec.name, encoder_type
|
||||
);
|
||||
(encoder_type, Some(codec.name.clone()))
|
||||
}
|
||||
|
||||
/// Check if H265 hardware encoding is available
|
||||
pub fn is_h265_available() -> bool {
|
||||
let registry = EncoderRegistry::global();
|
||||
registry.is_format_available(VideoEncoderType::H265, true)
|
||||
}
|
||||
|
||||
/// Encoded frame from hwcodec (cloned for ownership)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HwEncodeFrame {
|
||||
pub data: Vec<u8>,
|
||||
pub pts: i64,
|
||||
pub key: i32,
|
||||
}
|
||||
|
||||
/// H.265 encoder using hwcodec (hardware only)
|
||||
pub struct H265Encoder {
|
||||
/// hwcodec encoder instance
|
||||
inner: HwEncoder,
|
||||
/// Encoder configuration
|
||||
config: H265Config,
|
||||
/// Detected encoder type
|
||||
encoder_type: H265EncoderType,
|
||||
/// Codec name
|
||||
codec_name: String,
|
||||
/// Frame counter
|
||||
frame_count: u64,
|
||||
/// Required buffer length from hwcodec
|
||||
buffer_length: i32,
|
||||
}
|
||||
|
||||
impl H265Encoder {
|
||||
/// Create a new H.265 encoder with automatic hardware codec detection
|
||||
///
|
||||
/// Returns an error if no hardware encoder is available.
|
||||
pub fn new(config: H265Config) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
// Detect best hardware encoder
|
||||
let (encoder_type, codec_name) = detect_best_h265_encoder(width, height);
|
||||
|
||||
if encoder_type == H265EncoderType::None {
|
||||
return Err(AppError::VideoError(
|
||||
"No H.265 encoder available. Please ensure FFmpeg is built with libx265 support.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let codec_name = codec_name.unwrap();
|
||||
Self::with_codec(config, &codec_name)
|
||||
}
|
||||
|
||||
/// Create encoder with specific codec name
|
||||
pub fn with_codec(config: H265Config, codec_name: &str) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
// Determine if this is a software encoder
|
||||
let is_software = codec_name.contains("libx265");
|
||||
|
||||
// Warn about software encoder performance
|
||||
if is_software {
|
||||
warn!(
|
||||
"Using software H.265 encoder (libx265) - high CPU usage expected. \
|
||||
Hardware encoder is recommended for better performance."
|
||||
);
|
||||
}
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
// Software encoders (libx265) require YUV420P, hardware encoders use NV12
|
||||
let (pixfmt, actual_input_format) = if is_software {
|
||||
(AVPixelFormat::AV_PIX_FMT_YUV420P, H265InputFormat::Yuv420p)
|
||||
} else {
|
||||
match config.input_format {
|
||||
H265InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, H265InputFormat::Nv12),
|
||||
H265InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, H265InputFormat::Yuv420p),
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"Creating H.265 encoder: {} at {}x{} @ {} kbps (input: {:?})",
|
||||
codec_name, width, height, config.bitrate_kbps, actual_input_format
|
||||
);
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: codec_name.to_string(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt,
|
||||
align: 1,
|
||||
fps: config.fps as i32,
|
||||
gop: config.gop_size as i32,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: config.bitrate_kbps as i32,
|
||||
q: 23,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
let inner = HwEncoder::new(ctx).map_err(|_| {
|
||||
AppError::VideoError(format!("Failed to create H.265 encoder: {}", codec_name))
|
||||
})?;
|
||||
|
||||
let buffer_length = inner.length;
|
||||
let backend = EncoderBackend::from_codec_name(codec_name);
|
||||
let encoder_type = H265EncoderType::from(backend);
|
||||
|
||||
// Update config to reflect actual input format used
|
||||
let mut config = config;
|
||||
config.input_format = actual_input_format;
|
||||
|
||||
info!(
|
||||
"H.265 encoder created: {} (type: {}, buffer_length: {})",
|
||||
codec_name, encoder_type, buffer_length
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
config,
|
||||
encoder_type,
|
||||
codec_name: codec_name.to_string(),
|
||||
frame_count: 0,
|
||||
buffer_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with auto-detected encoder
|
||||
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
|
||||
let config = H265Config::low_latency(resolution, bitrate_kbps);
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Get encoder type
|
||||
pub fn encoder_type(&self) -> &H265EncoderType {
|
||||
&self.encoder_type
|
||||
}
|
||||
|
||||
/// Get codec name
|
||||
pub fn codec_name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically
|
||||
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
|
||||
AppError::VideoError("Failed to set H.265 bitrate".to_string())
|
||||
})?;
|
||||
self.config.bitrate_kbps = bitrate_kbps;
|
||||
debug!("H.265 bitrate updated to {} kbps", bitrate_kbps);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Request next frame to be a keyframe (IDR)
|
||||
pub fn request_keyframe(&mut self) {
|
||||
self.inner.request_keyframe();
|
||||
debug!("H265 keyframe requested");
|
||||
}
|
||||
|
||||
/// Encode raw frame data (NV12 or YUV420P depending on config)
|
||||
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
if data.len() < self.buffer_length as usize {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Frame data too small: {} < {}",
|
||||
data.len(),
|
||||
self.buffer_length
|
||||
)));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
// Debug log every 30 frames (1 second at 30fps)
|
||||
if self.frame_count % 30 == 1 {
|
||||
debug!(
|
||||
"[H265] Encoding frame #{}: input_size={}, pts_ms={}, codec={}",
|
||||
self.frame_count,
|
||||
data.len(),
|
||||
pts_ms,
|
||||
self.codec_name
|
||||
);
|
||||
}
|
||||
|
||||
match self.inner.encode(data, pts_ms) {
|
||||
Ok(frames) => {
|
||||
let owned_frames: Vec<HwEncodeFrame> = frames
|
||||
.iter()
|
||||
.map(|f| HwEncodeFrame {
|
||||
data: f.data.clone(),
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Log encoded output
|
||||
if !owned_frames.is_empty() {
|
||||
let total_size: usize = owned_frames.iter().map(|f| f.data.len()).sum();
|
||||
let keyframe = owned_frames.iter().any(|f| f.key == 1);
|
||||
|
||||
if keyframe || self.frame_count % 30 == 1 {
|
||||
debug!(
|
||||
"[H265] Encoded frame #{}: output_size={}, keyframe={}, frame_count={}",
|
||||
self.frame_count, total_size, keyframe, owned_frames.len()
|
||||
);
|
||||
|
||||
// Log first few bytes of keyframe for debugging
|
||||
if keyframe && !owned_frames[0].data.is_empty() {
|
||||
let preview_len = owned_frames[0].data.len().min(32);
|
||||
debug!(
|
||||
"[H265] Keyframe data preview: {:02x?}",
|
||||
&owned_frames[0].data[..preview_len]
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("[H265] Encoder returned empty frame list for frame #{}", self.frame_count);
|
||||
}
|
||||
|
||||
Ok(owned_frames)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("[H265] Encode failed at frame #{}: {}", self.frame_count, e);
|
||||
Err(AppError::VideoError(format!("H.265 encode failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode NV12 data
|
||||
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
self.encode_raw(nv12_data, pts_ms)
|
||||
}
|
||||
|
||||
/// Get input format
|
||||
pub fn input_format(&self) -> H265InputFormat {
|
||||
self.config.input_format
|
||||
}
|
||||
|
||||
/// Get buffer info (linesize, offset, length)
|
||||
pub fn buffer_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
|
||||
(
|
||||
self.inner.linesize.clone(),
|
||||
self.inner.offset.clone(),
|
||||
self.inner.length,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: H265Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
|
||||
// that are not Send by default. However, we ensure that H265Encoder is only used from
|
||||
// a single task/thread at a time (encoding is sequential), so this is safe.
|
||||
unsafe impl Send for H265Encoder {}
|
||||
|
||||
impl Encoder for H265Encoder {
|
||||
fn name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
fn output_format(&self) -> EncodedFormat {
|
||||
EncodedFormat::H265
|
||||
}
|
||||
|
||||
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
|
||||
|
||||
let frames = self.encode_raw(data, pts_ms)?;
|
||||
|
||||
if frames.is_empty() {
|
||||
warn!("H.265 encoder returned no frames");
|
||||
return Err(AppError::VideoError(
|
||||
"H.265 encoder returned no frames".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let frame = &frames[0];
|
||||
let key_frame = frame.key == 1;
|
||||
|
||||
Ok(EncodedFrame {
|
||||
data: Bytes::from(frame.data.clone()),
|
||||
format: EncodedFormat::H265,
|
||||
resolution: self.config.base.resolution,
|
||||
key_frame,
|
||||
sequence,
|
||||
timestamp: std::time::Instant::now(),
|
||||
pts: frame.pts as u64,
|
||||
dts: frame.pts as u64,
|
||||
})
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<()> {
|
||||
self.frame_count = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn config(&self) -> &EncoderConfig {
|
||||
&self.config.base
|
||||
}
|
||||
|
||||
fn supports_format(&self, format: PixelFormat) -> bool {
|
||||
match self.config.input_format {
|
||||
H265InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
|
||||
H265InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_h265_encoder() {
|
||||
let (encoder_type, codec_name) = detect_best_h265_encoder(1280, 720);
|
||||
println!("Detected H.265 encoder: {:?} ({:?})", encoder_type, codec_name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_h265_encoders() {
|
||||
let encoders = get_available_h265_encoders(1280, 720);
|
||||
println!("Available H.265 hardware encoders:");
|
||||
for enc in &encoders {
|
||||
println!(" - {} ({:?})", enc.name, enc.format);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_h265_availability() {
|
||||
let available = is_h265_available();
|
||||
println!("H.265 hardware encoding available: {}", available);
|
||||
}
|
||||
}
|
||||
226
src/video/encoder/jpeg.rs
Normal file
226
src/video/encoder/jpeg.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! JPEG encoder implementation
|
||||
//!
|
||||
//! Provides JPEG encoding for raw video frames (YUYV, NV12, RGB, BGR)
|
||||
//! Uses libyuv for SIMD-accelerated color space conversion to I420,
|
||||
//! then turbojpeg for direct YUV encoding (skips internal color conversion).
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use super::traits::{EncodedFormat, EncodedFrame, EncoderConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
/// JPEG encoder using libyuv + turbojpeg
|
||||
///
|
||||
/// Encoding pipeline (all SIMD accelerated):
|
||||
/// ```text
|
||||
/// YUYV/NV12/BGR24/RGB24 ──libyuv──> I420 ──turbojpeg──> JPEG
|
||||
/// ```
|
||||
///
|
||||
/// Note: This encoder is NOT thread-safe due to turbojpeg limitations.
|
||||
/// Use it from a single thread or wrap in a Mutex.
|
||||
pub struct JpegEncoder {
|
||||
config: EncoderConfig,
|
||||
compressor: turbojpeg::Compressor,
|
||||
/// I420 buffer for YUV encoding (Y + U + V planes)
|
||||
i420_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl JpegEncoder {
|
||||
/// Create a new JPEG encoder
|
||||
pub fn new(config: EncoderConfig) -> Result<Self> {
|
||||
let resolution = config.resolution;
|
||||
let width = resolution.width as usize;
|
||||
let height = resolution.height as usize;
|
||||
// I420: Y = width*height, U = width*height/4, V = width*height/4
|
||||
let i420_size = width * height * 3 / 2;
|
||||
|
||||
let mut compressor = turbojpeg::Compressor::new()
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to create turbojpeg compressor: {}", e)))?;
|
||||
|
||||
compressor.set_quality(config.quality.min(100) as i32)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to set JPEG quality: {}", e)))?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
compressor,
|
||||
i420_buffer: vec![0u8; i420_size],
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with specific quality
|
||||
pub fn with_quality(resolution: Resolution, quality: u32) -> Result<Self> {
|
||||
let config = EncoderConfig::jpeg(resolution, quality);
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Set JPEG quality (1-100)
|
||||
pub fn set_quality(&mut self, quality: u32) -> Result<()> {
|
||||
self.compressor.set_quality(quality.min(100) as i32)
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to set JPEG quality: {}", e)))?;
|
||||
self.config.quality = quality;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encode I420 buffer to JPEG using turbojpeg's YUV encoder
|
||||
#[inline]
|
||||
fn encode_i420_to_jpeg(&mut self, sequence: u64) -> Result<EncodedFrame> {
|
||||
let width = self.config.resolution.width as usize;
|
||||
let height = self.config.resolution.height as usize;
|
||||
|
||||
// Create YuvImage for turbojpeg (I420 = YUV420 = Sub2x2)
|
||||
let yuv_image = turbojpeg::YuvImage {
|
||||
pixels: self.i420_buffer.as_slice(),
|
||||
width,
|
||||
height,
|
||||
align: 1, // No padding between rows
|
||||
subsamp: turbojpeg::Subsamp::Sub2x2, // YUV 4:2:0
|
||||
};
|
||||
|
||||
// Compress YUV directly to JPEG (skips color space conversion!)
|
||||
let jpeg_data = self.compressor.compress_yuv_to_vec(yuv_image)
|
||||
.map_err(|e| AppError::VideoError(format!("JPEG compression failed: {}", e)))?;
|
||||
|
||||
Ok(EncodedFrame::jpeg(
|
||||
Bytes::from(jpeg_data),
|
||||
self.config.resolution,
|
||||
sequence,
|
||||
))
|
||||
}
|
||||
|
||||
/// Encode YUYV (YUV422) frame to JPEG
|
||||
pub fn encode_yuyv(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let width = self.config.resolution.width as usize;
|
||||
let height = self.config.resolution.height as usize;
|
||||
let expected_size = width * height * 2;
|
||||
|
||||
if data.len() < expected_size {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"YUYV data too small: {} < {}",
|
||||
data.len(),
|
||||
expected_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Convert YUYV to I420 using libyuv (SIMD accelerated)
|
||||
libyuv::yuy2_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv YUYV→I420 failed: {}", e)))?;
|
||||
|
||||
self.encode_i420_to_jpeg(sequence)
|
||||
}
|
||||
|
||||
/// Encode NV12 frame to JPEG
|
||||
pub fn encode_nv12(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let width = self.config.resolution.width as usize;
|
||||
let height = self.config.resolution.height as usize;
|
||||
let expected_size = width * height * 3 / 2;
|
||||
|
||||
if data.len() < expected_size {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"NV12 data too small: {} < {}",
|
||||
data.len(),
|
||||
expected_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Convert NV12 to I420 using libyuv (SIMD accelerated)
|
||||
libyuv::nv12_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv NV12→I420 failed: {}", e)))?;
|
||||
|
||||
self.encode_i420_to_jpeg(sequence)
|
||||
}
|
||||
|
||||
/// Encode RGB24 frame to JPEG
|
||||
pub fn encode_rgb(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let width = self.config.resolution.width as usize;
|
||||
let height = self.config.resolution.height as usize;
|
||||
let expected_size = width * height * 3;
|
||||
|
||||
if data.len() < expected_size {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"RGB data too small: {} < {}",
|
||||
data.len(),
|
||||
expected_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Convert RGB24 to I420 using libyuv (SIMD accelerated)
|
||||
libyuv::rgb24_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv RGB24→I420 failed: {}", e)))?;
|
||||
|
||||
self.encode_i420_to_jpeg(sequence)
|
||||
}
|
||||
|
||||
/// Encode BGR24 frame to JPEG
|
||||
pub fn encode_bgr(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let width = self.config.resolution.width as usize;
|
||||
let height = self.config.resolution.height as usize;
|
||||
let expected_size = width * height * 3;
|
||||
|
||||
if data.len() < expected_size {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"BGR data too small: {} < {}",
|
||||
data.len(),
|
||||
expected_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Convert BGR24 to I420 using libyuv (SIMD accelerated)
|
||||
// Note: libyuv's RAWToI420 is BGR24 → I420
|
||||
libyuv::bgr24_to_i420(data, &mut self.i420_buffer, width as i32, height as i32)
|
||||
.map_err(|e| AppError::VideoError(format!("libyuv BGR24→I420 failed: {}", e)))?;
|
||||
|
||||
self.encode_i420_to_jpeg(sequence)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::video::encoder::traits::Encoder for JpegEncoder {
|
||||
fn name(&self) -> &str {
|
||||
"JPEG (libyuv+turbojpeg)"
|
||||
}
|
||||
|
||||
fn output_format(&self) -> EncodedFormat {
|
||||
EncodedFormat::Jpeg
|
||||
}
|
||||
|
||||
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
match self.config.input_format {
|
||||
PixelFormat::Yuyv | PixelFormat::Yvyu => self.encode_yuyv(data, sequence),
|
||||
PixelFormat::Nv12 => self.encode_nv12(data, sequence),
|
||||
PixelFormat::Rgb24 => self.encode_rgb(data, sequence),
|
||||
PixelFormat::Bgr24 => self.encode_bgr(data, sequence),
|
||||
_ => Err(AppError::VideoError(format!(
|
||||
"Unsupported input format for JPEG: {}",
|
||||
self.config.input_format
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn config(&self) -> &EncoderConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn supports_format(&self, format: PixelFormat) -> bool {
|
||||
matches!(
|
||||
format,
|
||||
PixelFormat::Yuyv
|
||||
| PixelFormat::Yvyu
|
||||
| PixelFormat::Nv12
|
||||
| PixelFormat::Rgb24
|
||||
| PixelFormat::Bgr24
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_i420_buffer_size() {
|
||||
// 1920x1080 I420 = 1920*1080 + 960*540 + 960*540 = 3110400 bytes
|
||||
let config = EncoderConfig::jpeg(Resolution::HD1080, 80);
|
||||
let encoder = JpegEncoder::new(config).unwrap();
|
||||
assert_eq!(encoder.i420_buffer.len(), 1920 * 1080 * 3 / 2);
|
||||
}
|
||||
}
|
||||
43
src/video/encoder/mod.rs
Normal file
43
src/video/encoder/mod.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
//! Video encoder implementations
|
||||
//!
|
||||
//! This module provides video encoding capabilities including:
|
||||
//! - JPEG encoding for raw frames (YUYV, NV12, etc.)
|
||||
//! - H264 encoding (hardware + software)
|
||||
//! - H265 encoding (hardware only)
|
||||
//! - VP8 encoding (hardware only - VAAPI)
|
||||
//! - VP9 encoding (hardware only - VAAPI)
|
||||
//! - WebRTC video codec abstraction
|
||||
//! - Encoder registry for automatic detection
|
||||
|
||||
pub mod codec;
|
||||
pub mod h264;
|
||||
pub mod h265;
|
||||
pub mod jpeg;
|
||||
pub mod registry;
|
||||
pub mod traits;
|
||||
pub mod vp8;
|
||||
pub mod vp9;
|
||||
|
||||
// Core traits and types
|
||||
pub use traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig, EncoderFactory};
|
||||
|
||||
// WebRTC codec abstraction
|
||||
pub use codec::{CodecFrame, VideoCodec, VideoCodecConfig, VideoCodecFactory, VideoCodecType};
|
||||
|
||||
// Encoder registry
|
||||
pub use registry::{AvailableEncoder, EncoderBackend, EncoderRegistry, VideoEncoderType};
|
||||
|
||||
// H264 encoder
|
||||
pub use h264::{H264Config, H264Encoder, H264EncoderType, H264InputFormat};
|
||||
|
||||
// H265 encoder (hardware only)
|
||||
pub use h265::{H265Config, H265Encoder, H265EncoderType, H265InputFormat};
|
||||
|
||||
// VP8 encoder (hardware only)
|
||||
pub use vp8::{VP8Config, VP8Encoder, VP8EncoderType, VP8InputFormat};
|
||||
|
||||
// VP9 encoder (hardware only)
|
||||
pub use vp9::{VP9Config, VP9Encoder, VP9EncoderType, VP9InputFormat};
|
||||
|
||||
// JPEG encoder
|
||||
pub use jpeg::JpegEncoder;
|
||||
531
src/video/encoder/registry.rs
Normal file
531
src/video/encoder/registry.rs
Normal file
@@ -0,0 +1,531 @@
|
||||
//! Encoder registry - Detection and management of available video encoders
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - Automatic detection of available hardware/software encoders
|
||||
//! - Encoder selection based on format and priority
|
||||
//! - Global registry for encoder availability queries
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::OnceLock;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use hwcodec::common::{DataFormat, Quality, RateControl};
|
||||
use hwcodec::ffmpeg::AVPixelFormat;
|
||||
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
|
||||
use hwcodec::ffmpeg_ram::CodecInfo;
|
||||
|
||||
/// Video encoder format type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum VideoEncoderType {
|
||||
/// H.264/AVC
|
||||
H264,
|
||||
/// H.265/HEVC
|
||||
H265,
|
||||
/// VP8
|
||||
VP8,
|
||||
/// VP9
|
||||
VP9,
|
||||
}
|
||||
|
||||
impl VideoEncoderType {
|
||||
/// Convert to hwcodec DataFormat
|
||||
pub fn to_data_format(&self) -> DataFormat {
|
||||
match self {
|
||||
VideoEncoderType::H264 => DataFormat::H264,
|
||||
VideoEncoderType::H265 => DataFormat::H265,
|
||||
VideoEncoderType::VP8 => DataFormat::VP8,
|
||||
VideoEncoderType::VP9 => DataFormat::VP9,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from hwcodec DataFormat
|
||||
pub fn from_data_format(format: DataFormat) -> Option<Self> {
|
||||
match format {
|
||||
DataFormat::H264 => Some(VideoEncoderType::H264),
|
||||
DataFormat::H265 => Some(VideoEncoderType::H265),
|
||||
DataFormat::VP8 => Some(VideoEncoderType::VP8),
|
||||
DataFormat::VP9 => Some(VideoEncoderType::VP9),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get codec name prefix for FFmpeg
|
||||
pub fn codec_prefix(&self) -> &'static str {
|
||||
match self {
|
||||
VideoEncoderType::H264 => "h264",
|
||||
VideoEncoderType::H265 => "hevc",
|
||||
VideoEncoderType::VP8 => "vp8",
|
||||
VideoEncoderType::VP9 => "vp9",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get display name
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
VideoEncoderType::H264 => "H.264",
|
||||
VideoEncoderType::H265 => "H.265/HEVC",
|
||||
VideoEncoderType::VP8 => "VP8",
|
||||
VideoEncoderType::VP9 => "VP9",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this format requires hardware-only encoding
|
||||
/// H264 supports software fallback, others require hardware
|
||||
pub fn hardware_only(&self) -> bool {
|
||||
match self {
|
||||
VideoEncoderType::H264 => false,
|
||||
VideoEncoderType::H265 => true,
|
||||
VideoEncoderType::VP8 => true,
|
||||
VideoEncoderType::VP9 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VideoEncoderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.display_name())
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoder backend type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum EncoderBackend {
|
||||
/// Intel/AMD/NVIDIA VAAPI (Linux)
|
||||
Vaapi,
|
||||
/// NVIDIA NVENC
|
||||
Nvenc,
|
||||
/// Intel Quick Sync Video
|
||||
Qsv,
|
||||
/// AMD AMF
|
||||
Amf,
|
||||
/// Rockchip MPP
|
||||
Rkmpp,
|
||||
/// V4L2 Memory-to-Memory (ARM)
|
||||
V4l2m2m,
|
||||
/// Software encoding (libx264, libx265, libvpx)
|
||||
Software,
|
||||
}
|
||||
|
||||
impl EncoderBackend {
|
||||
/// Detect backend from codec name
|
||||
pub fn from_codec_name(name: &str) -> Self {
|
||||
if name.contains("vaapi") {
|
||||
EncoderBackend::Vaapi
|
||||
} else if name.contains("nvenc") {
|
||||
EncoderBackend::Nvenc
|
||||
} else if name.contains("qsv") {
|
||||
EncoderBackend::Qsv
|
||||
} else if name.contains("amf") {
|
||||
EncoderBackend::Amf
|
||||
} else if name.contains("rkmpp") {
|
||||
EncoderBackend::Rkmpp
|
||||
} else if name.contains("v4l2m2m") {
|
||||
EncoderBackend::V4l2m2m
|
||||
} else {
|
||||
EncoderBackend::Software
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is a hardware backend
|
||||
pub fn is_hardware(&self) -> bool {
|
||||
!matches!(self, EncoderBackend::Software)
|
||||
}
|
||||
|
||||
/// Get display name
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
EncoderBackend::Vaapi => "VAAPI",
|
||||
EncoderBackend::Nvenc => "NVENC",
|
||||
EncoderBackend::Qsv => "QSV",
|
||||
EncoderBackend::Amf => "AMF",
|
||||
EncoderBackend::Rkmpp => "RKMPP",
|
||||
EncoderBackend::V4l2m2m => "V4L2 M2M",
|
||||
EncoderBackend::Software => "Software",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from string (case-insensitive)
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"vaapi" => Some(EncoderBackend::Vaapi),
|
||||
"nvenc" => Some(EncoderBackend::Nvenc),
|
||||
"qsv" => Some(EncoderBackend::Qsv),
|
||||
"amf" => Some(EncoderBackend::Amf),
|
||||
"rkmpp" => Some(EncoderBackend::Rkmpp),
|
||||
"v4l2m2m" | "v4l2" => Some(EncoderBackend::V4l2m2m),
|
||||
"software" | "cpu" => Some(EncoderBackend::Software),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EncoderBackend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.display_name())
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about an available encoder
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AvailableEncoder {
|
||||
/// Encoder format type
|
||||
pub format: VideoEncoderType,
|
||||
/// FFmpeg codec name (e.g., "h264_vaapi", "hevc_nvenc")
|
||||
pub codec_name: String,
|
||||
/// Backend type
|
||||
pub backend: EncoderBackend,
|
||||
/// Priority (lower is better)
|
||||
pub priority: i32,
|
||||
/// Whether this is a hardware encoder
|
||||
pub is_hardware: bool,
|
||||
}
|
||||
|
||||
impl AvailableEncoder {
|
||||
/// Create from hwcodec CodecInfo
|
||||
pub fn from_codec_info(info: &CodecInfo) -> Option<Self> {
|
||||
let format = VideoEncoderType::from_data_format(info.format)?;
|
||||
let backend = EncoderBackend::from_codec_name(&info.name);
|
||||
let is_hardware = backend.is_hardware();
|
||||
|
||||
Some(Self {
|
||||
format,
|
||||
codec_name: info.name.clone(),
|
||||
backend,
|
||||
priority: info.priority,
|
||||
is_hardware,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Global encoder registry
|
||||
///
|
||||
/// Detects and caches available encoders at startup.
|
||||
/// Use `EncoderRegistry::global()` to access the singleton instance.
|
||||
pub struct EncoderRegistry {
|
||||
/// Available encoders grouped by format
|
||||
encoders: HashMap<VideoEncoderType, Vec<AvailableEncoder>>,
|
||||
/// Detection resolution (used for testing)
|
||||
detection_resolution: (u32, u32),
|
||||
}
|
||||
|
||||
impl EncoderRegistry {
|
||||
/// Get the global registry instance
|
||||
///
|
||||
/// The registry is initialized lazily on first access with 1920x1080 detection.
|
||||
pub fn global() -> &'static Self {
|
||||
static INSTANCE: OnceLock<EncoderRegistry> = OnceLock::new();
|
||||
INSTANCE.get_or_init(|| {
|
||||
let mut registry = EncoderRegistry::new();
|
||||
registry.detect_encoders(1920, 1080);
|
||||
registry
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new empty registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
encoders: HashMap::new(),
|
||||
detection_resolution: (0, 0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect all available encoders
|
||||
///
|
||||
/// This queries hwcodec/FFmpeg for available encoders and populates the registry.
|
||||
pub fn detect_encoders(&mut self, width: u32, height: u32) {
|
||||
info!("Detecting available video encoders at {}x{}", width, height);
|
||||
|
||||
self.encoders.clear();
|
||||
self.detection_resolution = (width, height);
|
||||
|
||||
// Create test context for encoder detection
|
||||
let ctx = EncodeContext {
|
||||
name: String::new(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
|
||||
align: 1,
|
||||
fps: 30,
|
||||
gop: 30,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: 2000,
|
||||
q: 23,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
// Get all available encoders from hwcodec
|
||||
let all_encoders = HwEncoder::available_encoders(ctx, None);
|
||||
|
||||
info!("Found {} encoders from hwcodec", all_encoders.len());
|
||||
|
||||
for codec_info in &all_encoders {
|
||||
if let Some(encoder) = AvailableEncoder::from_codec_info(codec_info) {
|
||||
debug!(
|
||||
"Detected encoder: {} ({}) - {} priority={}",
|
||||
encoder.codec_name,
|
||||
encoder.format,
|
||||
encoder.backend,
|
||||
encoder.priority
|
||||
);
|
||||
|
||||
self.encoders
|
||||
.entry(encoder.format)
|
||||
.or_default()
|
||||
.push(encoder);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort encoders by priority (lower is better)
|
||||
for encoders in self.encoders.values_mut() {
|
||||
encoders.sort_by_key(|e| e.priority);
|
||||
}
|
||||
|
||||
// Register software encoders as fallback
|
||||
info!("Registering software encoders...");
|
||||
let software_encoders = [
|
||||
(VideoEncoderType::H264, "libx264", 100),
|
||||
(VideoEncoderType::H265, "libx265", 100),
|
||||
(VideoEncoderType::VP8, "libvpx", 100),
|
||||
(VideoEncoderType::VP9, "libvpx-vp9", 100),
|
||||
];
|
||||
|
||||
for (format, codec_name, priority) in software_encoders {
|
||||
self.encoders
|
||||
.entry(format)
|
||||
.or_default()
|
||||
.push(AvailableEncoder {
|
||||
format,
|
||||
codec_name: codec_name.to_string(),
|
||||
backend: EncoderBackend::Software,
|
||||
priority,
|
||||
is_hardware: false,
|
||||
});
|
||||
|
||||
debug!(
|
||||
"Registered software encoder: {} for {} (priority: {})",
|
||||
codec_name, format, priority
|
||||
);
|
||||
}
|
||||
|
||||
// Log summary
|
||||
for (format, encoders) in &self.encoders {
|
||||
let hw_count = encoders.iter().filter(|e| e.is_hardware).count();
|
||||
let sw_count = encoders.len() - hw_count;
|
||||
info!(
|
||||
"{}: {} encoders ({} hardware, {} software)",
|
||||
format,
|
||||
encoders.len(),
|
||||
hw_count,
|
||||
sw_count
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the best encoder for a format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `format` - The video format to encode
|
||||
/// * `hardware_only` - If true, only return hardware encoders
|
||||
///
|
||||
/// # Returns
|
||||
/// The best available encoder, or None if no suitable encoder is found
|
||||
pub fn best_encoder(
|
||||
&self,
|
||||
format: VideoEncoderType,
|
||||
hardware_only: bool,
|
||||
) -> Option<&AvailableEncoder> {
|
||||
self.encoders.get(&format)?.iter().find(|e| {
|
||||
if hardware_only {
|
||||
e.is_hardware
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Get all encoders for a format
|
||||
pub fn encoders_for_format(&self, format: VideoEncoderType) -> &[AvailableEncoder] {
|
||||
self.encoders
|
||||
.get(&format)
|
||||
.map(|v| v.as_slice())
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
/// Get all available formats
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `hardware_only` - If true, only return formats with hardware encoders
|
||||
pub fn available_formats(&self, hardware_only: bool) -> Vec<VideoEncoderType> {
|
||||
self.encoders
|
||||
.iter()
|
||||
.filter(|(_, encoders)| {
|
||||
if hardware_only {
|
||||
encoders.iter().any(|e| e.is_hardware)
|
||||
} else {
|
||||
!encoders.is_empty()
|
||||
}
|
||||
})
|
||||
.map(|(format, _)| *format)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if a format is available
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `format` - The video format to check
|
||||
/// * `hardware_only` - If true, only check for hardware encoders
|
||||
pub fn is_format_available(&self, format: VideoEncoderType, hardware_only: bool) -> bool {
|
||||
self.best_encoder(format, hardware_only).is_some()
|
||||
}
|
||||
|
||||
/// Get available formats for user selection
|
||||
///
|
||||
/// Returns formats that are actually usable based on their requirements:
|
||||
/// - H264: Available if any encoder exists (hardware or software)
|
||||
/// - H265/VP8/VP9: Available only if hardware encoder exists
|
||||
pub fn selectable_formats(&self) -> Vec<VideoEncoderType> {
|
||||
let mut formats = Vec::new();
|
||||
|
||||
// H264 - supports software fallback
|
||||
if self.is_format_available(VideoEncoderType::H264, false) {
|
||||
formats.push(VideoEncoderType::H264);
|
||||
}
|
||||
|
||||
// H265/VP8/VP9 - hardware only
|
||||
for format in [
|
||||
VideoEncoderType::H265,
|
||||
VideoEncoderType::VP8,
|
||||
VideoEncoderType::VP9,
|
||||
] {
|
||||
if self.is_format_available(format, true) {
|
||||
formats.push(format);
|
||||
}
|
||||
}
|
||||
|
||||
formats
|
||||
}
|
||||
|
||||
/// Get detection resolution
|
||||
pub fn detection_resolution(&self) -> (u32, u32) {
|
||||
self.detection_resolution
|
||||
}
|
||||
|
||||
/// Get all available backend types
|
||||
pub fn available_backends(&self) -> Vec<EncoderBackend> {
|
||||
use std::collections::HashSet;
|
||||
|
||||
let mut backends = HashSet::new();
|
||||
for encoders in self.encoders.values() {
|
||||
for encoder in encoders {
|
||||
backends.insert(encoder.backend);
|
||||
}
|
||||
}
|
||||
|
||||
let mut result: Vec<_> = backends.into_iter().collect();
|
||||
// Sort: hardware backends first, software last
|
||||
result.sort_by_key(|b| if b.is_hardware() { 0 } else { 1 });
|
||||
result
|
||||
}
|
||||
|
||||
/// Get formats supported by a specific backend
|
||||
pub fn formats_for_backend(&self, backend: EncoderBackend) -> Vec<VideoEncoderType> {
|
||||
let mut formats = Vec::new();
|
||||
for (format, encoders) in &self.encoders {
|
||||
if encoders.iter().any(|e| e.backend == backend) {
|
||||
formats.push(*format);
|
||||
}
|
||||
}
|
||||
formats
|
||||
}
|
||||
|
||||
/// Get encoder for a format with specific backend
|
||||
pub fn encoder_with_backend(
|
||||
&self,
|
||||
format: VideoEncoderType,
|
||||
backend: EncoderBackend,
|
||||
) -> Option<&AvailableEncoder> {
|
||||
self.encoders
|
||||
.get(&format)?
|
||||
.iter()
|
||||
.find(|e| e.backend == backend)
|
||||
}
|
||||
|
||||
/// Get encoders grouped by backend for a format
|
||||
pub fn encoders_by_backend(
|
||||
&self,
|
||||
format: VideoEncoderType,
|
||||
) -> HashMap<EncoderBackend, Vec<&AvailableEncoder>> {
|
||||
let mut grouped = HashMap::new();
|
||||
if let Some(encoders) = self.encoders.get(&format) {
|
||||
for encoder in encoders {
|
||||
grouped
|
||||
.entry(encoder.backend)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(encoder);
|
||||
}
|
||||
}
|
||||
grouped
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EncoderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_video_encoder_type_display() {
|
||||
assert_eq!(VideoEncoderType::H264.display_name(), "H.264");
|
||||
assert_eq!(VideoEncoderType::H265.display_name(), "H.265/HEVC");
|
||||
assert_eq!(VideoEncoderType::VP8.display_name(), "VP8");
|
||||
assert_eq!(VideoEncoderType::VP9.display_name(), "VP9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encoder_backend_detection() {
|
||||
assert_eq!(
|
||||
EncoderBackend::from_codec_name("h264_vaapi"),
|
||||
EncoderBackend::Vaapi
|
||||
);
|
||||
assert_eq!(
|
||||
EncoderBackend::from_codec_name("hevc_nvenc"),
|
||||
EncoderBackend::Nvenc
|
||||
);
|
||||
assert_eq!(
|
||||
EncoderBackend::from_codec_name("h264_qsv"),
|
||||
EncoderBackend::Qsv
|
||||
);
|
||||
assert_eq!(
|
||||
EncoderBackend::from_codec_name("libx264"),
|
||||
EncoderBackend::Software
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hardware_only_requirement() {
|
||||
assert!(!VideoEncoderType::H264.hardware_only());
|
||||
assert!(VideoEncoderType::H265.hardware_only());
|
||||
assert!(VideoEncoderType::VP8.hardware_only());
|
||||
assert!(VideoEncoderType::VP9.hardware_only());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_detection() {
|
||||
let mut registry = EncoderRegistry::new();
|
||||
registry.detect_encoders(1280, 720);
|
||||
|
||||
// Should have detected at least H264 (software fallback available)
|
||||
println!("Available formats: {:?}", registry.available_formats(false));
|
||||
println!(
|
||||
"Selectable formats: {:?}",
|
||||
registry.selectable_formats()
|
||||
);
|
||||
}
|
||||
}
|
||||
188
src/video/encoder/traits.rs
Normal file
188
src/video/encoder/traits.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Encoder traits and common types
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::error::Result;
|
||||
|
||||
/// Encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncoderConfig {
|
||||
/// Target resolution
|
||||
pub resolution: Resolution,
|
||||
/// Input pixel format
|
||||
pub input_format: PixelFormat,
|
||||
/// Output quality (1-100 for JPEG, bitrate kbps for H264)
|
||||
pub quality: u32,
|
||||
/// Target frame rate
|
||||
pub fps: u32,
|
||||
/// Keyframe interval (for H264)
|
||||
pub gop_size: u32,
|
||||
}
|
||||
|
||||
impl Default for EncoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
resolution: Resolution::HD1080,
|
||||
input_format: PixelFormat::Yuyv,
|
||||
quality: 80,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
pub fn jpeg(resolution: Resolution, quality: u32) -> Self {
|
||||
Self {
|
||||
resolution,
|
||||
input_format: PixelFormat::Yuyv,
|
||||
quality,
|
||||
fps: 30,
|
||||
gop_size: 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn h264(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
resolution,
|
||||
input_format: PixelFormat::Yuyv,
|
||||
quality: bitrate_kbps,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoded frame output
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncodedFrame {
|
||||
/// Encoded data
|
||||
pub data: Bytes,
|
||||
/// Output format (JPEG, H264, etc.)
|
||||
pub format: EncodedFormat,
|
||||
/// Resolution
|
||||
pub resolution: Resolution,
|
||||
/// Whether this is a key frame
|
||||
pub key_frame: bool,
|
||||
/// Frame sequence number
|
||||
pub sequence: u64,
|
||||
/// Encoding timestamp
|
||||
pub timestamp: Instant,
|
||||
/// Presentation timestamp (for video sync)
|
||||
pub pts: u64,
|
||||
/// Decode timestamp (for B-frames)
|
||||
pub dts: u64,
|
||||
}
|
||||
|
||||
impl EncodedFrame {
|
||||
pub fn jpeg(data: Bytes, resolution: Resolution, sequence: u64) -> Self {
|
||||
Self {
|
||||
data,
|
||||
format: EncodedFormat::Jpeg,
|
||||
resolution,
|
||||
key_frame: true,
|
||||
sequence,
|
||||
timestamp: Instant::now(),
|
||||
pts: sequence,
|
||||
dts: sequence,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn h264(
|
||||
data: Bytes,
|
||||
resolution: Resolution,
|
||||
key_frame: bool,
|
||||
sequence: u64,
|
||||
pts: u64,
|
||||
dts: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
data,
|
||||
format: EncodedFormat::H264,
|
||||
resolution,
|
||||
key_frame,
|
||||
sequence,
|
||||
timestamp: Instant::now(),
|
||||
pts,
|
||||
dts,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Encoded output format
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum EncodedFormat {
|
||||
Jpeg,
|
||||
H264,
|
||||
H265,
|
||||
Vp8,
|
||||
Vp9,
|
||||
Av1,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EncodedFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EncodedFormat::Jpeg => write!(f, "JPEG"),
|
||||
EncodedFormat::H264 => write!(f, "H.264"),
|
||||
EncodedFormat::H265 => write!(f, "H.265"),
|
||||
EncodedFormat::Vp8 => write!(f, "VP8"),
|
||||
EncodedFormat::Vp9 => write!(f, "VP9"),
|
||||
EncodedFormat::Av1 => write!(f, "AV1"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic encoder trait
|
||||
/// Note: Not Sync because some encoders (like turbojpeg) are not thread-safe
|
||||
pub trait Encoder: Send {
|
||||
/// Get encoder name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get output format
|
||||
fn output_format(&self) -> EncodedFormat;
|
||||
|
||||
/// Encode a raw frame
|
||||
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame>;
|
||||
|
||||
/// Flush any pending frames
|
||||
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
/// Reset encoder state
|
||||
fn reset(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
fn config(&self) -> &EncoderConfig;
|
||||
|
||||
/// Check if encoder supports the given input format
|
||||
fn supports_format(&self, format: PixelFormat) -> bool;
|
||||
}
|
||||
|
||||
/// Encoder factory for creating encoders
|
||||
pub trait EncoderFactory: Send + Sync {
|
||||
/// Create an encoder with the given configuration
|
||||
fn create(&self, config: EncoderConfig) -> Result<Box<dyn Encoder>>;
|
||||
|
||||
/// Get encoder type name
|
||||
fn encoder_type(&self) -> &str;
|
||||
|
||||
/// Check if this encoder is available on the system
|
||||
fn is_available(&self) -> bool;
|
||||
|
||||
/// Get encoder priority (higher = preferred)
|
||||
fn priority(&self) -> u32;
|
||||
}
|
||||
488
src/video/encoder/vp8.rs
Normal file
488
src/video/encoder/vp8.rs
Normal file
@@ -0,0 +1,488 @@
|
||||
//! VP8 encoder using hwcodec (FFmpeg wrapper)
|
||||
//!
|
||||
//! Supports both hardware and software encoding:
|
||||
//! - Hardware: VAAPI (Intel on Linux)
|
||||
//! - Software: libvpx (CPU-based, high CPU usage)
|
||||
//!
|
||||
//! Hardware encoding is preferred when available for better performance.
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::sync::Once;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use hwcodec::common::{DataFormat, Quality, RateControl};
|
||||
use hwcodec::ffmpeg::AVPixelFormat;
|
||||
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
|
||||
use hwcodec::ffmpeg_ram::CodecInfo;
|
||||
|
||||
use super::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
|
||||
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize hwcodec logging (only once)
|
||||
fn init_hwcodec_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
debug!("hwcodec logging initialized for VP8");
|
||||
});
|
||||
}
|
||||
|
||||
/// VP8 encoder type (detected from hwcodec)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum VP8EncoderType {
|
||||
/// VAAPI (Intel on Linux)
|
||||
Vaapi,
|
||||
/// Software encoder (libvpx)
|
||||
Software,
|
||||
/// No encoder available
|
||||
None,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VP8EncoderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
VP8EncoderType::Vaapi => write!(f, "VAAPI"),
|
||||
VP8EncoderType::Software => write!(f, "Software"),
|
||||
VP8EncoderType::None => write!(f, "None"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VP8EncoderType {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncoderBackend> for VP8EncoderType {
|
||||
fn from(backend: EncoderBackend) -> Self {
|
||||
match backend {
|
||||
EncoderBackend::Vaapi => VP8EncoderType::Vaapi,
|
||||
EncoderBackend::Software => VP8EncoderType::Software,
|
||||
_ => VP8EncoderType::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Input pixel format for VP8 encoder
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VP8InputFormat {
|
||||
/// YUV420P (I420) - planar Y, U, V
|
||||
Yuv420p,
|
||||
/// NV12 - Y plane + interleaved UV plane
|
||||
Nv12,
|
||||
}
|
||||
|
||||
impl Default for VP8InputFormat {
|
||||
fn default() -> Self {
|
||||
Self::Nv12 // Default to NV12 for VAAPI compatibility
|
||||
}
|
||||
}
|
||||
|
||||
/// VP8 encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VP8Config {
|
||||
/// Base encoder config
|
||||
pub base: EncoderConfig,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// GOP size (keyframe interval)
|
||||
pub gop_size: u32,
|
||||
/// Frame rate
|
||||
pub fps: u32,
|
||||
/// Input pixel format
|
||||
pub input_format: VP8InputFormat,
|
||||
}
|
||||
|
||||
impl Default for VP8Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::default(),
|
||||
bitrate_kbps: 8000,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: VP8InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VP8Config {
|
||||
/// Create config for low latency streaming with NV12 input
|
||||
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig {
|
||||
resolution,
|
||||
input_format: PixelFormat::Nv12,
|
||||
quality: bitrate_kbps,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
},
|
||||
bitrate_kbps,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: VP8InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set input format
|
||||
pub fn with_input_format(mut self, format: VP8InputFormat) -> Self {
|
||||
self.input_format = format;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Get available VP8 hardware encoders from hwcodec
|
||||
pub fn get_available_vp8_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: String::new(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
|
||||
align: 1,
|
||||
fps: 30,
|
||||
gop: 30,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: 2000,
|
||||
q: 23,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
let all_encoders = HwEncoder::available_encoders(ctx, None);
|
||||
|
||||
// Include both hardware and software VP8 encoders
|
||||
all_encoders
|
||||
.into_iter()
|
||||
.filter(|e| e.format == DataFormat::VP8)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Detect best available VP8 encoder (hardware preferred, software fallback)
|
||||
pub fn detect_best_vp8_encoder(width: u32, height: u32) -> (VP8EncoderType, Option<String>) {
|
||||
let encoders = get_available_vp8_encoders(width, height);
|
||||
|
||||
if encoders.is_empty() {
|
||||
warn!("No VP8 encoders available");
|
||||
return (VP8EncoderType::None, None);
|
||||
}
|
||||
|
||||
// Prefer hardware encoders (VAAPI) over software (libvpx)
|
||||
let codec = encoders
|
||||
.iter()
|
||||
.find(|e| e.name.contains("vaapi"))
|
||||
.or_else(|| encoders.first())
|
||||
.unwrap();
|
||||
|
||||
let encoder_type = if codec.name.contains("vaapi") {
|
||||
VP8EncoderType::Vaapi
|
||||
} else if codec.name.contains("libvpx") {
|
||||
VP8EncoderType::Software
|
||||
} else {
|
||||
VP8EncoderType::Software // Default to software for unknown
|
||||
};
|
||||
|
||||
info!(
|
||||
"Selected VP8 encoder: {} ({})",
|
||||
codec.name, encoder_type
|
||||
);
|
||||
(encoder_type, Some(codec.name.clone()))
|
||||
}
|
||||
|
||||
/// Check if VP8 hardware encoding is available
|
||||
pub fn is_vp8_available() -> bool {
|
||||
let registry = EncoderRegistry::global();
|
||||
registry.is_format_available(VideoEncoderType::VP8, true)
|
||||
}
|
||||
|
||||
/// Encoded frame from hwcodec (cloned for ownership)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HwEncodeFrame {
|
||||
pub data: Vec<u8>,
|
||||
pub pts: i64,
|
||||
pub key: i32,
|
||||
}
|
||||
|
||||
/// VP8 encoder using hwcodec (hardware only - VAAPI)
|
||||
pub struct VP8Encoder {
|
||||
/// hwcodec encoder instance
|
||||
inner: HwEncoder,
|
||||
/// Encoder configuration
|
||||
config: VP8Config,
|
||||
/// Detected encoder type
|
||||
encoder_type: VP8EncoderType,
|
||||
/// Codec name
|
||||
codec_name: String,
|
||||
/// Frame counter
|
||||
frame_count: u64,
|
||||
/// Required buffer length from hwcodec
|
||||
buffer_length: i32,
|
||||
}
|
||||
|
||||
impl VP8Encoder {
|
||||
/// Create a new VP8 encoder with automatic hardware codec detection
|
||||
///
|
||||
/// Returns an error if no hardware encoder is available.
|
||||
/// VP8 hardware encoding requires Intel VAAPI support.
|
||||
pub fn new(config: VP8Config) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
let (encoder_type, codec_name) = detect_best_vp8_encoder(width, height);
|
||||
|
||||
if encoder_type == VP8EncoderType::None {
|
||||
return Err(AppError::VideoError(
|
||||
"No VP8 encoder available. Please ensure FFmpeg is built with libvpx support.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let codec_name = codec_name.unwrap();
|
||||
Self::with_codec(config, &codec_name)
|
||||
}
|
||||
|
||||
/// Create encoder with specific codec name
|
||||
pub fn with_codec(config: VP8Config, codec_name: &str) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
// Determine if this is a software encoder
|
||||
let is_software = codec_name.contains("libvpx");
|
||||
|
||||
// Warn about software encoder performance
|
||||
if is_software {
|
||||
warn!(
|
||||
"Using software VP8 encoder (libvpx) - high CPU usage expected. \
|
||||
Hardware encoder is recommended for better performance."
|
||||
);
|
||||
}
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
// Software encoders (libvpx) require YUV420P, hardware (VAAPI) uses NV12
|
||||
let (pixfmt, actual_input_format) = if is_software {
|
||||
(AVPixelFormat::AV_PIX_FMT_YUV420P, VP8InputFormat::Yuv420p)
|
||||
} else {
|
||||
match config.input_format {
|
||||
VP8InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, VP8InputFormat::Nv12),
|
||||
VP8InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, VP8InputFormat::Yuv420p),
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"Creating VP8 encoder: {} at {}x{} @ {} kbps (input: {:?})",
|
||||
codec_name, width, height, config.bitrate_kbps, actual_input_format
|
||||
);
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: codec_name.to_string(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt,
|
||||
align: 1,
|
||||
fps: config.fps as i32,
|
||||
gop: config.gop_size as i32,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: config.bitrate_kbps as i32,
|
||||
q: 23,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
let inner = HwEncoder::new(ctx).map_err(|_| {
|
||||
AppError::VideoError(format!("Failed to create VP8 encoder: {}", codec_name))
|
||||
})?;
|
||||
|
||||
let buffer_length = inner.length;
|
||||
let backend = EncoderBackend::from_codec_name(codec_name);
|
||||
let encoder_type = VP8EncoderType::from(backend);
|
||||
|
||||
// Update config to reflect actual input format used
|
||||
let mut config = config;
|
||||
config.input_format = actual_input_format;
|
||||
|
||||
info!(
|
||||
"VP8 encoder created: {} (type: {}, buffer_length: {})",
|
||||
codec_name, encoder_type, buffer_length
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
config,
|
||||
encoder_type,
|
||||
codec_name: codec_name.to_string(),
|
||||
frame_count: 0,
|
||||
buffer_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with auto-detected encoder
|
||||
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
|
||||
let config = VP8Config::low_latency(resolution, bitrate_kbps);
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Get encoder type
|
||||
pub fn encoder_type(&self) -> &VP8EncoderType {
|
||||
&self.encoder_type
|
||||
}
|
||||
|
||||
/// Get codec name
|
||||
pub fn codec_name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically
|
||||
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
|
||||
AppError::VideoError("Failed to set VP8 bitrate".to_string())
|
||||
})?;
|
||||
self.config.bitrate_kbps = bitrate_kbps;
|
||||
debug!("VP8 bitrate updated to {} kbps", bitrate_kbps);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encode raw frame data
|
||||
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
if data.len() < self.buffer_length as usize {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Frame data too small: {} < {}",
|
||||
data.len(),
|
||||
self.buffer_length
|
||||
)));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
match self.inner.encode(data, pts_ms) {
|
||||
Ok(frames) => {
|
||||
let owned_frames: Vec<HwEncodeFrame> = frames
|
||||
.iter()
|
||||
.map(|f| HwEncodeFrame {
|
||||
data: f.data.clone(),
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect();
|
||||
Ok(owned_frames)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("VP8 encode failed: {}", e);
|
||||
Err(AppError::VideoError(format!("VP8 encode failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode NV12 data
|
||||
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
self.encode_raw(nv12_data, pts_ms)
|
||||
}
|
||||
|
||||
/// Get input format
|
||||
pub fn input_format(&self) -> VP8InputFormat {
|
||||
self.config.input_format
|
||||
}
|
||||
|
||||
/// Get buffer info
|
||||
pub fn buffer_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
|
||||
(
|
||||
self.inner.linesize.clone(),
|
||||
self.inner.offset.clone(),
|
||||
self.inner.length,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: VP8Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
|
||||
// that are not Send by default. However, we ensure that VP8Encoder is only used from
|
||||
// a single task/thread at a time (encoding is sequential), so this is safe.
|
||||
unsafe impl Send for VP8Encoder {}
|
||||
|
||||
impl Encoder for VP8Encoder {
|
||||
fn name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
fn output_format(&self) -> EncodedFormat {
|
||||
EncodedFormat::Vp8
|
||||
}
|
||||
|
||||
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
|
||||
|
||||
let frames = self.encode_raw(data, pts_ms)?;
|
||||
|
||||
if frames.is_empty() {
|
||||
warn!("VP8 encoder returned no frames");
|
||||
return Err(AppError::VideoError(
|
||||
"VP8 encoder returned no frames".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let frame = &frames[0];
|
||||
let key_frame = frame.key == 1;
|
||||
|
||||
Ok(EncodedFrame {
|
||||
data: Bytes::from(frame.data.clone()),
|
||||
format: EncodedFormat::Vp8,
|
||||
resolution: self.config.base.resolution,
|
||||
key_frame,
|
||||
sequence,
|
||||
timestamp: std::time::Instant::now(),
|
||||
pts: frame.pts as u64,
|
||||
dts: frame.pts as u64,
|
||||
})
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<()> {
|
||||
self.frame_count = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn config(&self) -> &EncoderConfig {
|
||||
&self.config.base
|
||||
}
|
||||
|
||||
fn supports_format(&self, format: PixelFormat) -> bool {
|
||||
match self.config.input_format {
|
||||
VP8InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
|
||||
VP8InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_vp8_encoder() {
|
||||
let (encoder_type, codec_name) = detect_best_vp8_encoder(1280, 720);
|
||||
println!("Detected VP8 encoder: {:?} ({:?})", encoder_type, codec_name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_vp8_encoders() {
|
||||
let encoders = get_available_vp8_encoders(1280, 720);
|
||||
println!("Available VP8 hardware encoders:");
|
||||
for enc in &encoders {
|
||||
println!(" - {} ({:?})", enc.name, enc.format);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vp8_availability() {
|
||||
let available = is_vp8_available();
|
||||
println!("VP8 hardware encoding available: {}", available);
|
||||
}
|
||||
}
|
||||
488
src/video/encoder/vp9.rs
Normal file
488
src/video/encoder/vp9.rs
Normal file
@@ -0,0 +1,488 @@
|
||||
//! VP9 encoder using hwcodec (FFmpeg wrapper)
|
||||
//!
|
||||
//! Supports both hardware and software encoding:
|
||||
//! - Hardware: VAAPI (Intel on Linux)
|
||||
//! - Software: libvpx-vp9 (CPU-based, high CPU usage)
|
||||
//!
|
||||
//! Hardware encoding is preferred when available for better performance.
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::sync::Once;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use hwcodec::common::{DataFormat, Quality, RateControl};
|
||||
use hwcodec::ffmpeg::AVPixelFormat;
|
||||
use hwcodec::ffmpeg_ram::encode::{EncodeContext, Encoder as HwEncoder};
|
||||
use hwcodec::ffmpeg_ram::CodecInfo;
|
||||
|
||||
use super::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
|
||||
use super::traits::{EncodedFormat, EncodedFrame, Encoder, EncoderConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
static INIT_LOGGING: Once = Once::new();
|
||||
|
||||
/// Initialize hwcodec logging (only once)
|
||||
fn init_hwcodec_logging() {
|
||||
INIT_LOGGING.call_once(|| {
|
||||
debug!("hwcodec logging initialized for VP9");
|
||||
});
|
||||
}
|
||||
|
||||
/// VP9 encoder type (detected from hwcodec)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum VP9EncoderType {
|
||||
/// VAAPI (Intel on Linux)
|
||||
Vaapi,
|
||||
/// Software encoder (libvpx-vp9)
|
||||
Software,
|
||||
/// No encoder available
|
||||
None,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VP9EncoderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
VP9EncoderType::Vaapi => write!(f, "VAAPI"),
|
||||
VP9EncoderType::Software => write!(f, "Software"),
|
||||
VP9EncoderType::None => write!(f, "None"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VP9EncoderType {
|
||||
fn default() -> Self {
|
||||
Self::None
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncoderBackend> for VP9EncoderType {
|
||||
fn from(backend: EncoderBackend) -> Self {
|
||||
match backend {
|
||||
EncoderBackend::Vaapi => VP9EncoderType::Vaapi,
|
||||
EncoderBackend::Software => VP9EncoderType::Software,
|
||||
_ => VP9EncoderType::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Input pixel format for VP9 encoder
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VP9InputFormat {
|
||||
/// YUV420P (I420) - planar Y, U, V
|
||||
Yuv420p,
|
||||
/// NV12 - Y plane + interleaved UV plane
|
||||
Nv12,
|
||||
}
|
||||
|
||||
impl Default for VP9InputFormat {
|
||||
fn default() -> Self {
|
||||
Self::Nv12 // Default to NV12 for VAAPI compatibility
|
||||
}
|
||||
}
|
||||
|
||||
/// VP9 encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VP9Config {
|
||||
/// Base encoder config
|
||||
pub base: EncoderConfig,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// GOP size (keyframe interval)
|
||||
pub gop_size: u32,
|
||||
/// Frame rate
|
||||
pub fps: u32,
|
||||
/// Input pixel format
|
||||
pub input_format: VP9InputFormat,
|
||||
}
|
||||
|
||||
impl Default for VP9Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base: EncoderConfig::default(),
|
||||
bitrate_kbps: 8000,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: VP9InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VP9Config {
|
||||
/// Create config for low latency streaming with NV12 input
|
||||
pub fn low_latency(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
base: EncoderConfig {
|
||||
resolution,
|
||||
input_format: PixelFormat::Nv12,
|
||||
quality: bitrate_kbps,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
},
|
||||
bitrate_kbps,
|
||||
gop_size: 30,
|
||||
fps: 30,
|
||||
input_format: VP9InputFormat::Nv12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set input format
|
||||
pub fn with_input_format(mut self, format: VP9InputFormat) -> Self {
|
||||
self.input_format = format;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Get available VP9 hardware encoders from hwcodec
|
||||
pub fn get_available_vp9_encoders(width: u32, height: u32) -> Vec<CodecInfo> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: String::new(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt: AVPixelFormat::AV_PIX_FMT_NV12,
|
||||
align: 1,
|
||||
fps: 30,
|
||||
gop: 30,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: 2000,
|
||||
q: 23,
|
||||
thread_count: 1,
|
||||
};
|
||||
|
||||
let all_encoders = HwEncoder::available_encoders(ctx, None);
|
||||
|
||||
// Include both hardware and software VP9 encoders
|
||||
all_encoders
|
||||
.into_iter()
|
||||
.filter(|e| e.format == DataFormat::VP9)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Detect best available VP9 encoder (hardware preferred, software fallback)
|
||||
pub fn detect_best_vp9_encoder(width: u32, height: u32) -> (VP9EncoderType, Option<String>) {
|
||||
let encoders = get_available_vp9_encoders(width, height);
|
||||
|
||||
if encoders.is_empty() {
|
||||
warn!("No VP9 encoders available");
|
||||
return (VP9EncoderType::None, None);
|
||||
}
|
||||
|
||||
// Prefer hardware encoders (VAAPI) over software (libvpx-vp9)
|
||||
let codec = encoders
|
||||
.iter()
|
||||
.find(|e| e.name.contains("vaapi"))
|
||||
.or_else(|| encoders.first())
|
||||
.unwrap();
|
||||
|
||||
let encoder_type = if codec.name.contains("vaapi") {
|
||||
VP9EncoderType::Vaapi
|
||||
} else if codec.name.contains("libvpx") {
|
||||
VP9EncoderType::Software
|
||||
} else {
|
||||
VP9EncoderType::Software // Default to software for unknown
|
||||
};
|
||||
|
||||
info!(
|
||||
"Selected VP9 encoder: {} ({})",
|
||||
codec.name, encoder_type
|
||||
);
|
||||
(encoder_type, Some(codec.name.clone()))
|
||||
}
|
||||
|
||||
/// Check if VP9 hardware encoding is available
|
||||
pub fn is_vp9_available() -> bool {
|
||||
let registry = EncoderRegistry::global();
|
||||
registry.is_format_available(VideoEncoderType::VP9, true)
|
||||
}
|
||||
|
||||
/// Encoded frame from hwcodec (cloned for ownership)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HwEncodeFrame {
|
||||
pub data: Vec<u8>,
|
||||
pub pts: i64,
|
||||
pub key: i32,
|
||||
}
|
||||
|
||||
/// VP9 encoder using hwcodec (hardware only - VAAPI)
|
||||
pub struct VP9Encoder {
|
||||
/// hwcodec encoder instance
|
||||
inner: HwEncoder,
|
||||
/// Encoder configuration
|
||||
config: VP9Config,
|
||||
/// Detected encoder type
|
||||
encoder_type: VP9EncoderType,
|
||||
/// Codec name
|
||||
codec_name: String,
|
||||
/// Frame counter
|
||||
frame_count: u64,
|
||||
/// Required buffer length from hwcodec
|
||||
buffer_length: i32,
|
||||
}
|
||||
|
||||
impl VP9Encoder {
|
||||
/// Create a new VP9 encoder with automatic hardware codec detection
|
||||
///
|
||||
/// Returns an error if no hardware encoder is available.
|
||||
/// VP9 hardware encoding requires Intel VAAPI support.
|
||||
pub fn new(config: VP9Config) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
let (encoder_type, codec_name) = detect_best_vp9_encoder(width, height);
|
||||
|
||||
if encoder_type == VP9EncoderType::None {
|
||||
return Err(AppError::VideoError(
|
||||
"No VP9 encoder available. Please ensure FFmpeg is built with libvpx support.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let codec_name = codec_name.unwrap();
|
||||
Self::with_codec(config, &codec_name)
|
||||
}
|
||||
|
||||
/// Create encoder with specific codec name
|
||||
pub fn with_codec(config: VP9Config, codec_name: &str) -> Result<Self> {
|
||||
init_hwcodec_logging();
|
||||
|
||||
// Determine if this is a software encoder
|
||||
let is_software = codec_name.contains("libvpx");
|
||||
|
||||
// Warn about software encoder performance
|
||||
if is_software {
|
||||
warn!(
|
||||
"Using software VP9 encoder (libvpx-vp9) - high CPU usage expected. \
|
||||
Hardware encoder is recommended for better performance."
|
||||
);
|
||||
}
|
||||
|
||||
let width = config.base.resolution.width;
|
||||
let height = config.base.resolution.height;
|
||||
|
||||
// Software encoders (libvpx-vp9) require YUV420P, hardware (VAAPI) uses NV12
|
||||
let (pixfmt, actual_input_format) = if is_software {
|
||||
(AVPixelFormat::AV_PIX_FMT_YUV420P, VP9InputFormat::Yuv420p)
|
||||
} else {
|
||||
match config.input_format {
|
||||
VP9InputFormat::Nv12 => (AVPixelFormat::AV_PIX_FMT_NV12, VP9InputFormat::Nv12),
|
||||
VP9InputFormat::Yuv420p => (AVPixelFormat::AV_PIX_FMT_YUV420P, VP9InputFormat::Yuv420p),
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"Creating VP9 encoder: {} at {}x{} @ {} kbps (input: {:?})",
|
||||
codec_name, width, height, config.bitrate_kbps, actual_input_format
|
||||
);
|
||||
|
||||
let ctx = EncodeContext {
|
||||
name: codec_name.to_string(),
|
||||
mc_name: None,
|
||||
width: width as i32,
|
||||
height: height as i32,
|
||||
pixfmt,
|
||||
align: 1,
|
||||
fps: config.fps as i32,
|
||||
gop: config.gop_size as i32,
|
||||
rc: RateControl::RC_CBR,
|
||||
quality: Quality::Quality_Default,
|
||||
kbs: config.bitrate_kbps as i32,
|
||||
q: 31,
|
||||
thread_count: 4, // VP9 benefits from multi-threading
|
||||
};
|
||||
|
||||
let inner = HwEncoder::new(ctx).map_err(|_| {
|
||||
AppError::VideoError(format!("Failed to create VP9 encoder: {}", codec_name))
|
||||
})?;
|
||||
|
||||
let buffer_length = inner.length;
|
||||
let backend = EncoderBackend::from_codec_name(codec_name);
|
||||
let encoder_type = VP9EncoderType::from(backend);
|
||||
|
||||
// Update config to reflect actual input format used
|
||||
let mut config = config;
|
||||
config.input_format = actual_input_format;
|
||||
|
||||
info!(
|
||||
"VP9 encoder created: {} (type: {}, buffer_length: {})",
|
||||
codec_name, encoder_type, buffer_length
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
config,
|
||||
encoder_type,
|
||||
codec_name: codec_name.to_string(),
|
||||
frame_count: 0,
|
||||
buffer_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with auto-detected encoder
|
||||
pub fn auto(resolution: Resolution, bitrate_kbps: u32) -> Result<Self> {
|
||||
let config = VP9Config::low_latency(resolution, bitrate_kbps);
|
||||
Self::new(config)
|
||||
}
|
||||
|
||||
/// Get encoder type
|
||||
pub fn encoder_type(&self) -> &VP9EncoderType {
|
||||
&self.encoder_type
|
||||
}
|
||||
|
||||
/// Get codec name
|
||||
pub fn codec_name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically
|
||||
pub fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.inner.set_bitrate(bitrate_kbps as i32).map_err(|_| {
|
||||
AppError::VideoError("Failed to set VP9 bitrate".to_string())
|
||||
})?;
|
||||
self.config.bitrate_kbps = bitrate_kbps;
|
||||
debug!("VP9 bitrate updated to {} kbps", bitrate_kbps);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encode raw frame data
|
||||
pub fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
if data.len() < self.buffer_length as usize {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Frame data too small: {} < {}",
|
||||
data.len(),
|
||||
self.buffer_length
|
||||
)));
|
||||
}
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
match self.inner.encode(data, pts_ms) {
|
||||
Ok(frames) => {
|
||||
let owned_frames: Vec<HwEncodeFrame> = frames
|
||||
.iter()
|
||||
.map(|f| HwEncodeFrame {
|
||||
data: f.data.clone(),
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect();
|
||||
Ok(owned_frames)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("VP9 encode failed: {}", e);
|
||||
Err(AppError::VideoError(format!("VP9 encode failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode NV12 data
|
||||
pub fn encode_nv12(&mut self, nv12_data: &[u8], pts_ms: i64) -> Result<Vec<HwEncodeFrame>> {
|
||||
self.encode_raw(nv12_data, pts_ms)
|
||||
}
|
||||
|
||||
/// Get input format
|
||||
pub fn input_format(&self) -> VP9InputFormat {
|
||||
self.config.input_format
|
||||
}
|
||||
|
||||
/// Get buffer info
|
||||
pub fn buffer_info(&self) -> (Vec<i32>, Vec<i32>, i32) {
|
||||
(
|
||||
self.inner.linesize.clone(),
|
||||
self.inner.offset.clone(),
|
||||
self.inner.length,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: VP9Encoder contains hwcodec::ffmpeg_ram::encode::Encoder which has raw pointers
|
||||
// that are not Send by default. However, we ensure that VP9Encoder is only used from
|
||||
// a single task/thread at a time (encoding is sequential), so this is safe.
|
||||
unsafe impl Send for VP9Encoder {}
|
||||
|
||||
impl Encoder for VP9Encoder {
|
||||
fn name(&self) -> &str {
|
||||
&self.codec_name
|
||||
}
|
||||
|
||||
fn output_format(&self) -> EncodedFormat {
|
||||
EncodedFormat::Vp9
|
||||
}
|
||||
|
||||
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
|
||||
let pts_ms = (sequence * 1000 / self.config.fps as u64) as i64;
|
||||
|
||||
let frames = self.encode_raw(data, pts_ms)?;
|
||||
|
||||
if frames.is_empty() {
|
||||
warn!("VP9 encoder returned no frames");
|
||||
return Err(AppError::VideoError(
|
||||
"VP9 encoder returned no frames".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let frame = &frames[0];
|
||||
let key_frame = frame.key == 1;
|
||||
|
||||
Ok(EncodedFrame {
|
||||
data: Bytes::from(frame.data.clone()),
|
||||
format: EncodedFormat::Vp9,
|
||||
resolution: self.config.base.resolution,
|
||||
key_frame,
|
||||
sequence,
|
||||
timestamp: std::time::Instant::now(),
|
||||
pts: frame.pts as u64,
|
||||
dts: frame.pts as u64,
|
||||
})
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<Vec<EncodedFrame>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<()> {
|
||||
self.frame_count = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn config(&self) -> &EncoderConfig {
|
||||
&self.config.base
|
||||
}
|
||||
|
||||
fn supports_format(&self, format: PixelFormat) -> bool {
|
||||
match self.config.input_format {
|
||||
VP9InputFormat::Nv12 => matches!(format, PixelFormat::Nv12),
|
||||
VP9InputFormat::Yuv420p => matches!(format, PixelFormat::Yuv420),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_vp9_encoder() {
|
||||
let (encoder_type, codec_name) = detect_best_vp9_encoder(1280, 720);
|
||||
println!("Detected VP9 encoder: {:?} ({:?})", encoder_type, codec_name);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_vp9_encoders() {
|
||||
let encoders = get_available_vp9_encoders(1280, 720);
|
||||
println!("Available VP9 hardware encoders:");
|
||||
for enc in &encoders {
|
||||
println!(" - {} ({:?})", enc.name, enc.format);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vp9_availability() {
|
||||
let available = is_vp9_available();
|
||||
println!("VP9 hardware encoding available: {}", available);
|
||||
}
|
||||
}
|
||||
259
src/video/format.rs
Normal file
259
src/video/format.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
//! Pixel format definitions and conversions
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use v4l::format::fourcc;
|
||||
|
||||
/// Supported pixel formats
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum PixelFormat {
|
||||
/// MJPEG compressed format (preferred for capture cards)
|
||||
Mjpeg,
|
||||
/// JPEG compressed format
|
||||
Jpeg,
|
||||
/// YUYV 4:2:2 packed format
|
||||
Yuyv,
|
||||
/// YVYU 4:2:2 packed format
|
||||
Yvyu,
|
||||
/// UYVY 4:2:2 packed format
|
||||
Uyvy,
|
||||
/// NV12 semi-planar format (Y plane + interleaved UV)
|
||||
Nv12,
|
||||
/// NV16 semi-planar format
|
||||
Nv16,
|
||||
/// NV24 semi-planar format
|
||||
Nv24,
|
||||
/// YUV420 planar format
|
||||
Yuv420,
|
||||
/// YVU420 planar format
|
||||
Yvu420,
|
||||
/// RGB565 format
|
||||
Rgb565,
|
||||
/// RGB24 format (3 bytes per pixel)
|
||||
Rgb24,
|
||||
/// BGR24 format (3 bytes per pixel)
|
||||
Bgr24,
|
||||
/// Grayscale format
|
||||
Grey,
|
||||
}
|
||||
|
||||
impl PixelFormat {
|
||||
/// Convert to V4L2 FourCC
|
||||
pub fn to_fourcc(&self) -> fourcc::FourCC {
|
||||
match self {
|
||||
PixelFormat::Mjpeg => fourcc::FourCC::new(b"MJPG"),
|
||||
PixelFormat::Jpeg => fourcc::FourCC::new(b"JPEG"),
|
||||
PixelFormat::Yuyv => fourcc::FourCC::new(b"YUYV"),
|
||||
PixelFormat::Yvyu => fourcc::FourCC::new(b"YVYU"),
|
||||
PixelFormat::Uyvy => fourcc::FourCC::new(b"UYVY"),
|
||||
PixelFormat::Nv12 => fourcc::FourCC::new(b"NV12"),
|
||||
PixelFormat::Nv16 => fourcc::FourCC::new(b"NV16"),
|
||||
PixelFormat::Nv24 => fourcc::FourCC::new(b"NV24"),
|
||||
PixelFormat::Yuv420 => fourcc::FourCC::new(b"YU12"),
|
||||
PixelFormat::Yvu420 => fourcc::FourCC::new(b"YV12"),
|
||||
PixelFormat::Rgb565 => fourcc::FourCC::new(b"RGBP"),
|
||||
PixelFormat::Rgb24 => fourcc::FourCC::new(b"RGB3"),
|
||||
PixelFormat::Bgr24 => fourcc::FourCC::new(b"BGR3"),
|
||||
PixelFormat::Grey => fourcc::FourCC::new(b"GREY"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to convert from V4L2 FourCC
|
||||
pub fn from_fourcc(fourcc: fourcc::FourCC) -> Option<Self> {
|
||||
let repr = fourcc.repr;
|
||||
match &repr {
|
||||
b"MJPG" => Some(PixelFormat::Mjpeg),
|
||||
b"JPEG" => Some(PixelFormat::Jpeg),
|
||||
b"YUYV" => Some(PixelFormat::Yuyv),
|
||||
b"YVYU" => Some(PixelFormat::Yvyu),
|
||||
b"UYVY" => Some(PixelFormat::Uyvy),
|
||||
b"NV12" => Some(PixelFormat::Nv12),
|
||||
b"NV16" => Some(PixelFormat::Nv16),
|
||||
b"NV24" => Some(PixelFormat::Nv24),
|
||||
b"YU12" | b"I420" => Some(PixelFormat::Yuv420),
|
||||
b"YV12" => Some(PixelFormat::Yvu420),
|
||||
b"RGBP" => Some(PixelFormat::Rgb565),
|
||||
b"RGB3" => Some(PixelFormat::Rgb24),
|
||||
b"BGR3" => Some(PixelFormat::Bgr24),
|
||||
b"GREY" | b"Y800" => Some(PixelFormat::Grey),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if format is compressed (JPEG/MJPEG)
|
||||
pub fn is_compressed(&self) -> bool {
|
||||
matches!(self, PixelFormat::Mjpeg | PixelFormat::Jpeg)
|
||||
}
|
||||
|
||||
/// Get bytes per pixel for uncompressed formats
|
||||
/// Returns None for compressed formats
|
||||
pub fn bytes_per_pixel(&self) -> Option<usize> {
|
||||
match self {
|
||||
PixelFormat::Mjpeg | PixelFormat::Jpeg => None,
|
||||
PixelFormat::Yuyv | PixelFormat::Yvyu | PixelFormat::Uyvy => Some(2),
|
||||
PixelFormat::Nv12 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => None, // Variable
|
||||
PixelFormat::Nv16 => None,
|
||||
PixelFormat::Nv24 => None,
|
||||
PixelFormat::Rgb565 => Some(2),
|
||||
PixelFormat::Rgb24 | PixelFormat::Bgr24 => Some(3),
|
||||
PixelFormat::Grey => Some(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate expected frame size for a given resolution
|
||||
/// Returns None for compressed formats (variable size)
|
||||
pub fn frame_size(&self, resolution: Resolution) -> Option<usize> {
|
||||
let pixels = (resolution.width * resolution.height) as usize;
|
||||
match self {
|
||||
PixelFormat::Mjpeg | PixelFormat::Jpeg => None,
|
||||
PixelFormat::Yuyv | PixelFormat::Yvyu | PixelFormat::Uyvy => Some(pixels * 2),
|
||||
PixelFormat::Nv12 | PixelFormat::Yuv420 | PixelFormat::Yvu420 => Some(pixels * 3 / 2),
|
||||
PixelFormat::Nv16 => Some(pixels * 2),
|
||||
PixelFormat::Nv24 => Some(pixels * 3),
|
||||
PixelFormat::Rgb565 => Some(pixels * 2),
|
||||
PixelFormat::Rgb24 | PixelFormat::Bgr24 => Some(pixels * 3),
|
||||
PixelFormat::Grey => Some(pixels),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get priority for format selection (higher is better)
|
||||
/// MJPEG is preferred for HDMI capture cards
|
||||
pub fn priority(&self) -> u8 {
|
||||
match self {
|
||||
PixelFormat::Mjpeg => 100,
|
||||
PixelFormat::Jpeg => 99,
|
||||
PixelFormat::Yuyv => 80,
|
||||
PixelFormat::Nv12 => 75,
|
||||
PixelFormat::Yuv420 => 70,
|
||||
PixelFormat::Uyvy => 65,
|
||||
PixelFormat::Yvyu => 64,
|
||||
PixelFormat::Yvu420 => 63,
|
||||
PixelFormat::Nv16 => 60,
|
||||
PixelFormat::Nv24 => 55,
|
||||
PixelFormat::Rgb24 => 50,
|
||||
PixelFormat::Bgr24 => 49,
|
||||
PixelFormat::Rgb565 => 40,
|
||||
PixelFormat::Grey => 10,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all supported formats
|
||||
pub fn all() -> &'static [PixelFormat] {
|
||||
&[
|
||||
PixelFormat::Mjpeg,
|
||||
PixelFormat::Jpeg,
|
||||
PixelFormat::Yuyv,
|
||||
PixelFormat::Yvyu,
|
||||
PixelFormat::Uyvy,
|
||||
PixelFormat::Nv12,
|
||||
PixelFormat::Nv16,
|
||||
PixelFormat::Nv24,
|
||||
PixelFormat::Yuv420,
|
||||
PixelFormat::Yvu420,
|
||||
PixelFormat::Rgb565,
|
||||
PixelFormat::Rgb24,
|
||||
PixelFormat::Bgr24,
|
||||
PixelFormat::Grey,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PixelFormat {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let name = match self {
|
||||
PixelFormat::Mjpeg => "MJPEG",
|
||||
PixelFormat::Jpeg => "JPEG",
|
||||
PixelFormat::Yuyv => "YUYV",
|
||||
PixelFormat::Yvyu => "YVYU",
|
||||
PixelFormat::Uyvy => "UYVY",
|
||||
PixelFormat::Nv12 => "NV12",
|
||||
PixelFormat::Nv16 => "NV16",
|
||||
PixelFormat::Nv24 => "NV24",
|
||||
PixelFormat::Yuv420 => "YUV420",
|
||||
PixelFormat::Yvu420 => "YVU420",
|
||||
PixelFormat::Rgb565 => "RGB565",
|
||||
PixelFormat::Rgb24 => "RGB24",
|
||||
PixelFormat::Bgr24 => "BGR24",
|
||||
PixelFormat::Grey => "GREY",
|
||||
};
|
||||
write!(f, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for PixelFormat {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_uppercase().as_str() {
|
||||
"MJPEG" | "MJPG" => Ok(PixelFormat::Mjpeg),
|
||||
"JPEG" => Ok(PixelFormat::Jpeg),
|
||||
"YUYV" => Ok(PixelFormat::Yuyv),
|
||||
"YVYU" => Ok(PixelFormat::Yvyu),
|
||||
"UYVY" => Ok(PixelFormat::Uyvy),
|
||||
"NV12" => Ok(PixelFormat::Nv12),
|
||||
"NV16" => Ok(PixelFormat::Nv16),
|
||||
"NV24" => Ok(PixelFormat::Nv24),
|
||||
"YUV420" | "I420" => Ok(PixelFormat::Yuv420),
|
||||
"YVU420" | "YV12" => Ok(PixelFormat::Yvu420),
|
||||
"RGB565" => Ok(PixelFormat::Rgb565),
|
||||
"RGB24" => Ok(PixelFormat::Rgb24),
|
||||
"BGR24" => Ok(PixelFormat::Bgr24),
|
||||
"GREY" | "GRAY" => Ok(PixelFormat::Grey),
|
||||
_ => Err(format!("Unknown pixel format: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolution (width x height)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct Resolution {
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
impl Resolution {
|
||||
pub fn new(width: u32, height: u32) -> Self {
|
||||
Self { width, height }
|
||||
}
|
||||
|
||||
/// Check if resolution is valid
|
||||
pub fn is_valid(&self) -> bool {
|
||||
self.width >= 160 && self.width <= 15360 && self.height >= 120 && self.height <= 8640
|
||||
}
|
||||
|
||||
/// Get total pixels
|
||||
pub fn pixels(&self) -> u64 {
|
||||
self.width as u64 * self.height as u64
|
||||
}
|
||||
|
||||
/// Common resolutions
|
||||
pub const VGA: Resolution = Resolution {
|
||||
width: 640,
|
||||
height: 480,
|
||||
};
|
||||
pub const HD720: Resolution = Resolution {
|
||||
width: 1280,
|
||||
height: 720,
|
||||
};
|
||||
pub const HD1080: Resolution = Resolution {
|
||||
width: 1920,
|
||||
height: 1080,
|
||||
};
|
||||
pub const UHD4K: Resolution = Resolution {
|
||||
width: 3840,
|
||||
height: 2160,
|
||||
};
|
||||
}
|
||||
|
||||
impl fmt::Display for Resolution {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}x{}", self.width, self.height)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(u32, u32)> for Resolution {
|
||||
fn from((width, height): (u32, u32)) -> Self {
|
||||
Self { width, height }
|
||||
}
|
||||
}
|
||||
239
src/video/frame.rs
Normal file
239
src/video/frame.rs
Normal file
@@ -0,0 +1,239 @@
|
||||
//! Video frame data structures
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Instant;
|
||||
|
||||
use super::format::{PixelFormat, Resolution};
|
||||
|
||||
/// A video frame with metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VideoFrame {
|
||||
/// Raw frame data
|
||||
data: Arc<Bytes>,
|
||||
/// Cached xxHash64 of frame data (lazy computed for deduplication)
|
||||
hash: Arc<OnceLock<u64>>,
|
||||
/// Frame resolution
|
||||
pub resolution: Resolution,
|
||||
/// Pixel format
|
||||
pub format: PixelFormat,
|
||||
/// Stride (bytes per line)
|
||||
pub stride: u32,
|
||||
/// Whether this is a key frame (for compressed formats)
|
||||
pub key_frame: bool,
|
||||
/// Frame sequence number
|
||||
pub sequence: u64,
|
||||
/// Timestamp when frame was captured
|
||||
pub capture_ts: Instant,
|
||||
/// Whether capture is online (signal present)
|
||||
pub online: bool,
|
||||
}
|
||||
|
||||
impl VideoFrame {
|
||||
/// Create a new video frame
|
||||
pub fn new(
|
||||
data: Bytes,
|
||||
resolution: Resolution,
|
||||
format: PixelFormat,
|
||||
stride: u32,
|
||||
sequence: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
hash: Arc::new(OnceLock::new()),
|
||||
resolution,
|
||||
format,
|
||||
stride,
|
||||
key_frame: true,
|
||||
sequence,
|
||||
capture_ts: Instant::now(),
|
||||
online: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a frame from a Vec<u8>
|
||||
pub fn from_vec(
|
||||
data: Vec<u8>,
|
||||
resolution: Resolution,
|
||||
format: PixelFormat,
|
||||
stride: u32,
|
||||
sequence: u64,
|
||||
) -> Self {
|
||||
Self::new(Bytes::from(data), resolution, format, stride, sequence)
|
||||
}
|
||||
|
||||
/// Get frame data as bytes slice
|
||||
pub fn data(&self) -> &[u8] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Get frame data as Bytes (cheap clone)
|
||||
pub fn data_bytes(&self) -> Bytes {
|
||||
(*self.data).clone()
|
||||
}
|
||||
|
||||
/// Get data length
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Check if frame is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Get width
|
||||
pub fn width(&self) -> u32 {
|
||||
self.resolution.width
|
||||
}
|
||||
|
||||
/// Get height
|
||||
pub fn height(&self) -> u32 {
|
||||
self.resolution.height
|
||||
}
|
||||
|
||||
/// Get age of this frame (time since capture)
|
||||
pub fn age(&self) -> std::time::Duration {
|
||||
self.capture_ts.elapsed()
|
||||
}
|
||||
|
||||
/// Check if this frame is still fresh (within threshold)
|
||||
pub fn is_fresh(&self, max_age_ms: u64) -> bool {
|
||||
self.age().as_millis() < max_age_ms as u128
|
||||
}
|
||||
|
||||
/// Get hash of frame data (computed once, cached)
|
||||
/// Used for fast frame deduplication comparison
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
*self.hash.get_or_init(|| {
|
||||
xxhash_rust::xxh64::xxh64(self.data.as_ref(), 0)
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if format is JPEG/MJPEG
|
||||
pub fn is_jpeg(&self) -> bool {
|
||||
self.format.is_compressed()
|
||||
}
|
||||
|
||||
/// Validate JPEG frame data
|
||||
pub fn is_valid_jpeg(&self) -> bool {
|
||||
if !self.is_jpeg() {
|
||||
return false;
|
||||
}
|
||||
if self.data.len() < 125 {
|
||||
return false;
|
||||
}
|
||||
// Check JPEG header
|
||||
let start_marker = ((self.data[0] as u16) << 8) | self.data[1] as u16;
|
||||
if start_marker != 0xFFD8 {
|
||||
return false;
|
||||
}
|
||||
// Check JPEG end marker
|
||||
let end = self.data.len();
|
||||
let end_marker = ((self.data[end - 2] as u16) << 8) | self.data[end - 1] as u16;
|
||||
// Valid end markers: 0xFFD9, 0xD900, 0x0000 (padded)
|
||||
matches!(end_marker, 0xFFD9 | 0xD900 | 0x0000)
|
||||
}
|
||||
|
||||
/// Create an offline placeholder frame
|
||||
pub fn offline(resolution: Resolution, format: PixelFormat) -> Self {
|
||||
Self {
|
||||
data: Arc::new(Bytes::new()),
|
||||
hash: Arc::new(OnceLock::new()),
|
||||
resolution,
|
||||
format,
|
||||
stride: 0,
|
||||
key_frame: true,
|
||||
sequence: 0,
|
||||
capture_ts: Instant::now(),
|
||||
online: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Frame metadata without actual data (for logging/stats)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FrameMeta {
|
||||
pub resolution: Resolution,
|
||||
pub format: PixelFormat,
|
||||
pub size: usize,
|
||||
pub sequence: u64,
|
||||
pub key_frame: bool,
|
||||
pub online: bool,
|
||||
}
|
||||
|
||||
impl From<&VideoFrame> for FrameMeta {
|
||||
fn from(frame: &VideoFrame) -> Self {
|
||||
Self {
|
||||
resolution: frame.resolution,
|
||||
format: frame.format,
|
||||
size: frame.len(),
|
||||
sequence: frame.sequence,
|
||||
key_frame: frame.key_frame,
|
||||
online: frame.online,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ring buffer for storing recent frames
|
||||
pub struct FrameRing {
|
||||
frames: Vec<Option<VideoFrame>>,
|
||||
capacity: usize,
|
||||
write_pos: usize,
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl FrameRing {
|
||||
/// Create a new frame ring with specified capacity
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
assert!(capacity > 0, "Ring capacity must be > 0");
|
||||
Self {
|
||||
frames: (0..capacity).map(|_| None).collect(),
|
||||
capacity,
|
||||
write_pos: 0,
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a frame into the ring
|
||||
pub fn push(&mut self, frame: VideoFrame) {
|
||||
self.frames[self.write_pos] = Some(frame);
|
||||
self.write_pos = (self.write_pos + 1) % self.capacity;
|
||||
if self.count < self.capacity {
|
||||
self.count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the latest frame
|
||||
pub fn latest(&self) -> Option<&VideoFrame> {
|
||||
if self.count == 0 {
|
||||
return None;
|
||||
}
|
||||
let pos = if self.write_pos == 0 {
|
||||
self.capacity - 1
|
||||
} else {
|
||||
self.write_pos - 1
|
||||
};
|
||||
self.frames[pos].as_ref()
|
||||
}
|
||||
|
||||
/// Get number of frames in ring
|
||||
pub fn len(&self) -> usize {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Check if ring is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.count == 0
|
||||
}
|
||||
|
||||
/// Clear all frames
|
||||
pub fn clear(&mut self) {
|
||||
for frame in &mut self.frames {
|
||||
*frame = None;
|
||||
}
|
||||
self.write_pos = 0;
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
539
src/video/h264_pipeline.rs
Normal file
539
src/video/h264_pipeline.rs
Normal file
@@ -0,0 +1,539 @@
|
||||
//! H264 video encoding pipeline for WebRTC streaming
|
||||
//!
|
||||
//! This module provides a complete H264 encoding pipeline that connects:
|
||||
//! 1. Video capture (YUYV/MJPEG from V4L2)
|
||||
//! 2. Pixel conversion (YUYV → YUV420P) or JPEG decode
|
||||
//! 3. H264 encoding (via hwcodec)
|
||||
//! 4. RTP packetization and WebRTC track output
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{broadcast, watch, Mutex};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::convert::Nv12Converter;
|
||||
use crate::video::decoder::mjpeg::{MjpegVaapiDecoder, MjpegVaapiDecoderConfig};
|
||||
use crate::video::encoder::h264::{H264Config, H264Encoder};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::webrtc::rtp::{H264VideoTrack, H264VideoTrackConfig};
|
||||
|
||||
/// H264 pipeline configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct H264PipelineConfig {
|
||||
/// Input resolution
|
||||
pub resolution: Resolution,
|
||||
/// Input pixel format (YUYV, NV12, etc.)
|
||||
pub input_format: PixelFormat,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// Target FPS
|
||||
pub fps: u32,
|
||||
/// GOP size (keyframe interval in frames)
|
||||
pub gop_size: u32,
|
||||
/// Track ID for WebRTC
|
||||
pub track_id: String,
|
||||
/// Stream ID for WebRTC
|
||||
pub stream_id: String,
|
||||
}
|
||||
|
||||
impl Default for H264PipelineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
resolution: Resolution::HD720,
|
||||
input_format: PixelFormat::Yuyv,
|
||||
bitrate_kbps: 8000,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
track_id: "video0".to_string(),
|
||||
stream_id: "one-kvm-stream".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// H264 pipeline statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct H264PipelineStats {
|
||||
/// Total frames captured
|
||||
pub frames_captured: u64,
|
||||
/// Total frames encoded
|
||||
pub frames_encoded: u64,
|
||||
/// Frames dropped (encoding too slow)
|
||||
pub frames_dropped: u64,
|
||||
/// Total bytes encoded
|
||||
pub bytes_encoded: u64,
|
||||
/// Keyframes encoded
|
||||
pub keyframes_encoded: u64,
|
||||
/// Average encoding time per frame (ms)
|
||||
pub avg_encode_time_ms: f32,
|
||||
/// Current encoding FPS
|
||||
pub current_fps: f32,
|
||||
/// Errors encountered
|
||||
pub errors: u64,
|
||||
}
|
||||
|
||||
/// H264 video encoding pipeline
|
||||
pub struct H264Pipeline {
|
||||
config: H264PipelineConfig,
|
||||
/// H264 encoder instance
|
||||
encoder: Arc<Mutex<Option<H264Encoder>>>,
|
||||
/// NV12 converter (for BGR24/RGB24/YUYV → NV12)
|
||||
nv12_converter: Arc<Mutex<Option<Nv12Converter>>>,
|
||||
/// MJPEG VAAPI decoder (for MJPEG input, outputs NV12)
|
||||
mjpeg_decoder: Arc<Mutex<Option<MjpegVaapiDecoder>>>,
|
||||
/// WebRTC video track
|
||||
video_track: Arc<H264VideoTrack>,
|
||||
/// Pipeline statistics
|
||||
stats: Arc<Mutex<H264PipelineStats>>,
|
||||
/// Running state
|
||||
running: watch::Sender<bool>,
|
||||
/// Encode time accumulator for averaging
|
||||
encode_times: Arc<Mutex<Vec<f32>>>,
|
||||
}
|
||||
|
||||
impl H264Pipeline {
|
||||
/// Create a new H264 pipeline
|
||||
pub fn new(config: H264PipelineConfig) -> Result<Self> {
|
||||
info!(
|
||||
"Creating H264 pipeline: {}x{} @ {} kbps, {} fps",
|
||||
config.resolution.width,
|
||||
config.resolution.height,
|
||||
config.bitrate_kbps,
|
||||
config.fps
|
||||
);
|
||||
|
||||
// Determine encoder input format based on pipeline input
|
||||
// NV12 is optimal for VAAPI, use it for all formats
|
||||
// VAAPI encoders typically only support NV12 input
|
||||
let encoder_input_format = crate::video::encoder::h264::H264InputFormat::Nv12;
|
||||
|
||||
// Create H264 encoder with appropriate input format
|
||||
let encoder_config = H264Config {
|
||||
base: crate::video::encoder::traits::EncoderConfig::h264(
|
||||
config.resolution,
|
||||
config.bitrate_kbps,
|
||||
),
|
||||
bitrate_kbps: config.bitrate_kbps,
|
||||
gop_size: config.gop_size,
|
||||
fps: config.fps,
|
||||
input_format: encoder_input_format,
|
||||
};
|
||||
|
||||
let encoder = H264Encoder::new(encoder_config)?;
|
||||
info!(
|
||||
"H264 encoder created: {} ({}) with {:?} input",
|
||||
encoder.codec_name(),
|
||||
encoder.encoder_type(),
|
||||
encoder_input_format
|
||||
);
|
||||
|
||||
// Create NV12 converter or MJPEG decoder based on input format
|
||||
// All formats are converted to NV12 for VAAPI encoder
|
||||
let (nv12_converter, mjpeg_decoder) = match config.input_format {
|
||||
// NV12 input - direct passthrough
|
||||
PixelFormat::Nv12 => {
|
||||
info!("NV12 input: direct passthrough to encoder");
|
||||
(None, None)
|
||||
}
|
||||
|
||||
// YUYV (4:2:2 packed) → NV12
|
||||
PixelFormat::Yuyv => {
|
||||
info!("YUYV input: converting to NV12");
|
||||
(Some(Nv12Converter::yuyv_to_nv12(config.resolution)), None)
|
||||
}
|
||||
|
||||
// RGB24 → NV12
|
||||
PixelFormat::Rgb24 => {
|
||||
info!("RGB24 input: converting to NV12");
|
||||
(Some(Nv12Converter::rgb24_to_nv12(config.resolution)), None)
|
||||
}
|
||||
|
||||
// BGR24 → NV12
|
||||
PixelFormat::Bgr24 => {
|
||||
info!("BGR24 input: converting to NV12");
|
||||
(Some(Nv12Converter::bgr24_to_nv12(config.resolution)), None)
|
||||
}
|
||||
|
||||
// MJPEG/JPEG → NV12 (via hwcodec decoder)
|
||||
PixelFormat::Mjpeg | PixelFormat::Jpeg => {
|
||||
let decoder_config = MjpegVaapiDecoderConfig {
|
||||
resolution: config.resolution,
|
||||
use_hwaccel: true,
|
||||
};
|
||||
let decoder = MjpegVaapiDecoder::new(decoder_config)?;
|
||||
info!(
|
||||
"MJPEG decoder created for H264 pipeline (outputs NV12)"
|
||||
);
|
||||
(None, Some(decoder))
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Unsupported input format for H264 pipeline: {}",
|
||||
config.input_format
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Create WebRTC video track
|
||||
let track_config = H264VideoTrackConfig {
|
||||
track_id: config.track_id.clone(),
|
||||
stream_id: config.stream_id.clone(),
|
||||
resolution: config.resolution,
|
||||
bitrate_kbps: config.bitrate_kbps,
|
||||
fps: config.fps,
|
||||
profile_level_id: None, // Let browser negotiate the best profile
|
||||
};
|
||||
let video_track = Arc::new(H264VideoTrack::new(track_config));
|
||||
|
||||
let (running_tx, _) = watch::channel(false);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
encoder: Arc::new(Mutex::new(Some(encoder))),
|
||||
nv12_converter: Arc::new(Mutex::new(nv12_converter)),
|
||||
mjpeg_decoder: Arc::new(Mutex::new(mjpeg_decoder)),
|
||||
video_track,
|
||||
stats: Arc::new(Mutex::new(H264PipelineStats::default())),
|
||||
running: running_tx,
|
||||
encode_times: Arc::new(Mutex::new(Vec::with_capacity(100))),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the WebRTC video track
|
||||
pub fn video_track(&self) -> Arc<H264VideoTrack> {
|
||||
self.video_track.clone()
|
||||
}
|
||||
|
||||
/// Get current statistics
|
||||
pub async fn stats(&self) -> H264PipelineStats {
|
||||
self.stats.lock().await.clone()
|
||||
}
|
||||
|
||||
/// Check if pipeline is running
|
||||
pub fn is_running(&self) -> bool {
|
||||
*self.running.borrow()
|
||||
}
|
||||
|
||||
/// Start the encoding pipeline
|
||||
///
|
||||
/// This starts a background task that receives raw frames from the receiver,
|
||||
/// encodes them to H264, and sends them to the WebRTC track.
|
||||
pub async fn start(&self, mut frame_rx: broadcast::Receiver<Vec<u8>>) {
|
||||
if *self.running.borrow() {
|
||||
warn!("H264 pipeline already running");
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = self.running.send(true);
|
||||
info!("Starting H264 pipeline (input format: {})", self.config.input_format);
|
||||
|
||||
let encoder = self.encoder.lock().await.take();
|
||||
let nv12_converter = self.nv12_converter.lock().await.take();
|
||||
let mjpeg_decoder = self.mjpeg_decoder.lock().await.take();
|
||||
let video_track = self.video_track.clone();
|
||||
let stats = self.stats.clone();
|
||||
let encode_times = self.encode_times.clone();
|
||||
let config = self.config.clone();
|
||||
let mut running_rx = self.running.subscribe();
|
||||
|
||||
// Spawn encoding task
|
||||
tokio::spawn(async move {
|
||||
let mut encoder = match encoder {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
error!("No encoder available");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut nv12_converter = nv12_converter;
|
||||
let mut mjpeg_decoder = mjpeg_decoder;
|
||||
let mut frame_count: u64 = 0;
|
||||
let mut last_fps_time = Instant::now();
|
||||
let mut fps_frame_count: u64 = 0;
|
||||
|
||||
// Pre-allocated NV12 buffer for MJPEG decoder output (avoids per-frame allocation)
|
||||
let nv12_size = (config.resolution.width * config.resolution.height * 3 / 2) as usize;
|
||||
let mut nv12_buffer = vec![0u8; nv12_size];
|
||||
|
||||
// Flag for one-time warnings
|
||||
let mut size_mismatch_warned = false;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = running_rx.changed() => {
|
||||
if !*running_rx.borrow() {
|
||||
info!("H264 pipeline stopping");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result = frame_rx.recv() => {
|
||||
match result {
|
||||
Ok(raw_frame) => {
|
||||
let start = Instant::now();
|
||||
|
||||
// Validate frame size for uncompressed formats
|
||||
if let Some(expected_size) = config.input_format.frame_size(config.resolution) {
|
||||
if raw_frame.len() != expected_size && !size_mismatch_warned {
|
||||
warn!(
|
||||
"Frame size mismatch: got {} bytes, expected {} for {} {}x{}",
|
||||
raw_frame.len(),
|
||||
expected_size,
|
||||
config.input_format,
|
||||
config.resolution.width,
|
||||
config.resolution.height
|
||||
);
|
||||
size_mismatch_warned = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Update captured count
|
||||
{
|
||||
let mut s = stats.lock().await;
|
||||
s.frames_captured += 1;
|
||||
}
|
||||
|
||||
// Convert to NV12 for VAAPI encoder
|
||||
// MJPEG -> NV12 (via VAAPI decoder)
|
||||
// BGR24/RGB24/YUYV -> NV12 (via NV12 converter)
|
||||
// NV12 -> pass through
|
||||
//
|
||||
// Optimized: avoid unnecessary allocations and copies
|
||||
frame_count += 1;
|
||||
fps_frame_count += 1;
|
||||
let pts_ms = (frame_count * 1000 / config.fps as u64) as i64;
|
||||
|
||||
let encode_result = if let Some(ref mut decoder) = mjpeg_decoder {
|
||||
// MJPEG input - decode to NV12 via VAAPI
|
||||
match decoder.decode(&raw_frame) {
|
||||
Ok(nv12_frame) => {
|
||||
// Calculate required size for this frame
|
||||
let required_size = (nv12_frame.width * nv12_frame.height * 3 / 2) as usize;
|
||||
|
||||
// Resize buffer if needed (handles resolution changes)
|
||||
if nv12_buffer.len() < required_size {
|
||||
debug!(
|
||||
"Resizing NV12 buffer: {} -> {} bytes (resolution: {}x{})",
|
||||
nv12_buffer.len(), required_size,
|
||||
nv12_frame.width, nv12_frame.height
|
||||
);
|
||||
nv12_buffer.resize(required_size, 0);
|
||||
}
|
||||
|
||||
// Copy to pre-allocated buffer (guaranteed to fit after resize)
|
||||
let written = nv12_frame.copy_to_packed_nv12(&mut nv12_buffer)
|
||||
.expect("BUG: buffer too small after resize");
|
||||
encoder.encode_raw(&nv12_buffer[..written], pts_ms)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("MJPEG VAAPI decode failed: {}", e);
|
||||
let mut s = stats.lock().await;
|
||||
s.errors += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else if let Some(ref mut conv) = nv12_converter {
|
||||
// BGR24/RGB24/YUYV input - convert to NV12
|
||||
// Optimized: pass reference directly without copy
|
||||
match conv.convert(&raw_frame) {
|
||||
Ok(nv12_data) => encoder.encode_raw(nv12_data, pts_ms),
|
||||
Err(e) => {
|
||||
error!("NV12 conversion failed: {}", e);
|
||||
let mut s = stats.lock().await;
|
||||
s.errors += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// NV12 input - pass reference directly
|
||||
encoder.encode_raw(&raw_frame, pts_ms)
|
||||
};
|
||||
|
||||
match encode_result {
|
||||
Ok(frames) => {
|
||||
if !frames.is_empty() {
|
||||
let frame = &frames[0];
|
||||
let is_keyframe = frame.key == 1;
|
||||
|
||||
// Send to WebRTC track
|
||||
let duration = Duration::from_millis(
|
||||
1000 / config.fps as u64
|
||||
);
|
||||
|
||||
if let Err(e) = video_track
|
||||
.write_frame(&frame.data, duration, is_keyframe)
|
||||
.await
|
||||
{
|
||||
error!("Failed to write frame to track: {}", e);
|
||||
let mut s = stats.lock().await;
|
||||
s.errors += 1;
|
||||
} else {
|
||||
// Update stats
|
||||
let encode_time = start.elapsed().as_secs_f32() * 1000.0;
|
||||
let mut s = stats.lock().await;
|
||||
s.frames_encoded += 1;
|
||||
s.bytes_encoded += frame.data.len() as u64;
|
||||
if is_keyframe {
|
||||
s.keyframes_encoded += 1;
|
||||
}
|
||||
|
||||
// Update encode time average
|
||||
let mut times = encode_times.lock().await;
|
||||
times.push(encode_time);
|
||||
if times.len() > 100 {
|
||||
times.remove(0);
|
||||
}
|
||||
if !times.is_empty() {
|
||||
s.avg_encode_time_ms =
|
||||
times.iter().sum::<f32>() / times.len() as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Encoding failed: {}", e);
|
||||
let mut s = stats.lock().await;
|
||||
s.errors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Update FPS every second
|
||||
if last_fps_time.elapsed() >= Duration::from_secs(1) {
|
||||
let mut s = stats.lock().await;
|
||||
s.current_fps = fps_frame_count as f32
|
||||
/ last_fps_time.elapsed().as_secs_f32();
|
||||
fps_frame_count = 0;
|
||||
last_fps_time = Instant::now();
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
let mut s = stats.lock().await;
|
||||
s.frames_dropped += n;
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
info!("Frame channel closed, stopping H264 pipeline");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("H264 pipeline task exited");
|
||||
});
|
||||
}
|
||||
|
||||
/// Stop the encoding pipeline
|
||||
pub fn stop(&self) {
|
||||
if *self.running.borrow() {
|
||||
let _ = self.running.send(false);
|
||||
info!("Stopping H264 pipeline");
|
||||
}
|
||||
}
|
||||
|
||||
/// Request a keyframe (force IDR)
|
||||
pub async fn request_keyframe(&self) {
|
||||
// Note: hwcodec doesn't support on-demand keyframe requests
|
||||
// The encoder will produce keyframes based on GOP size
|
||||
debug!("Keyframe requested (will occur at next GOP boundary)");
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically
|
||||
pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> {
|
||||
if let Some(ref mut encoder) = *self.encoder.lock().await {
|
||||
encoder.set_bitrate(bitrate_kbps)?;
|
||||
info!("H264 pipeline bitrate updated to {} kbps", bitrate_kbps);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for H264 pipeline configuration
|
||||
pub struct H264PipelineBuilder {
|
||||
config: H264PipelineConfig,
|
||||
}
|
||||
|
||||
impl H264PipelineBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: H264PipelineConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolution(mut self, resolution: Resolution) -> Self {
|
||||
self.config.resolution = resolution;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn input_format(mut self, format: PixelFormat) -> Self {
|
||||
self.config.input_format = format;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn bitrate_kbps(mut self, bitrate: u32) -> Self {
|
||||
self.config.bitrate_kbps = bitrate;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn fps(mut self, fps: u32) -> Self {
|
||||
self.config.fps = fps;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn gop_size(mut self, gop: u32) -> Self {
|
||||
self.config.gop_size = gop;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn track_id(mut self, id: &str) -> Self {
|
||||
self.config.track_id = id.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream_id(mut self, id: &str) -> Self {
|
||||
self.config.stream_id = id.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<H264Pipeline> {
|
||||
H264Pipeline::new(self.config)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for H264PipelineBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_config_default() {
|
||||
let config = H264PipelineConfig::default();
|
||||
assert_eq!(config.resolution, Resolution::HD720);
|
||||
assert_eq!(config.bitrate_kbps, 2000);
|
||||
assert_eq!(config.fps, 30);
|
||||
assert_eq!(config.gop_size, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_builder() {
|
||||
let builder = H264PipelineBuilder::new()
|
||||
.resolution(Resolution::HD1080)
|
||||
.bitrate_kbps(4000)
|
||||
.fps(60)
|
||||
.input_format(PixelFormat::Yuyv);
|
||||
|
||||
assert_eq!(builder.config.resolution, Resolution::HD1080);
|
||||
assert_eq!(builder.config.bitrate_kbps, 4000);
|
||||
assert_eq!(builder.config.fps, 60);
|
||||
assert_eq!(builder.config.input_format, PixelFormat::Yuyv);
|
||||
}
|
||||
}
|
||||
29
src/video/mod.rs
Normal file
29
src/video/mod.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
//! Video capture and streaming module
|
||||
//!
|
||||
//! This module provides V4L2 video capture, encoding, and streaming functionality.
|
||||
|
||||
pub mod capture;
|
||||
pub mod convert;
|
||||
pub mod decoder;
|
||||
pub mod device;
|
||||
pub mod encoder;
|
||||
pub mod format;
|
||||
pub mod frame;
|
||||
pub mod h264_pipeline;
|
||||
pub mod shared_video_pipeline;
|
||||
pub mod stream_manager;
|
||||
pub mod streamer;
|
||||
pub mod video_session;
|
||||
|
||||
pub use capture::VideoCapturer;
|
||||
pub use convert::{MjpegDecoder, MjpegToYuv420Converter, PixelConverter, Yuv420pBuffer};
|
||||
pub use decoder::{MjpegVaapiDecoder, MjpegVaapiDecoderConfig};
|
||||
pub use device::{VideoDevice, VideoDeviceInfo};
|
||||
pub use encoder::{JpegEncoder, H264Encoder, H264EncoderType};
|
||||
pub use format::PixelFormat;
|
||||
pub use frame::VideoFrame;
|
||||
pub use h264_pipeline::{H264Pipeline, H264PipelineBuilder, H264PipelineConfig};
|
||||
pub use shared_video_pipeline::{EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats};
|
||||
pub use stream_manager::VideoStreamManager;
|
||||
pub use streamer::{Streamer, StreamerState};
|
||||
pub use video_session::{VideoSessionManager, VideoSessionManagerConfig, VideoSessionInfo, VideoSessionState, CodecInfo};
|
||||
940
src/video/shared_video_pipeline.rs
Normal file
940
src/video/shared_video_pipeline.rs
Normal file
@@ -0,0 +1,940 @@
|
||||
//! Universal shared video encoding pipeline
|
||||
//!
|
||||
//! Supports multiple codecs: H264, H265, VP8, VP9
|
||||
//! A single encoder broadcasts to multiple WebRTC sessions.
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! VideoCapturer (MJPEG/YUYV/NV12)
|
||||
//! |
|
||||
//! v (broadcast::Receiver<VideoFrame>)
|
||||
//! SharedVideoPipeline (single encoder)
|
||||
//! |
|
||||
//! v (broadcast::Sender<EncodedVideoFrame>)
|
||||
//! ┌────┴────┬────────┬────────┐
|
||||
//! v v v v
|
||||
//! Session1 Session2 Session3 ...
|
||||
//! ```
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{broadcast, watch, Mutex, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::convert::{Nv12Converter, PixelConverter};
|
||||
use crate::video::decoder::mjpeg::{MjpegTurboDecoder, MjpegVaapiDecoder, MjpegVaapiDecoderConfig};
|
||||
use crate::video::encoder::h264::{H264Config, H264Encoder};
|
||||
use crate::video::encoder::h265::{H265Config, H265Encoder};
|
||||
use crate::video::encoder::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
|
||||
use crate::video::encoder::traits::EncoderConfig;
|
||||
use crate::video::encoder::vp8::{VP8Config, VP8Encoder};
|
||||
use crate::video::encoder::vp9::{VP9Config, VP9Encoder};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::video::frame::VideoFrame;
|
||||
|
||||
/// Encoded video frame for distribution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncodedVideoFrame {
|
||||
/// Encoded data (Annex B for H264/H265, raw for VP8/VP9)
|
||||
pub data: Bytes,
|
||||
/// Presentation timestamp in milliseconds
|
||||
pub pts_ms: i64,
|
||||
/// Whether this is a keyframe
|
||||
pub is_keyframe: bool,
|
||||
/// Frame sequence number
|
||||
pub sequence: u64,
|
||||
/// Frame duration
|
||||
pub duration: Duration,
|
||||
/// Codec type
|
||||
pub codec: VideoEncoderType,
|
||||
}
|
||||
|
||||
/// Shared video pipeline configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SharedVideoPipelineConfig {
|
||||
/// Input resolution
|
||||
pub resolution: Resolution,
|
||||
/// Input pixel format
|
||||
pub input_format: PixelFormat,
|
||||
/// Output codec type
|
||||
pub output_codec: VideoEncoderType,
|
||||
/// Target bitrate in kbps
|
||||
pub bitrate_kbps: u32,
|
||||
/// Target FPS
|
||||
pub fps: u32,
|
||||
/// GOP size
|
||||
pub gop_size: u32,
|
||||
/// Encoder backend (None = auto select best available)
|
||||
pub encoder_backend: Option<EncoderBackend>,
|
||||
}
|
||||
|
||||
impl Default for SharedVideoPipelineConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
resolution: Resolution::HD720,
|
||||
input_format: PixelFormat::Yuyv,
|
||||
output_codec: VideoEncoderType::H264,
|
||||
bitrate_kbps: 8000,
|
||||
fps: 30,
|
||||
gop_size: 30,
|
||||
encoder_backend: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SharedVideoPipelineConfig {
|
||||
/// Create H264 config
|
||||
pub fn h264(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
resolution,
|
||||
output_codec: VideoEncoderType::H264,
|
||||
bitrate_kbps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create H265 config
|
||||
pub fn h265(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
resolution,
|
||||
output_codec: VideoEncoderType::H265,
|
||||
bitrate_kbps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create VP8 config
|
||||
pub fn vp8(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
resolution,
|
||||
output_codec: VideoEncoderType::VP8,
|
||||
bitrate_kbps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create VP9 config
|
||||
pub fn vp9(resolution: Resolution, bitrate_kbps: u32) -> Self {
|
||||
Self {
|
||||
resolution,
|
||||
output_codec: VideoEncoderType::VP9,
|
||||
bitrate_kbps,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pipeline statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SharedVideoPipelineStats {
|
||||
pub frames_captured: u64,
|
||||
pub frames_encoded: u64,
|
||||
pub frames_dropped: u64,
|
||||
pub bytes_encoded: u64,
|
||||
pub keyframes_encoded: u64,
|
||||
pub avg_encode_time_ms: f32,
|
||||
pub current_fps: f32,
|
||||
pub errors: u64,
|
||||
pub subscribers: u64,
|
||||
}
|
||||
|
||||
|
||||
/// Universal video encoder trait object
|
||||
#[allow(dead_code)]
|
||||
trait VideoEncoderTrait: Send {
|
||||
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>>;
|
||||
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()>;
|
||||
fn codec_name(&self) -> &str;
|
||||
fn request_keyframe(&mut self);
|
||||
}
|
||||
|
||||
/// Encoded frame from encoder
|
||||
#[allow(dead_code)]
|
||||
struct EncodedFrame {
|
||||
data: Vec<u8>,
|
||||
pts: i64,
|
||||
key: i32,
|
||||
}
|
||||
|
||||
/// H264 encoder wrapper
|
||||
struct H264EncoderWrapper(H264Encoder);
|
||||
|
||||
impl VideoEncoderTrait for H264EncoderWrapper {
|
||||
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
|
||||
let frames = self.0.encode_raw(data, pts_ms)?;
|
||||
Ok(frames
|
||||
.into_iter()
|
||||
.map(|f| EncodedFrame {
|
||||
data: f.data,
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.0.set_bitrate(bitrate_kbps)
|
||||
}
|
||||
|
||||
fn codec_name(&self) -> &str {
|
||||
self.0.codec_name()
|
||||
}
|
||||
|
||||
fn request_keyframe(&mut self) {
|
||||
self.0.request_keyframe()
|
||||
}
|
||||
}
|
||||
|
||||
/// H265 encoder wrapper
|
||||
struct H265EncoderWrapper(H265Encoder);
|
||||
|
||||
impl VideoEncoderTrait for H265EncoderWrapper {
|
||||
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
|
||||
let frames = self.0.encode_raw(data, pts_ms)?;
|
||||
Ok(frames
|
||||
.into_iter()
|
||||
.map(|f| EncodedFrame {
|
||||
data: f.data,
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.0.set_bitrate(bitrate_kbps)
|
||||
}
|
||||
|
||||
fn codec_name(&self) -> &str {
|
||||
self.0.codec_name()
|
||||
}
|
||||
|
||||
fn request_keyframe(&mut self) {
|
||||
self.0.request_keyframe()
|
||||
}
|
||||
}
|
||||
|
||||
/// VP8 encoder wrapper
|
||||
struct VP8EncoderWrapper(VP8Encoder);
|
||||
|
||||
impl VideoEncoderTrait for VP8EncoderWrapper {
|
||||
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
|
||||
let frames = self.0.encode_raw(data, pts_ms)?;
|
||||
Ok(frames
|
||||
.into_iter()
|
||||
.map(|f| EncodedFrame {
|
||||
data: f.data,
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.0.set_bitrate(bitrate_kbps)
|
||||
}
|
||||
|
||||
fn codec_name(&self) -> &str {
|
||||
self.0.codec_name()
|
||||
}
|
||||
|
||||
fn request_keyframe(&mut self) {
|
||||
// VP8 encoder doesn't support request_keyframe yet
|
||||
}
|
||||
}
|
||||
|
||||
/// VP9 encoder wrapper
|
||||
struct VP9EncoderWrapper(VP9Encoder);
|
||||
|
||||
impl VideoEncoderTrait for VP9EncoderWrapper {
|
||||
fn encode_raw(&mut self, data: &[u8], pts_ms: i64) -> Result<Vec<EncodedFrame>> {
|
||||
let frames = self.0.encode_raw(data, pts_ms)?;
|
||||
Ok(frames
|
||||
.into_iter()
|
||||
.map(|f| EncodedFrame {
|
||||
data: f.data,
|
||||
pts: f.pts,
|
||||
key: f.key,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn set_bitrate(&mut self, bitrate_kbps: u32) -> Result<()> {
|
||||
self.0.set_bitrate(bitrate_kbps)
|
||||
}
|
||||
|
||||
fn codec_name(&self) -> &str {
|
||||
self.0.codec_name()
|
||||
}
|
||||
|
||||
fn request_keyframe(&mut self) {
|
||||
// VP9 encoder doesn't support request_keyframe yet
|
||||
}
|
||||
}
|
||||
|
||||
/// Universal shared video pipeline
|
||||
pub struct SharedVideoPipeline {
|
||||
config: RwLock<SharedVideoPipelineConfig>,
|
||||
encoder: Mutex<Option<Box<dyn VideoEncoderTrait + Send>>>,
|
||||
nv12_converter: Mutex<Option<Nv12Converter>>,
|
||||
yuv420p_converter: Mutex<Option<PixelConverter>>,
|
||||
mjpeg_decoder: Mutex<Option<MjpegVaapiDecoder>>,
|
||||
/// Turbojpeg decoder for direct MJPEG->YUV420P (optimized for software encoders)
|
||||
mjpeg_turbo_decoder: Mutex<Option<MjpegTurboDecoder>>,
|
||||
nv12_buffer: Mutex<Vec<u8>>,
|
||||
/// YUV420P buffer for turbojpeg decoder output
|
||||
yuv420p_buffer: Mutex<Vec<u8>>,
|
||||
/// Whether the encoder needs YUV420P (true) or NV12 (false)
|
||||
encoder_needs_yuv420p: Mutex<bool>,
|
||||
frame_tx: broadcast::Sender<EncodedVideoFrame>,
|
||||
stats: Mutex<SharedVideoPipelineStats>,
|
||||
running: watch::Sender<bool>,
|
||||
running_rx: watch::Receiver<bool>,
|
||||
encode_times: Mutex<Vec<f32>>,
|
||||
sequence: Mutex<u64>,
|
||||
/// Atomic flag for keyframe request (avoids lock contention)
|
||||
keyframe_requested: AtomicBool,
|
||||
}
|
||||
|
||||
impl SharedVideoPipeline {
|
||||
/// Create a new shared video pipeline
|
||||
pub fn new(config: SharedVideoPipelineConfig) -> Result<Arc<Self>> {
|
||||
info!(
|
||||
"Creating shared video pipeline: {} {}x{} @ {} kbps (input: {})",
|
||||
config.output_codec,
|
||||
config.resolution.width,
|
||||
config.resolution.height,
|
||||
config.bitrate_kbps,
|
||||
config.input_format
|
||||
);
|
||||
|
||||
let (frame_tx, _) = broadcast::channel(16);
|
||||
let (running_tx, running_rx) = watch::channel(false);
|
||||
let nv12_size = (config.resolution.width * config.resolution.height * 3 / 2) as usize;
|
||||
let yuv420p_size = nv12_size; // Same size as NV12
|
||||
|
||||
let pipeline = Arc::new(Self {
|
||||
config: RwLock::new(config),
|
||||
encoder: Mutex::new(None),
|
||||
nv12_converter: Mutex::new(None),
|
||||
yuv420p_converter: Mutex::new(None),
|
||||
mjpeg_decoder: Mutex::new(None),
|
||||
mjpeg_turbo_decoder: Mutex::new(None),
|
||||
nv12_buffer: Mutex::new(vec![0u8; nv12_size]),
|
||||
yuv420p_buffer: Mutex::new(vec![0u8; yuv420p_size]),
|
||||
encoder_needs_yuv420p: Mutex::new(false),
|
||||
frame_tx,
|
||||
stats: Mutex::new(SharedVideoPipelineStats::default()),
|
||||
running: running_tx,
|
||||
running_rx,
|
||||
encode_times: Mutex::new(Vec::with_capacity(100)),
|
||||
sequence: Mutex::new(0),
|
||||
keyframe_requested: AtomicBool::new(false),
|
||||
});
|
||||
|
||||
Ok(pipeline)
|
||||
}
|
||||
|
||||
/// Initialize encoder based on config
|
||||
async fn init_encoder(&self) -> Result<()> {
|
||||
let config = self.config.read().await.clone();
|
||||
let registry = EncoderRegistry::global();
|
||||
|
||||
// Helper to get codec name for specific backend
|
||||
let get_codec_name = |format: VideoEncoderType, backend: Option<EncoderBackend>| -> Option<String> {
|
||||
match backend {
|
||||
Some(b) => registry.encoder_with_backend(format, b).map(|e| e.codec_name.clone()),
|
||||
None => registry.best_encoder(format, false).map(|e| e.codec_name.clone()),
|
||||
}
|
||||
};
|
||||
|
||||
// Create encoder based on codec type
|
||||
let encoder: Box<dyn VideoEncoderTrait + Send> = match config.output_codec {
|
||||
VideoEncoderType::H264 => {
|
||||
let encoder_config = H264Config {
|
||||
base: EncoderConfig::h264(config.resolution, config.bitrate_kbps),
|
||||
bitrate_kbps: config.bitrate_kbps,
|
||||
gop_size: config.gop_size,
|
||||
fps: config.fps,
|
||||
input_format: crate::video::encoder::h264::H264InputFormat::Nv12,
|
||||
};
|
||||
|
||||
let encoder = if let Some(ref backend) = config.encoder_backend {
|
||||
// Specific backend requested
|
||||
let codec_name = get_codec_name(VideoEncoderType::H264, Some(*backend))
|
||||
.ok_or_else(|| AppError::VideoError(format!(
|
||||
"Backend {:?} does not support H.264", backend
|
||||
)))?;
|
||||
info!("Creating H264 encoder with backend {:?} (codec: {})", backend, codec_name);
|
||||
H264Encoder::with_codec(encoder_config, &codec_name)?
|
||||
} else {
|
||||
// Auto select
|
||||
H264Encoder::new(encoder_config)?
|
||||
};
|
||||
|
||||
info!("Created H264 encoder: {}", encoder.codec_name());
|
||||
Box::new(H264EncoderWrapper(encoder))
|
||||
}
|
||||
VideoEncoderType::H265 => {
|
||||
let encoder_config = H265Config::low_latency(config.resolution, config.bitrate_kbps);
|
||||
|
||||
let encoder = if let Some(ref backend) = config.encoder_backend {
|
||||
let codec_name = get_codec_name(VideoEncoderType::H265, Some(*backend))
|
||||
.ok_or_else(|| AppError::VideoError(format!(
|
||||
"Backend {:?} does not support H.265", backend
|
||||
)))?;
|
||||
info!("Creating H265 encoder with backend {:?} (codec: {})", backend, codec_name);
|
||||
H265Encoder::with_codec(encoder_config, &codec_name)?
|
||||
} else {
|
||||
H265Encoder::new(encoder_config)?
|
||||
};
|
||||
|
||||
info!("Created H265 encoder: {}", encoder.codec_name());
|
||||
Box::new(H265EncoderWrapper(encoder))
|
||||
}
|
||||
VideoEncoderType::VP8 => {
|
||||
let encoder_config = VP8Config::low_latency(config.resolution, config.bitrate_kbps);
|
||||
|
||||
let encoder = if let Some(ref backend) = config.encoder_backend {
|
||||
let codec_name = get_codec_name(VideoEncoderType::VP8, Some(*backend))
|
||||
.ok_or_else(|| AppError::VideoError(format!(
|
||||
"Backend {:?} does not support VP8", backend
|
||||
)))?;
|
||||
info!("Creating VP8 encoder with backend {:?} (codec: {})", backend, codec_name);
|
||||
VP8Encoder::with_codec(encoder_config, &codec_name)?
|
||||
} else {
|
||||
VP8Encoder::new(encoder_config)?
|
||||
};
|
||||
|
||||
info!("Created VP8 encoder: {}", encoder.codec_name());
|
||||
Box::new(VP8EncoderWrapper(encoder))
|
||||
}
|
||||
VideoEncoderType::VP9 => {
|
||||
let encoder_config = VP9Config::low_latency(config.resolution, config.bitrate_kbps);
|
||||
|
||||
let encoder = if let Some(ref backend) = config.encoder_backend {
|
||||
let codec_name = get_codec_name(VideoEncoderType::VP9, Some(*backend))
|
||||
.ok_or_else(|| AppError::VideoError(format!(
|
||||
"Backend {:?} does not support VP9", backend
|
||||
)))?;
|
||||
info!("Creating VP9 encoder with backend {:?} (codec: {})", backend, codec_name);
|
||||
VP9Encoder::with_codec(encoder_config, &codec_name)?
|
||||
} else {
|
||||
VP9Encoder::new(encoder_config)?
|
||||
};
|
||||
|
||||
info!("Created VP9 encoder: {}", encoder.codec_name());
|
||||
Box::new(VP9EncoderWrapper(encoder))
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if encoder needs YUV420P (software encoders) or NV12 (hardware encoders)
|
||||
let codec_name = encoder.codec_name();
|
||||
let needs_yuv420p = codec_name.contains("libvpx") || codec_name.contains("libx265");
|
||||
|
||||
info!(
|
||||
"Encoder {} needs {} format",
|
||||
codec_name,
|
||||
if needs_yuv420p { "YUV420P" } else { "NV12" }
|
||||
);
|
||||
|
||||
// Create converter or decoder based on input format and encoder needs
|
||||
info!("Initializing input format handler for: {} -> {}",
|
||||
config.input_format,
|
||||
if needs_yuv420p { "YUV420P" } else { "NV12" });
|
||||
|
||||
let (nv12_converter, yuv420p_converter, mjpeg_decoder, mjpeg_turbo_decoder) = if needs_yuv420p {
|
||||
// Software encoder needs YUV420P
|
||||
match config.input_format {
|
||||
PixelFormat::Yuv420 => {
|
||||
info!("Using direct YUV420P input (no conversion)");
|
||||
(None, None, None, None)
|
||||
}
|
||||
PixelFormat::Yuyv => {
|
||||
info!("Using YUYV->YUV420P converter");
|
||||
(None, Some(PixelConverter::yuyv_to_yuv420p(config.resolution)), None, None)
|
||||
}
|
||||
PixelFormat::Nv12 => {
|
||||
info!("Using NV12->YUV420P converter");
|
||||
(None, Some(PixelConverter::nv12_to_yuv420p(config.resolution)), None, None)
|
||||
}
|
||||
PixelFormat::Rgb24 => {
|
||||
info!("Using RGB24->YUV420P converter");
|
||||
(None, Some(PixelConverter::rgb24_to_yuv420p(config.resolution)), None, None)
|
||||
}
|
||||
PixelFormat::Bgr24 => {
|
||||
info!("Using BGR24->YUV420P converter");
|
||||
(None, Some(PixelConverter::bgr24_to_yuv420p(config.resolution)), None, None)
|
||||
}
|
||||
PixelFormat::Mjpeg | PixelFormat::Jpeg => {
|
||||
// Use turbojpeg for direct MJPEG->YUV420P (no intermediate NV12)
|
||||
info!("Using turbojpeg MJPEG decoder (direct YUV420P output)");
|
||||
let turbo_decoder = MjpegTurboDecoder::new(config.resolution)?;
|
||||
(None, None, None, Some(turbo_decoder))
|
||||
}
|
||||
_ => {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Unsupported input format: {}",
|
||||
config.input_format
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Hardware encoder needs NV12
|
||||
match config.input_format {
|
||||
PixelFormat::Nv12 => {
|
||||
info!("Using direct NV12 input (no conversion)");
|
||||
(None, None, None, None)
|
||||
}
|
||||
PixelFormat::Yuyv => {
|
||||
info!("Using YUYV->NV12 converter");
|
||||
(Some(Nv12Converter::yuyv_to_nv12(config.resolution)), None, None, None)
|
||||
}
|
||||
PixelFormat::Rgb24 => {
|
||||
info!("Using RGB24->NV12 converter");
|
||||
(Some(Nv12Converter::rgb24_to_nv12(config.resolution)), None, None, None)
|
||||
}
|
||||
PixelFormat::Bgr24 => {
|
||||
info!("Using BGR24->NV12 converter");
|
||||
(Some(Nv12Converter::bgr24_to_nv12(config.resolution)), None, None, None)
|
||||
}
|
||||
PixelFormat::Mjpeg | PixelFormat::Jpeg => {
|
||||
info!("Using MJPEG decoder (NV12 output)");
|
||||
let decoder_config = MjpegVaapiDecoderConfig {
|
||||
resolution: config.resolution,
|
||||
use_hwaccel: true,
|
||||
};
|
||||
let decoder = MjpegVaapiDecoder::new(decoder_config)?;
|
||||
(None, None, Some(decoder), None)
|
||||
}
|
||||
_ => {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Unsupported input format: {}",
|
||||
config.input_format
|
||||
)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
*self.encoder.lock().await = Some(encoder);
|
||||
*self.nv12_converter.lock().await = nv12_converter;
|
||||
*self.yuv420p_converter.lock().await = yuv420p_converter;
|
||||
*self.mjpeg_decoder.lock().await = mjpeg_decoder;
|
||||
*self.mjpeg_turbo_decoder.lock().await = mjpeg_turbo_decoder;
|
||||
*self.encoder_needs_yuv420p.lock().await = needs_yuv420p;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Subscribe to encoded frames
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<EncodedVideoFrame> {
|
||||
self.frame_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get subscriber count
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.frame_tx.receiver_count()
|
||||
}
|
||||
|
||||
/// Request encoder to produce a keyframe on next encode
|
||||
///
|
||||
/// This is useful when a new client connects and needs an immediate
|
||||
/// keyframe to start decoding the video stream.
|
||||
///
|
||||
/// Uses an atomic flag to avoid lock contention with the encoding loop.
|
||||
pub async fn request_keyframe(&self) {
|
||||
self.keyframe_requested.store(true, Ordering::Release);
|
||||
info!("[Pipeline] Keyframe requested for new client");
|
||||
}
|
||||
|
||||
/// Get current stats
|
||||
pub async fn stats(&self) -> SharedVideoPipelineStats {
|
||||
let mut stats = self.stats.lock().await.clone();
|
||||
stats.subscribers = self.frame_tx.receiver_count() as u64;
|
||||
stats
|
||||
}
|
||||
|
||||
/// Check if running
|
||||
pub fn is_running(&self) -> bool {
|
||||
*self.running_rx.borrow()
|
||||
}
|
||||
|
||||
/// Get current codec
|
||||
pub async fn current_codec(&self) -> VideoEncoderType {
|
||||
self.config.read().await.output_codec
|
||||
}
|
||||
|
||||
/// Switch codec (requires restart)
|
||||
pub async fn switch_codec(&self, codec: VideoEncoderType) -> Result<()> {
|
||||
let was_running = self.is_running();
|
||||
|
||||
if was_running {
|
||||
self.stop();
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.output_codec = codec;
|
||||
}
|
||||
|
||||
// Clear encoder state
|
||||
*self.encoder.lock().await = None;
|
||||
*self.nv12_converter.lock().await = None;
|
||||
*self.yuv420p_converter.lock().await = None;
|
||||
*self.mjpeg_decoder.lock().await = None;
|
||||
*self.mjpeg_turbo_decoder.lock().await = None;
|
||||
*self.encoder_needs_yuv420p.lock().await = false;
|
||||
|
||||
info!("Switched to {} codec", codec);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start the pipeline
|
||||
pub async fn start(self: &Arc<Self>, mut frame_rx: broadcast::Receiver<VideoFrame>) -> Result<()> {
|
||||
if *self.running_rx.borrow() {
|
||||
warn!("Pipeline already running");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.init_encoder().await?;
|
||||
let _ = self.running.send(true);
|
||||
|
||||
let config = self.config.read().await.clone();
|
||||
info!("Starting {} pipeline", config.output_codec);
|
||||
|
||||
let pipeline = self.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut frame_count: u64 = 0;
|
||||
let mut last_fps_time = Instant::now();
|
||||
let mut fps_frame_count: u64 = 0;
|
||||
let mut running_rx = pipeline.running_rx.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = running_rx.changed() => {
|
||||
if !*running_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result = frame_rx.recv() => {
|
||||
match result {
|
||||
Ok(video_frame) => {
|
||||
pipeline.stats.lock().await.frames_captured += 1;
|
||||
|
||||
if pipeline.frame_tx.receiver_count() == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
match pipeline.encode_frame(&video_frame, frame_count).await {
|
||||
Ok(Some(encoded_frame)) => {
|
||||
let encode_time = start.elapsed().as_secs_f32() * 1000.0;
|
||||
let _ = pipeline.frame_tx.send(encoded_frame.clone());
|
||||
|
||||
let is_keyframe = encoded_frame.is_keyframe;
|
||||
|
||||
// Update stats
|
||||
{
|
||||
let mut s = pipeline.stats.lock().await;
|
||||
s.frames_encoded += 1;
|
||||
s.bytes_encoded += encoded_frame.data.len() as u64;
|
||||
if is_keyframe {
|
||||
s.keyframes_encoded += 1;
|
||||
}
|
||||
|
||||
let mut times = pipeline.encode_times.lock().await;
|
||||
times.push(encode_time);
|
||||
if times.len() > 100 {
|
||||
times.remove(0);
|
||||
}
|
||||
if !times.is_empty() {
|
||||
s.avg_encode_time_ms = times.iter().sum::<f32>() / times.len() as f32;
|
||||
}
|
||||
}
|
||||
|
||||
frame_count += 1;
|
||||
fps_frame_count += 1;
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
error!("Encoding failed: {}", e);
|
||||
pipeline.stats.lock().await.errors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if last_fps_time.elapsed() >= Duration::from_secs(1) {
|
||||
let mut s = pipeline.stats.lock().await;
|
||||
s.current_fps = fps_frame_count as f32 / last_fps_time.elapsed().as_secs_f32();
|
||||
fps_frame_count = 0;
|
||||
last_fps_time = Instant::now();
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
pipeline.stats.lock().await.frames_dropped += n;
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Video pipeline stopped");
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encode a single frame
|
||||
async fn encode_frame(&self, frame: &VideoFrame, frame_count: u64) -> Result<Option<EncodedVideoFrame>> {
|
||||
let config = self.config.read().await;
|
||||
let raw_frame = frame.data();
|
||||
let fps = config.fps;
|
||||
let codec = config.output_codec;
|
||||
drop(config);
|
||||
|
||||
let pts_ms = (frame_count * 1000 / fps as u64) as i64;
|
||||
|
||||
// Debug log for H265
|
||||
if codec == VideoEncoderType::H265 && frame_count % 30 == 1 {
|
||||
debug!(
|
||||
"[Pipeline-H265] Processing frame #{}: input_size={}, pts_ms={}",
|
||||
frame_count,
|
||||
raw_frame.len(),
|
||||
pts_ms
|
||||
);
|
||||
}
|
||||
|
||||
let mut mjpeg_decoder = self.mjpeg_decoder.lock().await;
|
||||
let mut mjpeg_turbo_decoder = self.mjpeg_turbo_decoder.lock().await;
|
||||
let mut nv12_converter = self.nv12_converter.lock().await;
|
||||
let mut yuv420p_converter = self.yuv420p_converter.lock().await;
|
||||
let needs_yuv420p = *self.encoder_needs_yuv420p.lock().await;
|
||||
let mut encoder_guard = self.encoder.lock().await;
|
||||
|
||||
let encoder = encoder_guard.as_mut().ok_or_else(|| {
|
||||
AppError::VideoError("Encoder not initialized".to_string())
|
||||
})?;
|
||||
|
||||
// Check and consume keyframe request (atomic, no lock contention)
|
||||
if self.keyframe_requested.swap(false, Ordering::AcqRel) {
|
||||
encoder.request_keyframe();
|
||||
debug!("[Pipeline] Keyframe will be generated for this frame");
|
||||
}
|
||||
|
||||
let encode_result = if mjpeg_turbo_decoder.is_some() {
|
||||
// Optimized path: MJPEG -> YUV420P directly via turbojpeg (for software encoders)
|
||||
let turbo = mjpeg_turbo_decoder.as_mut().unwrap();
|
||||
let mut yuv420p_buffer = self.yuv420p_buffer.lock().await;
|
||||
let written = turbo.decode_to_yuv420p_buffer(raw_frame, &mut yuv420p_buffer)
|
||||
.map_err(|e| AppError::VideoError(format!("turbojpeg decode failed: {}", e)))?;
|
||||
encoder.encode_raw(&yuv420p_buffer[..written], pts_ms)
|
||||
} else if mjpeg_decoder.is_some() {
|
||||
// MJPEG input: decode to NV12 (for hardware encoders)
|
||||
let decoder = mjpeg_decoder.as_mut().unwrap();
|
||||
let nv12_frame = decoder.decode(raw_frame)
|
||||
.map_err(|e| AppError::VideoError(format!("MJPEG decode failed: {}", e)))?;
|
||||
|
||||
let required_size = (nv12_frame.width * nv12_frame.height * 3 / 2) as usize;
|
||||
let mut nv12_buffer = self.nv12_buffer.lock().await;
|
||||
if nv12_buffer.len() < required_size {
|
||||
nv12_buffer.resize(required_size, 0);
|
||||
}
|
||||
|
||||
let written = nv12_frame.copy_to_packed_nv12(&mut nv12_buffer)
|
||||
.expect("Buffer too small");
|
||||
|
||||
// Debug log for H265 after MJPEG decode
|
||||
if codec == VideoEncoderType::H265 && frame_count % 30 == 1 {
|
||||
debug!(
|
||||
"[Pipeline-H265] MJPEG decoded: nv12_size={}, frame_width={}, frame_height={}",
|
||||
written, nv12_frame.width, nv12_frame.height
|
||||
);
|
||||
}
|
||||
|
||||
encoder.encode_raw(&nv12_buffer[..written], pts_ms)
|
||||
} else if needs_yuv420p && yuv420p_converter.is_some() {
|
||||
// Software encoder with direct input conversion to YUV420P
|
||||
let conv = yuv420p_converter.as_mut().unwrap();
|
||||
let yuv420p_data = conv.convert(raw_frame)
|
||||
.map_err(|e| AppError::VideoError(format!("YUV420P conversion failed: {}", e)))?;
|
||||
encoder.encode_raw(yuv420p_data, pts_ms)
|
||||
} else if nv12_converter.is_some() {
|
||||
// Hardware encoder with input conversion to NV12
|
||||
let conv = nv12_converter.as_mut().unwrap();
|
||||
let nv12_data = conv.convert(raw_frame)
|
||||
.map_err(|e| AppError::VideoError(format!("NV12 conversion failed: {}", e)))?;
|
||||
encoder.encode_raw(nv12_data, pts_ms)
|
||||
} else {
|
||||
// Direct input (already in correct format)
|
||||
encoder.encode_raw(raw_frame, pts_ms)
|
||||
};
|
||||
|
||||
drop(encoder_guard);
|
||||
drop(nv12_converter);
|
||||
drop(yuv420p_converter);
|
||||
drop(mjpeg_decoder);
|
||||
drop(mjpeg_turbo_decoder);
|
||||
|
||||
match encode_result {
|
||||
Ok(frames) => {
|
||||
if !frames.is_empty() {
|
||||
let encoded = frames.into_iter().next().unwrap();
|
||||
let is_keyframe = encoded.key == 1;
|
||||
|
||||
let sequence = {
|
||||
let mut seq = self.sequence.lock().await;
|
||||
*seq += 1;
|
||||
*seq
|
||||
};
|
||||
|
||||
// Debug log for H265 encoded frame
|
||||
if codec == VideoEncoderType::H265 && (is_keyframe || frame_count % 30 == 1) {
|
||||
debug!(
|
||||
"[Pipeline-H265] Encoded frame #{}: output_size={}, keyframe={}, sequence={}",
|
||||
frame_count,
|
||||
encoded.data.len(),
|
||||
is_keyframe,
|
||||
sequence
|
||||
);
|
||||
|
||||
// Log H265 NAL unit types in the encoded data
|
||||
if is_keyframe {
|
||||
let nal_types = parse_h265_nal_types(&encoded.data);
|
||||
debug!("[Pipeline-H265] Keyframe NAL types: {:?}", nal_types);
|
||||
}
|
||||
}
|
||||
|
||||
let config = self.config.read().await;
|
||||
Ok(Some(EncodedVideoFrame {
|
||||
data: Bytes::from(encoded.data),
|
||||
pts_ms,
|
||||
is_keyframe,
|
||||
sequence,
|
||||
duration: Duration::from_millis(1000 / config.fps as u64),
|
||||
codec,
|
||||
}))
|
||||
} else {
|
||||
if codec == VideoEncoderType::H265 {
|
||||
warn!("[Pipeline-H265] Encoder returned no frames for frame #{}", frame_count);
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if codec == VideoEncoderType::H265 {
|
||||
error!("[Pipeline-H265] Encode error at frame #{}: {}", frame_count, e);
|
||||
}
|
||||
Err(e)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Stop the pipeline
|
||||
pub fn stop(&self) {
|
||||
if *self.running_rx.borrow() {
|
||||
let _ = self.running.send(false);
|
||||
info!("Stopping video pipeline");
|
||||
}
|
||||
}
|
||||
|
||||
/// Set bitrate
|
||||
pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> {
|
||||
if let Some(ref mut encoder) = *self.encoder.lock().await {
|
||||
encoder.set_bitrate(bitrate_kbps)?;
|
||||
self.config.write().await.bitrate_kbps = bitrate_kbps;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current config
|
||||
pub async fn config(&self) -> SharedVideoPipelineConfig {
|
||||
self.config.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SharedVideoPipeline {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.running.send(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse H265 NAL unit types from Annex B data
|
||||
fn parse_h265_nal_types(data: &[u8]) -> Vec<(u8, usize)> {
|
||||
let mut nal_types = Vec::new();
|
||||
let mut i = 0;
|
||||
|
||||
while i < data.len() {
|
||||
// Find start code
|
||||
let nal_start = if i + 4 <= data.len()
|
||||
&& data[i] == 0
|
||||
&& data[i + 1] == 0
|
||||
&& data[i + 2] == 0
|
||||
&& data[i + 3] == 1
|
||||
{
|
||||
i + 4
|
||||
} else if i + 3 <= data.len()
|
||||
&& data[i] == 0
|
||||
&& data[i + 1] == 0
|
||||
&& data[i + 2] == 1
|
||||
{
|
||||
i + 3
|
||||
} else {
|
||||
i += 1;
|
||||
continue;
|
||||
};
|
||||
|
||||
if nal_start >= data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Find next start code to get NAL size
|
||||
let mut nal_end = data.len();
|
||||
let mut j = nal_start + 1;
|
||||
while j + 3 <= data.len() {
|
||||
if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1)
|
||||
|| (j + 4 <= data.len()
|
||||
&& data[j] == 0
|
||||
&& data[j + 1] == 0
|
||||
&& data[j + 2] == 0
|
||||
&& data[j + 3] == 1)
|
||||
{
|
||||
nal_end = j;
|
||||
break;
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
|
||||
// H265 NAL type is in bits 1-6 of first byte
|
||||
let nal_type = (data[nal_start] >> 1) & 0x3F;
|
||||
let nal_size = nal_end - nal_start;
|
||||
nal_types.push((nal_type, nal_size));
|
||||
i = nal_end;
|
||||
}
|
||||
|
||||
nal_types
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_config() {
|
||||
let h264 = SharedVideoPipelineConfig::h264(Resolution::HD1080, 4000);
|
||||
assert_eq!(h264.output_codec, VideoEncoderType::H264);
|
||||
|
||||
let h265 = SharedVideoPipelineConfig::h265(Resolution::HD720, 2000);
|
||||
assert_eq!(h265.output_codec, VideoEncoderType::H265);
|
||||
}
|
||||
}
|
||||
574
src/video/stream_manager.rs
Normal file
574
src/video/stream_manager.rs
Normal file
@@ -0,0 +1,574 @@
|
||||
//! Video Stream Manager
|
||||
//!
|
||||
//! Unified manager for video streaming that supports single-mode operation.
|
||||
//! At any given time, only one streaming mode (MJPEG or WebRTC) is active.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! VideoStreamManager (Public API - Single Entry Point)
|
||||
//! │
|
||||
//! ├── mode: StreamMode (current active mode)
|
||||
//! │
|
||||
//! ├── MJPEG Mode
|
||||
//! │ └── Streamer ──► MjpegStreamHandler
|
||||
//! │ (Future: MjpegStreamer with WsAudio/WsHid)
|
||||
//! │
|
||||
//! └── WebRTC Mode
|
||||
//! └── WebRtcStreamer ──► H264SessionManager
|
||||
//! (Extensible: H264, VP8, VP9, H265)
|
||||
//! ```
|
||||
//!
|
||||
//! # Design Goals
|
||||
//!
|
||||
//! 1. **Single Entry Point**: All video operations go through VideoStreamManager
|
||||
//! 2. **Mode Isolation**: MJPEG and WebRTC modes are cleanly separated
|
||||
//! 3. **Extensible Codecs**: WebRTC supports multiple video codecs (H264 now, others reserved)
|
||||
//! 4. **Simplified API**: Complex configuration flows are encapsulated
|
||||
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::config::{ConfigStore, StreamMode};
|
||||
use crate::error::Result;
|
||||
use crate::events::{EventBus, SystemEvent, VideoDeviceInfo};
|
||||
use crate::hid::HidController;
|
||||
use crate::stream::MjpegStreamHandler;
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::video::streamer::{Streamer, StreamerState};
|
||||
use crate::webrtc::WebRtcStreamer;
|
||||
|
||||
/// Video stream manager configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamManagerConfig {
|
||||
/// Initial streaming mode
|
||||
pub mode: StreamMode,
|
||||
/// Video device path
|
||||
pub device: Option<String>,
|
||||
/// Video format
|
||||
pub format: PixelFormat,
|
||||
/// Resolution
|
||||
pub resolution: Resolution,
|
||||
/// FPS
|
||||
pub fps: u32,
|
||||
}
|
||||
|
||||
impl Default for StreamManagerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: StreamMode::Mjpeg,
|
||||
device: None,
|
||||
format: PixelFormat::Mjpeg,
|
||||
resolution: Resolution::HD1080,
|
||||
fps: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unified video stream manager
|
||||
///
|
||||
/// Manages both MJPEG and WebRTC streaming modes, ensuring only one is active
|
||||
/// at any given time. This reduces resource usage and simplifies the architecture.
|
||||
///
|
||||
/// # Components
|
||||
///
|
||||
/// - **Streamer**: Handles video capture and MJPEG distribution (current implementation)
|
||||
/// - **WebRtcStreamer**: High-level WebRTC manager with multi-codec support (new)
|
||||
/// - **H264SessionManager**: Legacy WebRTC manager (for backward compatibility)
|
||||
pub struct VideoStreamManager {
|
||||
/// Current streaming mode
|
||||
mode: RwLock<StreamMode>,
|
||||
/// MJPEG streamer (handles video capture and MJPEG distribution)
|
||||
streamer: Arc<Streamer>,
|
||||
/// WebRTC streamer (unified WebRTC manager with multi-codec support)
|
||||
webrtc_streamer: Arc<WebRtcStreamer>,
|
||||
/// Event bus for notifications
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
/// Configuration store
|
||||
config_store: RwLock<Option<ConfigStore>>,
|
||||
/// Mode switching lock to prevent concurrent switch requests
|
||||
switching: AtomicBool,
|
||||
}
|
||||
|
||||
impl VideoStreamManager {
|
||||
/// Create a new video stream manager with WebRtcStreamer
|
||||
pub fn with_webrtc_streamer(
|
||||
streamer: Arc<Streamer>,
|
||||
webrtc_streamer: Arc<WebRtcStreamer>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
mode: RwLock::new(StreamMode::Mjpeg),
|
||||
streamer,
|
||||
webrtc_streamer,
|
||||
events: RwLock::new(None),
|
||||
config_store: RwLock::new(None),
|
||||
switching: AtomicBool::new(false),
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if mode switching is in progress
|
||||
pub fn is_switching(&self) -> bool {
|
||||
self.switching.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Set event bus for notifications
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Set configuration store
|
||||
pub async fn set_config_store(&self, config: ConfigStore) {
|
||||
*self.config_store.write().await = Some(config);
|
||||
}
|
||||
|
||||
/// Get current streaming mode
|
||||
pub async fn current_mode(&self) -> StreamMode {
|
||||
self.mode.read().await.clone()
|
||||
}
|
||||
|
||||
/// Check if MJPEG mode is active
|
||||
pub async fn is_mjpeg_enabled(&self) -> bool {
|
||||
*self.mode.read().await == StreamMode::Mjpeg
|
||||
}
|
||||
|
||||
/// Check if WebRTC mode is active
|
||||
pub async fn is_webrtc_enabled(&self) -> bool {
|
||||
*self.mode.read().await == StreamMode::WebRTC
|
||||
}
|
||||
|
||||
/// Get the underlying streamer (for MJPEG mode)
|
||||
pub fn streamer(&self) -> Arc<Streamer> {
|
||||
self.streamer.clone()
|
||||
}
|
||||
|
||||
/// Get the WebRTC streamer (unified interface with multi-codec support)
|
||||
pub fn webrtc_streamer(&self) -> Arc<WebRtcStreamer> {
|
||||
self.webrtc_streamer.clone()
|
||||
}
|
||||
|
||||
/// Get the MJPEG stream handler
|
||||
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
|
||||
self.streamer.mjpeg_handler()
|
||||
}
|
||||
|
||||
/// Initialize with a specific mode
|
||||
pub async fn init_with_mode(self: &Arc<Self>, mode: StreamMode) -> Result<()> {
|
||||
info!("Initializing video stream manager with mode: {:?}", mode);
|
||||
*self.mode.write().await = mode.clone();
|
||||
|
||||
// Check if streamer is already initialized (capturer exists)
|
||||
let needs_init = self.streamer.state().await == StreamerState::Uninitialized;
|
||||
|
||||
if needs_init {
|
||||
match mode {
|
||||
StreamMode::Mjpeg => {
|
||||
// Initialize MJPEG streamer
|
||||
if let Err(e) = self.streamer.init_auto().await {
|
||||
warn!("Failed to auto-initialize MJPEG streamer: {}", e);
|
||||
}
|
||||
}
|
||||
StreamMode::WebRTC => {
|
||||
// WebRTC is initialized on-demand when clients connect
|
||||
// But we still need to initialize the video capture
|
||||
if let Err(e) = self.streamer.init_auto().await {
|
||||
warn!("Failed to auto-initialize video capture for WebRTC: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Always reconnect frame source after initialization
|
||||
// This ensures WebRTC has the correct frame_tx from the current capturer
|
||||
if let Some(frame_tx) = self.streamer.frame_sender().await {
|
||||
// Synchronize WebRTC config with actual capture format
|
||||
let (format, resolution, fps) = self.streamer.current_video_config().await;
|
||||
info!(
|
||||
"Reconnecting frame source to WebRTC after init: {}x{} {:?} @ {}fps (receiver_count={})",
|
||||
resolution.width, resolution.height, format, fps, frame_tx.receiver_count()
|
||||
);
|
||||
self.webrtc_streamer.update_video_config(resolution, format, fps).await;
|
||||
self.webrtc_streamer.set_video_source(frame_tx).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Switch streaming mode
|
||||
///
|
||||
/// This will:
|
||||
/// 1. Acquire switching lock (prevent concurrent switches)
|
||||
/// 2. Notify clients of the mode change
|
||||
/// 3. Stop the current mode
|
||||
/// 4. Start the new mode (ensuring video capture runs for WebRTC)
|
||||
/// 5. Update configuration
|
||||
pub async fn switch_mode(self: &Arc<Self>, new_mode: StreamMode) -> Result<()> {
|
||||
let current_mode = self.mode.read().await.clone();
|
||||
|
||||
if current_mode == new_mode {
|
||||
debug!("Already in {:?} mode, no switch needed", new_mode);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Acquire switching lock - prevent concurrent switch requests
|
||||
if self.switching.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
|
||||
debug!("Mode switch already in progress, ignoring duplicate request");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Use a helper to ensure we release the lock when done
|
||||
let result = self.do_switch_mode(current_mode, new_mode.clone()).await;
|
||||
self.switching.store(false, Ordering::SeqCst);
|
||||
result
|
||||
}
|
||||
|
||||
/// Internal implementation of mode switching (called with lock held)
|
||||
async fn do_switch_mode(self: &Arc<Self>, current_mode: StreamMode, new_mode: StreamMode) -> Result<()> {
|
||||
info!("Switching video mode: {:?} -> {:?}", current_mode, new_mode);
|
||||
|
||||
// Get the actual mode strings (with codec info for WebRTC)
|
||||
let new_mode_str = match &new_mode {
|
||||
StreamMode::Mjpeg => "mjpeg".to_string(),
|
||||
StreamMode::WebRTC => {
|
||||
let codec = self.webrtc_streamer.current_video_codec().await;
|
||||
codec_to_string(codec)
|
||||
}
|
||||
};
|
||||
let previous_mode_str = match ¤t_mode {
|
||||
StreamMode::Mjpeg => "mjpeg".to_string(),
|
||||
StreamMode::WebRTC => {
|
||||
let codec = self.webrtc_streamer.current_video_codec().await;
|
||||
codec_to_string(codec)
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Publish mode change event (clients should prepare to reconnect)
|
||||
self.publish_event(SystemEvent::StreamModeChanged {
|
||||
mode: new_mode_str,
|
||||
previous_mode: previous_mode_str,
|
||||
})
|
||||
.await;
|
||||
|
||||
// 2. Stop current mode
|
||||
match current_mode {
|
||||
StreamMode::Mjpeg => {
|
||||
info!("Stopping MJPEG streaming");
|
||||
// Only stop MJPEG distribution, keep video capture running for WebRTC
|
||||
self.streamer.mjpeg_handler().set_offline();
|
||||
if let Err(e) = self.streamer.stop().await {
|
||||
warn!("Error stopping MJPEG streamer: {}", e);
|
||||
}
|
||||
}
|
||||
StreamMode::WebRTC => {
|
||||
info!("Closing all WebRTC sessions");
|
||||
let closed = self.webrtc_streamer.close_all_sessions().await;
|
||||
if closed > 0 {
|
||||
info!("Closed {} WebRTC sessions", closed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Update mode
|
||||
*self.mode.write().await = new_mode.clone();
|
||||
|
||||
// 4. Start new mode
|
||||
match new_mode {
|
||||
StreamMode::Mjpeg => {
|
||||
info!("Starting MJPEG streaming");
|
||||
if let Err(e) = self.streamer.start().await {
|
||||
error!("Failed to start MJPEG streamer: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
StreamMode::WebRTC => {
|
||||
// WebRTC mode: ensure video capture is running for H264 encoding
|
||||
info!("Activating WebRTC mode");
|
||||
|
||||
// Initialize streamer if not already initialized
|
||||
if self.streamer.state().await == StreamerState::Uninitialized {
|
||||
info!("Initializing video capture for WebRTC");
|
||||
if let Err(e) = self.streamer.init_auto().await {
|
||||
error!("Failed to initialize video capture for WebRTC: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Start video capture if not streaming
|
||||
if self.streamer.state().await != StreamerState::Streaming {
|
||||
info!("Starting video capture for WebRTC");
|
||||
if let Err(e) = self.streamer.start().await {
|
||||
error!("Failed to start video capture for WebRTC: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait a bit for capture to stabilize
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Connect frame source to WebRTC with correct format
|
||||
if let Some(frame_tx) = self.streamer.frame_sender().await {
|
||||
// Synchronize WebRTC config with actual capture format
|
||||
let (format, resolution, fps) = self.streamer.current_video_config().await;
|
||||
info!(
|
||||
"Connecting frame source to WebRTC pipeline: {}x{} {:?} @ {}fps",
|
||||
resolution.width, resolution.height, format, fps
|
||||
);
|
||||
self.webrtc_streamer.update_video_config(resolution, format, fps).await;
|
||||
self.webrtc_streamer.set_video_source(frame_tx).await;
|
||||
} else {
|
||||
warn!("No frame source available for WebRTC - sessions may fail to receive video");
|
||||
}
|
||||
|
||||
info!("WebRTC mode activated (sessions created on-demand)");
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Update configuration store if available
|
||||
if let Some(ref config_store) = *self.config_store.read().await {
|
||||
let mut config = (*config_store.get()).clone();
|
||||
config.stream.mode = new_mode.clone();
|
||||
if let Err(e) = config_store.set(config).await {
|
||||
warn!("Failed to persist stream mode to config: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
info!("Video mode switched to {:?}", new_mode);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply video configuration (device, format, resolution, fps)
|
||||
///
|
||||
/// This is called when video settings change. It will restart the
|
||||
/// appropriate streaming pipeline based on current mode.
|
||||
pub async fn apply_video_config(
|
||||
self: &Arc<Self>,
|
||||
device_path: &str,
|
||||
format: PixelFormat,
|
||||
resolution: Resolution,
|
||||
fps: u32,
|
||||
) -> Result<()> {
|
||||
let mode = self.mode.read().await.clone();
|
||||
|
||||
info!(
|
||||
"Applying video config: {} {:?} {}x{} @ {} fps (mode: {:?})",
|
||||
device_path, format, resolution.width, resolution.height, fps, mode
|
||||
);
|
||||
|
||||
// Apply to streamer (handles video capture)
|
||||
self.streamer
|
||||
.apply_video_config(device_path, format, resolution, fps)
|
||||
.await?;
|
||||
|
||||
// Update WebRTC config if in WebRTC mode
|
||||
if mode == StreamMode::WebRTC {
|
||||
self.webrtc_streamer
|
||||
.update_video_config(resolution, format, fps)
|
||||
.await;
|
||||
|
||||
// Restart video capture for WebRTC (it was stopped during config change)
|
||||
info!("Restarting video capture for WebRTC after config change");
|
||||
if let Err(e) = self.streamer.start().await {
|
||||
error!("Failed to restart video capture for WebRTC: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// Wait a bit for capture to stabilize
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Reconnect frame source with the new capturer
|
||||
if let Some(frame_tx) = self.streamer.frame_sender().await {
|
||||
// Note: update_video_config was already called above with the requested config,
|
||||
// but verify that actual capture matches
|
||||
let (actual_format, actual_resolution, actual_fps) = self.streamer.current_video_config().await;
|
||||
if actual_format != format || actual_resolution != resolution || actual_fps != fps {
|
||||
info!(
|
||||
"Actual capture config differs from requested, updating WebRTC: {}x{} {:?} @ {}fps",
|
||||
actual_resolution.width, actual_resolution.height, actual_format, actual_fps
|
||||
);
|
||||
self.webrtc_streamer.update_video_config(actual_resolution, actual_format, actual_fps).await;
|
||||
}
|
||||
info!("Reconnecting frame source to WebRTC after config change");
|
||||
self.webrtc_streamer.set_video_source(frame_tx).await;
|
||||
} else {
|
||||
warn!("No frame source available after config change");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start streaming (based on current mode)
|
||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||
let mode = self.mode.read().await.clone();
|
||||
|
||||
match mode {
|
||||
StreamMode::Mjpeg => {
|
||||
self.streamer.start().await?;
|
||||
}
|
||||
StreamMode::WebRTC => {
|
||||
// Ensure video capture is running
|
||||
if self.streamer.state().await == StreamerState::Uninitialized {
|
||||
self.streamer.init_auto().await?;
|
||||
}
|
||||
if self.streamer.state().await != StreamerState::Streaming {
|
||||
self.streamer.start().await?;
|
||||
}
|
||||
|
||||
// Connect frame source with correct format
|
||||
if let Some(frame_tx) = self.streamer.frame_sender().await {
|
||||
// Synchronize WebRTC config with actual capture format
|
||||
let (format, resolution, fps) = self.streamer.current_video_config().await;
|
||||
self.webrtc_streamer.update_video_config(resolution, format, fps).await;
|
||||
self.webrtc_streamer.set_video_source(frame_tx).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop streaming
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
let mode = self.mode.read().await.clone();
|
||||
|
||||
match mode {
|
||||
StreamMode::Mjpeg => {
|
||||
self.streamer.stop().await?;
|
||||
}
|
||||
StreamMode::WebRTC => {
|
||||
self.webrtc_streamer.close_all_sessions().await;
|
||||
self.streamer.stop().await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get video device info for device_info event
|
||||
pub async fn get_video_info(&self) -> VideoDeviceInfo {
|
||||
let stats = self.streamer.stats().await;
|
||||
let state = self.streamer.state().await;
|
||||
let device = self.streamer.current_device().await;
|
||||
let mode = self.mode.read().await.clone();
|
||||
|
||||
// For WebRTC mode, return specific codec type (h264, h265, vp8, vp9)
|
||||
// instead of generic "webrtc" to prevent frontend from defaulting to h264
|
||||
let stream_mode = match &mode {
|
||||
StreamMode::Mjpeg => "mjpeg".to_string(),
|
||||
StreamMode::WebRTC => {
|
||||
let codec = self.webrtc_streamer.current_video_codec().await;
|
||||
codec_to_string(codec)
|
||||
}
|
||||
};
|
||||
|
||||
VideoDeviceInfo {
|
||||
available: state != StreamerState::Uninitialized,
|
||||
device: device.map(|d| d.path.display().to_string()),
|
||||
format: stats.format,
|
||||
resolution: stats.resolution,
|
||||
fps: stats.target_fps,
|
||||
online: state == StreamerState::Streaming,
|
||||
stream_mode,
|
||||
config_changing: self.streamer.is_config_changing(),
|
||||
error: if state == StreamerState::Error {
|
||||
Some("Video stream error".to_string())
|
||||
} else if state == StreamerState::NoSignal {
|
||||
Some("No video signal".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Get MJPEG client count
|
||||
pub fn mjpeg_client_count(&self) -> u64 {
|
||||
self.streamer.mjpeg_handler().client_count()
|
||||
}
|
||||
|
||||
/// Get WebRTC session count
|
||||
pub async fn webrtc_session_count(&self) -> usize {
|
||||
self.webrtc_streamer.session_count().await
|
||||
}
|
||||
|
||||
/// Set HID controller for WebRTC DataChannel
|
||||
pub async fn set_hid_controller(&self, hid: Arc<HidController>) {
|
||||
self.webrtc_streamer.set_hid_controller(hid).await;
|
||||
}
|
||||
|
||||
/// Set audio enabled state for WebRTC
|
||||
pub async fn set_webrtc_audio_enabled(&self, enabled: bool) -> Result<()> {
|
||||
self.webrtc_streamer.set_audio_enabled(enabled).await
|
||||
}
|
||||
|
||||
/// Check if WebRTC audio is enabled
|
||||
pub async fn is_webrtc_audio_enabled(&self) -> bool {
|
||||
self.webrtc_streamer.is_audio_enabled().await
|
||||
}
|
||||
|
||||
/// Reconnect audio sources for all WebRTC sessions
|
||||
/// Call this after audio controller restarts (e.g., quality change)
|
||||
pub async fn reconnect_webrtc_audio_sources(&self) {
|
||||
self.webrtc_streamer.reconnect_audio_sources().await;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Delegated methods from Streamer (for backward compatibility)
|
||||
// =========================================================================
|
||||
|
||||
/// List available video devices
|
||||
pub async fn list_devices(&self) -> crate::error::Result<Vec<crate::video::device::VideoDeviceInfo>> {
|
||||
self.streamer.list_devices().await
|
||||
}
|
||||
|
||||
/// Get streamer statistics
|
||||
pub async fn stats(&self) -> crate::video::streamer::StreamerStats {
|
||||
self.streamer.stats().await
|
||||
}
|
||||
|
||||
/// Check if config is being changed
|
||||
pub fn is_config_changing(&self) -> bool {
|
||||
self.streamer.is_config_changing()
|
||||
}
|
||||
|
||||
/// Check if streaming is active
|
||||
pub async fn is_streaming(&self) -> bool {
|
||||
self.streamer.is_streaming().await
|
||||
}
|
||||
|
||||
/// Get frame sender for video frames
|
||||
pub async fn frame_sender(&self) -> Option<tokio::sync::broadcast::Sender<crate::video::frame::VideoFrame>> {
|
||||
self.streamer.frame_sender().await
|
||||
}
|
||||
|
||||
/// Publish event to event bus
|
||||
async fn publish_event(&self, event: SystemEvent) {
|
||||
if let Some(ref events) = *self.events.read().await {
|
||||
events.publish(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert VideoCodecType to lowercase string for frontend
|
||||
fn codec_to_string(codec: crate::video::encoder::VideoCodecType) -> String {
|
||||
match codec {
|
||||
crate::video::encoder::VideoCodecType::H264 => "h264".to_string(),
|
||||
crate::video::encoder::VideoCodecType::H265 => "h265".to_string(),
|
||||
crate::video::encoder::VideoCodecType::VP8 => "vp8".to_string(),
|
||||
crate::video::encoder::VideoCodecType::VP9 => "vp9".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::video::encoder::VideoCodecType;
|
||||
|
||||
#[test]
|
||||
fn test_codec_to_string() {
|
||||
assert_eq!(codec_to_string(VideoCodecType::H264), "h264");
|
||||
assert_eq!(codec_to_string(VideoCodecType::H265), "h265");
|
||||
assert_eq!(codec_to_string(VideoCodecType::VP8), "vp8");
|
||||
assert_eq!(codec_to_string(VideoCodecType::VP9), "vp9");
|
||||
}
|
||||
}
|
||||
892
src/video/streamer.rs
Normal file
892
src/video/streamer.rs
Normal file
@@ -0,0 +1,892 @@
|
||||
//! Video streamer that integrates capture and streaming
|
||||
//!
|
||||
//! This module provides a high-level interface for video capture and streaming,
|
||||
//! managing the lifecycle of the capture thread and MJPEG/WebRTC distribution.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use super::capture::{CaptureConfig, CaptureState, VideoCapturer};
|
||||
use super::device::{enumerate_devices, find_best_device, VideoDeviceInfo};
|
||||
use super::format::{PixelFormat, Resolution};
|
||||
use super::frame::VideoFrame;
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
use crate::stream::MjpegStreamHandler;
|
||||
|
||||
/// Streamer configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamerConfig {
|
||||
/// Device path (None = auto-detect)
|
||||
pub device_path: Option<PathBuf>,
|
||||
/// Desired resolution
|
||||
pub resolution: Resolution,
|
||||
/// Desired format
|
||||
pub format: PixelFormat,
|
||||
/// Desired FPS
|
||||
pub fps: u32,
|
||||
/// JPEG quality (1-100)
|
||||
pub jpeg_quality: u8,
|
||||
}
|
||||
|
||||
impl Default for StreamerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_path: None,
|
||||
resolution: Resolution::HD1080,
|
||||
format: PixelFormat::Mjpeg,
|
||||
fps: 30,
|
||||
jpeg_quality: 80,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streamer state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StreamerState {
|
||||
/// Not initialized
|
||||
Uninitialized,
|
||||
/// Ready but not streaming
|
||||
Ready,
|
||||
/// Actively streaming
|
||||
Streaming,
|
||||
/// No video signal
|
||||
NoSignal,
|
||||
/// Error occurred
|
||||
Error,
|
||||
/// Device was lost (unplugged)
|
||||
DeviceLost,
|
||||
/// Device is being recovered (reconnecting)
|
||||
Recovering,
|
||||
}
|
||||
|
||||
/// Video streamer service
|
||||
pub struct Streamer {
|
||||
config: RwLock<StreamerConfig>,
|
||||
capturer: RwLock<Option<Arc<VideoCapturer>>>,
|
||||
mjpeg_handler: Arc<MjpegStreamHandler>,
|
||||
current_device: RwLock<Option<VideoDeviceInfo>>,
|
||||
state: RwLock<StreamerState>,
|
||||
start_lock: tokio::sync::Mutex<()>,
|
||||
/// Event bus for broadcasting state changes (optional)
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
/// Last published state (for change detection)
|
||||
last_published_state: RwLock<Option<StreamerState>>,
|
||||
/// Flag to indicate config is being changed (prevents auto-start during config change)
|
||||
config_changing: std::sync::atomic::AtomicBool,
|
||||
/// Flag to indicate background tasks (stats, cleanup, monitor) have been started
|
||||
/// These tasks should only be started once per Streamer instance
|
||||
background_tasks_started: std::sync::atomic::AtomicBool,
|
||||
/// Device recovery retry count
|
||||
recovery_retry_count: std::sync::atomic::AtomicU32,
|
||||
/// Device recovery in progress flag
|
||||
recovery_in_progress: std::sync::atomic::AtomicBool,
|
||||
/// Last lost device path (for recovery)
|
||||
last_lost_device: RwLock<Option<String>>,
|
||||
/// Last lost device reason (for logging)
|
||||
last_lost_reason: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl Streamer {
|
||||
/// Create a new streamer
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
config: RwLock::new(StreamerConfig::default()),
|
||||
capturer: RwLock::new(None),
|
||||
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
|
||||
current_device: RwLock::new(None),
|
||||
state: RwLock::new(StreamerState::Uninitialized),
|
||||
start_lock: tokio::sync::Mutex::new(()),
|
||||
events: RwLock::new(None),
|
||||
last_published_state: RwLock::new(None),
|
||||
config_changing: std::sync::atomic::AtomicBool::new(false),
|
||||
background_tasks_started: std::sync::atomic::AtomicBool::new(false),
|
||||
recovery_retry_count: std::sync::atomic::AtomicU32::new(0),
|
||||
recovery_in_progress: std::sync::atomic::AtomicBool::new(false),
|
||||
last_lost_device: RwLock::new(None),
|
||||
last_lost_reason: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with specific config
|
||||
pub fn with_config(config: StreamerConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
config: RwLock::new(config),
|
||||
capturer: RwLock::new(None),
|
||||
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
|
||||
current_device: RwLock::new(None),
|
||||
state: RwLock::new(StreamerState::Uninitialized),
|
||||
start_lock: tokio::sync::Mutex::new(()),
|
||||
events: RwLock::new(None),
|
||||
last_published_state: RwLock::new(None),
|
||||
config_changing: std::sync::atomic::AtomicBool::new(false),
|
||||
background_tasks_started: std::sync::atomic::AtomicBool::new(false),
|
||||
recovery_retry_count: std::sync::atomic::AtomicU32::new(0),
|
||||
recovery_in_progress: std::sync::atomic::AtomicBool::new(false),
|
||||
last_lost_device: RwLock::new(None),
|
||||
last_lost_reason: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current state as SystemEvent
|
||||
pub async fn current_state_event(&self) -> SystemEvent {
|
||||
let state = *self.state.read().await;
|
||||
let device = self.current_device.read().await.as_ref().map(|d| d.path.display().to_string());
|
||||
|
||||
SystemEvent::StreamStateChanged {
|
||||
state: match state {
|
||||
StreamerState::Uninitialized => "uninitialized".to_string(),
|
||||
StreamerState::Ready => "ready".to_string(),
|
||||
StreamerState::Streaming => "streaming".to_string(),
|
||||
StreamerState::NoSignal => "no_signal".to_string(),
|
||||
StreamerState::Error => "error".to_string(),
|
||||
StreamerState::DeviceLost => "device_lost".to_string(),
|
||||
StreamerState::Recovering => "recovering".to_string(),
|
||||
},
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub async fn state(&self) -> StreamerState {
|
||||
*self.state.read().await
|
||||
}
|
||||
|
||||
/// Check if config is currently being changed
|
||||
/// When true, auto-start should be blocked to prevent device busy errors
|
||||
pub fn is_config_changing(&self) -> bool {
|
||||
self.config_changing.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Get MJPEG handler for stream endpoints
|
||||
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
|
||||
self.mjpeg_handler.clone()
|
||||
}
|
||||
|
||||
/// Get frame sender for WebRTC integration
|
||||
/// Returns None if no capturer is initialized
|
||||
pub async fn frame_sender(&self) -> Option<broadcast::Sender<VideoFrame>> {
|
||||
let capturer = self.capturer.read().await;
|
||||
capturer.as_ref().map(|c| c.frame_sender())
|
||||
}
|
||||
|
||||
/// Subscribe to video frames
|
||||
/// Returns None if no capturer is initialized
|
||||
pub async fn subscribe_frames(&self) -> Option<broadcast::Receiver<VideoFrame>> {
|
||||
let capturer = self.capturer.read().await;
|
||||
capturer.as_ref().map(|c| c.subscribe())
|
||||
}
|
||||
|
||||
/// Get current device info
|
||||
pub async fn current_device(&self) -> Option<VideoDeviceInfo> {
|
||||
self.current_device.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get current video configuration (format, resolution, fps)
|
||||
pub async fn current_video_config(&self) -> (PixelFormat, Resolution, u32) {
|
||||
let config = self.config.read().await;
|
||||
(config.format, config.resolution, config.fps)
|
||||
}
|
||||
|
||||
/// List available video devices
|
||||
pub async fn list_devices(&self) -> Result<Vec<VideoDeviceInfo>> {
|
||||
enumerate_devices()
|
||||
}
|
||||
|
||||
/// Validate and apply requested video parameters without auto-selection
|
||||
pub async fn apply_video_config(
|
||||
self: &Arc<Self>,
|
||||
device_path: &str,
|
||||
format: PixelFormat,
|
||||
resolution: Resolution,
|
||||
fps: u32,
|
||||
) -> Result<()> {
|
||||
// Set config_changing flag to prevent frontend mode sync during config change
|
||||
self.config_changing.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
let result = self.apply_video_config_inner(device_path, format, resolution, fps).await;
|
||||
|
||||
// Clear the flag after config change is complete
|
||||
// The stream will be started by MJPEG client connection, not here
|
||||
self.config_changing.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Internal implementation of apply_video_config
|
||||
async fn apply_video_config_inner(
|
||||
self: &Arc<Self>,
|
||||
device_path: &str,
|
||||
format: PixelFormat,
|
||||
resolution: Resolution,
|
||||
fps: u32,
|
||||
) -> Result<()> {
|
||||
// Publish "config changing" event
|
||||
self.publish_event(SystemEvent::StreamConfigChanging {
|
||||
reason: "device_switch".to_string(),
|
||||
})
|
||||
.await;
|
||||
|
||||
let devices = enumerate_devices()?;
|
||||
let device = devices
|
||||
.into_iter()
|
||||
.find(|d| d.path.to_string_lossy() == device_path)
|
||||
.ok_or_else(|| AppError::VideoError("Video device not found".to_string()))?;
|
||||
|
||||
// Validate format
|
||||
let fmt_info = device
|
||||
.formats
|
||||
.iter()
|
||||
.find(|f| f.format == format)
|
||||
.ok_or_else(|| AppError::VideoError("Requested format not supported".to_string()))?;
|
||||
|
||||
// Validate resolution
|
||||
if !fmt_info.resolutions.is_empty()
|
||||
&& !fmt_info
|
||||
.resolutions
|
||||
.iter()
|
||||
.any(|r| r.width == resolution.width && r.height == resolution.height)
|
||||
{
|
||||
return Err(AppError::VideoError("Requested resolution not supported".to_string()));
|
||||
}
|
||||
|
||||
// IMPORTANT: Disconnect all MJPEG clients FIRST before stopping capture
|
||||
// This prevents race conditions where clients try to reconnect and reopen the device
|
||||
info!("Disconnecting all MJPEG clients before config change...");
|
||||
self.mjpeg_handler.disconnect_all_clients();
|
||||
|
||||
// Give clients time to receive the disconnect signal and close their connections
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Stop existing capturer and wait for device release
|
||||
{
|
||||
// Take ownership of the old capturer to ensure it's dropped
|
||||
let old_capturer = self.capturer.write().await.take();
|
||||
if let Some(capturer) = old_capturer {
|
||||
info!("Stopping existing capture before applying new config...");
|
||||
if let Err(e) = capturer.stop().await {
|
||||
warn!("Error stopping old capturer: {}", e);
|
||||
}
|
||||
// Explicitly drop the capturer to release V4L2 resources
|
||||
drop(capturer);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Update config
|
||||
{
|
||||
let mut cfg = self.config.write().await;
|
||||
cfg.device_path = Some(device.path.clone());
|
||||
cfg.format = format;
|
||||
cfg.resolution = resolution;
|
||||
cfg.fps = fps;
|
||||
}
|
||||
|
||||
// Recreate capturer
|
||||
let capture_config = CaptureConfig {
|
||||
device_path: device.path.clone(),
|
||||
resolution,
|
||||
format,
|
||||
fps,
|
||||
jpeg_quality: self.config.read().await.jpeg_quality,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let capturer = Arc::new(VideoCapturer::new(capture_config));
|
||||
*self.capturer.write().await = Some(capturer.clone());
|
||||
*self.current_device.write().await = Some(device.clone());
|
||||
*self.state.write().await = StreamerState::Ready;
|
||||
|
||||
// Publish "config applied" event
|
||||
info!("Publishing StreamConfigApplied event: {}x{} {:?} @ {}fps",
|
||||
resolution.width, resolution.height, format, fps);
|
||||
self.publish_event(SystemEvent::StreamConfigApplied {
|
||||
device: device_path.to_string(),
|
||||
resolution: (resolution.width, resolution.height),
|
||||
format: format!("{:?}", format),
|
||||
fps,
|
||||
})
|
||||
.await;
|
||||
|
||||
// Note: We don't auto-start here anymore.
|
||||
// The stream will be started when MJPEG client connects (handlers.rs:790)
|
||||
// This avoids race conditions between config change and client reconnection.
|
||||
info!("Config applied, stream will start when client connects");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize with auto-detected device
|
||||
pub async fn init_auto(self: &Arc<Self>) -> Result<()> {
|
||||
info!("Auto-detecting video device...");
|
||||
|
||||
let device = find_best_device()?;
|
||||
info!("Found device: {} ({})", device.name, device.path.display());
|
||||
|
||||
self.init_with_device(device).await
|
||||
}
|
||||
|
||||
/// Initialize with specific device
|
||||
pub async fn init_with_device(self: &Arc<Self>, device: VideoDeviceInfo) -> Result<()> {
|
||||
info!(
|
||||
"Initializing streamer with device: {} ({})",
|
||||
device.name,
|
||||
device.path.display()
|
||||
);
|
||||
|
||||
// Determine best format for this device
|
||||
let config = self.config.read().await;
|
||||
let format = self.select_format(&device, config.format)?;
|
||||
let resolution = self.select_resolution(&device, &format, config.resolution)?;
|
||||
|
||||
drop(config);
|
||||
|
||||
// Update config with actual values
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.device_path = Some(device.path.clone());
|
||||
config.format = format;
|
||||
config.resolution = resolution;
|
||||
}
|
||||
|
||||
// Store device info
|
||||
*self.current_device.write().await = Some(device.clone());
|
||||
|
||||
// Create capturer
|
||||
let config = self.config.read().await;
|
||||
let capture_config = CaptureConfig {
|
||||
device_path: device.path.clone(),
|
||||
resolution: config.resolution,
|
||||
format: config.format,
|
||||
fps: config.fps,
|
||||
jpeg_quality: config.jpeg_quality,
|
||||
..Default::default()
|
||||
};
|
||||
drop(config);
|
||||
|
||||
let capturer = Arc::new(VideoCapturer::new(capture_config));
|
||||
*self.capturer.write().await = Some(capturer);
|
||||
|
||||
*self.state.write().await = StreamerState::Ready;
|
||||
|
||||
info!("Streamer initialized: {} @ {}", format, resolution);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Select best format for device
|
||||
fn select_format(&self, device: &VideoDeviceInfo, preferred: PixelFormat) -> Result<PixelFormat> {
|
||||
// Check if preferred format is available
|
||||
if device.formats.iter().any(|f| f.format == preferred) {
|
||||
return Ok(preferred);
|
||||
}
|
||||
|
||||
// Select best available format
|
||||
device
|
||||
.formats
|
||||
.first()
|
||||
.map(|f| f.format)
|
||||
.ok_or_else(|| AppError::VideoError("No supported formats found".to_string()))
|
||||
}
|
||||
|
||||
/// Select best resolution for format
|
||||
fn select_resolution(
|
||||
&self,
|
||||
device: &VideoDeviceInfo,
|
||||
format: &PixelFormat,
|
||||
preferred: Resolution,
|
||||
) -> Result<Resolution> {
|
||||
let format_info = device
|
||||
.formats
|
||||
.iter()
|
||||
.find(|f| &f.format == format)
|
||||
.ok_or_else(|| AppError::VideoError("Format not found".to_string()))?;
|
||||
|
||||
// Check if preferred resolution is available
|
||||
if format_info.resolutions.is_empty()
|
||||
|| format_info.resolutions.iter().any(|r| {
|
||||
r.width == preferred.width && r.height == preferred.height
|
||||
})
|
||||
{
|
||||
return Ok(preferred);
|
||||
}
|
||||
|
||||
// Select largest available resolution
|
||||
format_info
|
||||
.resolutions
|
||||
.first()
|
||||
.map(|r| r.resolution())
|
||||
.ok_or_else(|| AppError::VideoError("No resolutions available".to_string()))
|
||||
}
|
||||
|
||||
/// Restart the capturer only (for recovery - doesn't spawn new monitor)
|
||||
///
|
||||
/// This is a simpler version of start() used during device recovery.
|
||||
/// It doesn't spawn a new state monitor since the existing one is still active.
|
||||
async fn restart_capturer(&self) -> Result<()> {
|
||||
let capturer = self.capturer.read().await;
|
||||
let capturer = capturer
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::VideoError("Capturer not initialized".to_string()))?;
|
||||
|
||||
// Start capture
|
||||
capturer.start().await?;
|
||||
|
||||
// Set MJPEG handler online
|
||||
self.mjpeg_handler.set_online();
|
||||
|
||||
// Start frame distribution task
|
||||
let mjpeg_handler = self.mjpeg_handler.clone();
|
||||
let mut frame_rx = capturer.subscribe();
|
||||
|
||||
tokio::spawn(async move {
|
||||
debug!("Recovery frame distribution task started");
|
||||
loop {
|
||||
match frame_rx.recv().await {
|
||||
Ok(frame) => {
|
||||
mjpeg_handler.update_frame(frame);
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
trace!("Frame distribution lagged by {} frames", n);
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||
debug!("Frame channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start streaming
|
||||
pub async fn start(self: &Arc<Self>) -> Result<()> {
|
||||
let _lock = self.start_lock.lock().await;
|
||||
|
||||
let state = self.state().await;
|
||||
if state == StreamerState::Streaming {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if state == StreamerState::Uninitialized {
|
||||
// Auto-initialize if not done
|
||||
self.init_auto().await?;
|
||||
}
|
||||
|
||||
let capturer = self.capturer.read().await;
|
||||
let capturer = capturer
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::VideoError("Capturer not initialized".to_string()))?;
|
||||
|
||||
// Start capture
|
||||
capturer.start().await?;
|
||||
|
||||
// Set MJPEG handler online before starting frame distribution
|
||||
// This is important after config changes where disconnect_all_clients() set it offline
|
||||
self.mjpeg_handler.set_online();
|
||||
|
||||
// Start frame distribution task
|
||||
let mjpeg_handler = self.mjpeg_handler.clone();
|
||||
let mut frame_rx = capturer.subscribe();
|
||||
let state_ref = Arc::downgrade(self);
|
||||
|
||||
tokio::spawn(async move {
|
||||
info!("Frame distribution task started");
|
||||
loop {
|
||||
match frame_rx.recv().await {
|
||||
Ok(frame) => {
|
||||
mjpeg_handler.update_frame(frame);
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
trace!("Frame distribution lagged by {} frames", n);
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||
debug!("Frame channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if streamer still exists
|
||||
if state_ref.upgrade().is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
info!("Frame distribution task ended");
|
||||
});
|
||||
|
||||
// Monitor capture state
|
||||
let mut state_rx = capturer.state_watch();
|
||||
let state_ref = Arc::downgrade(self);
|
||||
let mjpeg_handler = self.mjpeg_handler.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while state_rx.changed().await.is_ok() {
|
||||
let capture_state = *state_rx.borrow();
|
||||
match capture_state {
|
||||
CaptureState::Running => {
|
||||
if let Some(streamer) = state_ref.upgrade() {
|
||||
*streamer.state.write().await = StreamerState::Streaming;
|
||||
}
|
||||
}
|
||||
CaptureState::NoSignal => {
|
||||
mjpeg_handler.set_offline();
|
||||
if let Some(streamer) = state_ref.upgrade() {
|
||||
*streamer.state.write().await = StreamerState::NoSignal;
|
||||
}
|
||||
}
|
||||
CaptureState::Stopped => {
|
||||
mjpeg_handler.set_offline();
|
||||
if let Some(streamer) = state_ref.upgrade() {
|
||||
*streamer.state.write().await = StreamerState::Ready;
|
||||
}
|
||||
}
|
||||
CaptureState::Error => {
|
||||
mjpeg_handler.set_offline();
|
||||
if let Some(streamer) = state_ref.upgrade() {
|
||||
*streamer.state.write().await = StreamerState::Error;
|
||||
}
|
||||
}
|
||||
CaptureState::DeviceLost => {
|
||||
mjpeg_handler.set_offline();
|
||||
if let Some(streamer) = state_ref.upgrade() {
|
||||
*streamer.state.write().await = StreamerState::DeviceLost;
|
||||
// Start device recovery task (fire and forget)
|
||||
let streamer_clone = Arc::clone(&streamer);
|
||||
tokio::spawn(async move {
|
||||
streamer_clone.start_device_recovery_internal().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
CaptureState::Starting => {
|
||||
// Starting state - device is initializing, no action needed
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Start background tasks only once per Streamer instance
|
||||
// Use compare_exchange to atomically check and set the flag
|
||||
if self.background_tasks_started
|
||||
.compare_exchange(false, true, std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst)
|
||||
.is_ok()
|
||||
{
|
||||
info!("Starting background tasks (stats, cleanup, monitor)");
|
||||
|
||||
// Start stats broadcast task (sends stats updates every 1 second)
|
||||
let stats_ref = Arc::downgrade(self);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if let Some(streamer) = stats_ref.upgrade() {
|
||||
let clients_stat = streamer.mjpeg_handler().get_clients_stat();
|
||||
let clients = clients_stat.len() as u64;
|
||||
|
||||
streamer.publish_event(SystemEvent::StreamStatsUpdate {
|
||||
clients,
|
||||
clients_stat,
|
||||
}).await;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Start client cleanup task (removes stale clients every 5s)
|
||||
self.mjpeg_handler.clone().start_cleanup_task();
|
||||
|
||||
// Start auto-pause monitor task (stops stream if no clients)
|
||||
let monitor_ref = Arc::downgrade(self);
|
||||
let monitor_handler = self.mjpeg_handler.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(2));
|
||||
let mut zero_since: Option<std::time::Instant> = None;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let Some(streamer) = monitor_ref.upgrade() else { break; };
|
||||
|
||||
// Check auto-pause configuration
|
||||
let config = monitor_handler.auto_pause_config();
|
||||
if !config.enabled {
|
||||
zero_since = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
let count = monitor_handler.client_count();
|
||||
|
||||
if count == 0 {
|
||||
if zero_since.is_none() {
|
||||
zero_since = Some(std::time::Instant::now());
|
||||
info!("No clients connected, starting shutdown timer ({}s)", config.shutdown_delay_secs);
|
||||
} else if let Some(since) = zero_since {
|
||||
if since.elapsed().as_secs() >= config.shutdown_delay_secs {
|
||||
info!("Auto-pausing stream (no clients for {}s)", config.shutdown_delay_secs);
|
||||
if let Err(e) = streamer.stop().await {
|
||||
error!("Auto-pause failed: {}", e);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if zero_since.is_some() {
|
||||
info!("Clients reconnected, canceling auto-pause");
|
||||
zero_since = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
debug!("Background tasks already started, skipping");
|
||||
}
|
||||
|
||||
*self.state.write().await = StreamerState::Streaming;
|
||||
|
||||
// Publish state change event so DeviceInfo broadcaster can update frontend
|
||||
self.publish_event(self.current_state_event().await).await;
|
||||
|
||||
info!("Streaming started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop streaming
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
if let Some(capturer) = self.capturer.read().await.as_ref() {
|
||||
capturer.stop().await?;
|
||||
}
|
||||
|
||||
self.mjpeg_handler.set_offline();
|
||||
*self.state.write().await = StreamerState::Ready;
|
||||
|
||||
// Publish state change event so DeviceInfo broadcaster can update frontend
|
||||
self.publish_event(self.current_state_event().await).await;
|
||||
|
||||
info!("Streaming stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if streaming
|
||||
pub async fn is_streaming(&self) -> bool {
|
||||
self.state().await == StreamerState::Streaming
|
||||
}
|
||||
|
||||
/// Get stream statistics
|
||||
pub async fn stats(&self) -> StreamerStats {
|
||||
let capturer = self.capturer.read().await;
|
||||
let capture_stats = if let Some(c) = capturer.as_ref() {
|
||||
Some(c.stats().await)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let config = self.config.read().await;
|
||||
|
||||
StreamerStats {
|
||||
state: self.state().await,
|
||||
device: self.current_device().await.map(|d| d.name),
|
||||
format: Some(config.format.to_string()),
|
||||
resolution: Some((config.resolution.width, config.resolution.height)),
|
||||
clients: self.mjpeg_handler.client_count(),
|
||||
target_fps: config.fps,
|
||||
fps: capture_stats.as_ref().map(|s| s.current_fps).unwrap_or(0.0),
|
||||
frames_captured: capture_stats.as_ref().map(|s| s.frames_captured).unwrap_or(0),
|
||||
frames_dropped: capture_stats.as_ref().map(|s| s.frames_dropped).unwrap_or(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish event to event bus (if configured)
|
||||
/// For StreamStateChanged events, only publishes if state actually changed (de-duplication)
|
||||
async fn publish_event(&self, event: SystemEvent) {
|
||||
if let Some(events) = self.events.read().await.as_ref() {
|
||||
// For state change events, check if state actually changed
|
||||
if let SystemEvent::StreamStateChanged { ref state, .. } = event {
|
||||
let current_state = match state.as_str() {
|
||||
"uninitialized" => StreamerState::Uninitialized,
|
||||
"ready" => StreamerState::Ready,
|
||||
"streaming" => StreamerState::Streaming,
|
||||
"no_signal" => StreamerState::NoSignal,
|
||||
"error" => StreamerState::Error,
|
||||
"device_lost" => StreamerState::DeviceLost,
|
||||
"recovering" => StreamerState::Recovering,
|
||||
_ => StreamerState::Error,
|
||||
};
|
||||
|
||||
let mut last_state = self.last_published_state.write().await;
|
||||
if *last_state == Some(current_state) {
|
||||
// State hasn't changed, skip publishing
|
||||
trace!("Skipping duplicate stream state event: {}", state);
|
||||
return;
|
||||
}
|
||||
*last_state = Some(current_state);
|
||||
}
|
||||
|
||||
events.publish(event);
|
||||
}
|
||||
}
|
||||
|
||||
/// Start device recovery task (internal implementation)
|
||||
///
|
||||
/// This method starts a background task that attempts to reconnect
|
||||
/// to the video device after it was lost. It retries every 1 second
|
||||
/// until the device is recovered.
|
||||
async fn start_device_recovery_internal(self: &Arc<Self>) {
|
||||
// Check if recovery is already in progress
|
||||
if self.recovery_in_progress.swap(true, std::sync::atomic::Ordering::SeqCst) {
|
||||
debug!("Device recovery already in progress, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get last lost device info from capturer
|
||||
let (device, reason) = {
|
||||
let capturer = self.capturer.read().await;
|
||||
if let Some(cap) = capturer.as_ref() {
|
||||
cap.last_error().unwrap_or_else(|| {
|
||||
let device_path = self.current_device.blocking_read()
|
||||
.as_ref()
|
||||
.map(|d| d.path.display().to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
(device_path, "Device lost".to_string())
|
||||
})
|
||||
} else {
|
||||
("unknown".to_string(), "Device lost".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
// Store error info
|
||||
*self.last_lost_device.write().await = Some(device.clone());
|
||||
*self.last_lost_reason.write().await = Some(reason.clone());
|
||||
self.recovery_retry_count.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
// Publish device lost event
|
||||
self.publish_event(SystemEvent::StreamDeviceLost {
|
||||
device: device.clone(),
|
||||
reason: reason.clone(),
|
||||
}).await;
|
||||
|
||||
// Start recovery task
|
||||
let streamer = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
let device_path = device.clone();
|
||||
|
||||
loop {
|
||||
let attempt = streamer.recovery_retry_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if still in device lost state
|
||||
let current_state = *streamer.state.read().await;
|
||||
if current_state != StreamerState::DeviceLost && current_state != StreamerState::Recovering {
|
||||
info!("Stream state changed during recovery, stopping recovery task");
|
||||
break;
|
||||
}
|
||||
|
||||
// Update state to Recovering
|
||||
*streamer.state.write().await = StreamerState::Recovering;
|
||||
|
||||
// Publish reconnecting event (every 5 attempts to avoid spam)
|
||||
if attempt == 1 || attempt % 5 == 0 {
|
||||
streamer.publish_event(SystemEvent::StreamReconnecting {
|
||||
device: device_path.clone(),
|
||||
attempt,
|
||||
}).await;
|
||||
info!("Attempting to recover video device {} (attempt {})", device_path, attempt);
|
||||
}
|
||||
|
||||
// Wait before retry (1 second)
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
|
||||
// Check if device file exists
|
||||
let device_exists = std::path::Path::new(&device_path).exists();
|
||||
if !device_exists {
|
||||
debug!("Device {} not present yet", device_path);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to restart capture
|
||||
match streamer.restart_capturer().await {
|
||||
Ok(_) => {
|
||||
info!("Video device {} recovered after {} attempts", device_path, attempt);
|
||||
streamer.recovery_in_progress.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
// Publish recovered event
|
||||
streamer.publish_event(SystemEvent::StreamRecovered {
|
||||
device: device_path.clone(),
|
||||
}).await;
|
||||
|
||||
// Clear error info
|
||||
*streamer.last_lost_device.write().await = None;
|
||||
*streamer.last_lost_reason.write().await = None;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to restart capture (attempt {}): {}", attempt, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
streamer.recovery_in_progress.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Streamer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config: RwLock::new(StreamerConfig::default()),
|
||||
capturer: RwLock::new(None),
|
||||
mjpeg_handler: Arc::new(MjpegStreamHandler::new()),
|
||||
current_device: RwLock::new(None),
|
||||
state: RwLock::new(StreamerState::Uninitialized),
|
||||
start_lock: tokio::sync::Mutex::new(()),
|
||||
events: RwLock::new(None),
|
||||
last_published_state: RwLock::new(None),
|
||||
config_changing: std::sync::atomic::AtomicBool::new(false),
|
||||
background_tasks_started: std::sync::atomic::AtomicBool::new(false),
|
||||
recovery_retry_count: std::sync::atomic::AtomicU32::new(0),
|
||||
recovery_in_progress: std::sync::atomic::AtomicBool::new(false),
|
||||
last_lost_device: RwLock::new(None),
|
||||
last_lost_reason: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streamer statistics
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct StreamerStats {
|
||||
pub state: StreamerState,
|
||||
pub device: Option<String>,
|
||||
pub format: Option<String>,
|
||||
pub resolution: Option<(u32, u32)>,
|
||||
pub clients: u64,
|
||||
/// Target FPS from configuration
|
||||
pub target_fps: u32,
|
||||
/// Current actual FPS
|
||||
pub fps: f32,
|
||||
pub frames_captured: u64,
|
||||
pub frames_dropped: u64,
|
||||
}
|
||||
|
||||
impl serde::Serialize for StreamerState {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let s = match self {
|
||||
StreamerState::Uninitialized => "uninitialized",
|
||||
StreamerState::Ready => "ready",
|
||||
StreamerState::Streaming => "streaming",
|
||||
StreamerState::NoSignal => "no_signal",
|
||||
StreamerState::Error => "error",
|
||||
StreamerState::DeviceLost => "device_lost",
|
||||
StreamerState::Recovering => "recovering",
|
||||
};
|
||||
serializer.serialize_str(s)
|
||||
}
|
||||
}
|
||||
595
src/video/video_session.rs
Normal file
595
src/video/video_session.rs
Normal file
@@ -0,0 +1,595 @@
|
||||
//! Video session management with multi-codec support
|
||||
//!
|
||||
//! This module provides session management for video streaming with:
|
||||
//! - Multi-codec support (H264, H265, VP8, VP9)
|
||||
//! - Session lifecycle management
|
||||
//! - Dynamic codec switching
|
||||
//! - Statistics and monitoring
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::encoder::registry::{EncoderBackend, EncoderRegistry, VideoEncoderType};
|
||||
use super::format::Resolution;
|
||||
use super::frame::VideoFrame;
|
||||
use super::shared_video_pipeline::{
|
||||
EncodedVideoFrame, SharedVideoPipeline, SharedVideoPipelineConfig, SharedVideoPipelineStats,
|
||||
};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Maximum concurrent video sessions
|
||||
const MAX_VIDEO_SESSIONS: usize = 8;
|
||||
|
||||
/// Video session state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VideoSessionState {
|
||||
/// Session created but not started
|
||||
Created,
|
||||
/// Session is active and streaming
|
||||
Active,
|
||||
/// Session is paused
|
||||
Paused,
|
||||
/// Session is closing
|
||||
Closing,
|
||||
/// Session is closed
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VideoSessionState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
VideoSessionState::Created => write!(f, "Created"),
|
||||
VideoSessionState::Active => write!(f, "Active"),
|
||||
VideoSessionState::Paused => write!(f, "Paused"),
|
||||
VideoSessionState::Closing => write!(f, "Closing"),
|
||||
VideoSessionState::Closed => write!(f, "Closed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Video session information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VideoSessionInfo {
|
||||
/// Session ID
|
||||
pub session_id: String,
|
||||
/// Current codec
|
||||
pub codec: VideoEncoderType,
|
||||
/// Session state
|
||||
pub state: VideoSessionState,
|
||||
/// Creation time
|
||||
pub created_at: Instant,
|
||||
/// Last activity time
|
||||
pub last_activity: Instant,
|
||||
/// Frames received
|
||||
pub frames_received: u64,
|
||||
/// Bytes received
|
||||
pub bytes_received: u64,
|
||||
}
|
||||
|
||||
/// Individual video session
|
||||
struct VideoSession {
|
||||
/// Session ID
|
||||
session_id: String,
|
||||
/// Codec for this session
|
||||
codec: VideoEncoderType,
|
||||
/// Session state
|
||||
state: VideoSessionState,
|
||||
/// Creation time
|
||||
created_at: Instant,
|
||||
/// Last activity time
|
||||
last_activity: Instant,
|
||||
/// Frame receiver
|
||||
frame_rx: Option<broadcast::Receiver<EncodedVideoFrame>>,
|
||||
/// Stats
|
||||
frames_received: u64,
|
||||
bytes_received: u64,
|
||||
}
|
||||
|
||||
impl VideoSession {
|
||||
fn new(session_id: String, codec: VideoEncoderType) -> Self {
|
||||
let now = Instant::now();
|
||||
Self {
|
||||
session_id,
|
||||
codec,
|
||||
state: VideoSessionState::Created,
|
||||
created_at: now,
|
||||
last_activity: now,
|
||||
frame_rx: None,
|
||||
frames_received: 0,
|
||||
bytes_received: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn info(&self) -> VideoSessionInfo {
|
||||
VideoSessionInfo {
|
||||
session_id: self.session_id.clone(),
|
||||
codec: self.codec,
|
||||
state: self.state,
|
||||
created_at: self.created_at,
|
||||
last_activity: self.last_activity,
|
||||
frames_received: self.frames_received,
|
||||
bytes_received: self.bytes_received,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Video session manager configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VideoSessionManagerConfig {
|
||||
/// Default codec
|
||||
pub default_codec: VideoEncoderType,
|
||||
/// Default resolution
|
||||
pub resolution: Resolution,
|
||||
/// Default bitrate (kbps)
|
||||
pub bitrate_kbps: u32,
|
||||
/// Default FPS
|
||||
pub fps: u32,
|
||||
/// Session timeout (seconds)
|
||||
pub session_timeout_secs: u64,
|
||||
/// Encoder backend (None = auto select best available)
|
||||
pub encoder_backend: Option<EncoderBackend>,
|
||||
}
|
||||
|
||||
impl Default for VideoSessionManagerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_codec: VideoEncoderType::H264,
|
||||
resolution: Resolution::HD720,
|
||||
bitrate_kbps: 8000,
|
||||
fps: 30,
|
||||
session_timeout_secs: 300,
|
||||
encoder_backend: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Video session manager
|
||||
///
|
||||
/// Manages video encoding sessions with multi-codec support.
|
||||
/// A single encoder is shared across all sessions with the same codec.
|
||||
pub struct VideoSessionManager {
|
||||
/// Configuration
|
||||
config: VideoSessionManagerConfig,
|
||||
/// Active sessions
|
||||
sessions: RwLock<HashMap<String, VideoSession>>,
|
||||
/// Current pipeline (shared across sessions with same codec)
|
||||
pipeline: RwLock<Option<Arc<SharedVideoPipeline>>>,
|
||||
/// Current codec (active pipeline codec)
|
||||
current_codec: RwLock<Option<VideoEncoderType>>,
|
||||
/// Video frame source
|
||||
frame_source: RwLock<Option<broadcast::Receiver<VideoFrame>>>,
|
||||
}
|
||||
|
||||
impl VideoSessionManager {
|
||||
/// Create a new video session manager
|
||||
pub fn new(config: VideoSessionManagerConfig) -> Self {
|
||||
info!(
|
||||
"Creating video session manager with default codec: {}",
|
||||
config.default_codec
|
||||
);
|
||||
|
||||
Self {
|
||||
config,
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
pipeline: RwLock::new(None),
|
||||
current_codec: RwLock::new(None),
|
||||
frame_source: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(VideoSessionManagerConfig::default())
|
||||
}
|
||||
|
||||
/// Set the video frame source
|
||||
pub async fn set_frame_source(&self, rx: broadcast::Receiver<VideoFrame>) {
|
||||
*self.frame_source.write().await = Some(rx);
|
||||
}
|
||||
|
||||
/// Get available codecs based on hardware capabilities
|
||||
pub fn available_codecs(&self) -> Vec<VideoEncoderType> {
|
||||
EncoderRegistry::global().selectable_formats()
|
||||
}
|
||||
|
||||
/// Check if a codec is available
|
||||
pub fn is_codec_available(&self, codec: VideoEncoderType) -> bool {
|
||||
let hardware_only = codec.hardware_only();
|
||||
EncoderRegistry::global().is_format_available(codec, hardware_only)
|
||||
}
|
||||
|
||||
/// Create a new video session
|
||||
pub async fn create_session(&self, codec: Option<VideoEncoderType>) -> Result<String> {
|
||||
let sessions = self.sessions.read().await;
|
||||
if sessions.len() >= MAX_VIDEO_SESSIONS {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Maximum video sessions ({}) reached",
|
||||
MAX_VIDEO_SESSIONS
|
||||
)));
|
||||
}
|
||||
drop(sessions);
|
||||
|
||||
// Use specified codec or default
|
||||
let codec = codec.unwrap_or(self.config.default_codec);
|
||||
|
||||
// Verify codec is available
|
||||
if !self.is_codec_available(codec) {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Codec {} is not available on this system",
|
||||
codec
|
||||
)));
|
||||
}
|
||||
|
||||
// Generate session ID
|
||||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Create session
|
||||
let session = VideoSession::new(session_id.clone(), codec);
|
||||
|
||||
// Store session
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(session_id.clone(), session);
|
||||
|
||||
info!(
|
||||
"Video session created: {} (codec: {})",
|
||||
session_id, codec
|
||||
);
|
||||
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Start a video session (subscribe to encoded frames)
|
||||
pub async fn start_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<broadcast::Receiver<EncodedVideoFrame>> {
|
||||
// Ensure pipeline is running with correct codec
|
||||
self.ensure_pipeline_for_session(session_id).await?;
|
||||
|
||||
let mut sessions = self.sessions.write().await;
|
||||
let session = sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
|
||||
|
||||
// Get pipeline and subscribe
|
||||
let pipeline = self.pipeline.read().await;
|
||||
let pipeline = pipeline
|
||||
.as_ref()
|
||||
.ok_or_else(|| AppError::VideoError("Pipeline not initialized".to_string()))?;
|
||||
|
||||
let rx = pipeline.subscribe();
|
||||
session.frame_rx = Some(pipeline.subscribe());
|
||||
session.state = VideoSessionState::Active;
|
||||
session.last_activity = Instant::now();
|
||||
|
||||
info!("Video session started: {}", session_id);
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
/// Ensure pipeline is running with correct codec for session
|
||||
async fn ensure_pipeline_for_session(&self, session_id: &str) -> Result<()> {
|
||||
let sessions = self.sessions.read().await;
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
|
||||
let required_codec = session.codec;
|
||||
drop(sessions);
|
||||
|
||||
let current_codec = *self.current_codec.read().await;
|
||||
|
||||
// Check if we need to create or switch pipeline
|
||||
if current_codec != Some(required_codec) {
|
||||
self.switch_pipeline_codec(required_codec).await?;
|
||||
}
|
||||
|
||||
// Ensure pipeline is started
|
||||
let pipeline = self.pipeline.read().await;
|
||||
if let Some(ref pipe) = *pipeline {
|
||||
if !pipe.is_running() {
|
||||
// Need frame source to start
|
||||
let frame_rx = {
|
||||
let source = self.frame_source.read().await;
|
||||
source.as_ref().map(|rx| rx.resubscribe())
|
||||
};
|
||||
|
||||
if let Some(rx) = frame_rx {
|
||||
drop(pipeline);
|
||||
let pipeline = self.pipeline.read().await;
|
||||
if let Some(ref pipe) = *pipeline {
|
||||
pipe.start(rx).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Switch pipeline to different codec
|
||||
async fn switch_pipeline_codec(&self, codec: VideoEncoderType) -> Result<()> {
|
||||
info!("Switching pipeline to codec: {}", codec);
|
||||
|
||||
// Stop existing pipeline
|
||||
{
|
||||
let pipeline = self.pipeline.read().await;
|
||||
if let Some(ref pipe) = *pipeline {
|
||||
pipe.stop();
|
||||
}
|
||||
}
|
||||
|
||||
// Create new pipeline config
|
||||
let pipeline_config = SharedVideoPipelineConfig {
|
||||
resolution: self.config.resolution,
|
||||
input_format: crate::video::format::PixelFormat::Mjpeg, // Common input
|
||||
output_codec: codec,
|
||||
bitrate_kbps: self.config.bitrate_kbps,
|
||||
fps: self.config.fps,
|
||||
gop_size: 30,
|
||||
encoder_backend: self.config.encoder_backend,
|
||||
};
|
||||
|
||||
// Create new pipeline
|
||||
let new_pipeline = SharedVideoPipeline::new(pipeline_config)?;
|
||||
|
||||
// Update state
|
||||
*self.pipeline.write().await = Some(new_pipeline);
|
||||
*self.current_codec.write().await = Some(codec);
|
||||
|
||||
info!("Pipeline switched to codec: {}", codec);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get session info
|
||||
pub async fn get_session(&self, session_id: &str) -> Option<VideoSessionInfo> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(session_id).map(|s| s.info())
|
||||
}
|
||||
|
||||
/// List all sessions
|
||||
pub async fn list_sessions(&self) -> Vec<VideoSessionInfo> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.values().map(|s| s.info()).collect()
|
||||
}
|
||||
|
||||
/// Pause a session
|
||||
pub async fn pause_session(&self, session_id: &str) -> Result<()> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
let session = sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
|
||||
|
||||
session.state = VideoSessionState::Paused;
|
||||
session.last_activity = Instant::now();
|
||||
|
||||
debug!("Video session paused: {}", session_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resume a session
|
||||
pub async fn resume_session(&self, session_id: &str) -> Result<()> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
let session = sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
|
||||
|
||||
session.state = VideoSessionState::Active;
|
||||
session.last_activity = Instant::now();
|
||||
|
||||
debug!("Video session resumed: {}", session_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Close a session
|
||||
pub async fn close_session(&self, session_id: &str) -> Result<()> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(mut session) = sessions.remove(session_id) {
|
||||
session.state = VideoSessionState::Closed;
|
||||
session.frame_rx = None;
|
||||
info!("Video session closed: {}", session_id);
|
||||
}
|
||||
|
||||
// If no more sessions, consider stopping pipeline
|
||||
if sessions.is_empty() {
|
||||
drop(sessions);
|
||||
self.maybe_stop_pipeline().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop pipeline if no active sessions
|
||||
async fn maybe_stop_pipeline(&self) {
|
||||
let sessions = self.sessions.read().await;
|
||||
let has_active = sessions
|
||||
.values()
|
||||
.any(|s| s.state == VideoSessionState::Active);
|
||||
drop(sessions);
|
||||
|
||||
if !has_active {
|
||||
let pipeline = self.pipeline.read().await;
|
||||
if let Some(ref pipe) = *pipeline {
|
||||
pipe.stop();
|
||||
debug!("Pipeline stopped - no active sessions");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleanup stale/timed out sessions
|
||||
pub async fn cleanup_stale_sessions(&self) {
|
||||
let timeout = std::time::Duration::from_secs(self.config.session_timeout_secs);
|
||||
let now = Instant::now();
|
||||
|
||||
let stale_ids: Vec<String> = {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions
|
||||
.iter()
|
||||
.filter(|(_, s)| {
|
||||
(s.state == VideoSessionState::Paused
|
||||
|| s.state == VideoSessionState::Created)
|
||||
&& now.duration_since(s.last_activity) > timeout
|
||||
})
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
};
|
||||
|
||||
if !stale_ids.is_empty() {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
for id in stale_ids {
|
||||
info!("Removing stale video session: {}", id);
|
||||
sessions.remove(&id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get session count
|
||||
pub async fn session_count(&self) -> usize {
|
||||
self.sessions.read().await.len()
|
||||
}
|
||||
|
||||
/// Get active session count
|
||||
pub async fn active_session_count(&self) -> usize {
|
||||
self.sessions
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.filter(|s| s.state == VideoSessionState::Active)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Get pipeline statistics
|
||||
pub async fn pipeline_stats(&self) -> Option<SharedVideoPipelineStats> {
|
||||
let pipeline = self.pipeline.read().await;
|
||||
if let Some(ref pipe) = *pipeline {
|
||||
Some(pipe.stats().await)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current active codec
|
||||
pub async fn current_codec(&self) -> Option<VideoEncoderType> {
|
||||
*self.current_codec.read().await
|
||||
}
|
||||
|
||||
/// Set bitrate for current pipeline
|
||||
pub async fn set_bitrate(&self, bitrate_kbps: u32) -> Result<()> {
|
||||
let pipeline = self.pipeline.read().await;
|
||||
if let Some(ref pipe) = *pipeline {
|
||||
pipe.set_bitrate(bitrate_kbps).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Request keyframe for all sessions
|
||||
pub async fn request_keyframe(&self) {
|
||||
// This would be implemented if encoders support forced keyframes
|
||||
warn!("Keyframe request not yet implemented");
|
||||
}
|
||||
|
||||
/// Change codec for a session (requires restart)
|
||||
pub async fn change_session_codec(
|
||||
&self,
|
||||
session_id: &str,
|
||||
new_codec: VideoEncoderType,
|
||||
) -> Result<()> {
|
||||
if !self.is_codec_available(new_codec) {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Codec {} is not available",
|
||||
new_codec
|
||||
)));
|
||||
}
|
||||
|
||||
let mut sessions = self.sessions.write().await;
|
||||
let session = sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Session not found: {}", session_id)))?;
|
||||
|
||||
let old_codec = session.codec;
|
||||
session.codec = new_codec;
|
||||
session.state = VideoSessionState::Created; // Require restart
|
||||
session.frame_rx = None;
|
||||
session.last_activity = Instant::now();
|
||||
|
||||
info!(
|
||||
"Session {} codec changed: {} -> {}",
|
||||
session_id, old_codec, new_codec
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get codec info
|
||||
pub fn get_codec_info(&self, codec: VideoEncoderType) -> Option<CodecInfo> {
|
||||
let registry = EncoderRegistry::global();
|
||||
let encoder = registry.best_encoder(codec, codec.hardware_only())?;
|
||||
|
||||
Some(CodecInfo {
|
||||
codec_type: codec,
|
||||
codec_name: encoder.codec_name.clone(),
|
||||
backend: encoder.backend.to_string(),
|
||||
is_hardware: encoder.is_hardware,
|
||||
})
|
||||
}
|
||||
|
||||
/// List all available codecs with their info
|
||||
pub fn list_codec_info(&self) -> Vec<CodecInfo> {
|
||||
self.available_codecs()
|
||||
.iter()
|
||||
.filter_map(|c| self.get_codec_info(*c))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Codec information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodecInfo {
|
||||
/// Codec type
|
||||
pub codec_type: VideoEncoderType,
|
||||
/// FFmpeg codec name
|
||||
pub codec_name: String,
|
||||
/// Backend (VAAPI, NVENC, etc.)
|
||||
pub backend: String,
|
||||
/// Whether this is hardware accelerated
|
||||
pub is_hardware: bool,
|
||||
}
|
||||
|
||||
impl Default for VideoSessionManager {
|
||||
fn default() -> Self {
|
||||
Self::with_defaults()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_session_state_display() {
|
||||
assert_eq!(VideoSessionState::Active.to_string(), "Active");
|
||||
assert_eq!(VideoSessionState::Closed.to_string(), "Closed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_codecs() {
|
||||
let manager = VideoSessionManager::with_defaults();
|
||||
let codecs = manager.available_codecs();
|
||||
println!("Available codecs: {:?}", codecs);
|
||||
// H264 should always be available (software fallback)
|
||||
assert!(codecs.contains(&VideoEncoderType::H264));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codec_info() {
|
||||
let manager = VideoSessionManager::with_defaults();
|
||||
let info = manager.get_codec_info(VideoEncoderType::H264);
|
||||
if let Some(info) = info {
|
||||
println!(
|
||||
"H264: {} ({}, hardware={})",
|
||||
info.codec_name, info.backend, info.is_hardware
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
257
src/web/audio_ws.rs
Normal file
257
src/web/audio_ws.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Audio WebSocket handler for MJPEG mode
|
||||
//!
|
||||
//! Provides a dedicated WebSocket endpoint (`/api/ws/audio`) for streaming
|
||||
//! Opus-encoded audio data in binary format.
|
||||
//!
|
||||
//! ## Binary Protocol
|
||||
//!
|
||||
//! Each audio packet is sent as a binary WebSocket message with the following format:
|
||||
//!
|
||||
//! ```text
|
||||
//! Byte 0: Type (0x02 = audio)
|
||||
//! Bytes 1-4: Timestamp (u32 LE, milliseconds since stream start)
|
||||
//! Bytes 5-6: Duration (u16 LE, milliseconds)
|
||||
//! Bytes 7-10: Sequence (u32 LE)
|
||||
//! Bytes 11-14: Data length (u32 LE)
|
||||
//! Bytes 15+: Opus encoded data
|
||||
//! ```
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
response::Response,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::audio::OpusFrame;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Audio packet type identifier
|
||||
const AUDIO_PACKET_TYPE: u8 = 0x02;
|
||||
|
||||
/// Audio WebSocket upgrade handler
|
||||
///
|
||||
/// Upgrades HTTP connection to WebSocket for audio streaming.
|
||||
pub async fn audio_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_audio_socket(socket, state))
|
||||
}
|
||||
|
||||
/// Handle audio WebSocket connection
|
||||
async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Try to get Opus frame subscription
|
||||
let opus_rx = match state.audio.subscribe_opus_async().await {
|
||||
Some(rx) => rx,
|
||||
None => {
|
||||
warn!("Audio not streaming, rejecting WebSocket connection");
|
||||
// Send error message before closing
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
r#"{"error": "Audio not streaming"}"#.to_string(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut opus_rx = opus_rx;
|
||||
let stream_start = Instant::now();
|
||||
|
||||
info!("Audio WebSocket client connected");
|
||||
|
||||
// Track connection for cleanup
|
||||
let mut closed = false;
|
||||
|
||||
// Use interval instead of sleep for more efficient keepalive
|
||||
let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive Opus frames and send to client
|
||||
opus_result = opus_rx.recv() => {
|
||||
match opus_result {
|
||||
Ok(frame) => {
|
||||
let binary = encode_audio_packet(&frame, stream_start);
|
||||
if sender.send(Message::Binary(binary)).await.is_err() {
|
||||
debug!("Failed to send audio frame, client disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!("Audio WebSocket client lagged by {} frames", n);
|
||||
// Continue - just skip the missed frames
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
info!("Audio stream closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle client messages (ping/close)
|
||||
msg = receiver.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
debug!("Audio WebSocket client requested close");
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
if sender.send(Message::Pong(data)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Pong(_))) => {
|
||||
// Pong received, connection is alive
|
||||
}
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
// Handle potential control messages
|
||||
debug!("Received text message on audio WS: {}", text);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("Audio WebSocket receive error: {}", e);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Connection closed
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Periodic ping to keep connection alive (using interval)
|
||||
_ = ping_interval.tick() => {
|
||||
if sender.send(Message::Ping(vec![])).await.is_err() {
|
||||
warn!("Failed to send ping, disconnecting");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !closed {
|
||||
// Try to send close message
|
||||
let _ = sender.send(Message::Close(None)).await;
|
||||
}
|
||||
|
||||
info!("Audio WebSocket client disconnected");
|
||||
}
|
||||
|
||||
/// Encode Opus frame to binary packet format
|
||||
///
|
||||
/// ## Format
|
||||
///
|
||||
/// | Offset | Size | Description |
|
||||
/// |--------|------|-------------|
|
||||
/// | 0 | 1 | Packet type (0x02 for audio) |
|
||||
/// | 1 | 4 | Timestamp (u32 LE, ms since start) |
|
||||
/// | 5 | 2 | Duration (u16 LE, ms) |
|
||||
/// | 7 | 4 | Sequence number (u32 LE) |
|
||||
/// | 11 | 4 | Data length (u32 LE) |
|
||||
/// | 15 | N | Opus encoded data |
|
||||
fn encode_audio_packet(frame: &OpusFrame, stream_start: Instant) -> Vec<u8> {
|
||||
let timestamp_ms = stream_start.elapsed().as_millis() as u32;
|
||||
let data_len = frame.data.len() as u32;
|
||||
|
||||
let mut buf = Vec::with_capacity(15 + frame.data.len());
|
||||
|
||||
// Header
|
||||
buf.push(AUDIO_PACKET_TYPE);
|
||||
buf.extend_from_slice(×tamp_ms.to_le_bytes());
|
||||
buf.extend_from_slice(&(frame.duration_ms as u16).to_le_bytes());
|
||||
buf.extend_from_slice(&(frame.sequence as u32).to_le_bytes());
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
|
||||
// Opus data
|
||||
buf.extend_from_slice(&frame.data);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode audio packet from binary format (for testing/debugging)
|
||||
#[allow(dead_code)]
|
||||
pub fn decode_audio_packet(data: &[u8]) -> Option<AudioPacketHeader> {
|
||||
if data.len() < 15 {
|
||||
return None;
|
||||
}
|
||||
|
||||
if data[0] != AUDIO_PACKET_TYPE {
|
||||
return None;
|
||||
}
|
||||
|
||||
let timestamp = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
|
||||
let duration_ms = u16::from_le_bytes([data[5], data[6]]);
|
||||
let sequence = u32::from_le_bytes([data[7], data[8], data[9], data[10]]);
|
||||
let data_length = u32::from_le_bytes([data[11], data[12], data[13], data[14]]);
|
||||
|
||||
Some(AudioPacketHeader {
|
||||
packet_type: data[0],
|
||||
timestamp,
|
||||
duration_ms,
|
||||
sequence,
|
||||
data_length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Audio packet header (for decoding/testing)
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AudioPacketHeader {
|
||||
pub packet_type: u8,
|
||||
pub timestamp: u32,
|
||||
pub duration_ms: u16,
|
||||
pub sequence: u32,
|
||||
pub data_length: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use bytes::Bytes;
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_packet() {
|
||||
let frame = OpusFrame {
|
||||
data: Bytes::from(vec![1, 2, 3, 4, 5]),
|
||||
duration_ms: 20,
|
||||
sequence: 42,
|
||||
timestamp: Instant::now(),
|
||||
rtp_timestamp: 0,
|
||||
};
|
||||
|
||||
let stream_start = Instant::now();
|
||||
let encoded = encode_audio_packet(&frame, stream_start);
|
||||
|
||||
assert!(encoded.len() >= 15);
|
||||
assert_eq!(encoded[0], AUDIO_PACKET_TYPE);
|
||||
|
||||
let header = decode_audio_packet(&encoded).unwrap();
|
||||
assert_eq!(header.packet_type, AUDIO_PACKET_TYPE);
|
||||
assert_eq!(header.duration_ms, 20);
|
||||
assert_eq!(header.sequence, 42);
|
||||
assert_eq!(header.data_length, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_invalid_packet() {
|
||||
// Too short
|
||||
assert!(decode_audio_packet(&[]).is_none());
|
||||
assert!(decode_audio_packet(&[0x02; 10]).is_none());
|
||||
|
||||
// Wrong type
|
||||
let mut bad = vec![0x01; 20];
|
||||
assert!(decode_audio_packet(&bad).is_none());
|
||||
}
|
||||
}
|
||||
362
src/web/handlers/config/apply.rs
Normal file
362
src/web/handlers/config/apply.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
//! 配置热重载逻辑
|
||||
//!
|
||||
//! 从 handlers.rs 中抽取的配置应用函数,负责将配置变更应用到各个子系统。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::*;
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 应用 Video 配置变更
|
||||
pub async fn apply_video_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &VideoConfig,
|
||||
new_config: &VideoConfig,
|
||||
) -> Result<()> {
|
||||
// 检查配置是否实际变更
|
||||
if old_config == new_config {
|
||||
tracing::info!("Video config unchanged, skipping reload");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!("Applying video config changes...");
|
||||
|
||||
let device = new_config
|
||||
.device
|
||||
.clone()
|
||||
.ok_or_else(|| AppError::BadRequest("video_device is required".to_string()))?;
|
||||
|
||||
let format = new_config
|
||||
.format
|
||||
.as_ref()
|
||||
.and_then(|f| {
|
||||
serde_json::from_value::<crate::video::format::PixelFormat>(
|
||||
serde_json::Value::String(f.clone()),
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.unwrap_or(crate::video::format::PixelFormat::Mjpeg);
|
||||
|
||||
let resolution =
|
||||
crate::video::format::Resolution::new(new_config.width, new_config.height);
|
||||
|
||||
// Step 1: 更新 WebRTC streamer 配置(停止现有 pipeline 和 sessions)
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.update_video_config(resolution, format, new_config.fps)
|
||||
.await;
|
||||
tracing::info!("WebRTC streamer config updated");
|
||||
|
||||
// Step 2: 应用视频配置到 streamer(重新创建 capturer)
|
||||
state
|
||||
.stream_manager
|
||||
.streamer()
|
||||
.apply_video_config(&device, format, resolution, new_config.fps)
|
||||
.await
|
||||
.map_err(|e| AppError::VideoError(format!("Failed to apply video config: {}", e)))?;
|
||||
tracing::info!("Video config applied to streamer");
|
||||
|
||||
// Step 3: 重启 streamer
|
||||
if let Err(e) = state.stream_manager.start().await {
|
||||
tracing::error!("Failed to start streamer after config change: {}", e);
|
||||
} else {
|
||||
tracing::info!("Streamer started after config change");
|
||||
}
|
||||
|
||||
// Step 4: 更新 WebRTC frame source
|
||||
if let Some(frame_tx) = state.stream_manager.frame_sender().await {
|
||||
let receiver_count = frame_tx.receiver_count();
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.set_video_source(frame_tx)
|
||||
.await;
|
||||
tracing::info!(
|
||||
"WebRTC streamer frame source updated (receiver_count={})",
|
||||
receiver_count
|
||||
);
|
||||
} else {
|
||||
tracing::warn!("No frame source available after config change");
|
||||
}
|
||||
|
||||
tracing::info!("Video config applied successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 Stream 配置变更
|
||||
pub async fn apply_stream_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &StreamConfig,
|
||||
new_config: &StreamConfig,
|
||||
) -> Result<()> {
|
||||
tracing::info!("Applying stream config changes...");
|
||||
|
||||
// 更新编码器后端
|
||||
if old_config.encoder != new_config.encoder {
|
||||
let encoder_backend = new_config.encoder.to_backend();
|
||||
tracing::info!(
|
||||
"Updating encoder backend to: {:?} (from config: {:?})",
|
||||
encoder_backend,
|
||||
new_config.encoder
|
||||
);
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.update_encoder_backend(encoder_backend)
|
||||
.await;
|
||||
}
|
||||
|
||||
// 更新码率
|
||||
if old_config.bitrate_kbps != new_config.bitrate_kbps {
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.set_bitrate(new_config.bitrate_kbps)
|
||||
.await
|
||||
.ok(); // Ignore error if no active stream
|
||||
}
|
||||
|
||||
// 更新 ICE 配置 (STUN/TURN)
|
||||
let ice_changed = old_config.stun_server != new_config.stun_server
|
||||
|| old_config.turn_server != new_config.turn_server
|
||||
|| old_config.turn_username != new_config.turn_username
|
||||
|| old_config.turn_password != new_config.turn_password;
|
||||
|
||||
if ice_changed {
|
||||
tracing::info!(
|
||||
"Updating ICE config: STUN={:?}, TURN={:?}",
|
||||
new_config.stun_server,
|
||||
new_config.turn_server
|
||||
);
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.update_ice_config(
|
||||
new_config.stun_server.clone(),
|
||||
new_config.turn_server.clone(),
|
||||
new_config.turn_username.clone(),
|
||||
new_config.turn_password.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Stream config applied: encoder={:?}, bitrate={} kbps",
|
||||
new_config.encoder,
|
||||
new_config.bitrate_kbps
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 HID 配置变更
|
||||
pub async fn apply_hid_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &HidConfig,
|
||||
new_config: &HidConfig,
|
||||
) -> Result<()> {
|
||||
// 检查是否需要重载
|
||||
if old_config.backend == new_config.backend
|
||||
&& old_config.ch9329_port == new_config.ch9329_port
|
||||
&& old_config.ch9329_baudrate == new_config.ch9329_baudrate
|
||||
&& old_config.otg_udc == new_config.otg_udc
|
||||
{
|
||||
tracing::info!("HID config unchanged, skipping reload");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!("Applying HID config changes...");
|
||||
|
||||
let new_hid_backend = match new_config.backend {
|
||||
HidBackend::Otg => crate::hid::HidBackendType::Otg,
|
||||
HidBackend::Ch9329 => crate::hid::HidBackendType::Ch9329 {
|
||||
port: new_config.ch9329_port.clone(),
|
||||
baud_rate: new_config.ch9329_baudrate,
|
||||
},
|
||||
HidBackend::None => crate::hid::HidBackendType::None,
|
||||
};
|
||||
|
||||
state
|
||||
.hid
|
||||
.reload(new_hid_backend)
|
||||
.await
|
||||
.map_err(|e| AppError::Config(format!("HID reload failed: {}", e)))?;
|
||||
|
||||
tracing::info!("HID backend reloaded successfully: {:?}", new_config.backend);
|
||||
|
||||
// When switching to OTG backend, automatically enable MSD if not already enabled
|
||||
// OTG HID and MSD share the same USB gadget, so it makes sense to enable both
|
||||
if new_config.backend == HidBackend::Otg && old_config.backend != HidBackend::Otg {
|
||||
let msd_guard = state.msd.read().await;
|
||||
if msd_guard.is_none() {
|
||||
drop(msd_guard); // Release read lock before acquiring write lock
|
||||
|
||||
tracing::info!("OTG HID enabled, automatically initializing MSD...");
|
||||
|
||||
// Get MSD config from store
|
||||
let config = state.config.get();
|
||||
|
||||
let msd = crate::msd::MsdController::new(
|
||||
state.otg_service.clone(),
|
||||
&config.msd.images_path,
|
||||
&config.msd.drive_path,
|
||||
);
|
||||
|
||||
if let Err(e) = msd.init().await {
|
||||
tracing::warn!("Failed to auto-initialize MSD for OTG: {}", e);
|
||||
} else {
|
||||
let events = state.events.clone();
|
||||
msd.set_event_bus(events).await;
|
||||
*state.msd.write().await = Some(msd);
|
||||
tracing::info!("MSD automatically initialized for OTG mode");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 MSD 配置变更
|
||||
pub async fn apply_msd_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &MsdConfig,
|
||||
new_config: &MsdConfig,
|
||||
) -> Result<()> {
|
||||
tracing::info!("MSD config sent, checking if reload needed...");
|
||||
tracing::debug!("Old MSD config: {:?}", old_config);
|
||||
tracing::debug!("New MSD config: {:?}", new_config);
|
||||
|
||||
// Check if MSD enabled state changed
|
||||
let old_msd_enabled = old_config.enabled;
|
||||
let new_msd_enabled = new_config.enabled;
|
||||
|
||||
tracing::info!("MSD enabled: old={}, new={}", old_msd_enabled, new_msd_enabled);
|
||||
|
||||
if old_msd_enabled != new_msd_enabled {
|
||||
if new_msd_enabled {
|
||||
// MSD was disabled, now enabled - need to initialize
|
||||
tracing::info!("MSD enabled in config, initializing...");
|
||||
|
||||
let msd = crate::msd::MsdController::new(
|
||||
state.otg_service.clone(),
|
||||
&new_config.images_path,
|
||||
&new_config.drive_path,
|
||||
);
|
||||
msd.init().await.map_err(|e| {
|
||||
AppError::Config(format!("MSD initialization failed: {}", e))
|
||||
})?;
|
||||
|
||||
// Set event bus
|
||||
let events = state.events.clone();
|
||||
msd.set_event_bus(events).await;
|
||||
|
||||
// Store the initialized controller
|
||||
*state.msd.write().await = Some(msd);
|
||||
tracing::info!("MSD initialized successfully");
|
||||
} else {
|
||||
// MSD was enabled, now disabled - shutdown
|
||||
tracing::info!("MSD disabled in config, shutting down...");
|
||||
|
||||
if let Some(msd) = state.msd.write().await.as_mut() {
|
||||
if let Err(e) = msd.shutdown().await {
|
||||
tracing::warn!("MSD shutdown failed: {}", e);
|
||||
}
|
||||
}
|
||||
*state.msd.write().await = None;
|
||||
tracing::info!("MSD shutdown complete");
|
||||
}
|
||||
} else {
|
||||
tracing::info!(
|
||||
"MSD enabled state unchanged ({}), no reload needed",
|
||||
new_msd_enabled
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 ATX 配置变更
|
||||
pub async fn apply_atx_config(
|
||||
state: &Arc<AppState>,
|
||||
_old_config: &AtxConfig,
|
||||
new_config: &AtxConfig,
|
||||
) -> Result<()> {
|
||||
tracing::info!("Applying ATX config changes...");
|
||||
|
||||
// Convert AtxConfig to AtxControllerConfig
|
||||
let controller_config = new_config.to_controller_config();
|
||||
|
||||
// Reload the ATX controller with new configuration
|
||||
let atx_guard = state.atx.read().await;
|
||||
if let Some(atx) = atx_guard.as_ref() {
|
||||
if let Err(e) = atx.reload(controller_config).await {
|
||||
tracing::error!("ATX reload failed: {}", e);
|
||||
return Err(AppError::Config(format!("ATX reload failed: {}", e)));
|
||||
}
|
||||
tracing::info!("ATX controller reloaded successfully");
|
||||
} else {
|
||||
// ATX controller not initialized, create a new one if enabled
|
||||
drop(atx_guard);
|
||||
|
||||
if new_config.enabled {
|
||||
tracing::info!("ATX enabled in config, initializing...");
|
||||
|
||||
let atx = crate::atx::AtxController::new(controller_config);
|
||||
if let Err(e) = atx.init().await {
|
||||
tracing::warn!("ATX initialization failed: {}", e);
|
||||
} else {
|
||||
*state.atx.write().await = Some(atx);
|
||||
tracing::info!("ATX controller initialized successfully");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 Audio 配置变更
|
||||
pub async fn apply_audio_config(
|
||||
state: &Arc<AppState>,
|
||||
_old_config: &AudioConfig,
|
||||
new_config: &AudioConfig,
|
||||
) -> Result<()> {
|
||||
tracing::info!("Applying audio config changes...");
|
||||
|
||||
// Create audio controller config from new config
|
||||
let audio_config = crate::audio::AudioControllerConfig {
|
||||
enabled: new_config.enabled,
|
||||
device: new_config.device.clone(),
|
||||
quality: crate::audio::AudioQuality::from_str(&new_config.quality),
|
||||
};
|
||||
|
||||
// Update audio controller
|
||||
if let Err(e) = state.audio.update_config(audio_config).await {
|
||||
tracing::error!("Audio config update failed: {}", e);
|
||||
// Don't fail - audio errors are not critical
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Audio config applied: enabled={}, device={}",
|
||||
new_config.enabled,
|
||||
new_config.device
|
||||
);
|
||||
}
|
||||
|
||||
// Also update WebRTC audio enabled state
|
||||
if let Err(e) = state
|
||||
.stream_manager
|
||||
.set_webrtc_audio_enabled(new_config.enabled)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to update WebRTC audio state: {}", e);
|
||||
} else {
|
||||
tracing::info!("WebRTC audio enabled: {}", new_config.enabled);
|
||||
}
|
||||
|
||||
// Reconnect audio sources for existing WebRTC sessions
|
||||
if new_config.enabled {
|
||||
state.stream_manager.reconnect_webrtc_audio_sources().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
46
src/web/handlers/config/atx.rs
Normal file
46
src/web/handlers/config/atx.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! ATX 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::AtxConfig;
|
||||
use crate::error::Result;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::apply::apply_atx_config;
|
||||
use super::types::AtxConfigUpdate;
|
||||
|
||||
/// 获取 ATX 配置
|
||||
pub async fn get_atx_config(State(state): State<Arc<AppState>>) -> Json<AtxConfig> {
|
||||
Json(state.config.get().atx.clone())
|
||||
}
|
||||
|
||||
/// 更新 ATX 配置
|
||||
pub async fn update_atx_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<AtxConfigUpdate>,
|
||||
) -> Result<Json<AtxConfig>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_atx_config = state.config.get().atx.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
req.apply_to(&mut config.atx);
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_atx_config = state.config.get().atx.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_atx_config(&state, &old_atx_config, &new_atx_config).await {
|
||||
tracing::error!("Failed to apply ATX config: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(new_atx_config))
|
||||
}
|
||||
46
src/web/handlers/config/audio.rs
Normal file
46
src/web/handlers/config/audio.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! Audio 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::AudioConfig;
|
||||
use crate::error::Result;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::apply::apply_audio_config;
|
||||
use super::types::AudioConfigUpdate;
|
||||
|
||||
/// 获取 Audio 配置
|
||||
pub async fn get_audio_config(State(state): State<Arc<AppState>>) -> Json<AudioConfig> {
|
||||
Json(state.config.get().audio.clone())
|
||||
}
|
||||
|
||||
/// 更新 Audio 配置
|
||||
pub async fn update_audio_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<AudioConfigUpdate>,
|
||||
) -> Result<Json<AudioConfig>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_audio_config = state.config.get().audio.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
req.apply_to(&mut config.audio);
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_audio_config = state.config.get().audio.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_audio_config(&state, &old_audio_config, &new_audio_config).await {
|
||||
tracing::error!("Failed to apply audio config: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(new_audio_config))
|
||||
}
|
||||
46
src/web/handlers/config/hid.rs
Normal file
46
src/web/handlers/config/hid.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! HID 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::HidConfig;
|
||||
use crate::error::Result;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::apply::apply_hid_config;
|
||||
use super::types::HidConfigUpdate;
|
||||
|
||||
/// 获取 HID 配置
|
||||
pub async fn get_hid_config(State(state): State<Arc<AppState>>) -> Json<HidConfig> {
|
||||
Json(state.config.get().hid.clone())
|
||||
}
|
||||
|
||||
/// 更新 HID 配置
|
||||
pub async fn update_hid_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<HidConfigUpdate>,
|
||||
) -> Result<Json<HidConfig>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_hid_config = state.config.get().hid.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
req.apply_to(&mut config.hid);
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_hid_config = state.config.get().hid.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_hid_config(&state, &old_hid_config, &new_hid_config).await {
|
||||
tracing::error!("Failed to apply HID config: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(new_hid_config))
|
||||
}
|
||||
48
src/web/handlers/config/mod.rs
Normal file
48
src/web/handlers/config/mod.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
//! 配置管理 Handler 模块
|
||||
//!
|
||||
//! 提供 RESTful 域分离的配置 API:
|
||||
//! - GET /api/config/video - 获取视频配置
|
||||
//! - PATCH /api/config/video - 更新视频配置
|
||||
//! - GET /api/config/stream - 获取流配置
|
||||
//! - PATCH /api/config/stream - 更新流配置
|
||||
//! - GET /api/config/hid - 获取 HID 配置
|
||||
//! - PATCH /api/config/hid - 更新 HID 配置
|
||||
//! - GET /api/config/msd - 获取 MSD 配置
|
||||
//! - PATCH /api/config/msd - 更新 MSD 配置
|
||||
//! - GET /api/config/atx - 获取 ATX 配置
|
||||
//! - PATCH /api/config/atx - 更新 ATX 配置
|
||||
//! - GET /api/config/audio - 获取音频配置
|
||||
//! - PATCH /api/config/audio - 更新音频配置
|
||||
|
||||
mod apply;
|
||||
mod types;
|
||||
|
||||
mod video;
|
||||
mod stream;
|
||||
mod hid;
|
||||
mod msd;
|
||||
mod atx;
|
||||
mod audio;
|
||||
|
||||
// 导出 handler 函数
|
||||
pub use video::{get_video_config, update_video_config};
|
||||
pub use stream::{get_stream_config, update_stream_config};
|
||||
pub use hid::{get_hid_config, update_hid_config};
|
||||
pub use msd::{get_msd_config, update_msd_config};
|
||||
pub use atx::{get_atx_config, update_atx_config};
|
||||
pub use audio::{get_audio_config, update_audio_config};
|
||||
|
||||
// 保留全局配置查询(向后兼容)
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// 获取完整配置
|
||||
pub async fn get_all_config(State(state): State<Arc<AppState>>) -> Json<AppConfig> {
|
||||
let mut config = (*state.config.get()).clone();
|
||||
// 不暴露敏感信息
|
||||
config.auth.totp_secret = None;
|
||||
Json(config)
|
||||
}
|
||||
46
src/web/handlers/config/msd.rs
Normal file
46
src/web/handlers/config/msd.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! MSD 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::MsdConfig;
|
||||
use crate::error::Result;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::apply::apply_msd_config;
|
||||
use super::types::MsdConfigUpdate;
|
||||
|
||||
/// 获取 MSD 配置
|
||||
pub async fn get_msd_config(State(state): State<Arc<AppState>>) -> Json<MsdConfig> {
|
||||
Json(state.config.get().msd.clone())
|
||||
}
|
||||
|
||||
/// 更新 MSD 配置
|
||||
pub async fn update_msd_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<MsdConfigUpdate>,
|
||||
) -> Result<Json<MsdConfig>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_msd_config = state.config.get().msd.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
req.apply_to(&mut config.msd);
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_msd_config = state.config.get().msd.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_msd_config(&state, &old_msd_config, &new_msd_config).await {
|
||||
tracing::error!("Failed to apply MSD config: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(new_msd_config))
|
||||
}
|
||||
46
src/web/handlers/config/stream.rs
Normal file
46
src/web/handlers/config/stream.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! Stream 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::apply::apply_stream_config;
|
||||
use super::types::{StreamConfigResponse, StreamConfigUpdate};
|
||||
|
||||
/// 获取 Stream 配置
|
||||
pub async fn get_stream_config(State(state): State<Arc<AppState>>) -> Json<StreamConfigResponse> {
|
||||
let config = state.config.get();
|
||||
Json(StreamConfigResponse::from(&config.stream))
|
||||
}
|
||||
|
||||
/// 更新 Stream 配置
|
||||
pub async fn update_stream_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<StreamConfigUpdate>,
|
||||
) -> Result<Json<StreamConfigResponse>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_stream_config = state.config.get().stream.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
req.apply_to(&mut config.stream);
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_stream_config = state.config.get().stream.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_stream_config(&state, &old_stream_config, &new_stream_config).await {
|
||||
tracing::error!("Failed to apply stream config: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(StreamConfigResponse::from(&new_stream_config)))
|
||||
}
|
||||
396
src/web/handlers/config/types.rs
Normal file
396
src/web/handlers/config/types.rs
Normal file
@@ -0,0 +1,396 @@
|
||||
use serde::Deserialize;
|
||||
use typeshare::typeshare;
|
||||
use crate::config::*;
|
||||
use crate::error::AppError;
|
||||
|
||||
// ===== Video Config =====
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct VideoConfigUpdate {
|
||||
pub device: Option<String>,
|
||||
pub format: Option<String>,
|
||||
pub width: Option<u32>,
|
||||
pub height: Option<u32>,
|
||||
pub fps: Option<u32>,
|
||||
pub quality: Option<u32>,
|
||||
}
|
||||
|
||||
impl VideoConfigUpdate {
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if let Some(width) = self.width {
|
||||
if !(320..=7680).contains(&width) {
|
||||
return Err(AppError::BadRequest("Invalid width: must be 320-7680".into()));
|
||||
}
|
||||
}
|
||||
if let Some(height) = self.height {
|
||||
if !(240..=4320).contains(&height) {
|
||||
return Err(AppError::BadRequest("Invalid height: must be 240-4320".into()));
|
||||
}
|
||||
}
|
||||
if let Some(fps) = self.fps {
|
||||
if !(1..=120).contains(&fps) {
|
||||
return Err(AppError::BadRequest("Invalid fps: must be 1-120".into()));
|
||||
}
|
||||
}
|
||||
if let Some(quality) = self.quality {
|
||||
if !(1..=100).contains(&quality) {
|
||||
return Err(AppError::BadRequest("Invalid quality: must be 1-100".into()));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn apply_to(&self, config: &mut VideoConfig) {
|
||||
if let Some(ref device) = self.device {
|
||||
config.device = Some(device.clone());
|
||||
}
|
||||
if let Some(ref format) = self.format {
|
||||
config.format = Some(format.clone());
|
||||
}
|
||||
if let Some(width) = self.width {
|
||||
config.width = width;
|
||||
}
|
||||
if let Some(height) = self.height {
|
||||
config.height = height;
|
||||
}
|
||||
if let Some(fps) = self.fps {
|
||||
config.fps = fps;
|
||||
}
|
||||
if let Some(quality) = self.quality {
|
||||
config.quality = quality;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Stream Config =====
|
||||
|
||||
/// Stream 配置响应(包含 has_turn_password 字段)
|
||||
#[typeshare]
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct StreamConfigResponse {
|
||||
pub mode: StreamMode,
|
||||
pub encoder: EncoderType,
|
||||
pub bitrate_kbps: u32,
|
||||
pub gop_size: u32,
|
||||
pub stun_server: Option<String>,
|
||||
pub turn_server: Option<String>,
|
||||
pub turn_username: Option<String>,
|
||||
/// 指示是否已设置 TURN 密码(实际密码不返回)
|
||||
pub has_turn_password: bool,
|
||||
}
|
||||
|
||||
impl From<&StreamConfig> for StreamConfigResponse {
|
||||
fn from(config: &StreamConfig) -> Self {
|
||||
Self {
|
||||
mode: config.mode.clone(),
|
||||
encoder: config.encoder.clone(),
|
||||
bitrate_kbps: config.bitrate_kbps,
|
||||
gop_size: config.gop_size,
|
||||
stun_server: config.stun_server.clone(),
|
||||
turn_server: config.turn_server.clone(),
|
||||
turn_username: config.turn_username.clone(),
|
||||
has_turn_password: config.turn_password.is_some(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamConfigUpdate {
|
||||
pub mode: Option<StreamMode>,
|
||||
pub encoder: Option<EncoderType>,
|
||||
pub bitrate_kbps: Option<u32>,
|
||||
pub gop_size: Option<u32>,
|
||||
/// STUN server URL (e.g., "stun:stun.l.google.com:19302")
|
||||
pub stun_server: Option<String>,
|
||||
/// TURN server URL (e.g., "turn:turn.example.com:3478")
|
||||
pub turn_server: Option<String>,
|
||||
/// TURN username
|
||||
pub turn_username: Option<String>,
|
||||
/// TURN password
|
||||
pub turn_password: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamConfigUpdate {
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if let Some(bitrate) = self.bitrate_kbps {
|
||||
if !(1000..=15000).contains(&bitrate) {
|
||||
return Err(AppError::BadRequest("Bitrate must be 1000-15000 kbps".into()));
|
||||
}
|
||||
}
|
||||
if let Some(gop) = self.gop_size {
|
||||
if !(10..=300).contains(&gop) {
|
||||
return Err(AppError::BadRequest("GOP size must be 10-300".into()));
|
||||
}
|
||||
}
|
||||
// Validate STUN server format
|
||||
if let Some(ref stun) = self.stun_server {
|
||||
if !stun.is_empty() && !stun.starts_with("stun:") {
|
||||
return Err(AppError::BadRequest(
|
||||
"STUN server must start with 'stun:' (e.g., stun:stun.l.google.com:19302)".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
// Validate TURN server format
|
||||
if let Some(ref turn) = self.turn_server {
|
||||
if !turn.is_empty() && !turn.starts_with("turn:") && !turn.starts_with("turns:") {
|
||||
return Err(AppError::BadRequest(
|
||||
"TURN server must start with 'turn:' or 'turns:' (e.g., turn:turn.example.com:3478)".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn apply_to(&self, config: &mut StreamConfig) {
|
||||
if let Some(mode) = self.mode.clone() {
|
||||
config.mode = mode;
|
||||
}
|
||||
if let Some(encoder) = self.encoder.clone() {
|
||||
config.encoder = encoder;
|
||||
}
|
||||
if let Some(bitrate) = self.bitrate_kbps {
|
||||
config.bitrate_kbps = bitrate;
|
||||
}
|
||||
if let Some(gop) = self.gop_size {
|
||||
config.gop_size = gop;
|
||||
}
|
||||
// STUN/TURN settings - empty string means clear, Some("value") means set
|
||||
if let Some(ref stun) = self.stun_server {
|
||||
config.stun_server = if stun.is_empty() { None } else { Some(stun.clone()) };
|
||||
}
|
||||
if let Some(ref turn) = self.turn_server {
|
||||
config.turn_server = if turn.is_empty() { None } else { Some(turn.clone()) };
|
||||
}
|
||||
if let Some(ref username) = self.turn_username {
|
||||
config.turn_username = if username.is_empty() { None } else { Some(username.clone()) };
|
||||
}
|
||||
if let Some(ref password) = self.turn_password {
|
||||
config.turn_password = if password.is_empty() { None } else { Some(password.clone()) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== HID Config =====
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct HidConfigUpdate {
|
||||
pub backend: Option<HidBackend>,
|
||||
pub ch9329_port: Option<String>,
|
||||
pub ch9329_baudrate: Option<u32>,
|
||||
pub otg_udc: Option<String>,
|
||||
pub mouse_absolute: Option<bool>,
|
||||
}
|
||||
|
||||
impl HidConfigUpdate {
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if let Some(baudrate) = self.ch9329_baudrate {
|
||||
let valid_rates = [9600, 19200, 38400, 57600, 115200];
|
||||
if !valid_rates.contains(&baudrate) {
|
||||
return Err(AppError::BadRequest(
|
||||
"Invalid baudrate: must be 9600, 19200, 38400, 57600, or 115200".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn apply_to(&self, config: &mut HidConfig) {
|
||||
if let Some(backend) = self.backend.clone() {
|
||||
config.backend = backend;
|
||||
}
|
||||
if let Some(ref port) = self.ch9329_port {
|
||||
config.ch9329_port = port.clone();
|
||||
}
|
||||
if let Some(baudrate) = self.ch9329_baudrate {
|
||||
config.ch9329_baudrate = baudrate;
|
||||
}
|
||||
if let Some(ref udc) = self.otg_udc {
|
||||
config.otg_udc = Some(udc.clone());
|
||||
}
|
||||
if let Some(absolute) = self.mouse_absolute {
|
||||
config.mouse_absolute = absolute;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== MSD Config =====
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MsdConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub images_path: Option<String>,
|
||||
pub drive_path: Option<String>,
|
||||
pub virtual_drive_size_mb: Option<u32>,
|
||||
}
|
||||
|
||||
impl MsdConfigUpdate {
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if let Some(size) = self.virtual_drive_size_mb {
|
||||
if !(1..=10240).contains(&size) {
|
||||
return Err(AppError::BadRequest("Drive size must be 1-10240 MB".into()));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn apply_to(&self, config: &mut MsdConfig) {
|
||||
if let Some(enabled) = self.enabled {
|
||||
config.enabled = enabled;
|
||||
}
|
||||
if let Some(ref path) = self.images_path {
|
||||
config.images_path = path.clone();
|
||||
}
|
||||
if let Some(ref path) = self.drive_path {
|
||||
config.drive_path = path.clone();
|
||||
}
|
||||
if let Some(size) = self.virtual_drive_size_mb {
|
||||
config.virtual_drive_size_mb = size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== ATX Config =====
|
||||
|
||||
/// Update for a single ATX key configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AtxKeyConfigUpdate {
|
||||
pub driver: Option<crate::atx::AtxDriverType>,
|
||||
pub device: Option<String>,
|
||||
pub pin: Option<u32>,
|
||||
pub active_level: Option<crate::atx::ActiveLevel>,
|
||||
}
|
||||
|
||||
/// Update for LED sensing configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AtxLedConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub gpio_chip: Option<String>,
|
||||
pub gpio_pin: Option<u32>,
|
||||
pub inverted: Option<bool>,
|
||||
}
|
||||
|
||||
/// ATX configuration update request
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AtxConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
/// Power button configuration
|
||||
pub power: Option<AtxKeyConfigUpdate>,
|
||||
/// Reset button configuration
|
||||
pub reset: Option<AtxKeyConfigUpdate>,
|
||||
/// LED sensing configuration
|
||||
pub led: Option<AtxLedConfigUpdate>,
|
||||
/// Network interface for WOL packets (empty = auto)
|
||||
pub wol_interface: Option<String>,
|
||||
}
|
||||
|
||||
impl AtxConfigUpdate {
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
// Validate power key config if present
|
||||
if let Some(ref power) = self.power {
|
||||
Self::validate_key_config(power, "power")?;
|
||||
}
|
||||
// Validate reset key config if present
|
||||
if let Some(ref reset) = self.reset {
|
||||
Self::validate_key_config(reset, "reset")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_key_config(key: &AtxKeyConfigUpdate, name: &str) -> crate::error::Result<()> {
|
||||
if let Some(ref device) = key.device {
|
||||
if !device.is_empty() && !std::path::Path::new(device).exists() {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"{} device '{}' does not exist",
|
||||
name, device
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn apply_to(&self, config: &mut AtxConfig) {
|
||||
if let Some(enabled) = self.enabled {
|
||||
config.enabled = enabled;
|
||||
}
|
||||
if let Some(ref power) = self.power {
|
||||
Self::apply_key_update(power, &mut config.power);
|
||||
}
|
||||
if let Some(ref reset) = self.reset {
|
||||
Self::apply_key_update(reset, &mut config.reset);
|
||||
}
|
||||
if let Some(ref led) = self.led {
|
||||
Self::apply_led_update(led, &mut config.led);
|
||||
}
|
||||
if let Some(ref wol_interface) = self.wol_interface {
|
||||
config.wol_interface = wol_interface.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_key_update(update: &AtxKeyConfigUpdate, config: &mut crate::atx::AtxKeyConfig) {
|
||||
if let Some(driver) = update.driver {
|
||||
config.driver = driver;
|
||||
}
|
||||
if let Some(ref device) = update.device {
|
||||
config.device = device.clone();
|
||||
}
|
||||
if let Some(pin) = update.pin {
|
||||
config.pin = pin;
|
||||
}
|
||||
if let Some(level) = update.active_level {
|
||||
config.active_level = level;
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_led_update(update: &AtxLedConfigUpdate, config: &mut crate::atx::AtxLedConfig) {
|
||||
if let Some(enabled) = update.enabled {
|
||||
config.enabled = enabled;
|
||||
}
|
||||
if let Some(ref chip) = update.gpio_chip {
|
||||
config.gpio_chip = chip.clone();
|
||||
}
|
||||
if let Some(pin) = update.gpio_pin {
|
||||
config.gpio_pin = pin;
|
||||
}
|
||||
if let Some(inverted) = update.inverted {
|
||||
config.inverted = inverted;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Audio Config =====
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AudioConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub device: Option<String>,
|
||||
pub quality: Option<String>,
|
||||
}
|
||||
|
||||
impl AudioConfigUpdate {
|
||||
pub fn validate(&self) -> crate::error::Result<()> {
|
||||
if let Some(ref quality) = self.quality {
|
||||
if !["voice", "balanced", "high"].contains(&quality.as_str()) {
|
||||
return Err(AppError::BadRequest(
|
||||
"Invalid quality: must be 'voice', 'balanced', or 'high'".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn apply_to(&self, config: &mut AudioConfig) {
|
||||
if let Some(enabled) = self.enabled {
|
||||
config.enabled = enabled;
|
||||
}
|
||||
if let Some(ref device) = self.device {
|
||||
config.device = device.clone();
|
||||
}
|
||||
if let Some(ref quality) = self.quality {
|
||||
config.quality = quality.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
47
src/web/handlers/config/video.rs
Normal file
47
src/web/handlers/config/video.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Video 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::VideoConfig;
|
||||
use crate::error::Result;
|
||||
use crate::state::AppState;
|
||||
|
||||
use super::apply::apply_video_config;
|
||||
use super::types::VideoConfigUpdate;
|
||||
|
||||
/// 获取 Video 配置
|
||||
pub async fn get_video_config(State(state): State<Arc<AppState>>) -> Json<VideoConfig> {
|
||||
Json(state.config.get().video.clone())
|
||||
}
|
||||
|
||||
/// 更新 Video 配置
|
||||
pub async fn update_video_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<VideoConfigUpdate>,
|
||||
) -> Result<Json<VideoConfig>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_video_config = state.config.get().video.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
req.apply_to(&mut config.video);
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_video_config = state.config.get().video.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_video_config(&state, &old_video_config, &new_video_config).await {
|
||||
tracing::error!("Failed to apply video config: {}", e);
|
||||
// 根据用户选择,仅记录错误,不回滚
|
||||
}
|
||||
|
||||
Ok(Json(new_video_config))
|
||||
}
|
||||
14
src/web/handlers/devices.rs
Normal file
14
src/web/handlers/devices.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! Device discovery handlers
|
||||
//!
|
||||
//! Provides API endpoints for discovering available hardware devices.
|
||||
|
||||
use axum::Json;
|
||||
|
||||
use crate::atx::{discover_devices, AtxDevices};
|
||||
|
||||
/// GET /api/devices/atx - List available ATX devices
|
||||
///
|
||||
/// Returns lists of available GPIO chips and USB HID relay devices.
|
||||
pub async fn list_atx_devices() -> Json<AtxDevices> {
|
||||
Json(discover_devices())
|
||||
}
|
||||
352
src/web/handlers/extensions.rs
Normal file
352
src/web/handlers/extensions.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
//! Extension management API handlers
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use typeshare::typeshare;
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::extensions::{
|
||||
EasytierConfig, EasytierInfo, ExtensionId, ExtensionInfo, ExtensionLogs,
|
||||
ExtensionsStatus, GostcConfig, GostcInfo, TtydConfig, TtydInfo,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
|
||||
// ============================================================================
|
||||
// Get all extensions status
|
||||
// ============================================================================
|
||||
|
||||
/// Get status of all extensions
|
||||
/// GET /api/extensions
|
||||
pub async fn list_extensions(State(state): State<Arc<AppState>>) -> Json<ExtensionsStatus> {
|
||||
let config = state.config.get();
|
||||
let mgr = &state.extensions;
|
||||
|
||||
Json(ExtensionsStatus {
|
||||
ttyd: TtydInfo {
|
||||
available: mgr.check_available(ExtensionId::Ttyd),
|
||||
status: mgr.status(ExtensionId::Ttyd).await,
|
||||
config: config.extensions.ttyd.clone(),
|
||||
},
|
||||
gostc: GostcInfo {
|
||||
available: mgr.check_available(ExtensionId::Gostc),
|
||||
status: mgr.status(ExtensionId::Gostc).await,
|
||||
config: config.extensions.gostc.clone(),
|
||||
},
|
||||
easytier: EasytierInfo {
|
||||
available: mgr.check_available(ExtensionId::Easytier),
|
||||
status: mgr.status(ExtensionId::Easytier).await,
|
||||
config: config.extensions.easytier.clone(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Individual extension status
|
||||
// ============================================================================
|
||||
|
||||
/// Get status of a single extension
|
||||
/// GET /api/extensions/:id
|
||||
pub async fn get_extension(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ExtensionInfo>> {
|
||||
let ext_id: ExtensionId = id
|
||||
.parse()
|
||||
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
|
||||
|
||||
let mgr = &state.extensions;
|
||||
|
||||
Ok(Json(ExtensionInfo {
|
||||
available: mgr.check_available(ext_id),
|
||||
status: mgr.status(ext_id).await,
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Start/Stop extensions
|
||||
// ============================================================================
|
||||
|
||||
/// Start an extension
|
||||
/// POST /api/extensions/:id/start
|
||||
pub async fn start_extension(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ExtensionInfo>> {
|
||||
let ext_id: ExtensionId = id
|
||||
.parse()
|
||||
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
|
||||
|
||||
let config = state.config.get();
|
||||
let mgr = &state.extensions;
|
||||
|
||||
// Start the extension
|
||||
mgr.start(ext_id, &config.extensions)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(e))?;
|
||||
|
||||
// Return updated status
|
||||
Ok(Json(ExtensionInfo {
|
||||
available: mgr.check_available(ext_id),
|
||||
status: mgr.status(ext_id).await,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Stop an extension
|
||||
/// POST /api/extensions/:id/stop
|
||||
pub async fn stop_extension(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ExtensionInfo>> {
|
||||
let ext_id: ExtensionId = id
|
||||
.parse()
|
||||
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
|
||||
|
||||
let mgr = &state.extensions;
|
||||
|
||||
// Stop the extension
|
||||
mgr.stop(ext_id)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(e))?;
|
||||
|
||||
// Return updated status
|
||||
Ok(Json(ExtensionInfo {
|
||||
available: mgr.check_available(ext_id),
|
||||
status: mgr.status(ext_id).await,
|
||||
}))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Extension logs
|
||||
// ============================================================================
|
||||
|
||||
/// Query parameters for logs
|
||||
#[derive(Deserialize, Default)]
|
||||
pub struct LogsQuery {
|
||||
/// Number of lines to return (default: 100, max: 500)
|
||||
pub lines: Option<usize>,
|
||||
}
|
||||
|
||||
/// Get extension logs
|
||||
/// GET /api/extensions/:id/logs
|
||||
pub async fn get_extension_logs(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<LogsQuery>,
|
||||
) -> Result<Json<ExtensionLogs>> {
|
||||
let ext_id: ExtensionId = id
|
||||
.parse()
|
||||
.map_err(|_| AppError::NotFound(format!("Unknown extension: {}", id)))?;
|
||||
|
||||
let lines = params.lines.unwrap_or(100).min(500);
|
||||
let logs = state.extensions.logs(ext_id, lines).await;
|
||||
|
||||
Ok(Json(ExtensionLogs { id: ext_id, logs }))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Update extension config
|
||||
// ============================================================================
|
||||
|
||||
/// Update ttyd config
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct TtydConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub port: Option<u16>,
|
||||
pub shell: Option<String>,
|
||||
pub credential: Option<String>,
|
||||
}
|
||||
|
||||
/// Update gostc config
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GostcConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub addr: Option<String>,
|
||||
pub key: Option<String>,
|
||||
pub tls: Option<bool>,
|
||||
}
|
||||
|
||||
/// Update easytier config
|
||||
#[typeshare]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EasytierConfigUpdate {
|
||||
pub enabled: Option<bool>,
|
||||
pub network_name: Option<String>,
|
||||
pub network_secret: Option<String>,
|
||||
pub peer_urls: Option<Vec<String>>,
|
||||
pub virtual_ip: Option<String>,
|
||||
}
|
||||
|
||||
/// Update ttyd configuration
|
||||
/// PATCH /api/extensions/ttyd/config
|
||||
pub async fn update_ttyd_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<TtydConfigUpdate>,
|
||||
) -> Result<Json<TtydConfig>> {
|
||||
// Get current config
|
||||
let was_enabled = state.config.get().extensions.ttyd.enabled;
|
||||
|
||||
// Update config
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
let ttyd = &mut config.extensions.ttyd;
|
||||
if let Some(enabled) = req.enabled {
|
||||
ttyd.enabled = enabled;
|
||||
}
|
||||
if let Some(port) = req.port {
|
||||
ttyd.port = port;
|
||||
}
|
||||
if let Some(ref shell) = req.shell {
|
||||
ttyd.shell = shell.clone();
|
||||
}
|
||||
if req.credential.is_some() {
|
||||
ttyd.credential = req.credential.clone();
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
let new_config = state.config.get();
|
||||
let is_enabled = new_config.extensions.ttyd.enabled;
|
||||
|
||||
// Handle enable/disable state change
|
||||
if was_enabled && !is_enabled {
|
||||
// Was running, now disabled - stop it
|
||||
state.extensions.stop(ExtensionId::Ttyd).await.ok();
|
||||
} else if !was_enabled && is_enabled {
|
||||
// Was disabled, now enabled - start it
|
||||
if state.extensions.check_available(ExtensionId::Ttyd) {
|
||||
state
|
||||
.extensions
|
||||
.start(ExtensionId::Ttyd, &new_config.extensions)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(new_config.extensions.ttyd.clone()))
|
||||
}
|
||||
|
||||
/// Update gostc configuration
|
||||
/// PATCH /api/extensions/gostc/config
|
||||
pub async fn update_gostc_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<GostcConfigUpdate>,
|
||||
) -> Result<Json<GostcConfig>> {
|
||||
let was_enabled = state.config.get().extensions.gostc.enabled;
|
||||
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
let gostc = &mut config.extensions.gostc;
|
||||
if let Some(enabled) = req.enabled {
|
||||
gostc.enabled = enabled;
|
||||
}
|
||||
if let Some(ref addr) = req.addr {
|
||||
gostc.addr = addr.clone();
|
||||
}
|
||||
if let Some(ref key) = req.key {
|
||||
gostc.key = key.clone();
|
||||
}
|
||||
if let Some(tls) = req.tls {
|
||||
gostc.tls = tls;
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
let new_config = state.config.get();
|
||||
let is_enabled = new_config.extensions.gostc.enabled;
|
||||
let has_key = !new_config.extensions.gostc.key.is_empty();
|
||||
|
||||
if was_enabled && !is_enabled {
|
||||
state.extensions.stop(ExtensionId::Gostc).await.ok();
|
||||
} else if !was_enabled && is_enabled && has_key {
|
||||
if state.extensions.check_available(ExtensionId::Gostc) {
|
||||
state
|
||||
.extensions
|
||||
.start(ExtensionId::Gostc, &new_config.extensions)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(new_config.extensions.gostc.clone()))
|
||||
}
|
||||
|
||||
/// Update easytier configuration
|
||||
/// PATCH /api/extensions/easytier/config
|
||||
pub async fn update_easytier_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<EasytierConfigUpdate>,
|
||||
) -> Result<Json<EasytierConfig>> {
|
||||
let was_enabled = state.config.get().extensions.easytier.enabled;
|
||||
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
let et = &mut config.extensions.easytier;
|
||||
if let Some(enabled) = req.enabled {
|
||||
et.enabled = enabled;
|
||||
}
|
||||
if let Some(ref name) = req.network_name {
|
||||
et.network_name = name.clone();
|
||||
}
|
||||
if let Some(ref secret) = req.network_secret {
|
||||
et.network_secret = secret.clone();
|
||||
}
|
||||
if let Some(ref peers) = req.peer_urls {
|
||||
et.peer_urls = peers.clone();
|
||||
}
|
||||
if req.virtual_ip.is_some() {
|
||||
et.virtual_ip = req.virtual_ip.clone();
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
let new_config = state.config.get();
|
||||
let is_enabled = new_config.extensions.easytier.enabled;
|
||||
let has_name = !new_config.extensions.easytier.network_name.is_empty();
|
||||
|
||||
if was_enabled && !is_enabled {
|
||||
state.extensions.stop(ExtensionId::Easytier).await.ok();
|
||||
} else if !was_enabled && is_enabled && has_name {
|
||||
if state.extensions.check_available(ExtensionId::Easytier) {
|
||||
state
|
||||
.extensions
|
||||
.start(ExtensionId::Easytier, &new_config.extensions)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(new_config.extensions.easytier.clone()))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Ttyd status for console (simplified)
|
||||
// ============================================================================
|
||||
|
||||
/// Simple ttyd status for console view
|
||||
#[typeshare]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TtydStatus {
|
||||
pub available: bool,
|
||||
pub running: bool,
|
||||
}
|
||||
|
||||
/// Get ttyd status for console view
|
||||
/// GET /api/extensions/ttyd/status
|
||||
pub async fn get_ttyd_status(State(state): State<Arc<AppState>>) -> Json<TtydStatus> {
|
||||
let mgr = &state.extensions;
|
||||
let status = mgr.status(ExtensionId::Ttyd).await;
|
||||
|
||||
Json(TtydStatus {
|
||||
available: mgr.check_available(ExtensionId::Ttyd),
|
||||
running: status.is_running(),
|
||||
})
|
||||
}
|
||||
2583
src/web/handlers/mod.rs
Normal file
2583
src/web/handlers/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
239
src/web/handlers/terminal.rs
Normal file
239
src/web/handlers/terminal.rs
Normal file
@@ -0,0 +1,239 @@
|
||||
//! Terminal proxy handler - reverse proxy to ttyd via Unix socket
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{
|
||||
ws::{Message as AxumMessage, WebSocket, WebSocketUpgrade},
|
||||
OriginalUri, Path, State,
|
||||
},
|
||||
http::{Request, StatusCode},
|
||||
response::Response,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::UnixStream;
|
||||
use tokio_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
http::HeaderValue,
|
||||
Message as TungsteniteMessage,
|
||||
};
|
||||
|
||||
use crate::error::AppError;
|
||||
use crate::extensions::TTYD_SOCKET_PATH;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Handle WebSocket upgrade for terminal
|
||||
pub async fn terminal_ws(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
OriginalUri(original_uri): OriginalUri,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Response {
|
||||
let query_string = original_uri
|
||||
.query()
|
||||
.map(|q| format!("?{}", q))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Use the tty subprotocol that ttyd expects
|
||||
ws.protocols(["tty"])
|
||||
.on_upgrade(move |socket| handle_terminal_websocket(socket, query_string))
|
||||
}
|
||||
|
||||
/// Handle terminal WebSocket connection - bridge browser and ttyd
|
||||
async fn handle_terminal_websocket(client_ws: WebSocket, query_string: String) {
|
||||
// Connect to ttyd Unix socket
|
||||
let unix_stream = match UnixStream::connect(TTYD_SOCKET_PATH).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to connect to ttyd socket: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Build WebSocket request for ttyd with tty subprotocol
|
||||
let uri_str = format!("ws://localhost/api/terminal/ws{}", query_string);
|
||||
let mut request = match uri_str.into_client_request() {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to create WebSocket request: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
request.headers_mut().insert(
|
||||
"Sec-WebSocket-Protocol",
|
||||
HeaderValue::from_static("tty"),
|
||||
);
|
||||
|
||||
// Create WebSocket connection to ttyd
|
||||
let ws_stream = match tokio_tungstenite::client_async(request, unix_stream).await {
|
||||
Ok((ws, _)) => ws,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to establish WebSocket with ttyd: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Split both WebSocket connections
|
||||
let (mut client_tx, mut client_rx) = client_ws.split();
|
||||
let (mut ttyd_tx, mut ttyd_rx) = ws_stream.split();
|
||||
|
||||
// Forward messages from browser to ttyd
|
||||
let client_to_ttyd = tokio::spawn(async move {
|
||||
while let Some(msg) = client_rx.next().await {
|
||||
let ttyd_msg = match msg {
|
||||
Ok(AxumMessage::Text(text)) => TungsteniteMessage::Text(text),
|
||||
Ok(AxumMessage::Binary(data)) => TungsteniteMessage::Binary(data),
|
||||
Ok(AxumMessage::Ping(data)) => TungsteniteMessage::Ping(data),
|
||||
Ok(AxumMessage::Pong(data)) => TungsteniteMessage::Pong(data),
|
||||
Ok(AxumMessage::Close(_)) => {
|
||||
let _ = ttyd_tx.send(TungsteniteMessage::Close(None)).await;
|
||||
break;
|
||||
}
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
if ttyd_tx.send(ttyd_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Forward messages from ttyd to browser
|
||||
let ttyd_to_client = tokio::spawn(async move {
|
||||
while let Some(msg) = ttyd_rx.next().await {
|
||||
let client_msg = match msg {
|
||||
Ok(TungsteniteMessage::Text(text)) => AxumMessage::Text(text),
|
||||
Ok(TungsteniteMessage::Binary(data)) => AxumMessage::Binary(data),
|
||||
Ok(TungsteniteMessage::Ping(data)) => AxumMessage::Ping(data),
|
||||
Ok(TungsteniteMessage::Pong(data)) => AxumMessage::Pong(data),
|
||||
Ok(TungsteniteMessage::Close(_)) => {
|
||||
let _ = client_tx.send(AxumMessage::Close(None)).await;
|
||||
break;
|
||||
}
|
||||
Ok(TungsteniteMessage::Frame(_)) => continue,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
if client_tx.send(client_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for either direction to complete
|
||||
tokio::select! {
|
||||
_ = client_to_ttyd => {}
|
||||
_ = ttyd_to_client => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy HTTP requests to ttyd
|
||||
pub async fn terminal_proxy(
|
||||
State(_state): State<Arc<AppState>>,
|
||||
path: Option<Path<String>>,
|
||||
req: Request<Body>,
|
||||
) -> Result<Response, AppError> {
|
||||
let path_str = path.map(|p| p.0).unwrap_or_default();
|
||||
|
||||
// Connect to ttyd Unix socket
|
||||
let mut unix_stream = UnixStream::connect(TTYD_SOCKET_PATH)
|
||||
.await
|
||||
.map_err(|e| AppError::ServiceUnavailable(format!("ttyd not running: {}", e)))?;
|
||||
|
||||
// Build HTTP request to forward
|
||||
let method = req.method().as_str();
|
||||
let query = req.uri().query().map(|q| format!("?{}", q)).unwrap_or_default();
|
||||
let uri_path = if path_str.is_empty() {
|
||||
format!("/api/terminal/{}", query)
|
||||
} else {
|
||||
format!("/api/terminal/{}{}", path_str, query)
|
||||
};
|
||||
|
||||
// Forward relevant headers
|
||||
let mut headers_str = String::new();
|
||||
for (name, value) in req.headers() {
|
||||
if let Ok(v) = value.to_str() {
|
||||
let name_lower = name.as_str().to_lowercase();
|
||||
if !matches!(
|
||||
name_lower.as_str(),
|
||||
"connection" | "keep-alive" | "transfer-encoding" | "upgrade"
|
||||
) {
|
||||
headers_str.push_str(&format!("{}: {}\r\n", name, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let http_request = format!(
|
||||
"{} {} HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n{}\r\n",
|
||||
method, uri_path, headers_str
|
||||
);
|
||||
|
||||
// Send request
|
||||
unix_stream
|
||||
.write_all(http_request.as_bytes())
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to send request: {}", e)))?;
|
||||
|
||||
// Read response
|
||||
let mut response_buf = Vec::new();
|
||||
unix_stream
|
||||
.read_to_end(&mut response_buf)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read response: {}", e)))?;
|
||||
|
||||
// Parse HTTP response
|
||||
let response_str = String::from_utf8_lossy(&response_buf);
|
||||
let header_end = response_str
|
||||
.find("\r\n\r\n")
|
||||
.ok_or_else(|| AppError::Internal("Invalid HTTP response".to_string()))?;
|
||||
|
||||
let headers_part = &response_str[..header_end];
|
||||
let body_start = header_end + 4;
|
||||
|
||||
// Parse status line
|
||||
let status_line = headers_part
|
||||
.lines()
|
||||
.next()
|
||||
.ok_or_else(|| AppError::Internal("Missing status line".to_string()))?;
|
||||
let status_code: u16 = status_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(200);
|
||||
|
||||
// Build response
|
||||
let mut builder = Response::builder().status(StatusCode::from_u16(status_code).unwrap_or(StatusCode::OK));
|
||||
|
||||
// Forward response headers
|
||||
for line in headers_part.lines().skip(1) {
|
||||
if let Some((name, value)) = line.split_once(':') {
|
||||
let name = name.trim();
|
||||
let value = value.trim();
|
||||
if !matches!(
|
||||
name.to_lowercase().as_str(),
|
||||
"connection" | "keep-alive" | "transfer-encoding"
|
||||
) {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let body = if body_start < response_buf.len() {
|
||||
Body::from(response_buf[body_start..].to_vec())
|
||||
} else {
|
||||
Body::empty()
|
||||
};
|
||||
|
||||
builder
|
||||
.body(body)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to build response: {}", e)))
|
||||
}
|
||||
|
||||
/// Terminal index page
|
||||
pub async fn terminal_index(
|
||||
State(state): State<Arc<AppState>>,
|
||||
req: Request<Body>,
|
||||
) -> Result<Response, AppError> {
|
||||
terminal_proxy(State(state), None, req).await
|
||||
}
|
||||
12
src/web/mod.rs
Normal file
12
src/web/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
mod audio_ws;
|
||||
mod routes;
|
||||
mod handlers;
|
||||
mod static_files;
|
||||
mod ws;
|
||||
|
||||
pub use audio_ws::audio_ws_handler;
|
||||
pub use routes::create_router;
|
||||
// StaticAssets is only available in release mode (embedded assets)
|
||||
#[cfg(not(debug_assertions))]
|
||||
pub use static_files::StaticAssets;
|
||||
pub use ws::ws_handler;
|
||||
178
src/web/routes.rs
Normal file
178
src/web/routes.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
use axum::{
|
||||
extract::DefaultBodyLimit,
|
||||
middleware,
|
||||
routing::{any, delete, get, patch, post, put},
|
||||
Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tower_http::{
|
||||
cors::{Any, CorsLayer},
|
||||
trace::TraceLayer,
|
||||
};
|
||||
|
||||
use super::audio_ws::audio_ws_handler;
|
||||
use super::handlers;
|
||||
use super::ws::ws_handler;
|
||||
use crate::auth::{auth_middleware, require_admin};
|
||||
use crate::hid::websocket::ws_hid_handler;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Create the main application router
|
||||
pub fn create_router(state: Arc<AppState>) -> Router {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
// Public routes (no auth required)
|
||||
// Note: /info moved to user_routes for security (contains hostname, IPs, etc.)
|
||||
let public_routes = Router::new()
|
||||
.route("/health", get(handlers::health_check))
|
||||
.route("/auth/login", post(handlers::login))
|
||||
.route("/setup", get(handlers::setup_status))
|
||||
.route("/setup/init", post(handlers::setup_init));
|
||||
|
||||
// User routes (authenticated users - both regular and admin)
|
||||
let user_routes = Router::new()
|
||||
.route("/info", get(handlers::system_info))
|
||||
.route("/auth/logout", post(handlers::logout))
|
||||
.route("/auth/check", get(handlers::auth_check))
|
||||
.route("/devices", get(handlers::list_devices))
|
||||
// WebSocket endpoint for real-time events
|
||||
.route("/ws", any(ws_handler))
|
||||
// Stream control endpoints
|
||||
.route("/stream/status", get(handlers::stream_state))
|
||||
.route("/stream/start", post(handlers::stream_start))
|
||||
.route("/stream/stop", post(handlers::stream_stop))
|
||||
.route("/stream/mode", get(handlers::stream_mode_get))
|
||||
.route("/stream/mode", post(handlers::stream_mode_set))
|
||||
.route("/stream/bitrate", post(handlers::stream_set_bitrate))
|
||||
.route("/stream/codecs", get(handlers::stream_codecs_list))
|
||||
// WebRTC endpoints
|
||||
.route("/webrtc/session", post(handlers::webrtc_create_session))
|
||||
.route("/webrtc/offer", post(handlers::webrtc_offer))
|
||||
.route("/webrtc/ice", post(handlers::webrtc_ice_candidate))
|
||||
.route("/webrtc/ice-servers", get(handlers::webrtc_ice_servers))
|
||||
.route("/webrtc/status", get(handlers::webrtc_status))
|
||||
.route("/webrtc/close", post(handlers::webrtc_close_session))
|
||||
// HID endpoints
|
||||
.route("/hid/status", get(handlers::hid_status))
|
||||
.route("/hid/reset", post(handlers::hid_reset))
|
||||
// WebSocket HID endpoint (for MJPEG mode)
|
||||
.route("/ws/hid", any(ws_hid_handler))
|
||||
// Audio endpoints
|
||||
.route("/audio/status", get(handlers::audio_status))
|
||||
.route("/audio/start", post(handlers::start_audio_streaming))
|
||||
.route("/audio/stop", post(handlers::stop_audio_streaming))
|
||||
.route("/audio/quality", post(handlers::set_audio_quality))
|
||||
.route("/audio/device", post(handlers::select_audio_device))
|
||||
.route("/audio/devices", get(handlers::list_audio_devices))
|
||||
// Audio WebSocket endpoint
|
||||
.route("/ws/audio", any(audio_ws_handler))
|
||||
// User can change their own password (handler will check ownership)
|
||||
.route("/users/:id/password", post(handlers::change_user_password));
|
||||
|
||||
// Admin-only routes (require admin privileges)
|
||||
let admin_routes = Router::new()
|
||||
// Configuration management (domain-separated endpoints)
|
||||
.route("/config", get(handlers::config::get_all_config))
|
||||
.route("/config", post(handlers::update_config))
|
||||
.route("/config/video", get(handlers::config::get_video_config))
|
||||
.route("/config/video", patch(handlers::config::update_video_config))
|
||||
.route("/config/stream", get(handlers::config::get_stream_config))
|
||||
.route("/config/stream", patch(handlers::config::update_stream_config))
|
||||
.route("/config/hid", get(handlers::config::get_hid_config))
|
||||
.route("/config/hid", patch(handlers::config::update_hid_config))
|
||||
.route("/config/msd", get(handlers::config::get_msd_config))
|
||||
.route("/config/msd", patch(handlers::config::update_msd_config))
|
||||
.route("/config/atx", get(handlers::config::get_atx_config))
|
||||
.route("/config/atx", patch(handlers::config::update_atx_config))
|
||||
.route("/config/audio", get(handlers::config::get_audio_config))
|
||||
.route("/config/audio", patch(handlers::config::update_audio_config))
|
||||
// MSD (Mass Storage Device) endpoints
|
||||
.route("/msd/status", get(handlers::msd_status))
|
||||
.route("/msd/images", get(handlers::msd_images_list))
|
||||
.route("/msd/images/download", post(handlers::msd_image_download))
|
||||
.route("/msd/images/download/cancel", post(handlers::msd_image_download_cancel))
|
||||
.route("/msd/images/:id", get(handlers::msd_image_get))
|
||||
.route("/msd/images/:id", delete(handlers::msd_image_delete))
|
||||
.route("/msd/connect", post(handlers::msd_connect))
|
||||
.route("/msd/disconnect", post(handlers::msd_disconnect))
|
||||
// MSD Virtual Drive endpoints
|
||||
.route("/msd/drive", get(handlers::msd_drive_info))
|
||||
.route("/msd/drive", delete(handlers::msd_drive_delete))
|
||||
.route("/msd/drive/init", post(handlers::msd_drive_init))
|
||||
.route("/msd/drive/files", get(handlers::msd_drive_files))
|
||||
.route("/msd/drive/files/*path", get(handlers::msd_drive_download))
|
||||
.route("/msd/drive/files/*path", delete(handlers::msd_drive_file_delete))
|
||||
.route("/msd/drive/mkdir/*path", post(handlers::msd_drive_mkdir))
|
||||
// ATX (Power Control) endpoints
|
||||
.route("/atx/status", get(handlers::atx_status))
|
||||
.route("/atx/power", post(handlers::atx_power))
|
||||
.route("/atx/wol", post(handlers::atx_wol))
|
||||
// Device discovery endpoints
|
||||
.route("/devices/atx", get(handlers::devices::list_atx_devices))
|
||||
// User management endpoints
|
||||
.route("/users", get(handlers::list_users))
|
||||
.route("/users", post(handlers::create_user))
|
||||
.route("/users/:id", put(handlers::update_user))
|
||||
.route("/users/:id", delete(handlers::delete_user))
|
||||
// Extension management endpoints
|
||||
.route("/extensions", get(handlers::extensions::list_extensions))
|
||||
.route("/extensions/:id", get(handlers::extensions::get_extension))
|
||||
.route("/extensions/:id/start", post(handlers::extensions::start_extension))
|
||||
.route("/extensions/:id/stop", post(handlers::extensions::stop_extension))
|
||||
.route("/extensions/:id/logs", get(handlers::extensions::get_extension_logs))
|
||||
.route("/extensions/ttyd/config", patch(handlers::extensions::update_ttyd_config))
|
||||
.route("/extensions/ttyd/status", get(handlers::extensions::get_ttyd_status))
|
||||
.route("/extensions/gostc/config", patch(handlers::extensions::update_gostc_config))
|
||||
.route("/extensions/easytier/config", patch(handlers::extensions::update_easytier_config))
|
||||
// Terminal (ttyd) reverse proxy - WebSocket and HTTP
|
||||
.route("/terminal", get(handlers::terminal::terminal_index))
|
||||
.route("/terminal/", get(handlers::terminal::terminal_index))
|
||||
.route("/terminal/ws", get(handlers::terminal::terminal_ws))
|
||||
.route("/terminal/*path", get(handlers::terminal::terminal_proxy))
|
||||
// Apply admin middleware to all admin routes
|
||||
.layer(middleware::from_fn_with_state(state.clone(), require_admin));
|
||||
|
||||
// Combine protected routes (user + admin)
|
||||
let protected_routes = Router::new()
|
||||
.merge(user_routes)
|
||||
.merge(admin_routes);
|
||||
|
||||
// Stream endpoints (accessible with auth, but typically embedded in pages)
|
||||
let stream_routes = Router::new()
|
||||
.route("/stream", get(handlers::mjpeg_stream))
|
||||
.route("/stream/mjpeg", get(handlers::mjpeg_stream))
|
||||
.route("/snapshot", get(handlers::snapshot));
|
||||
|
||||
// Large file upload routes (MSD images and drive files)
|
||||
// Use streaming upload to support files larger than available RAM
|
||||
// Disable body limit for streaming uploads - files are written directly to disk
|
||||
let upload_routes = Router::new()
|
||||
.route("/msd/images", post(handlers::msd_image_upload))
|
||||
.route("/msd/drive/files", post(handlers::msd_drive_upload))
|
||||
.layer(DefaultBodyLimit::disable());
|
||||
|
||||
// Combine API routes
|
||||
let api_routes = Router::new()
|
||||
.merge(public_routes)
|
||||
.merge(protected_routes)
|
||||
.merge(stream_routes)
|
||||
.merge(upload_routes)
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
// Static file serving
|
||||
let static_routes = super::static_files::static_file_router();
|
||||
|
||||
// Main router
|
||||
Router::new()
|
||||
.nest("/api", api_routes)
|
||||
.merge(static_routes)
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user