mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2026-06-14 03:32:00 +08:00
refactor: 删除部分多余的代码和注释
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
//! ALSA audio capture implementation
|
||||
|
||||
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
|
||||
use alsa::{Direction, ValueOr, PCM};
|
||||
use bytes::Bytes;
|
||||
@@ -14,30 +12,23 @@ use crate::error::{AppError, Result};
|
||||
use crate::utils::LogThrottler;
|
||||
use crate::{error_throttled, warn_throttled};
|
||||
|
||||
/// Audio capture configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioConfig {
|
||||
/// ALSA device name (e.g., "hw:0,0" or "default")
|
||||
pub device_name: String,
|
||||
/// Sample rate in Hz
|
||||
pub sample_rate: u32,
|
||||
/// Number of channels (1 = mono, 2 = stereo)
|
||||
pub channels: u32,
|
||||
/// Samples per frame (for Opus, typically 480 for 10ms at 48kHz)
|
||||
pub frame_size: u32,
|
||||
/// Buffer size in frames
|
||||
pub buffer_frames: u32,
|
||||
/// Period size in frames
|
||||
pub period_frames: u32,
|
||||
}
|
||||
|
||||
impl Default for AudioConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
device_name: "default".to_string(),
|
||||
device_name: String::new(),
|
||||
sample_rate: 48000,
|
||||
channels: 2,
|
||||
frame_size: 960, // 20ms at 48kHz (good for Opus)
|
||||
frame_size: 960,
|
||||
buffer_frames: 4096,
|
||||
period_frames: 960,
|
||||
}
|
||||
@@ -45,7 +36,6 @@ impl Default for AudioConfig {
|
||||
}
|
||||
|
||||
impl AudioConfig {
|
||||
/// Create config for a specific device (48 kHz stereo only; must match ALSA hardware).
|
||||
pub fn for_device(device: &AudioDeviceInfo) -> Self {
|
||||
Self {
|
||||
device_name: device.name.clone(),
|
||||
@@ -53,36 +43,26 @@ impl AudioConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Bytes per sample (16-bit signed)
|
||||
pub fn bytes_per_sample(&self) -> u32 {
|
||||
2 * self.channels
|
||||
}
|
||||
|
||||
/// Bytes per frame
|
||||
pub fn bytes_per_frame(&self) -> usize {
|
||||
(self.frame_size * self.bytes_per_sample()) as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio frame data
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioFrame {
|
||||
/// Raw PCM data (S16LE interleaved)
|
||||
pub data: Bytes,
|
||||
/// Sample rate
|
||||
pub sample_rate: u32,
|
||||
/// Number of channels
|
||||
pub channels: u32,
|
||||
/// Number of samples per channel
|
||||
pub samples: u32,
|
||||
/// Frame sequence number
|
||||
pub sequence: u64,
|
||||
/// Capture timestamp
|
||||
pub timestamp: Instant,
|
||||
}
|
||||
|
||||
impl AudioFrame {
|
||||
/// One capture block: `sample_rate` must be the **hardware** rate (e.g. ALSA `actual_rate`).
|
||||
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
|
||||
let bps = 2 * channels;
|
||||
Self {
|
||||
@@ -96,7 +76,6 @@ impl AudioFrame {
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio capture state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CaptureState {
|
||||
Stopped,
|
||||
@@ -104,7 +83,6 @@ pub enum CaptureState {
|
||||
Error,
|
||||
}
|
||||
|
||||
/// ALSA audio capturer
|
||||
pub struct AudioCapturer {
|
||||
config: AudioConfig,
|
||||
state: Arc<watch::Sender<CaptureState>>,
|
||||
@@ -113,15 +91,13 @@ pub struct AudioCapturer {
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
sequence: Arc<AtomicU64>,
|
||||
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
/// Log throttler to prevent log flooding
|
||||
log_throttler: LogThrottler,
|
||||
}
|
||||
|
||||
impl AudioCapturer {
|
||||
/// Create a new audio capturer
|
||||
pub fn new(config: AudioConfig) -> Self {
|
||||
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
|
||||
let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency
|
||||
let (frame_tx, _) = broadcast::channel(16);
|
||||
|
||||
Self {
|
||||
config,
|
||||
@@ -135,22 +111,18 @@ impl AudioCapturer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub fn state(&self) -> CaptureState {
|
||||
*self.state_rx.borrow()
|
||||
}
|
||||
|
||||
/// Subscribe to state changes
|
||||
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
|
||||
self.state_rx.clone()
|
||||
}
|
||||
|
||||
/// Subscribe to audio frames
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
|
||||
self.frame_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Start capturing
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
if self.state() == CaptureState::Running {
|
||||
return Ok(());
|
||||
@@ -171,14 +143,27 @@ impl AudioCapturer {
|
||||
let log_throttler = self.log_throttler.clone();
|
||||
|
||||
let handle = tokio::task::spawn_blocking(move || {
|
||||
capture_loop(config, state, frame_tx, stop_flag, sequence, log_throttler);
|
||||
let result = run_capture(
|
||||
&config,
|
||||
&state,
|
||||
&frame_tx,
|
||||
&stop_flag,
|
||||
&sequence,
|
||||
&log_throttler,
|
||||
);
|
||||
|
||||
if let Err(e) = result {
|
||||
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
|
||||
let _ = state.send(CaptureState::Error);
|
||||
} else {
|
||||
let _ = state.send(CaptureState::Stopped);
|
||||
}
|
||||
});
|
||||
|
||||
*self.capture_handle.lock().await = Some(handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop capturing
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
info!("Stopping audio capture");
|
||||
self.stop_flag.store(true, Ordering::SeqCst);
|
||||
@@ -191,38 +176,11 @@ impl AudioCapturer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if running
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.state() == CaptureState::Running
|
||||
}
|
||||
}
|
||||
|
||||
/// Main capture loop
|
||||
fn capture_loop(
|
||||
config: AudioConfig,
|
||||
state: Arc<watch::Sender<CaptureState>>,
|
||||
frame_tx: broadcast::Sender<AudioFrame>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
sequence: Arc<AtomicU64>,
|
||||
log_throttler: LogThrottler,
|
||||
) {
|
||||
let result = run_capture(
|
||||
&config,
|
||||
&state,
|
||||
&frame_tx,
|
||||
&stop_flag,
|
||||
&sequence,
|
||||
&log_throttler,
|
||||
);
|
||||
|
||||
if let Err(e) = result {
|
||||
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
|
||||
let _ = state.send(CaptureState::Error);
|
||||
} else {
|
||||
let _ = state.send(CaptureState::Stopped);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_capture(
|
||||
config: &AudioConfig,
|
||||
state: &watch::Sender<CaptureState>,
|
||||
@@ -231,7 +189,6 @@ fn run_capture(
|
||||
sequence: &AtomicU64,
|
||||
log_throttler: &LogThrottler,
|
||||
) -> Result<()> {
|
||||
// Open ALSA device
|
||||
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
|
||||
AppError::AudioError(format!(
|
||||
"Failed to open audio device {}: {}",
|
||||
@@ -239,7 +196,6 @@ fn run_capture(
|
||||
))
|
||||
})?;
|
||||
|
||||
// Configure hardware parameters
|
||||
{
|
||||
let hwp = HwParams::any(&pcm)
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
|
||||
@@ -266,7 +222,6 @@ fn run_capture(
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
|
||||
}
|
||||
|
||||
// Fixed 48 kHz stereo: fail if hardware negotiated something else.
|
||||
let hw_now = pcm.hw_params_current().map_err(|e| {
|
||||
AppError::AudioError(format!("Failed to read hw_params after apply: {}", e))
|
||||
})?;
|
||||
@@ -290,13 +245,11 @@ fn run_capture(
|
||||
}
|
||||
info!("Audio capture: 48000 Hz, 2 ch");
|
||||
|
||||
// Prepare for capture
|
||||
pcm.prepare()
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
|
||||
|
||||
let _ = state.send(CaptureState::Running);
|
||||
|
||||
// Sized from actual period — `readi` may return up to ~one period of frames per call.
|
||||
let period_frames = pcm
|
||||
.hw_params_current()
|
||||
.ok()
|
||||
@@ -308,9 +261,7 @@ fn run_capture(
|
||||
let bytes_per_frame = (config.channels as usize) * 2;
|
||||
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
|
||||
|
||||
// Capture loop
|
||||
while !stop_flag.load(Ordering::Relaxed) {
|
||||
// Check PCM state
|
||||
match pcm.state() {
|
||||
State::XRun => {
|
||||
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
|
||||
@@ -329,9 +280,7 @@ fn run_capture(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Get IO handle and read audio data directly as bytes
|
||||
// Note: Use io() instead of io_checked() because USB audio devices
|
||||
// typically don't support mmap, which io_checked() requires
|
||||
// io_bytes: USB capture often lacks mmap (io_checked requires it).
|
||||
let io: IO<u8> = pcm.io_bytes();
|
||||
|
||||
match io.readi(&mut buffer) {
|
||||
@@ -340,10 +289,8 @@ fn run_capture(
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate actual byte count
|
||||
let byte_count = frames_read * config.channels as usize * 2;
|
||||
|
||||
// Directly use the buffer slice (already in correct byte format)
|
||||
let seq = sequence.fetch_add(1, Ordering::Relaxed);
|
||||
let frame = AudioFrame::new_interleaved(
|
||||
Bytes::copy_from_slice(&buffer[..byte_count]),
|
||||
@@ -352,7 +299,6 @@ fn run_capture(
|
||||
seq,
|
||||
);
|
||||
|
||||
// Send to subscribers
|
||||
if frame_tx.receiver_count() > 0 {
|
||||
if let Err(e) = frame_tx.send(frame) {
|
||||
debug!("No audio receivers: {}", e);
|
||||
@@ -360,14 +306,11 @@ fn run_capture(
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Check for buffer overrun (EPIPE = 32 on Linux)
|
||||
let desc = e.to_string();
|
||||
if desc.contains("EPIPE") || desc.contains("Broken pipe") {
|
||||
// Buffer overrun
|
||||
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
|
||||
let _ = pcm.prepare();
|
||||
} else if desc.contains("No such device") || desc.contains("ENODEV") {
|
||||
// Device disconnected - use longer throttle for this
|
||||
error_throttled!(log_throttler, "no_device", "Audio read error: {}", e);
|
||||
} else {
|
||||
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
|
||||
|
||||
@@ -1,35 +1,31 @@
|
||||
//! Audio controller for high-level audio management
|
||||
//!
|
||||
//! Provides device enumeration, selection, quality control, and streaming management.
|
||||
//! Device selection, quality presets, streaming.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
use super::capture::AudioConfig;
|
||||
use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo};
|
||||
use super::device::{
|
||||
enumerate_audio_devices_with_current, find_best_audio_device, AudioDeviceInfo,
|
||||
};
|
||||
use super::encoder::{OpusConfig, OpusFrame};
|
||||
use super::monitor::{AudioHealthMonitor, AudioHealthStatus};
|
||||
use super::monitor::AudioHealthMonitor;
|
||||
use super::streamer::{AudioStreamer, AudioStreamerConfig};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::events::EventBus;
|
||||
|
||||
/// Audio quality presets
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AudioQuality {
|
||||
/// Low bandwidth voice (32kbps)
|
||||
Voice,
|
||||
/// Balanced quality (64kbps) - default
|
||||
#[default]
|
||||
Balanced,
|
||||
/// High quality audio (128kbps)
|
||||
High,
|
||||
}
|
||||
|
||||
impl AudioQuality {
|
||||
/// Get the bitrate for this quality level
|
||||
pub fn bitrate(&self) -> u32 {
|
||||
match self {
|
||||
AudioQuality::Voice => 32000,
|
||||
@@ -38,17 +34,6 @@ impl AudioQuality {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from string
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"voice" | "low" => AudioQuality::Voice,
|
||||
"high" | "music" => AudioQuality::High,
|
||||
_ => AudioQuality::Balanced,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to OpusConfig
|
||||
pub fn to_opus_config(&self) -> OpusConfig {
|
||||
match self {
|
||||
AudioQuality::Voice => OpusConfig::voice(),
|
||||
@@ -58,6 +43,22 @@ impl AudioQuality {
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AudioQuality {
|
||||
type Err = AppError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.trim().to_lowercase().as_str() {
|
||||
"voice" => Ok(Self::Voice),
|
||||
"balanced" => Ok(Self::Balanced),
|
||||
"high" => Ok(Self::High),
|
||||
_ => Err(AppError::BadRequest(format!(
|
||||
"invalid audio quality {:?} (expected voice, balanced, or high)",
|
||||
s.trim()
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AudioQuality {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
@@ -68,17 +69,10 @@ impl std::fmt::Display for AudioQuality {
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio controller configuration
|
||||
///
|
||||
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
|
||||
/// These are optimal for Opus encoding and match WebRTC requirements.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioControllerConfig {
|
||||
/// Whether audio is enabled
|
||||
pub enabled: bool,
|
||||
/// Selected device name
|
||||
pub device: String,
|
||||
/// Audio quality preset
|
||||
pub quality: AudioQuality,
|
||||
}
|
||||
|
||||
@@ -86,74 +80,52 @@ impl Default for AudioControllerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
device: "default".to_string(),
|
||||
device: String::new(),
|
||||
quality: AudioQuality::Balanced,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Current audio status
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AudioStatus {
|
||||
/// Whether audio feature is enabled
|
||||
pub enabled: bool,
|
||||
/// Whether audio is currently streaming
|
||||
pub streaming: bool,
|
||||
/// Currently selected device
|
||||
pub device: Option<String>,
|
||||
/// Current quality preset
|
||||
pub quality: AudioQuality,
|
||||
/// Number of connected subscribers
|
||||
pub subscriber_count: usize,
|
||||
/// Error message if any
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Audio controller
|
||||
///
|
||||
/// High-level interface for audio management, providing:
|
||||
/// - Device enumeration and selection
|
||||
/// - Quality control
|
||||
/// - Stream start/stop
|
||||
/// - Status reporting
|
||||
pub struct AudioController {
|
||||
config: RwLock<AudioControllerConfig>,
|
||||
streamer: RwLock<Option<Arc<AudioStreamer>>>,
|
||||
devices: RwLock<Vec<AudioDeviceInfo>>,
|
||||
event_bus: RwLock<Option<Arc<EventBus>>>,
|
||||
last_error: RwLock<Option<String>>,
|
||||
/// Health monitor for error tracking and recovery
|
||||
monitor: Arc<AudioHealthMonitor>,
|
||||
}
|
||||
|
||||
impl AudioController {
|
||||
/// Create a new audio controller with configuration
|
||||
pub fn new(config: AudioControllerConfig) -> Self {
|
||||
Self {
|
||||
config: RwLock::new(config),
|
||||
streamer: RwLock::new(None),
|
||||
devices: RwLock::new(Vec::new()),
|
||||
event_bus: RwLock::new(None),
|
||||
last_error: RwLock::new(None),
|
||||
monitor: Arc::new(AudioHealthMonitor::with_defaults()),
|
||||
monitor: Arc::new(AudioHealthMonitor::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set event bus for internal state notifications.
|
||||
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
|
||||
*self.event_bus.write().await = Some(event_bus);
|
||||
}
|
||||
|
||||
/// Mark the device-info snapshot as stale.
|
||||
async fn mark_device_info_dirty(&self) {
|
||||
if let Some(ref bus) = *self.event_bus.read().await {
|
||||
bus.mark_device_info_dirty();
|
||||
}
|
||||
}
|
||||
|
||||
/// List available audio capture devices
|
||||
pub async fn list_devices(&self) -> Result<Vec<AudioDeviceInfo>> {
|
||||
// Get current device if streaming (it may be busy and unable to be opened)
|
||||
let current_device = if self.is_streaming().await {
|
||||
Some(self.config.read().await.device.clone())
|
||||
} else {
|
||||
@@ -165,41 +137,23 @@ impl AudioController {
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
/// Refresh device list and cache it
|
||||
pub async fn refresh_devices(&self) -> Result<()> {
|
||||
// Get current device if streaming (it may be busy and unable to be opened)
|
||||
let current_device = if self.is_streaming().await {
|
||||
Some(self.config.read().await.device.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
|
||||
*self.devices.write().await = devices;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get cached device list
|
||||
pub async fn get_cached_devices(&self) -> Vec<AudioDeviceInfo> {
|
||||
self.devices.read().await.clone()
|
||||
}
|
||||
|
||||
/// Select audio device
|
||||
pub async fn select_device(&self, device: &str) -> Result<()> {
|
||||
// Validate device exists
|
||||
let devices = self.list_devices().await?;
|
||||
let found = devices
|
||||
.iter()
|
||||
.any(|d| d.name == device || d.description.contains(device));
|
||||
|
||||
if !found && device != "default" {
|
||||
if !found {
|
||||
return Err(AppError::AudioError(format!(
|
||||
"Audio device not found: {}",
|
||||
device
|
||||
)));
|
||||
}
|
||||
|
||||
// Update config
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.device = device.to_string();
|
||||
@@ -207,7 +161,6 @@ impl AudioController {
|
||||
|
||||
info!("Audio device selected: {}", device);
|
||||
|
||||
// If streaming, restart with new device
|
||||
if self.is_streaming().await {
|
||||
self.stop_streaming().await?;
|
||||
self.start_streaming().await?;
|
||||
@@ -216,15 +169,12 @@ impl AudioController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set audio quality
|
||||
pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> {
|
||||
// Update config
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
config.quality = quality;
|
||||
}
|
||||
|
||||
// Update streamer if running
|
||||
if let Some(ref streamer) = *self.streamer.read().await {
|
||||
streamer.set_bitrate(quality.bitrate()).await?;
|
||||
}
|
||||
@@ -237,44 +187,45 @@ impl AudioController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start audio streaming
|
||||
pub async fn start_streaming(&self) -> Result<()> {
|
||||
let config = self.config.read().await.clone();
|
||||
|
||||
if !config.enabled {
|
||||
return Err(AppError::AudioError("Audio is disabled".to_string()));
|
||||
{
|
||||
let config = self.config.read().await;
|
||||
if !config.enabled {
|
||||
return Err(AppError::AudioError("Audio is disabled".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// Check if already streaming
|
||||
if self.is_streaming().await {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Starting audio streaming with device: {}", config.device);
|
||||
|
||||
// Clear any previous error
|
||||
*self.last_error.write().await = None;
|
||||
|
||||
// Create streamer config (fixed 48kHz stereo)
|
||||
let streamer_config = AudioStreamerConfig {
|
||||
capture: AudioConfig {
|
||||
device_name: config.device.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
opus: config.quality.to_opus_config(),
|
||||
let (device_name, quality) = {
|
||||
let mut cfg = self.config.write().await;
|
||||
if cfg.device.trim().is_empty() {
|
||||
let best = find_best_audio_device()?;
|
||||
cfg.device = best.name;
|
||||
}
|
||||
(cfg.device.clone(), cfg.quality)
|
||||
};
|
||||
|
||||
info!("Starting audio streaming with device: {}", device_name);
|
||||
|
||||
self.monitor.prepare_retry_attempt();
|
||||
|
||||
let streamer_config = AudioStreamerConfig {
|
||||
capture: AudioConfig {
|
||||
device_name: device_name.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
opus: quality.to_opus_config(),
|
||||
};
|
||||
|
||||
// Create and start streamer
|
||||
let streamer = Arc::new(AudioStreamer::with_config(streamer_config));
|
||||
|
||||
if let Err(e) = streamer.start().await {
|
||||
let error_msg = format!("Failed to start audio: {}", e);
|
||||
*self.last_error.write().await = Some(error_msg.clone());
|
||||
|
||||
// Report error to health monitor
|
||||
self.monitor
|
||||
.report_error(Some(&config.device), &error_msg, "start_failed")
|
||||
.await;
|
||||
self.monitor.report_error(&error_msg, "start_failed").await;
|
||||
|
||||
self.mark_device_info_dirty().await;
|
||||
|
||||
@@ -283,9 +234,8 @@ impl AudioController {
|
||||
|
||||
*self.streamer.write().await = Some(streamer);
|
||||
|
||||
// Report recovery if we were in an error state
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered(Some(&config.device)).await;
|
||||
self.monitor.report_recovered().await;
|
||||
}
|
||||
|
||||
self.mark_device_info_dirty().await;
|
||||
@@ -294,7 +244,6 @@ impl AudioController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop audio streaming
|
||||
pub async fn stop_streaming(&self) -> Result<()> {
|
||||
if let Some(streamer) = self.streamer.write().await.take() {
|
||||
streamer.stop().await?;
|
||||
@@ -306,7 +255,6 @@ impl AudioController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if currently streaming
|
||||
pub async fn is_streaming(&self) -> bool {
|
||||
if let Some(ref streamer) = *self.streamer.read().await {
|
||||
streamer.is_running()
|
||||
@@ -315,45 +263,37 @@ impl AudioController {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current status
|
||||
pub async fn status(&self) -> AudioStatus {
|
||||
let config = self.config.read().await;
|
||||
let streaming = self.is_streaming().await;
|
||||
let error = self.last_error.read().await.clone();
|
||||
let (enabled, device_str, quality) = {
|
||||
let c = self.config.read().await;
|
||||
(c.enabled, c.device.clone(), c.quality)
|
||||
};
|
||||
let error = self.monitor.error_message().await;
|
||||
|
||||
let subscriber_count = if let Some(ref streamer) = *self.streamer.read().await {
|
||||
streamer.stats().await.subscriber_count
|
||||
let (streaming, subscriber_count) = if let Some(ref streamer) = *self.streamer.read().await
|
||||
{
|
||||
let streaming = streamer.is_running();
|
||||
let subscriber_count = streamer.stats().subscriber_count;
|
||||
(streaming, subscriber_count)
|
||||
} else {
|
||||
0
|
||||
(false, 0)
|
||||
};
|
||||
|
||||
AudioStatus {
|
||||
enabled: config.enabled,
|
||||
enabled,
|
||||
streaming,
|
||||
device: if streaming || config.enabled {
|
||||
Some(config.device.clone())
|
||||
device: if streaming || enabled {
|
||||
Some(device_str)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
quality: config.quality,
|
||||
quality,
|
||||
subscriber_count,
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to Opus frames (for WebSocket clients)
|
||||
pub fn subscribe_opus(&self) -> Option<tokio::sync::mpsc::Receiver<Arc<OpusFrame>>> {
|
||||
if let Ok(guard) = self.streamer.try_read() {
|
||||
guard.as_ref().map(|s| s.subscribe_opus())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to Opus frames (async version)
|
||||
pub async fn subscribe_opus_async(
|
||||
&self,
|
||||
) -> Option<tokio::sync::mpsc::Receiver<Arc<OpusFrame>>> {
|
||||
pub async fn subscribe_opus(&self) -> Option<tokio::sync::mpsc::Receiver<Arc<OpusFrame>>> {
|
||||
self.streamer
|
||||
.read()
|
||||
.await
|
||||
@@ -361,7 +301,6 @@ impl AudioController {
|
||||
.map(|s| s.subscribe_opus())
|
||||
}
|
||||
|
||||
/// Enable or disable audio
|
||||
pub async fn set_enabled(&self, enabled: bool) -> Result<()> {
|
||||
{
|
||||
let mut config = self.config.write().await;
|
||||
@@ -376,21 +315,15 @@ impl AudioController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update full configuration
|
||||
pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> {
|
||||
let was_streaming = self.is_streaming().await;
|
||||
|
||||
// Stop streaming if running (device/quality/enabled may all change)
|
||||
if was_streaming {
|
||||
self.stop_streaming().await?;
|
||||
}
|
||||
|
||||
// Update config
|
||||
*self.config.write().await = new_config.clone();
|
||||
|
||||
// Start whenever audio is enabled — not only when we were already streaming.
|
||||
// Otherwise PATCH /config/audio alone leaves enabled=true with no capture until
|
||||
// POST /audio/start, which races WebRTC reconnect and matches "apply twice" reports.
|
||||
if new_config.enabled {
|
||||
self.start_streaming().await?;
|
||||
}
|
||||
@@ -398,25 +331,9 @@ impl AudioController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the controller
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
self.stop_streaming().await
|
||||
}
|
||||
|
||||
/// Get the health monitor reference
|
||||
pub fn monitor(&self) -> &Arc<AudioHealthMonitor> {
|
||||
&self.monitor
|
||||
}
|
||||
|
||||
/// Get current health status
|
||||
pub async fn health_status(&self) -> AudioHealthStatus {
|
||||
self.monitor.status().await
|
||||
}
|
||||
|
||||
/// Check if the audio is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.monitor.is_healthy().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AudioController {
|
||||
@@ -438,12 +355,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_audio_quality_from_str() {
|
||||
assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice);
|
||||
assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice);
|
||||
assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced);
|
||||
assert_eq!(AudioQuality::from_str("high"), AudioQuality::High);
|
||||
assert_eq!(AudioQuality::from_str("music"), AudioQuality::High);
|
||||
assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced);
|
||||
assert_eq!(
|
||||
"voice".parse::<AudioQuality>().unwrap(),
|
||||
AudioQuality::Voice
|
||||
);
|
||||
assert_eq!(
|
||||
"balanced".parse::<AudioQuality>().unwrap(),
|
||||
AudioQuality::Balanced
|
||||
);
|
||||
assert_eq!("high".parse::<AudioQuality>().unwrap(), AudioQuality::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_audio_quality_from_str_rejects_aliases_and_unknown() {
|
||||
assert!("low".parse::<AudioQuality>().is_err());
|
||||
assert!("music".parse::<AudioQuality>().is_err());
|
||||
assert!("unknown".parse::<AudioQuality>().is_err());
|
||||
assert!("".parse::<AudioQuality>().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! Audio device enumeration using ALSA
|
||||
|
||||
use alsa::pcm::HwParams;
|
||||
use alsa::{Direction, PCM};
|
||||
use serde::Serialize;
|
||||
@@ -7,54 +5,30 @@ use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Audio device information
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct AudioDeviceInfo {
|
||||
/// Device name (e.g., "hw:0,0" or "default")
|
||||
pub name: String,
|
||||
/// Human-readable description
|
||||
pub description: String,
|
||||
/// Card index
|
||||
pub card_index: i32,
|
||||
/// Device index
|
||||
pub device_index: i32,
|
||||
/// Supported sample rates
|
||||
pub sample_rates: Vec<u32>,
|
||||
/// Supported channel counts
|
||||
pub channels: Vec<u32>,
|
||||
/// Is this a capture device
|
||||
pub is_capture: bool,
|
||||
/// Is this an HDMI audio device (likely from capture card)
|
||||
pub is_hdmi: bool,
|
||||
/// USB bus info for matching with video devices (e.g., "1-1" from USB path)
|
||||
pub usb_bus: Option<String>,
|
||||
}
|
||||
|
||||
impl AudioDeviceInfo {
|
||||
/// Get ALSA device name
|
||||
pub fn alsa_name(&self) -> String {
|
||||
format!("hw:{},{}", self.card_index, self.device_index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get USB bus info for an audio card by reading sysfs
|
||||
/// Returns the USB port path like "1-1" or "1-2.3"
|
||||
fn get_usb_bus_info(card_index: i32) -> Option<String> {
|
||||
if card_index < 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read the device symlink: /sys/class/sound/cardX/device -> ../../usb1/1-1/1-1:1.0
|
||||
let device_path = format!("/sys/class/sound/card{}/device", card_index);
|
||||
let link_target = std::fs::read_link(&device_path).ok()?;
|
||||
let link_str = link_target.to_string_lossy();
|
||||
|
||||
// Extract USB port from path like "../../usb1/1-1/1-1:1.0" or "../../1-1/1-1:1.0"
|
||||
// We want the "1-1" part (USB bus-port)
|
||||
for component in link_str.split('/') {
|
||||
// Match patterns like "1-1", "1-2", "1-1.2", "2-1.3.1"
|
||||
if component.contains('-') && !component.contains(':') {
|
||||
// Verify it looks like a USB port (starts with digit)
|
||||
if component
|
||||
.chars()
|
||||
.next()
|
||||
@@ -69,22 +43,15 @@ fn get_usb_bus_info(card_index: i32) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Enumerate available audio capture devices
|
||||
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
|
||||
enumerate_audio_devices_with_current(None)
|
||||
}
|
||||
|
||||
/// Enumerate available audio capture devices, with option to include a currently-in-use device
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `current_device` - Optional device name that is currently in use. This device will be
|
||||
/// included in the list even if it cannot be opened (because it's already open by us).
|
||||
pub fn enumerate_audio_devices_with_current(
|
||||
current_device: Option<&str>,
|
||||
) -> Result<Vec<AudioDeviceInfo>> {
|
||||
let mut devices = Vec::new();
|
||||
|
||||
// Try to enumerate cards
|
||||
let cards = alsa::card::Iter::new();
|
||||
|
||||
for card_result in cards {
|
||||
@@ -102,104 +69,71 @@ pub fn enumerate_audio_devices_with_current(
|
||||
|
||||
debug!("Found audio card {}: {}", card_index, card_longname);
|
||||
|
||||
// Check if this looks like an HDMI capture device
|
||||
let is_hdmi = card_longname.to_lowercase().contains("hdmi")
|
||||
|| card_longname.to_lowercase().contains("capture")
|
||||
|| card_longname.to_lowercase().contains("usb");
|
||||
let long_lower = card_longname.to_lowercase();
|
||||
let is_hdmi = long_lower.contains("hdmi")
|
||||
|| long_lower.contains("capture")
|
||||
|| long_lower.contains("usb");
|
||||
|
||||
// Get USB bus info for this card
|
||||
let usb_bus = get_usb_bus_info(card_index);
|
||||
|
||||
// Try to open each device on this card for capture
|
||||
for device_index in 0..8 {
|
||||
let device_name = format!("hw:{},{}", card_index, device_index);
|
||||
|
||||
// Check if this is the currently-in-use device
|
||||
let is_current_device = current_device == Some(device_name.as_str());
|
||||
|
||||
// Try to open for capture
|
||||
let mut push_info =
|
||||
|sample_rates: Vec<u32>, channels: Vec<u32>, description: String| {
|
||||
devices.push(AudioDeviceInfo {
|
||||
name: device_name.clone(),
|
||||
description,
|
||||
card_index,
|
||||
device_index,
|
||||
sample_rates,
|
||||
channels,
|
||||
is_capture: true,
|
||||
is_hdmi,
|
||||
usb_bus: usb_bus.clone(),
|
||||
});
|
||||
};
|
||||
|
||||
match PCM::new(&device_name, Direction::Capture, false) {
|
||||
Ok(pcm) => {
|
||||
// Query capabilities
|
||||
let (sample_rates, channels) = query_device_caps(&pcm);
|
||||
|
||||
if !sample_rates.is_empty() && !channels.is_empty() {
|
||||
devices.push(AudioDeviceInfo {
|
||||
name: device_name,
|
||||
description: format!("{} - Device {}", card_longname, device_index),
|
||||
card_index,
|
||||
device_index,
|
||||
push_info(
|
||||
sample_rates,
|
||||
channels,
|
||||
is_capture: true,
|
||||
is_hdmi,
|
||||
usb_bus: usb_bus.clone(),
|
||||
});
|
||||
format!("{} - Device {}", card_longname, device_index),
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Device doesn't exist or can't be opened for capture
|
||||
// But if it's the current device, include it anyway (it's busy because we're using it)
|
||||
if is_current_device {
|
||||
debug!(
|
||||
"Device {} is busy (in use by us), adding with default caps",
|
||||
device_name
|
||||
);
|
||||
devices.push(AudioDeviceInfo {
|
||||
name: device_name,
|
||||
description: format!(
|
||||
"{} - Device {} (in use)",
|
||||
card_longname, device_index
|
||||
),
|
||||
card_index,
|
||||
device_index,
|
||||
// Use common default capabilities for HDMI capture devices
|
||||
sample_rates: vec![44100, 48000],
|
||||
channels: vec![2],
|
||||
is_capture: true,
|
||||
is_hdmi,
|
||||
usb_bus: usb_bus.clone(),
|
||||
});
|
||||
push_info(
|
||||
vec![44100, 48000],
|
||||
vec![2],
|
||||
format!("{} - Device {} (in use)", card_longname, device_index),
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for "default" device
|
||||
if let Ok(pcm) = PCM::new("default", Direction::Capture, false) {
|
||||
let (sample_rates, channels) = query_device_caps(&pcm);
|
||||
if !sample_rates.is_empty() {
|
||||
devices.insert(
|
||||
0,
|
||||
AudioDeviceInfo {
|
||||
name: "default".to_string(),
|
||||
description: "Default Audio Device".to_string(),
|
||||
card_index: -1,
|
||||
device_index: -1,
|
||||
sample_rates,
|
||||
channels,
|
||||
is_capture: true,
|
||||
is_hdmi: false,
|
||||
usb_bus: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
info!("Found {} audio capture devices", devices.len());
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
/// Query device capabilities
|
||||
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
|
||||
let hwp = match HwParams::any(pcm) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return (vec![], vec![]),
|
||||
};
|
||||
|
||||
// Common sample rates to check
|
||||
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
|
||||
let mut supported_rates = Vec::new();
|
||||
|
||||
@@ -209,7 +143,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check channel counts
|
||||
let mut supported_channels = Vec::new();
|
||||
for ch in 1..=8 {
|
||||
if hwp.test_channels(ch).is_ok() {
|
||||
@@ -220,8 +153,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
|
||||
(supported_rates, supported_channels)
|
||||
}
|
||||
|
||||
/// Find the best audio device for capture
|
||||
/// Prefers HDMI/capture devices over built-in microphones
|
||||
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
|
||||
let devices = enumerate_audio_devices()?;
|
||||
|
||||
@@ -231,23 +162,24 @@ pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
|
||||
));
|
||||
}
|
||||
|
||||
// First, look for HDMI/capture card devices that support 48kHz stereo
|
||||
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
|
||||
for device in &devices {
|
||||
if device.is_hdmi && device.sample_rates.contains(&48000) && device.channels.contains(&2) {
|
||||
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
|
||||
continue;
|
||||
}
|
||||
if device.is_hdmi {
|
||||
info!("Selected HDMI audio device: {}", device.description);
|
||||
return Ok(device.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Then look for any device supporting 48kHz stereo
|
||||
for device in &devices {
|
||||
if device.sample_rates.contains(&48000) && device.channels.contains(&2) {
|
||||
info!("Selected audio device: {}", device.description);
|
||||
return Ok(device.clone());
|
||||
if first_48k_stereo.is_none() {
|
||||
first_48k_stereo = Some(device);
|
||||
}
|
||||
}
|
||||
if let Some(device) = first_48k_stereo {
|
||||
info!("Selected audio device: {}", device.description);
|
||||
return Ok(device.clone());
|
||||
}
|
||||
|
||||
// Fall back to first device
|
||||
let device = devices.into_iter().next().unwrap();
|
||||
warn!(
|
||||
"Using fallback audio device: {} (may not support optimal settings)",
|
||||
@@ -262,10 +194,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_enumerate_devices() {
|
||||
// This test may not find devices in CI environment
|
||||
let result = enumerate_audio_devices();
|
||||
println!("Audio devices: {:?}", result);
|
||||
// Just verify it doesn't panic
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,19 @@
|
||||
//! Opus audio encoder for WebRTC
|
||||
//! Opus encoder.
|
||||
|
||||
use audiopus::coder::GenericCtl;
|
||||
use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate};
|
||||
use bytes::Bytes;
|
||||
use std::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
use super::capture::AudioFrame;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Opus encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpusConfig {
|
||||
/// Sample rate (must be 8000, 12000, 16000, 24000, or 48000)
|
||||
pub sample_rate: u32,
|
||||
/// Channels (1 or 2)
|
||||
pub channels: u32,
|
||||
/// Target bitrate in bps
|
||||
pub bitrate: u32,
|
||||
/// Application mode
|
||||
pub application: OpusApplication,
|
||||
/// Enable forward error correction
|
||||
pub fec: bool,
|
||||
}
|
||||
|
||||
@@ -29,7 +22,7 @@ impl Default for OpusConfig {
|
||||
Self {
|
||||
sample_rate: 48000,
|
||||
channels: 2,
|
||||
bitrate: 64000, // 64 kbps
|
||||
bitrate: 64000,
|
||||
application: OpusApplication::Audio,
|
||||
fec: true,
|
||||
}
|
||||
@@ -37,7 +30,6 @@ impl Default for OpusConfig {
|
||||
}
|
||||
|
||||
impl OpusConfig {
|
||||
/// Create config for voice (lower latency)
|
||||
pub fn voice() -> Self {
|
||||
Self {
|
||||
application: OpusApplication::Voip,
|
||||
@@ -46,7 +38,6 @@ impl OpusConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config for music (higher quality)
|
||||
pub fn music() -> Self {
|
||||
Self {
|
||||
application: OpusApplication::Audio,
|
||||
@@ -82,30 +73,18 @@ impl OpusConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Opus application mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OpusApplication {
|
||||
/// Voice over IP
|
||||
Voip,
|
||||
/// General audio
|
||||
Audio,
|
||||
/// Low delay mode
|
||||
LowDelay,
|
||||
}
|
||||
|
||||
/// Encoded Opus frame
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpusFrame {
|
||||
/// Encoded Opus data
|
||||
pub data: Bytes,
|
||||
/// Duration in milliseconds
|
||||
pub duration_ms: u32,
|
||||
/// Sequence number
|
||||
pub sequence: u64,
|
||||
/// Timestamp
|
||||
pub timestamp: Instant,
|
||||
/// RTP timestamp (samples)
|
||||
pub rtp_timestamp: u32,
|
||||
}
|
||||
|
||||
impl OpusFrame {
|
||||
@@ -118,20 +97,14 @@ impl OpusFrame {
|
||||
}
|
||||
}
|
||||
|
||||
/// Opus encoder
|
||||
pub struct OpusEncoder {
|
||||
config: OpusConfig,
|
||||
encoder: Encoder,
|
||||
/// Output buffer
|
||||
output_buffer: Vec<u8>,
|
||||
/// Frame counter for RTP timestamp
|
||||
frame_count: u64,
|
||||
/// Samples per frame
|
||||
samples_per_frame: u32,
|
||||
}
|
||||
|
||||
impl OpusEncoder {
|
||||
/// Create a new Opus encoder
|
||||
pub fn new(config: OpusConfig) -> Result<Self> {
|
||||
let sample_rate = config.to_audiopus_sample_rate();
|
||||
let channels = config.to_audiopus_channels();
|
||||
@@ -140,7 +113,6 @@ impl OpusEncoder {
|
||||
let mut encoder = Encoder::new(sample_rate, channels, application)
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)))?;
|
||||
|
||||
// Configure encoder
|
||||
encoder
|
||||
.set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32))
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
|
||||
@@ -151,9 +123,6 @@ impl OpusEncoder {
|
||||
.map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?;
|
||||
}
|
||||
|
||||
// Calculate samples per frame (20ms at sample_rate)
|
||||
let samples_per_frame = config.sample_rate / 50;
|
||||
|
||||
info!(
|
||||
"Opus encoder created: {}Hz {}ch {}bps",
|
||||
config.sample_rate, config.channels, config.bitrate
|
||||
@@ -162,18 +131,11 @@ impl OpusEncoder {
|
||||
Ok(Self {
|
||||
config,
|
||||
encoder,
|
||||
output_buffer: vec![0u8; 4000], // Max Opus frame size
|
||||
output_buffer: vec![0u8; 4000],
|
||||
frame_count: 0,
|
||||
samples_per_frame,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn default_config() -> Result<Self> {
|
||||
Self::new(OpusConfig::default())
|
||||
}
|
||||
|
||||
/// Encode PCM audio data (S16LE interleaved)
|
||||
pub fn encode(&mut self, pcm_data: &[i16]) -> Result<OpusFrame> {
|
||||
let encoded_len = self
|
||||
.encoder
|
||||
@@ -182,7 +144,6 @@ impl OpusEncoder {
|
||||
|
||||
let samples = pcm_data.len() as u32 / self.config.channels;
|
||||
let duration_ms = (samples * 1000) / self.config.sample_rate;
|
||||
let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32;
|
||||
|
||||
self.frame_count += 1;
|
||||
|
||||
@@ -190,27 +151,18 @@ impl OpusEncoder {
|
||||
data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]),
|
||||
duration_ms,
|
||||
sequence: self.frame_count - 1,
|
||||
timestamp: Instant::now(),
|
||||
rtp_timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode from AudioFrame
|
||||
///
|
||||
/// Uses zero-copy conversion from bytes to i16 samples via bytemuck.
|
||||
pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result<OpusFrame> {
|
||||
// Zero-copy: directly cast bytes to i16 slice
|
||||
// AudioFrame.data is S16LE format, which matches native little-endian i16
|
||||
let samples: &[i16] = bytemuck::cast_slice(&frame.data);
|
||||
self.encode(samples)
|
||||
}
|
||||
|
||||
/// Get encoder configuration
|
||||
pub fn config(&self) -> &OpusConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Reset encoder state
|
||||
pub fn reset(&mut self) -> Result<()> {
|
||||
self.encoder
|
||||
.reset_state()
|
||||
@@ -219,7 +171,6 @@ impl OpusEncoder {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set bitrate dynamically
|
||||
pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
|
||||
self.encoder
|
||||
.set_bitrate(Bitrate::BitsPerSecond(bitrate as i32))
|
||||
@@ -228,15 +179,6 @@ impl OpusEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio encoder statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EncoderStats {
|
||||
pub frames_encoded: u64,
|
||||
pub bytes_output: u64,
|
||||
pub avg_frame_size: usize,
|
||||
pub current_bitrate: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -261,13 +203,12 @@ mod tests {
|
||||
let config = OpusConfig::default();
|
||||
let mut encoder = OpusEncoder::new(config).unwrap();
|
||||
|
||||
// 20ms of stereo silence at 48kHz
|
||||
let silence = vec![0i16; 960 * 2];
|
||||
let result = encoder.encode(&silence);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let frame = result.unwrap();
|
||||
assert!(!frame.is_empty());
|
||||
assert!(frame.len() < silence.len() * 2); // Should be compressed
|
||||
assert!(frame.len() < silence.len() * 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
//! Audio capture and encoding module
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - ALSA audio capture
|
||||
//! - Opus encoding for WebRTC
|
||||
//! - Audio device enumeration
|
||||
//! - Audio streaming pipeline
|
||||
//! - High-level audio controller
|
||||
//! - Device health monitoring
|
||||
//! ALSA capture, Opus encode, device enumeration, streaming, controller, health monitor.
|
||||
|
||||
pub mod capture;
|
||||
pub mod controller;
|
||||
@@ -19,5 +11,5 @@ pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
|
||||
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
|
||||
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
|
||||
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
|
||||
pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig};
|
||||
pub use monitor::{AudioHealthMonitor, AudioHealthStatus};
|
||||
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
|
||||
|
||||
@@ -1,114 +1,58 @@
|
||||
//! Audio device health monitoring
|
||||
//!
|
||||
//! This module provides health monitoring for audio capture devices, including:
|
||||
//! - Device connectivity checks
|
||||
//! - Automatic reconnection on failure
|
||||
//! - Error tracking
|
||||
//! - Log throttling to prevent log flooding
|
||||
//! Audio device health and logging throttle for repeated failures.
|
||||
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::time::Duration;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// Audio health status
|
||||
const LOG_THROTTLE_SECS: u64 = 5;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub enum AudioHealthStatus {
|
||||
/// Device is healthy and operational
|
||||
#[default]
|
||||
Healthy,
|
||||
/// Device has an error, attempting recovery
|
||||
Error {
|
||||
/// Human-readable error reason
|
||||
reason: String,
|
||||
/// Error code for programmatic handling
|
||||
error_code: String,
|
||||
/// Number of recovery attempts made
|
||||
retry_count: u32,
|
||||
},
|
||||
/// Device is disconnected or not available
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
/// Audio health monitor configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioMonitorConfig {
|
||||
/// Retry interval when device is lost (milliseconds)
|
||||
pub retry_interval_ms: u64,
|
||||
/// Maximum retry attempts before giving up (0 = infinite)
|
||||
pub max_retries: u32,
|
||||
/// Log throttle interval in seconds
|
||||
pub log_throttle_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for AudioMonitorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
retry_interval_ms: 1000,
|
||||
max_retries: 0, // infinite retry
|
||||
log_throttle_secs: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio health monitor
|
||||
///
|
||||
/// Monitors audio device health and manages error recovery.
|
||||
pub struct AudioHealthMonitor {
|
||||
/// Current health status
|
||||
status: RwLock<AudioHealthStatus>,
|
||||
/// Log throttler to prevent log flooding
|
||||
throttler: LogThrottler,
|
||||
/// Configuration
|
||||
config: AudioMonitorConfig,
|
||||
/// Current retry count
|
||||
retry_count: AtomicU32,
|
||||
/// Last error code (for change detection)
|
||||
last_error_code: RwLock<Option<String>>,
|
||||
/// Hide `error_message` while a new capture attempt is in flight (internal error state unchanged).
|
||||
suppress_display: AtomicBool,
|
||||
}
|
||||
|
||||
impl AudioHealthMonitor {
|
||||
/// Create a new audio health monitor with the specified configuration
|
||||
pub fn new(config: AudioMonitorConfig) -> Self {
|
||||
let throttle_secs = config.log_throttle_secs;
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
status: RwLock::new(AudioHealthStatus::Healthy),
|
||||
throttler: LogThrottler::with_secs(throttle_secs),
|
||||
config,
|
||||
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
|
||||
retry_count: AtomicU32::new(0),
|
||||
last_error_code: RwLock::new(None),
|
||||
suppress_display: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new audio health monitor with default configuration
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(AudioMonitorConfig::default())
|
||||
/// Clears the error string exposed via [`Self::error_message`] until the next outcome (`report_error` or recovery).
|
||||
pub fn prepare_retry_attempt(&self) {
|
||||
self.suppress_display.store(true, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report an error from audio operations
|
||||
///
|
||||
/// This method is called when an audio operation fails. It:
|
||||
/// 1. Updates the health status
|
||||
/// 2. Logs the error (with throttling)
|
||||
/// 3. Updates in-memory error state
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The audio device name (if known)
|
||||
/// * `reason` - Human-readable error description
|
||||
/// * `error_code` - Error code for programmatic handling
|
||||
pub async fn report_error(&self, _device: Option<&str>, reason: &str, error_code: &str) {
|
||||
pub async fn report_error(&self, reason: &str, error_code: &str) {
|
||||
self.suppress_display.store(false, Ordering::Relaxed);
|
||||
|
||||
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if error code changed
|
||||
let error_changed = {
|
||||
let last = self.last_error_code.read().await;
|
||||
last.as_ref().map(|s| s.as_str()) != Some(error_code)
|
||||
};
|
||||
|
||||
// Log with throttling (always log if error type changed)
|
||||
let throttle_key = format!("audio_{}", error_code);
|
||||
if error_changed || self.throttler.should_log(&throttle_key) {
|
||||
warn!(
|
||||
@@ -117,34 +61,22 @@ impl AudioHealthMonitor {
|
||||
);
|
||||
}
|
||||
|
||||
// Update last error code
|
||||
*self.last_error_code.write().await = Some(error_code.to_string());
|
||||
|
||||
// Update status
|
||||
*self.status.write().await = AudioHealthStatus::Error {
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
retry_count: count,
|
||||
};
|
||||
}
|
||||
|
||||
/// Report that the device has recovered
|
||||
///
|
||||
/// This method is called when the audio device successfully reconnects.
|
||||
/// It resets the error state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - The audio device name
|
||||
pub async fn report_recovered(&self, _device: Option<&str>) {
|
||||
pub async fn report_recovered(&self) {
|
||||
let prev_status = self.status.read().await.clone();
|
||||
|
||||
// Only report recovery if we were in an error state
|
||||
if prev_status != AudioHealthStatus::Healthy {
|
||||
let retry_count = self.retry_count.load(Ordering::Relaxed);
|
||||
info!("Audio recovered after {} retries", retry_count);
|
||||
|
||||
// Reset state
|
||||
self.suppress_display.store(false, Ordering::Relaxed);
|
||||
self.retry_count.store(0, Ordering::Relaxed);
|
||||
self.throttler.clear("audio_");
|
||||
*self.last_error_code.write().await = None;
|
||||
@@ -152,58 +84,30 @@ impl AudioHealthMonitor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current health status
|
||||
pub async fn status(&self) -> AudioHealthStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get the current retry count
|
||||
pub fn retry_count(&self) -> u32 {
|
||||
self.retry_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check if the monitor is in an error state
|
||||
pub async fn is_error(&self) -> bool {
|
||||
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
|
||||
}
|
||||
|
||||
/// Check if the monitor is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
matches!(*self.status.read().await, AudioHealthStatus::Healthy)
|
||||
}
|
||||
|
||||
/// Reset the monitor to healthy state without publishing events
|
||||
///
|
||||
/// This is useful during initialization.
|
||||
pub async fn reset(&self) {
|
||||
self.suppress_display.store(false, Ordering::Relaxed);
|
||||
self.retry_count.store(0, Ordering::Relaxed);
|
||||
*self.last_error_code.write().await = None;
|
||||
*self.status.write().await = AudioHealthStatus::Healthy;
|
||||
self.throttler.clear_all();
|
||||
}
|
||||
|
||||
/// Get the configuration
|
||||
pub fn config(&self) -> &AudioMonitorConfig {
|
||||
&self.config
|
||||
pub async fn status(&self) -> AudioHealthStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
/// Check if we should continue retrying
|
||||
///
|
||||
/// Returns `false` if max_retries is set and we've exceeded it.
|
||||
pub fn should_retry(&self) -> bool {
|
||||
if self.config.max_retries == 0 {
|
||||
return true; // Infinite retry
|
||||
}
|
||||
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
|
||||
pub fn retry_count(&self) -> u32 {
|
||||
self.retry_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get the retry interval
|
||||
pub fn retry_interval(&self) -> Duration {
|
||||
Duration::from_millis(self.config.retry_interval_ms)
|
||||
pub async fn is_error(&self) -> bool {
|
||||
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
|
||||
}
|
||||
|
||||
/// Get the current error message if in error state
|
||||
pub async fn error_message(&self) -> Option<String> {
|
||||
if self.suppress_display.load(Ordering::Relaxed) {
|
||||
return None;
|
||||
}
|
||||
match &*self.status.read().await {
|
||||
AudioHealthStatus::Error { reason, .. } => Some(reason.clone()),
|
||||
_ => None,
|
||||
@@ -213,7 +117,7 @@ impl AudioHealthMonitor {
|
||||
|
||||
impl Default for AudioHealthMonitor {
|
||||
fn default() -> Self {
|
||||
Self::with_defaults()
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,32 +127,25 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initial_status() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
assert!(monitor.is_healthy().await);
|
||||
let monitor = AudioHealthMonitor::new();
|
||||
assert!(!monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_error() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
let monitor = AudioHealthMonitor::new();
|
||||
|
||||
monitor
|
||||
.report_error(Some("hw:0,0"), "Device not found", "device_disconnected")
|
||||
.report_error("Device not found", "device_disconnected")
|
||||
.await;
|
||||
|
||||
assert!(monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 1);
|
||||
|
||||
if let AudioHealthStatus::Error {
|
||||
reason,
|
||||
error_code,
|
||||
retry_count,
|
||||
} = monitor.status().await
|
||||
{
|
||||
if let AudioHealthStatus::Error { reason, error_code } = monitor.status().await {
|
||||
assert_eq!(reason, "Device not found");
|
||||
assert_eq!(error_code, "device_disconnected");
|
||||
assert_eq!(retry_count, 1);
|
||||
} else {
|
||||
panic!("Expected Error status");
|
||||
}
|
||||
@@ -256,39 +153,52 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_report_recovered() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
let monitor = AudioHealthMonitor::new();
|
||||
|
||||
// First report an error
|
||||
monitor
|
||||
.report_error(Some("default"), "Capture failed", "capture_error")
|
||||
.report_error("Capture failed", "capture_error")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
// Then report recovery
|
||||
monitor.report_recovered(Some("default")).await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
monitor.report_recovered().await;
|
||||
assert!(!monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_count_increments() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
let monitor = AudioHealthMonitor::new();
|
||||
|
||||
for i in 1..=5 {
|
||||
monitor.report_error(None, "Error", "io_error").await;
|
||||
monitor.report_error("Error", "io_error").await;
|
||||
assert_eq!(monitor.retry_count(), i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reset() {
|
||||
let monitor = AudioHealthMonitor::with_defaults();
|
||||
let monitor = AudioHealthMonitor::new();
|
||||
|
||||
monitor.report_error(None, "Error", "io_error").await;
|
||||
monitor.report_error("Error", "io_error").await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
monitor.reset().await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert!(!monitor.is_error().await);
|
||||
assert_eq!(monitor.retry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prepare_retry_hides_error_until_next_failure() {
|
||||
let monitor = AudioHealthMonitor::new();
|
||||
|
||||
monitor.report_error("bad", "e").await;
|
||||
assert_eq!(monitor.error_message().await.as_deref(), Some("bad"));
|
||||
|
||||
monitor.prepare_retry_attempt();
|
||||
assert!(monitor.is_error().await);
|
||||
assert!(monitor.error_message().await.is_none());
|
||||
|
||||
monitor.report_error("still bad", "e").await;
|
||||
assert_eq!(monitor.error_message().await.as_deref(), Some("still bad"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
//! Audio streaming pipeline
|
||||
//!
|
||||
//! ALSA capture (48 kHz stereo only) → fixed Opus 20 ms frames → `mpsc` fan-out per subscriber.
|
||||
//! ALSA 48 kHz stereo → Opus 20 ms frames, fan-out per subscriber.
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{broadcast, mpsc, watch, Mutex as AsyncMutex, RwLock};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
@@ -14,34 +11,25 @@ use crate::error::{AppError, Result};
|
||||
use bytemuck;
|
||||
use bytes::Bytes;
|
||||
|
||||
/// Stereo 48 kHz: 20 ms = 960 frames × 2 channels (S16LE).
|
||||
/// 48 kHz stereo: 20 ms = 960 × 2 samples (S16LE).
|
||||
const OPUS_STEREO_SAMPLES: usize = 960 * 2;
|
||||
|
||||
/// Audio stream state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum AudioStreamState {
|
||||
/// Stream is stopped
|
||||
#[default]
|
||||
Stopped,
|
||||
/// Stream is starting up
|
||||
Starting,
|
||||
/// Stream is running
|
||||
Running,
|
||||
/// Stream encountered an error
|
||||
Error,
|
||||
}
|
||||
|
||||
/// Audio streamer configuration
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AudioStreamerConfig {
|
||||
/// Audio capture configuration
|
||||
pub capture: AudioConfig,
|
||||
/// Opus encoder configuration
|
||||
pub opus: OpusConfig,
|
||||
}
|
||||
|
||||
impl AudioStreamerConfig {
|
||||
/// Create config for a specific device with default quality
|
||||
pub fn for_device(device_name: &str) -> Self {
|
||||
Self {
|
||||
capture: AudioConfig {
|
||||
@@ -52,45 +40,32 @@ impl AudioStreamerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create config with specified bitrate
|
||||
pub fn with_bitrate(mut self, bitrate: u32) -> Self {
|
||||
self.opus.bitrate = bitrate;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio stream statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AudioStreamStats {
|
||||
/// Frames encoded to Opus
|
||||
/// Number of active subscribers
|
||||
pub subscriber_count: usize,
|
||||
}
|
||||
|
||||
/// Audio streamer
|
||||
///
|
||||
/// Manages the audio capture → encode → mpsc fan-out pipeline.
|
||||
pub struct AudioStreamer {
|
||||
config: RwLock<AudioStreamerConfig>,
|
||||
state: watch::Sender<AudioStreamState>,
|
||||
state_rx: watch::Receiver<AudioStreamState>,
|
||||
capturer: RwLock<Option<Arc<AudioCapturer>>>,
|
||||
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
|
||||
/// One `mpsc::Sender` per subscriber (like shared video pipeline).
|
||||
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
|
||||
stats: Arc<AsyncMutex<AudioStreamStats>>,
|
||||
sequence: AtomicU64,
|
||||
stream_start_time: RwLock<Option<Instant>>,
|
||||
stop_flag: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl AudioStreamer {
|
||||
/// Create a new audio streamer with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(AudioStreamerConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new audio streamer with specified configuration
|
||||
pub fn with_config(config: AudioStreamerConfig) -> Self {
|
||||
let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped);
|
||||
|
||||
@@ -101,31 +76,24 @@ impl AudioStreamer {
|
||||
capturer: RwLock::new(None),
|
||||
encoder: Arc::new(AsyncMutex::new(None)),
|
||||
opus_subscribers: Arc::new(Mutex::new(Vec::new())),
|
||||
stats: Arc::new(AsyncMutex::new(AudioStreamStats::default())),
|
||||
sequence: AtomicU64::new(0),
|
||||
stream_start_time: RwLock::new(None),
|
||||
stop_flag: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub fn state(&self) -> AudioStreamState {
|
||||
*self.state_rx.borrow()
|
||||
}
|
||||
|
||||
/// Subscribe to state changes
|
||||
pub fn state_watch(&self) -> watch::Receiver<AudioStreamState> {
|
||||
self.state_rx.clone()
|
||||
}
|
||||
|
||||
/// Subscribe to Opus frames (each packet is one encoded 20 ms frame).
|
||||
pub fn subscribe_opus(&self) -> mpsc::Receiver<Arc<OpusFrame>> {
|
||||
let (tx, rx) = mpsc::channel::<Arc<OpusFrame>>(128);
|
||||
self.opus_subscribers.lock().unwrap().push(tx);
|
||||
rx
|
||||
}
|
||||
|
||||
/// Get number of active subscribers
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.opus_subscribers
|
||||
.lock()
|
||||
@@ -135,14 +103,12 @@ impl AudioStreamer {
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Get current statistics
|
||||
pub async fn stats(&self) -> AudioStreamStats {
|
||||
let mut stats = self.stats.lock().await.clone();
|
||||
stats.subscriber_count = self.subscriber_count();
|
||||
stats
|
||||
pub fn stats(&self) -> AudioStreamStats {
|
||||
AudioStreamStats {
|
||||
subscriber_count: self.subscriber_count(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update configuration (only when stopped)
|
||||
pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> {
|
||||
if self.state() != AudioStreamState::Stopped {
|
||||
return Err(AppError::AudioError(
|
||||
@@ -153,12 +119,9 @@ impl AudioStreamer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update bitrate dynamically (can be done while streaming)
|
||||
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
|
||||
// Update config
|
||||
self.config.write().await.opus.bitrate = bitrate;
|
||||
|
||||
// Update encoder if running
|
||||
if let Some(ref mut encoder) = *self.encoder.lock().await {
|
||||
encoder.set_bitrate(bitrate)?;
|
||||
}
|
||||
@@ -167,7 +130,6 @@ impl AudioStreamer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start the audio stream
|
||||
pub async fn start(&self) -> Result<()> {
|
||||
if self.state() == AudioStreamState::Running {
|
||||
return Ok(());
|
||||
@@ -186,28 +148,14 @@ impl AudioStreamer {
|
||||
config.opus.bitrate
|
||||
);
|
||||
|
||||
// Create capturer
|
||||
let capturer = Arc::new(AudioCapturer::new(config.capture.clone()));
|
||||
*self.capturer.write().await = Some(capturer.clone());
|
||||
|
||||
// Create encoder
|
||||
let encoder = OpusEncoder::new(config.opus.clone())?;
|
||||
*self.encoder.lock().await = Some(encoder);
|
||||
|
||||
// Start capture
|
||||
capturer.start().await?;
|
||||
|
||||
// Reset stats
|
||||
{
|
||||
let mut stats = self.stats.lock().await;
|
||||
*stats = AudioStreamStats::default();
|
||||
}
|
||||
|
||||
// Record start time
|
||||
*self.stream_start_time.write().await = Some(Instant::now());
|
||||
self.sequence.store(0, Ordering::SeqCst);
|
||||
|
||||
// Start encoding task
|
||||
let capturer_for_task = capturer.clone();
|
||||
let encoder = self.encoder.clone();
|
||||
let opus_subscribers = self.opus_subscribers.clone();
|
||||
@@ -215,14 +163,19 @@ impl AudioStreamer {
|
||||
let stop_flag = self.stop_flag.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
Self::stream_task(capturer_for_task, encoder, opus_subscribers, state, stop_flag)
|
||||
.await;
|
||||
Self::stream_task(
|
||||
capturer_for_task,
|
||||
encoder,
|
||||
opus_subscribers,
|
||||
state,
|
||||
stop_flag,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the audio stream
|
||||
pub async fn stop(&self) -> Result<()> {
|
||||
if self.state() == AudioStreamState::Stopped {
|
||||
return Ok(());
|
||||
@@ -230,18 +183,14 @@ impl AudioStreamer {
|
||||
|
||||
info!("Stopping audio stream");
|
||||
|
||||
// Signal stop
|
||||
self.stop_flag.store(true, Ordering::SeqCst);
|
||||
|
||||
// Stop capturer
|
||||
if let Some(ref capturer) = *self.capturer.read().await {
|
||||
capturer.stop().await?;
|
||||
}
|
||||
|
||||
// Clear resources — drop Opus senders so mpsc receivers see end-of-stream
|
||||
*self.capturer.write().await = None;
|
||||
*self.encoder.lock().await = None;
|
||||
*self.stream_start_time.write().await = None;
|
||||
self.opus_subscribers.lock().unwrap().clear();
|
||||
|
||||
let _ = self.state.send(AudioStreamState::Stopped);
|
||||
@@ -249,7 +198,6 @@ impl AudioStreamer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if streaming
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.state() == AudioStreamState::Running
|
||||
}
|
||||
|
||||
@@ -8,20 +8,16 @@ use axum::{
|
||||
use axum_extra::extract::CookieJar;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::ErrorResponse;
|
||||
use crate::state::AppState;
|
||||
use crate::web::ErrorResponse;
|
||||
|
||||
/// Session cookie name
|
||||
pub const SESSION_COOKIE: &str = "one_kvm_session";
|
||||
|
||||
/// Extract session ID from request
|
||||
pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option<String> {
|
||||
// First try cookie
|
||||
if let Some(cookie) = cookies.get(SESSION_COOKIE) {
|
||||
return Some(cookie.value().to_string());
|
||||
}
|
||||
|
||||
// Then try Authorization header (Bearer token)
|
||||
if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
|
||||
if let Ok(auth_str) = auth_header.to_str() {
|
||||
if let Some(token) = auth_str.strip_prefix("Bearer ") {
|
||||
@@ -33,7 +29,6 @@ pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap)
|
||||
None
|
||||
}
|
||||
|
||||
/// Authentication middleware
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<Arc<AppState>>,
|
||||
cookies: CookieJar,
|
||||
@@ -41,29 +36,23 @@ pub async fn auth_middleware(
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let raw_path = request.uri().path();
|
||||
// When this middleware is mounted under /api, Axum strips the prefix for the inner router.
|
||||
// Normalize the path so checks work whether it is mounted or not.
|
||||
// Mounted under /api: inner path may lack prefix; normalize for whitelist checks.
|
||||
let path = raw_path.strip_prefix("/api").unwrap_or(raw_path);
|
||||
|
||||
// Check if system is initialized
|
||||
if !state.config.is_initialized() {
|
||||
// Allow only setup-related endpoints when not initialized
|
||||
if is_setup_public_endpoint(path) {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
}
|
||||
|
||||
// Public endpoints that don't require auth
|
||||
if is_public_endpoint(path) {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
|
||||
// Extract session ID
|
||||
let session_id = extract_session_id(&cookies, request.headers());
|
||||
|
||||
if let Some(session_id) = session_id {
|
||||
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
|
||||
// Add session to request extensions
|
||||
request.extensions_mut().insert(session);
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
@@ -87,9 +76,7 @@ fn unauthorized_response(message: &str) -> Response {
|
||||
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
|
||||
}
|
||||
|
||||
/// Check if endpoint is public (no auth required)
|
||||
fn is_public_endpoint(path: &str) -> bool {
|
||||
// Note: paths here are relative to /api since middleware is applied within the nested router
|
||||
matches!(
|
||||
path,
|
||||
"/" | "/auth/login" | "/health" | "/setup" | "/setup/init"
|
||||
@@ -102,7 +89,6 @@ fn is_public_endpoint(path: &str) -> bool {
|
||||
|| path.ends_with(".svg")
|
||||
}
|
||||
|
||||
/// Setup-only endpoints allowed before initialization.
|
||||
fn is_setup_public_endpoint(path: &str) -> bool {
|
||||
matches!(
|
||||
path,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod middleware;
|
||||
mod password;
|
||||
mod rfc3339;
|
||||
mod session;
|
||||
mod user;
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ use argon2::{
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Hash a password using Argon2
|
||||
pub fn hash_password(password: &str) -> Result<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
@@ -16,7 +15,6 @@ pub fn hash_password(password: &str) -> Result<String> {
|
||||
.map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Verify a password against a hash
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
|
||||
let parsed_hash = PasswordHash::new(hash)
|
||||
.map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?;
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
//! RFC3339 strings in SQLite; structs use `time::serde::rfc3339`.
|
||||
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
/// Parse DB text; bad input → `now_utc()`.
|
||||
pub fn parse(s: &str) -> OffsetDateTime {
|
||||
OffsetDateTime::parse(s, &Rfc3339).unwrap_or_else(|_| OffsetDateTime::now_utc())
|
||||
}
|
||||
|
||||
pub fn format(dt: OffsetDateTime) -> String {
|
||||
dt.format(&Rfc3339).expect("RFC3339 format")
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Sqlite};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use time::{Duration, OffsetDateTime};
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rfc3339;
|
||||
use crate::error::Result;
|
||||
|
||||
/// Session data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub id: String,
|
||||
@@ -19,29 +19,25 @@ pub struct Session {
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Check if session is expired
|
||||
pub fn is_expired(&self) -> bool {
|
||||
OffsetDateTime::now_utc() > self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Session store backed by SQLite
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStore {
|
||||
pool: Pool<Sqlite>,
|
||||
inner: Arc<RwLock<HashMap<String, Session>>>,
|
||||
default_ttl: Duration,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Create a new session store
|
||||
pub fn new(pool: Pool<Sqlite>, ttl_secs: i64) -> Self {
|
||||
pub fn new(ttl_secs: i64) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
inner: Arc::new(RwLock::new(HashMap::new())),
|
||||
default_ttl: Duration::seconds(ttl_secs),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session
|
||||
pub async fn create(&self, user_id: &str) -> Result<Session> {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let session = Session {
|
||||
@@ -52,105 +48,57 @@ impl SessionStore {
|
||||
data: None,
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO sessions (id, user_id, created_at, expires_at, data)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
"#,
|
||||
)
|
||||
.bind(&session.id)
|
||||
.bind(&session.user_id)
|
||||
.bind(rfc3339::format(session.created_at))
|
||||
.bind(rfc3339::format(session.expires_at))
|
||||
.bind(session.data.as_ref().map(|d| d.to_string()))
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
let mut guard = self.inner.write().await;
|
||||
guard.insert(session.id.clone(), session.clone());
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Get a session by ID
|
||||
pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
|
||||
let row: Option<(String, String, String, String, Option<String>)> = sqlx::query_as(
|
||||
"SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1",
|
||||
)
|
||||
.bind(session_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
match row {
|
||||
Some((id, user_id, created_at, expires_at, data)) => {
|
||||
let session = Session {
|
||||
id,
|
||||
user_id,
|
||||
created_at: rfc3339::parse(&created_at),
|
||||
expires_at: rfc3339::parse(&expires_at),
|
||||
data: data.and_then(|d| serde_json::from_str(&d).ok()),
|
||||
};
|
||||
|
||||
if session.is_expired() {
|
||||
self.delete(&session.id).await?;
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(session))
|
||||
}
|
||||
}
|
||||
None => Ok(None),
|
||||
let mut guard = self.inner.write().await;
|
||||
let Some(session) = guard.get(session_id).cloned() else {
|
||||
return Ok(None);
|
||||
};
|
||||
if session.is_expired() {
|
||||
guard.remove(session_id);
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(session))
|
||||
}
|
||||
|
||||
/// Delete a session
|
||||
pub async fn delete(&self, session_id: &str) -> Result<()> {
|
||||
sqlx::query("DELETE FROM sessions WHERE id = ?1")
|
||||
.bind(session_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
let mut guard = self.inner.write().await;
|
||||
guard.remove(session_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all expired sessions
|
||||
pub async fn cleanup_expired(&self) -> Result<u64> {
|
||||
let now = rfc3339::format(OffsetDateTime::now_utc());
|
||||
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?1")
|
||||
.bind(now)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
let mut guard = self.inner.write().await;
|
||||
let before = guard.len();
|
||||
guard.retain(|_, s| !s.is_expired());
|
||||
Ok((before - guard.len()) as u64)
|
||||
}
|
||||
|
||||
/// Delete all sessions
|
||||
pub async fn delete_all(&self) -> Result<u64> {
|
||||
let result = sqlx::query("DELETE FROM sessions")
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
let mut guard = self.inner.write().await;
|
||||
let n = guard.len() as u64;
|
||||
guard.clear();
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
/// Delete all sessions for a specific user
|
||||
pub async fn delete_by_user_id(&self, user_id: &str) -> Result<u64> {
|
||||
let result = sqlx::query("DELETE FROM sessions WHERE user_id = ?1")
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// List all session IDs
|
||||
pub async fn list_ids(&self) -> Result<Vec<String>> {
|
||||
let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions")
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
Ok(rows.into_iter().map(|(id,)| id).collect())
|
||||
let guard = self.inner.read().await;
|
||||
Ok(guard.keys().cloned().collect())
|
||||
}
|
||||
|
||||
/// Extend session expiration
|
||||
pub async fn extend(&self, session_id: &str) -> Result<()> {
|
||||
let new_expires = OffsetDateTime::now_utc() + self.default_ttl;
|
||||
sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2")
|
||||
.bind(rfc3339::format(new_expires))
|
||||
.bind(session_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
let mut guard = self.inner.write().await;
|
||||
if let Some(session) = guard.get_mut(session_id) {
|
||||
if session.is_expired() {
|
||||
guard.remove(session_id);
|
||||
} else {
|
||||
session.expires_at = OffsetDateTime::now_utc() + self.default_ttl;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
155
src/auth/user.rs
155
src/auth/user.rs
@@ -1,122 +1,99 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Sqlite};
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
use time::OffsetDateTime;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::password::{hash_password, verify_password};
|
||||
use super::rfc3339;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// User row type from database
|
||||
type UserRow = (String, String, String, String, String);
|
||||
type UserRow = (String, String, String);
|
||||
|
||||
/// User data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub id: String,
|
||||
pub username: String,
|
||||
#[serde(skip_serializing)]
|
||||
pub password_hash: String,
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub created_at: OffsetDateTime,
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub updated_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
impl User {
|
||||
/// Convert from database row to User
|
||||
fn from_row(row: UserRow) -> Self {
|
||||
let (id, username, password_hash, created_at, updated_at) = row;
|
||||
let (id, username, password_hash) = row;
|
||||
Self {
|
||||
id,
|
||||
username,
|
||||
password_hash,
|
||||
created_at: rfc3339::parse(&created_at),
|
||||
updated_at: rfc3339::parse(&updated_at),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// User store backed by SQLite
|
||||
#[derive(Clone)]
|
||||
pub struct UserStore {
|
||||
pool: Pool<Sqlite>,
|
||||
}
|
||||
|
||||
impl UserStore {
|
||||
/// Create a new user store
|
||||
pub fn new(pool: Pool<Sqlite>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
/// Create a new user
|
||||
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
|
||||
// Check if username already exists
|
||||
if self.get_by_username(username).await?.is_some() {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"Username '{}' already exists",
|
||||
username
|
||||
)));
|
||||
/// The single local user, or `None` if none exists. Errors if more than one row is present.
|
||||
pub async fn single_user(&self) -> Result<Option<User>> {
|
||||
let mut rows: Vec<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash FROM users ORDER BY rowid ASC LIMIT 2",
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
match rows.len() {
|
||||
0 => Ok(None),
|
||||
1 => Ok(Some(User::from_row(rows.remove(0)))),
|
||||
_ => Err(AppError::Internal(
|
||||
"Multiple user accounts in database; this build supports only one".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_first_user(&self, username: &str, password: &str) -> Result<User> {
|
||||
if self.single_user().await?.is_some() {
|
||||
return Err(AppError::BadRequest(
|
||||
"A user account already exists".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let password_hash = hash_password(password)?;
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let user = User {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
username: username.to_string(),
|
||||
password_hash,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO users (id, username, password_hash, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
INSERT INTO users (id, username, password_hash)
|
||||
VALUES (?1, ?2, ?3)
|
||||
"#,
|
||||
)
|
||||
.bind(&user.id)
|
||||
.bind(&user.username)
|
||||
.bind(&user.password_hash)
|
||||
.bind(rfc3339::format(user.created_at))
|
||||
.bind(rfc3339::format(user.updated_at))
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// Get user by ID
|
||||
pub async fn get(&self, user_id: &str) -> Result<Option<User>> {
|
||||
let row: Option<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE id = ?1",
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(User::from_row))
|
||||
}
|
||||
|
||||
/// Get user by username
|
||||
pub async fn get_by_username(&self, username: &str) -> Result<Option<User>> {
|
||||
let row: Option<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE username = ?1",
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(row.map(User::from_row))
|
||||
}
|
||||
|
||||
/// Verify user credentials
|
||||
pub async fn verify(&self, username: &str, password: &str) -> Result<Option<User>> {
|
||||
let user = match self.get_by_username(username).await? {
|
||||
Some(user) => user,
|
||||
let user = match self.single_user().await? {
|
||||
Some(u) => u,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
if user.username != username {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if verify_password(password, &user.password_hash)? {
|
||||
Ok(Some(user))
|
||||
} else {
|
||||
@@ -124,15 +101,23 @@ impl UserStore {
|
||||
}
|
||||
}
|
||||
|
||||
/// Update user password
|
||||
pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> {
|
||||
let user = self
|
||||
.single_user()
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
|
||||
|
||||
if user.id != user_id {
|
||||
return Err(AppError::AuthError("Invalid session".to_string()));
|
||||
}
|
||||
|
||||
let password_hash = hash_password(new_password)?;
|
||||
let now = OffsetDateTime::now_utc();
|
||||
|
||||
let result =
|
||||
sqlx::query("UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
|
||||
.bind(&password_hash)
|
||||
.bind(rfc3339::format(now))
|
||||
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
@@ -144,21 +129,24 @@ impl UserStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update username
|
||||
pub async fn update_username(&self, user_id: &str, new_username: &str) -> Result<()> {
|
||||
if let Some(existing) = self.get_by_username(new_username).await? {
|
||||
if existing.id != user_id {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
"Username '{}' already exists",
|
||||
new_username
|
||||
)));
|
||||
}
|
||||
let user = self
|
||||
.single_user()
|
||||
.await?
|
||||
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
|
||||
|
||||
if user.id != user_id {
|
||||
return Err(AppError::AuthError("Invalid session".to_string()));
|
||||
}
|
||||
|
||||
if new_username == user.username {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let result = sqlx::query("UPDATE users SET username = ?1, updated_at = ?2 WHERE id = ?3")
|
||||
.bind(new_username)
|
||||
.bind(rfc3339::format(now))
|
||||
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
@@ -169,37 +157,4 @@ impl UserStore {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all users
|
||||
pub async fn list(&self) -> Result<Vec<User>> {
|
||||
let rows: Vec<UserRow> = sqlx::query_as(
|
||||
"SELECT id, username, password_hash, created_at, updated_at FROM users ORDER BY created_at",
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(User::from_row).collect())
|
||||
}
|
||||
|
||||
/// Delete user by ID
|
||||
pub async fn delete(&self, user_id: &str) -> Result<()> {
|
||||
let result = sqlx::query("DELETE FROM users WHERE id = ?1")
|
||||
.bind(user_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(AppError::NotFound("User not found".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if any users exist
|
||||
pub async fn has_users(&self) -> Result<bool> {
|
||||
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
Ok(count.0 > 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
mod persistence;
|
||||
mod schema;
|
||||
mod store;
|
||||
|
||||
pub use persistence::ConfigChange;
|
||||
pub use schema::*;
|
||||
pub use store::ConfigStore;
|
||||
|
||||
5
src/config/persistence.rs
Normal file
5
src/config/persistence.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
/// Configuration change event
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfigChange {
|
||||
pub key: String,
|
||||
}
|
||||
@@ -1,12 +1,85 @@
|
||||
use crate::video::encoder::BitratePreset;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
// Re-export ExtensionsConfig from extensions module
|
||||
// Re-export domain config types that are embedded in AppConfig.
|
||||
// These are simple data types defined in their respective modules;
|
||||
// keeping the re-export here is acceptable since they flow inward.
|
||||
pub use crate::extensions::ExtensionsConfig;
|
||||
// Re-export RustDeskConfig from rustdesk module
|
||||
pub use crate::rustdesk::config::RustDeskConfig;
|
||||
|
||||
/// Bitrate preset for video encoding
|
||||
///
|
||||
/// Simplifies bitrate configuration by providing three intuitive presets
|
||||
/// plus a custom option for advanced users.
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
#[derive(Default)]
|
||||
pub enum BitratePreset {
|
||||
/// Speed priority: 1 Mbps, lowest latency, smaller GOP
|
||||
Speed,
|
||||
/// Balanced: 4 Mbps, good quality/latency tradeoff
|
||||
#[default]
|
||||
Balanced,
|
||||
/// Quality priority: 8 Mbps, best visual quality
|
||||
Quality,
|
||||
/// Custom bitrate in kbps (for advanced users)
|
||||
Custom(u32),
|
||||
}
|
||||
|
||||
impl BitratePreset {
|
||||
/// Get bitrate value in kbps
|
||||
pub fn bitrate_kbps(&self) -> u32 {
|
||||
match self {
|
||||
Self::Speed => 1000,
|
||||
Self::Balanced => 4000,
|
||||
Self::Quality => 8000,
|
||||
Self::Custom(kbps) => *kbps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recommended GOP size based on preset
|
||||
pub fn gop_size(&self, fps: u32) -> u32 {
|
||||
match self {
|
||||
Self::Speed => (fps / 2).max(15),
|
||||
Self::Balanced => fps,
|
||||
Self::Quality => fps * 2,
|
||||
Self::Custom(_) => fps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get quality preset name for encoder configuration
|
||||
pub fn quality_level(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Speed => "low",
|
||||
Self::Balanced => "medium",
|
||||
Self::Quality => "high",
|
||||
Self::Custom(_) => "medium",
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from kbps value, mapping to nearest preset or Custom
|
||||
pub fn from_kbps(kbps: u32) -> Self {
|
||||
match kbps {
|
||||
0..=1500 => Self::Speed,
|
||||
1501..=6000 => Self::Balanced,
|
||||
6001..=10000 => Self::Quality,
|
||||
_ => Self::Custom(kbps),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BitratePreset {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Speed => write!(f, "Speed (1 Mbps)"),
|
||||
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
|
||||
Self::Quality => write!(f, "Quality (8 Mbps)"),
|
||||
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main application configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -179,27 +252,13 @@ pub enum OtgEndpointBudget {
|
||||
}
|
||||
|
||||
impl OtgEndpointBudget {
|
||||
pub fn default_for_udc_name(udc: Option<&str>) -> Self {
|
||||
if udc.is_some_and(crate::otg::configfs::is_low_endpoint_udc) {
|
||||
Self::Five
|
||||
} else {
|
||||
Self::Six
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolved(self, udc: Option<&str>) -> Self {
|
||||
/// Resolve endpoint limit assuming a known budget variant (not Auto).
|
||||
pub fn endpoint_limit_raw(&self) -> Option<u8> {
|
||||
match self {
|
||||
Self::Auto => Self::default_for_udc_name(udc),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_limit(self, udc: Option<&str>) -> Option<u8> {
|
||||
match self.resolved(udc) {
|
||||
Self::Five => Some(5),
|
||||
Self::Six => Some(6),
|
||||
Self::Unlimited => None,
|
||||
Self::Auto => unreachable!("auto budget must be resolved before use"),
|
||||
Self::Auto => None, // resolved via `HidConfig::resolved_otg_endpoint_limit`
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,32 +415,23 @@ impl Default for HidConfig {
|
||||
}
|
||||
|
||||
impl HidConfig {
|
||||
/// Resolve effective OTG HID functions from profile + custom selection.
|
||||
/// Pure logic, no external dependency.
|
||||
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
|
||||
self.otg_profile.resolve_functions(&self.otg_functions)
|
||||
}
|
||||
|
||||
pub fn resolved_otg_udc(&self) -> Option<String> {
|
||||
crate::otg::configfs::resolve_udc_name(self.otg_udc.as_deref())
|
||||
}
|
||||
|
||||
pub fn resolved_otg_endpoint_budget(&self) -> OtgEndpointBudget {
|
||||
self.otg_endpoint_budget
|
||||
.resolved(self.resolved_otg_udc().as_deref())
|
||||
}
|
||||
|
||||
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
|
||||
self.otg_endpoint_budget
|
||||
.endpoint_limit(self.resolved_otg_udc().as_deref())
|
||||
}
|
||||
|
||||
/// Whether keyboard LED feedback is effectively enabled.
|
||||
pub fn effective_otg_keyboard_leds(&self) -> bool {
|
||||
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
|
||||
}
|
||||
|
||||
/// Effective HID functions after applying all constraints.
|
||||
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
|
||||
self.effective_otg_functions()
|
||||
}
|
||||
|
||||
/// Calculate required endpoint count for the current function selection.
|
||||
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
|
||||
let functions = self.effective_otg_functions();
|
||||
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
|
||||
@@ -391,6 +441,7 @@ impl HidConfig {
|
||||
endpoints
|
||||
}
|
||||
|
||||
/// Validate endpoint budget for the current OTG configuration (UDC-aware when budget is Auto).
|
||||
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
|
||||
if self.backend != HidBackend::Otg {
|
||||
return Ok(());
|
||||
@@ -403,8 +454,9 @@ impl HidConfig {
|
||||
));
|
||||
}
|
||||
|
||||
let resolved_limit = self.resolved_otg_endpoint_limit();
|
||||
let required = self.effective_otg_required_endpoints(msd_enabled);
|
||||
if let Some(limit) = self.resolved_otg_endpoint_limit() {
|
||||
if let Some(limit) = resolved_limit {
|
||||
if required > limit {
|
||||
return Err(crate::error::AppError::BadRequest(format!(
|
||||
"OTG selection requires {} endpoints, but the configured limit is {}",
|
||||
@@ -415,6 +467,40 @@ impl HidConfig {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Effective OTG UDC name (for change detection / service).
|
||||
#[inline]
|
||||
pub fn resolved_otg_udc(&self) -> Option<String> {
|
||||
if self.backend != HidBackend::Otg {
|
||||
return None;
|
||||
}
|
||||
self.otg_udc
|
||||
.as_ref()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.or_else(|| crate::otg::OtgGadgetManager::find_udc())
|
||||
}
|
||||
|
||||
/// Resolved endpoint limit used for OTG gadget allocator / validation.
|
||||
#[inline]
|
||||
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
|
||||
if self.backend != HidBackend::Otg {
|
||||
return None;
|
||||
}
|
||||
match self.otg_endpoint_budget {
|
||||
OtgEndpointBudget::Five => Some(5),
|
||||
OtgEndpointBudget::Six => Some(6),
|
||||
OtgEndpointBudget::Unlimited => None,
|
||||
OtgEndpointBudget::Auto => {
|
||||
let udc = self.resolved_otg_udc().unwrap_or_default();
|
||||
if crate::otg::configfs::is_low_endpoint_udc(&udc) {
|
||||
Some(5)
|
||||
} else {
|
||||
Some(6)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD configuration
|
||||
@@ -511,7 +597,7 @@ impl Default for AudioConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
device: "default".to_string(),
|
||||
device: String::new(),
|
||||
quality: "balanced".to_string(),
|
||||
}
|
||||
}
|
||||
@@ -606,21 +692,6 @@ pub enum EncoderType {
|
||||
}
|
||||
|
||||
impl EncoderType {
|
||||
/// Convert to EncoderBackend for registry queries
|
||||
pub fn to_backend(&self) -> Option<crate::video::encoder::registry::EncoderBackend> {
|
||||
use crate::video::encoder::registry::EncoderBackend;
|
||||
match self {
|
||||
EncoderType::Auto => None,
|
||||
EncoderType::Software => Some(EncoderBackend::Software),
|
||||
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
|
||||
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
|
||||
EncoderType::Qsv => Some(EncoderBackend::Qsv),
|
||||
EncoderType::Amf => Some(EncoderBackend::Amf),
|
||||
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
|
||||
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get display name for UI
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
@@ -687,19 +758,17 @@ impl Default for StreamConfig {
|
||||
}
|
||||
|
||||
impl StreamConfig {
|
||||
/// Check if using public ICE servers (user left fields empty)
|
||||
/// Whether built-in / public ICE is used (no custom STUN or TURN URL configured).
|
||||
pub fn is_using_public_ice_servers(&self) -> bool {
|
||||
use crate::webrtc::config::public_ice;
|
||||
self.stun_server
|
||||
let no_custom_stun = self
|
||||
.stun_server
|
||||
.as_ref()
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
&& self
|
||||
.turn_server
|
||||
.as_ref()
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
&& public_ice::is_configured()
|
||||
.map_or(true, |s| s.trim().is_empty());
|
||||
let no_custom_turn = self
|
||||
.turn_server
|
||||
.as_ref()
|
||||
.map_or(true, |s| s.trim().is_empty());
|
||||
no_custom_stun && no_custom_turn
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use arc_swap::ArcSwap;
|
||||
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
|
||||
use std::path::Path;
|
||||
use sqlx::{Pool, Sqlite};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::persistence::ConfigChange;
|
||||
use super::AppConfig;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
@@ -23,127 +22,23 @@ pub struct ConfigStore {
|
||||
write_lock: Arc<Mutex<()>>,
|
||||
}
|
||||
|
||||
/// Configuration change event
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfigChange {
|
||||
pub key: String,
|
||||
}
|
||||
|
||||
impl ConfigStore {
|
||||
/// Create a new configuration store
|
||||
pub async fn new(db_path: &Path) -> Result<Self> {
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = db_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
// SQLite uses single-writer mode, 2 connections is sufficient for embedded devices
|
||||
// One for reads, one for writes to avoid blocking
|
||||
.max_connections(2)
|
||||
// Set reasonable timeouts for embedded environments
|
||||
.acquire_timeout(Duration::from_secs(5))
|
||||
.idle_timeout(Duration::from_secs(300))
|
||||
.connect(&db_url)
|
||||
.await?;
|
||||
|
||||
// Initialize database schema
|
||||
Self::init_schema(&pool).await?;
|
||||
|
||||
// Load or create default config
|
||||
let config = Self::load_config(&pool).await?;
|
||||
let cache = Arc::new(ArcSwap::from_pointee(config));
|
||||
|
||||
let (change_tx, _) = broadcast::channel(16);
|
||||
|
||||
pub fn new(pool: Pool<Sqlite>) -> Result<Self> {
|
||||
// Load or create default config synchronously wrapper
|
||||
// (actual DB load is async, handled in init())
|
||||
Ok(Self {
|
||||
pool,
|
||||
cache,
|
||||
change_tx,
|
||||
cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())),
|
||||
change_tx: broadcast::channel(16).0,
|
||||
write_lock: Arc::new(Mutex::new(())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize database schema
|
||||
async fn init_schema(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
expires_at TEXT NOT NULL,
|
||||
data TEXT
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS api_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
permissions TEXT NOT NULL,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
last_used TEXT
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS wol_history (
|
||||
mac_address TEXT PRIMARY KEY,
|
||||
updated_at INTEGER NOT NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
|
||||
ON wol_history(updated_at DESC)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
/// Load configuration from database (call after new())
|
||||
pub async fn load(&self) -> Result<()> {
|
||||
let config = Self::load_config(&self.pool).await?;
|
||||
self.cache.store(Arc::new(config));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -244,16 +139,12 @@ impl ConfigStore {
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.cache.load().initialized
|
||||
}
|
||||
|
||||
/// Get database pool for session management
|
||||
pub fn pool(&self) -> &Pool<Sqlite> {
|
||||
&self.pool
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db::DatabasePool;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -261,7 +152,11 @@ mod tests {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
|
||||
let store = ConfigStore::new(&db_path).await.unwrap();
|
||||
let db = DatabasePool::new(&db_path).await.unwrap();
|
||||
db.init_schema().await.unwrap();
|
||||
|
||||
let store = ConfigStore::new(db.clone_pool()).unwrap();
|
||||
store.load().await.unwrap();
|
||||
|
||||
// Check default config (now lock-free, no await needed)
|
||||
let config = store.get();
|
||||
@@ -282,7 +177,8 @@ mod tests {
|
||||
assert_eq!(config.web.http_port, 9000);
|
||||
|
||||
// Create new store instance and verify persistence
|
||||
let store2 = ConfigStore::new(&db_path).await.unwrap();
|
||||
let store2 = ConfigStore::new(db.clone_pool()).unwrap();
|
||||
store2.load().await.unwrap();
|
||||
let config = store2.get();
|
||||
assert!(config.initialized);
|
||||
assert_eq!(config.web.http_port, 9000);
|
||||
|
||||
3
src/db/mod.rs
Normal file
3
src/db/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod pool;
|
||||
|
||||
pub use pool::DatabasePool;
|
||||
119
src/db/pool.rs
Normal file
119
src/db/pool.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DatabasePool {
|
||||
pool: Pool<Sqlite>,
|
||||
}
|
||||
|
||||
impl DatabasePool {
|
||||
pub async fn new(db_path: &Path) -> Result<Self> {
|
||||
if let Some(parent) = db_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(4)
|
||||
.acquire_timeout(Duration::from_secs(5))
|
||||
.idle_timeout(Duration::from_secs(300))
|
||||
.connect(&db_url)
|
||||
.await?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
pub async fn init_schema(&self) -> Result<()> {
|
||||
self.create_config_table().await?;
|
||||
self.create_users_table().await?;
|
||||
self.create_api_tokens_table().await?;
|
||||
self.create_wol_history_table().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_config_table(&self) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_users_table(&self) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_api_tokens_table(&self) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS api_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
permissions TEXT NOT NULL,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
last_used TEXT
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_wol_history_table(&self) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS wol_history (
|
||||
mac_address TEXT PRIMARY KEY,
|
||||
updated_at INTEGER NOT NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
|
||||
ON wol_history(updated_at DESC)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn pool(&self) -> &Pool<Sqlite> {
|
||||
&self.pool
|
||||
}
|
||||
|
||||
pub fn clone_pool(&self) -> Pool<Sqlite> {
|
||||
self.pool.clone()
|
||||
}
|
||||
}
|
||||
56
src/error.rs
56
src/error.rs
@@ -1,12 +1,5 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Application-wide error type
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Authentication failed: {0}")]
|
||||
@@ -15,17 +8,14 @@ pub enum AppError {
|
||||
#[error("Not authenticated")]
|
||||
Unauthorized,
|
||||
|
||||
#[error("Forbidden: {0}")]
|
||||
Forbidden(String),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Bad request: {0}")]
|
||||
BadRequest(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Persistence error: {0}")]
|
||||
Persistence(String),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
@@ -42,9 +32,6 @@ pub enum AppError {
|
||||
#[error("Video error: {0}")]
|
||||
VideoError(String),
|
||||
|
||||
#[error("Video device lost [{device}]: {reason}")]
|
||||
VideoDeviceLost { device: String, reason: String },
|
||||
|
||||
/// No input signal while opening capture; `kind` is `SignalStatus` as string (`from_str`).
|
||||
#[error("Capture has no valid signal: {kind}")]
|
||||
CaptureNoSignal { kind: String },
|
||||
@@ -66,37 +53,10 @@ pub enum AppError {
|
||||
ServiceUnavailable(String),
|
||||
}
|
||||
|
||||
/// Error response body (unified success format)
|
||||
#[derive(Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
// Always return 200 OK - success/failure is indicated by the success field
|
||||
StatusCode::OK
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status_code();
|
||||
let body = ErrorResponse {
|
||||
success: false,
|
||||
message: self.to_string(),
|
||||
};
|
||||
|
||||
tracing::error!(
|
||||
error_type = std::any::type_name_of_val(&self),
|
||||
error_message = %body.message,
|
||||
"Request failed"
|
||||
);
|
||||
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for handlers
|
||||
pub type Result<T> = std::result::Result<T, AppError>;
|
||||
|
||||
impl From<sqlx::Error> for AppError {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
AppError::Persistence(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,41 +1,28 @@
|
||||
//! Event system for real-time state notifications
|
||||
//!
|
||||
//! This module provides a global event bus for broadcasting system events
|
||||
//! to WebSocket clients and other subscribers.
|
||||
//! Event bus: [`SystemEvent`] fan-out to WebSocket subscribers and internal tasks.
|
||||
|
||||
pub mod types;
|
||||
|
||||
use self::types::EXACT_EVENT_TOPICS;
|
||||
|
||||
pub use types::{
|
||||
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent,
|
||||
TtydDeviceInfo, VideoDeviceInfo,
|
||||
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, LedState, MsdDeviceInfo,
|
||||
SystemEvent, TtydDeviceInfo, VideoDeviceInfo,
|
||||
};
|
||||
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Event channel capacity (ring buffer size)
|
||||
const EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
const EXACT_TOPICS: &[&str] = &[
|
||||
"stream.mode_switching",
|
||||
"stream.state_changed",
|
||||
"stream.config_changing",
|
||||
"stream.config_applied",
|
||||
"stream.device_lost",
|
||||
"stream.reconnecting",
|
||||
"stream.recovered",
|
||||
"stream.webrtc_ready",
|
||||
"stream.stats_update",
|
||||
"stream.mode_changed",
|
||||
"stream.mode_ready",
|
||||
"webrtc.ice_candidate",
|
||||
"webrtc.ice_complete",
|
||||
"msd.upload_progress",
|
||||
"msd.download_progress",
|
||||
"system.device_info",
|
||||
"error",
|
||||
];
|
||||
|
||||
const PREFIX_TOPICS: &[&str] = &["stream.*", "webrtc.*", "msd.*", "system.*"];
|
||||
fn collect_prefix_wildcards(exact: &[&'static str]) -> Vec<String> {
|
||||
use std::collections::BTreeSet;
|
||||
let mut segments = BTreeSet::new();
|
||||
for name in exact {
|
||||
if let Some((seg, _)) = name.split_once('.') {
|
||||
segments.insert(seg);
|
||||
}
|
||||
}
|
||||
segments.into_iter().map(|s| format!("{}.*", s)).collect()
|
||||
}
|
||||
|
||||
fn make_sender() -> broadcast::Sender<SystemEvent> {
|
||||
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||
@@ -48,52 +35,23 @@ fn topic_prefix(event_name: &str) -> Option<String> {
|
||||
.map(|(prefix, _)| format!("{}.*", prefix))
|
||||
}
|
||||
|
||||
/// Global event bus for broadcasting system events
|
||||
///
|
||||
/// The event bus uses tokio's broadcast channel to distribute events
|
||||
/// to multiple subscribers. Events are delivered to all active subscribers.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use one_kvm::events::{EventBus, SystemEvent};
|
||||
///
|
||||
/// let bus = EventBus::new();
|
||||
///
|
||||
/// // Publish an event
|
||||
/// bus.publish(SystemEvent::StreamStateChanged {
|
||||
/// state: "streaming".to_string(),
|
||||
/// device: Some("/dev/video0".to_string()),
|
||||
/// reason: None,
|
||||
/// next_retry_ms: None,
|
||||
/// });
|
||||
///
|
||||
/// // Subscribe to events
|
||||
/// let mut rx = bus.subscribe();
|
||||
/// tokio::spawn(async move {
|
||||
/// while let Ok(event) = rx.recv().await {
|
||||
/// println!("Received event: {:?}", event);
|
||||
/// }
|
||||
/// });
|
||||
/// ```
|
||||
pub struct EventBus {
|
||||
tx: broadcast::Sender<SystemEvent>,
|
||||
exact_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
|
||||
prefix_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
|
||||
prefix_topics: std::collections::HashMap<String, broadcast::Sender<SystemEvent>>,
|
||||
device_info_dirty_tx: broadcast::Sender<()>,
|
||||
}
|
||||
|
||||
impl EventBus {
|
||||
/// Create a new event bus
|
||||
pub fn new() -> Self {
|
||||
let tx = make_sender();
|
||||
let exact_topics = EXACT_TOPICS
|
||||
let exact_topics = EXACT_EVENT_TOPICS
|
||||
.iter()
|
||||
.map(|topic| (*topic, make_sender()))
|
||||
.collect();
|
||||
let prefix_topics = PREFIX_TOPICS
|
||||
.iter()
|
||||
.map(|topic| (*topic, make_sender()))
|
||||
let prefix_topics = collect_prefix_wildcards(EXACT_EVENT_TOPICS)
|
||||
.into_iter()
|
||||
.map(|topic| (topic, make_sender()))
|
||||
.collect();
|
||||
let (device_info_dirty_tx, _dirty_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||
|
||||
@@ -105,10 +63,6 @@ impl EventBus {
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish an event to all subscribers
|
||||
///
|
||||
/// If there are no active subscribers, the event is silently dropped.
|
||||
/// This is by design - events are fire-and-forget notifications.
|
||||
pub fn publish(&self, event: SystemEvent) {
|
||||
let event_name = event.event_name();
|
||||
|
||||
@@ -117,28 +71,18 @@ impl EventBus {
|
||||
}
|
||||
|
||||
if let Some(prefix) = topic_prefix(event_name) {
|
||||
if let Some(tx) = self.prefix_topics.get(prefix.as_str()) {
|
||||
if let Some(tx) = self.prefix_topics.get(&prefix) {
|
||||
let _ = tx.send(event.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// If no subscribers, send returns Err which is normal
|
||||
let _ = self.tx.send(event);
|
||||
}
|
||||
|
||||
/// Subscribe to events
|
||||
///
|
||||
/// Returns a receiver that will receive all future events.
|
||||
/// The receiver uses a ring buffer, so if a subscriber falls too far
|
||||
/// behind, it will receive a `Lagged` error and miss some events.
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<SystemEvent> {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
|
||||
/// Subscribe to a specific topic.
|
||||
///
|
||||
/// Supports exact event names, namespace wildcards like `stream.*`, and
|
||||
/// `*` for the full event stream.
|
||||
pub fn subscribe_topic(&self, topic: &str) -> Option<broadcast::Receiver<SystemEvent>> {
|
||||
if topic == "*" {
|
||||
return Some(self.tx.subscribe());
|
||||
@@ -151,22 +95,14 @@ impl EventBus {
|
||||
self.exact_topics.get(topic).map(|tx| tx.subscribe())
|
||||
}
|
||||
|
||||
/// Mark the device-info snapshot as stale.
|
||||
///
|
||||
/// This is an internal trigger used to refresh the latest `system.device_info`
|
||||
/// snapshot without exposing another public WebSocket event.
|
||||
pub fn mark_device_info_dirty(&self) {
|
||||
let _ = self.device_info_dirty_tx.send(());
|
||||
}
|
||||
|
||||
/// Subscribe to internal device-info refresh triggers.
|
||||
pub fn subscribe_device_info_dirty(&self) -> broadcast::Receiver<()> {
|
||||
self.device_info_dirty_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Get the current number of active subscribers
|
||||
///
|
||||
/// Useful for monitoring and debugging.
|
||||
pub fn subscriber_count(&self) -> usize {
|
||||
self.tx.receiver_count()
|
||||
}
|
||||
@@ -263,7 +199,6 @@ mod tests {
|
||||
let bus = EventBus::new();
|
||||
assert_eq!(bus.subscriber_count(), 0);
|
||||
|
||||
// Should not panic when publishing with no subscribers
|
||||
bus.publish(SystemEvent::StreamStateChanged {
|
||||
state: "ready".to_string(),
|
||||
device: None,
|
||||
|
||||
@@ -1,165 +1,96 @@
|
||||
//! System event types
|
||||
//!
|
||||
//! Defines all event types that can be broadcast through the event bus.
|
||||
//! [`SystemEvent`] and device snapshot types (WebSocket / JSON).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::hid::LedState;
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct LedState {
|
||||
pub num_lock: bool,
|
||||
pub caps_lock: bool,
|
||||
pub scroll_lock: bool,
|
||||
pub compose: bool,
|
||||
pub kana: bool,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Device Info Structures (for system.device_info event)
|
||||
// ============================================================================
|
||||
|
||||
/// Video device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VideoDeviceInfo {
|
||||
/// Whether video device is available
|
||||
pub available: bool,
|
||||
/// Device path (e.g., /dev/video0)
|
||||
pub device: Option<String>,
|
||||
/// Pixel format (e.g., "MJPEG", "YUYV")
|
||||
pub format: Option<String>,
|
||||
/// Resolution (width, height)
|
||||
pub resolution: Option<(u32, u32)>,
|
||||
/// Frames per second
|
||||
pub fps: u32,
|
||||
/// Whether stream is currently active
|
||||
pub online: bool,
|
||||
/// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
|
||||
pub stream_mode: String,
|
||||
/// Whether video config is currently being changed (frontend should skip mode sync)
|
||||
pub config_changing: bool,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// HID device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HidDeviceInfo {
|
||||
/// Whether HID backend is available
|
||||
pub available: bool,
|
||||
/// Backend type: "otg", "ch9329", "none"
|
||||
pub backend: String,
|
||||
/// Whether backend is initialized and ready
|
||||
pub initialized: bool,
|
||||
/// Whether backend is currently online
|
||||
pub online: bool,
|
||||
/// Whether absolute mouse positioning is supported
|
||||
pub supports_absolute_mouse: bool,
|
||||
/// Whether keyboard LED/status feedback is enabled.
|
||||
pub keyboard_leds_enabled: bool,
|
||||
/// Last known keyboard LED state.
|
||||
pub led_state: LedState,
|
||||
/// Device path (e.g., serial port for CH9329)
|
||||
pub device: Option<String>,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
/// Error code if any, None if OK
|
||||
pub error_code: Option<String>,
|
||||
}
|
||||
|
||||
/// MSD device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MsdDeviceInfo {
|
||||
/// Whether MSD is available
|
||||
pub available: bool,
|
||||
/// Operating mode: "none", "image", "drive"
|
||||
pub mode: String,
|
||||
/// Whether storage is connected to target
|
||||
pub connected: bool,
|
||||
/// Currently mounted image ID
|
||||
pub image_id: Option<String>,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// ATX device information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AtxDeviceInfo {
|
||||
/// Whether ATX controller is available
|
||||
pub available: bool,
|
||||
/// Backend type: "gpio", "usb_relay", "none"
|
||||
pub backend: String,
|
||||
/// Whether backend is initialized
|
||||
pub initialized: bool,
|
||||
/// Whether power is currently on
|
||||
pub power_on: bool,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Audio device information
|
||||
///
|
||||
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AudioDeviceInfo {
|
||||
/// Whether audio is enabled/available
|
||||
pub available: bool,
|
||||
/// Whether audio is currently streaming
|
||||
pub streaming: bool,
|
||||
/// Current audio device name
|
||||
pub device: Option<String>,
|
||||
/// Quality preset: "voice", "balanced", "high"
|
||||
pub quality: String,
|
||||
/// Error message if any, None if OK
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// ttyd status information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TtydDeviceInfo {
|
||||
/// Whether ttyd binary is available
|
||||
pub available: bool,
|
||||
/// Whether ttyd is currently running
|
||||
pub running: bool,
|
||||
}
|
||||
|
||||
/// Per-client statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClientStats {
|
||||
/// Client ID
|
||||
pub id: String,
|
||||
/// Current FPS for this client (frames sent in last second)
|
||||
pub fps: u32,
|
||||
/// Connected duration (seconds)
|
||||
pub connected_secs: u64,
|
||||
}
|
||||
|
||||
/// System event enumeration
|
||||
///
|
||||
/// All events are tagged with their event name for serialization.
|
||||
/// The `serde(tag = "event", content = "data")` attribute creates a
|
||||
/// JSON structure like:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "event": "stream.state_changed",
|
||||
/// "data": { "state": "streaming", "device": "/dev/video0" }
|
||||
/// }
|
||||
/// ```
|
||||
/// JSON: `{"event": "<name>", "data": { ... }}`.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "event", content = "data")]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum SystemEvent {
|
||||
// ============================================================================
|
||||
// Video Stream Events
|
||||
// ============================================================================
|
||||
/// Stream mode switching started (transactional, correlates all following events)
|
||||
///
|
||||
/// Sent immediately after a mode switch request is accepted.
|
||||
/// Clients can use `transition_id` to correlate subsequent `stream.*` events.
|
||||
#[serde(rename = "stream.mode_switching")]
|
||||
StreamModeSwitching {
|
||||
/// Unique transition ID for this mode switch transaction
|
||||
transition_id: String,
|
||||
/// Target mode: "mjpeg", "h264", "h265", "vp8", "vp9"
|
||||
to_mode: String,
|
||||
/// Previous mode: "mjpeg", "h264", "h265", "vp8", "vp9"
|
||||
from_mode: String,
|
||||
},
|
||||
|
||||
/// Stream state for the UI (`streaming`, `no_signal`, `device_lost`, `device_busy`, etc.).
|
||||
/// Optional `reason` / `next_retry_ms` are hints only; branch on `state`.
|
||||
#[serde(rename = "stream.state_changed")]
|
||||
StreamStateChanged {
|
||||
state: String,
|
||||
@@ -170,193 +101,122 @@ pub enum SystemEvent {
|
||||
next_retry_ms: Option<u64>,
|
||||
},
|
||||
|
||||
/// Stream configuration is being changed
|
||||
///
|
||||
/// Sent before applying new configuration to notify clients that
|
||||
/// the stream will be interrupted temporarily.
|
||||
#[serde(rename = "stream.config_changing")]
|
||||
StreamConfigChanging {
|
||||
/// Optional transition ID if this config change is part of a mode switch transaction
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
transition_id: Option<String>,
|
||||
/// Reason for change: "device_switch", "resolution_change", "format_change"
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Stream configuration has been applied successfully
|
||||
///
|
||||
/// Sent after new configuration is active. Clients can reconnect now.
|
||||
#[serde(rename = "stream.config_applied")]
|
||||
StreamConfigApplied {
|
||||
/// Optional transition ID if this config change is part of a mode switch transaction
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
transition_id: Option<String>,
|
||||
/// Device path
|
||||
device: String,
|
||||
/// Resolution (width, height)
|
||||
resolution: (u32, u32),
|
||||
/// Pixel format: "mjpeg", "yuyv", etc.
|
||||
format: String,
|
||||
/// Frames per second
|
||||
fps: u32,
|
||||
},
|
||||
|
||||
/// Stream device was lost (disconnected or error)
|
||||
#[serde(rename = "stream.device_lost")]
|
||||
StreamDeviceLost {
|
||||
/// Device path that was lost
|
||||
device: String,
|
||||
/// Reason for loss
|
||||
reason: String,
|
||||
},
|
||||
StreamDeviceLost { device: String, reason: String },
|
||||
|
||||
/// Stream device is reconnecting
|
||||
#[serde(rename = "stream.reconnecting")]
|
||||
StreamReconnecting {
|
||||
/// Device path being reconnected
|
||||
device: String,
|
||||
/// Retry attempt number
|
||||
attempt: u32,
|
||||
},
|
||||
StreamReconnecting { device: String, attempt: u32 },
|
||||
|
||||
/// Stream device has recovered
|
||||
#[serde(rename = "stream.recovered")]
|
||||
StreamRecovered {
|
||||
/// Device path that was recovered
|
||||
device: String,
|
||||
},
|
||||
StreamRecovered { device: String },
|
||||
|
||||
/// WebRTC is ready to accept connections
|
||||
///
|
||||
/// Sent after video frame source is connected to WebRTC pipeline.
|
||||
/// Clients should wait for this event before attempting to create WebRTC sessions.
|
||||
#[serde(rename = "stream.webrtc_ready")]
|
||||
WebRTCReady {
|
||||
/// Optional transition ID if this readiness is part of a mode switch transaction
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
transition_id: Option<String>,
|
||||
/// Current video codec
|
||||
codec: String,
|
||||
/// Whether hardware encoding is being used
|
||||
hardware: bool,
|
||||
},
|
||||
|
||||
/// WebRTC ICE candidate (server -> client trickle)
|
||||
#[serde(rename = "webrtc.ice_candidate")]
|
||||
WebRTCIceCandidate {
|
||||
/// WebRTC session ID
|
||||
session_id: String,
|
||||
/// ICE candidate data
|
||||
candidate: crate::webrtc::signaling::IceCandidate,
|
||||
candidate: serde_json::Value,
|
||||
},
|
||||
|
||||
/// WebRTC ICE gathering complete (server -> client)
|
||||
#[serde(rename = "webrtc.ice_complete")]
|
||||
WebRTCIceComplete {
|
||||
/// WebRTC session ID
|
||||
session_id: String,
|
||||
},
|
||||
WebRTCIceComplete { session_id: String },
|
||||
|
||||
/// Stream statistics update (sent periodically for client stats)
|
||||
#[serde(rename = "stream.stats_update")]
|
||||
StreamStatsUpdate {
|
||||
/// Number of connected clients
|
||||
clients: u64,
|
||||
/// Per-client statistics (client_id -> client stats)
|
||||
/// Each client's FPS reflects the actual frames sent in the last second
|
||||
clients_stat: HashMap<String, ClientStats>,
|
||||
},
|
||||
|
||||
/// Stream mode changed (MJPEG <-> WebRTC)
|
||||
///
|
||||
/// Sent when the streaming mode is switched. Clients should disconnect
|
||||
/// from the current stream and reconnect using the new mode.
|
||||
#[serde(rename = "stream.mode_changed")]
|
||||
StreamModeChanged {
|
||||
/// Optional transition ID if this change is part of a mode switch transaction
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
transition_id: Option<String>,
|
||||
/// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
|
||||
mode: String,
|
||||
/// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
|
||||
previous_mode: String,
|
||||
},
|
||||
|
||||
/// Stream mode switching completed (transactional end marker)
|
||||
///
|
||||
/// Sent when the backend considers the new mode ready for clients to connect.
|
||||
#[serde(rename = "stream.mode_ready")]
|
||||
StreamModeReady {
|
||||
/// Unique transition ID for this mode switch transaction
|
||||
transition_id: String,
|
||||
/// Active mode after switch: "mjpeg", "h264", "h265", "vp8", "vp9"
|
||||
mode: String,
|
||||
},
|
||||
StreamModeReady { transition_id: String, mode: String },
|
||||
|
||||
// ============================================================================
|
||||
// MSD (Mass Storage Device) Events
|
||||
// ============================================================================
|
||||
/// File upload progress (for large file uploads)
|
||||
#[serde(rename = "msd.upload_progress")]
|
||||
MsdUploadProgress {
|
||||
/// Upload operation ID
|
||||
upload_id: String,
|
||||
/// Filename being uploaded
|
||||
filename: String,
|
||||
/// Bytes uploaded so far
|
||||
bytes_uploaded: u64,
|
||||
/// Total file size
|
||||
total_bytes: u64,
|
||||
/// Progress percentage (0.0 - 100.0)
|
||||
progress_pct: f32,
|
||||
},
|
||||
|
||||
/// Image download progress (for URL downloads)
|
||||
#[serde(rename = "msd.download_progress")]
|
||||
MsdDownloadProgress {
|
||||
/// Download operation ID
|
||||
download_id: String,
|
||||
/// Source URL
|
||||
url: String,
|
||||
/// Target filename
|
||||
filename: String,
|
||||
/// Bytes downloaded so far
|
||||
bytes_downloaded: u64,
|
||||
/// Total file size (None if unknown)
|
||||
total_bytes: Option<u64>,
|
||||
/// Progress percentage (0.0 - 100.0, None if total unknown)
|
||||
progress_pct: Option<f32>,
|
||||
/// Download status: "started", "in_progress", "completed", "failed"
|
||||
status: String,
|
||||
},
|
||||
|
||||
/// Complete device information (sent on WebSocket connect and state changes)
|
||||
#[serde(rename = "system.device_info")]
|
||||
DeviceInfo {
|
||||
/// Video device information
|
||||
video: VideoDeviceInfo,
|
||||
/// HID device information
|
||||
hid: HidDeviceInfo,
|
||||
/// MSD device information (None if MSD not enabled)
|
||||
msd: Option<MsdDeviceInfo>,
|
||||
/// ATX device information (None if ATX not enabled)
|
||||
atx: Option<AtxDeviceInfo>,
|
||||
/// Audio device information (None if audio not enabled)
|
||||
audio: Option<AudioDeviceInfo>,
|
||||
/// ttyd status information
|
||||
ttyd: TtydDeviceInfo,
|
||||
},
|
||||
|
||||
/// WebSocket error notification (for connection-level errors like lag)
|
||||
#[serde(rename = "error")]
|
||||
Error {
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// One entry per [`SystemEvent::event_name`]. `EventBus` builds `*.`-wildcard channels from the first segment; names without `.` (e.g. `error`) have no wildcard channel.
|
||||
pub(crate) const EXACT_EVENT_TOPICS: &[&str] = &[
|
||||
"stream.mode_switching",
|
||||
"stream.state_changed",
|
||||
"stream.config_changing",
|
||||
"stream.config_applied",
|
||||
"stream.device_lost",
|
||||
"stream.reconnecting",
|
||||
"stream.recovered",
|
||||
"stream.webrtc_ready",
|
||||
"stream.stats_update",
|
||||
"stream.mode_changed",
|
||||
"stream.mode_ready",
|
||||
"webrtc.ice_candidate",
|
||||
"webrtc.ice_complete",
|
||||
"msd.upload_progress",
|
||||
"msd.download_progress",
|
||||
"system.device_info",
|
||||
"error",
|
||||
];
|
||||
|
||||
impl SystemEvent {
|
||||
/// Get the event name (for filtering/routing)
|
||||
pub fn event_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::StreamModeSwitching { .. } => "stream.mode_switching",
|
||||
@@ -378,27 +238,6 @@ impl SystemEvent {
|
||||
Self::Error { .. } => "error",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if event name matches a topic pattern
|
||||
///
|
||||
/// Supports wildcards:
|
||||
/// - `*` matches all events
|
||||
/// - `stream.*` matches all stream events
|
||||
/// - `stream.state_changed` matches exact event
|
||||
pub fn matches_topic(&self, topic: &str) -> bool {
|
||||
if topic == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
let event_name = self.event_name();
|
||||
|
||||
if topic.ends_with(".*") {
|
||||
let prefix = topic.trim_end_matches(".*");
|
||||
event_name.starts_with(prefix)
|
||||
} else {
|
||||
event_name == topic
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -417,19 +256,124 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_topic() {
|
||||
let event = SystemEvent::StreamStateChanged {
|
||||
state: "streaming".to_string(),
|
||||
device: None,
|
||||
reason: None,
|
||||
next_retry_ms: None,
|
||||
};
|
||||
fn exact_topics_covers_all_variants() {
|
||||
use std::collections::HashSet;
|
||||
|
||||
assert!(event.matches_topic("*"));
|
||||
assert!(event.matches_topic("stream.*"));
|
||||
assert!(event.matches_topic("stream.state_changed"));
|
||||
assert!(!event.matches_topic("msd.*"));
|
||||
assert!(!event.matches_topic("stream.config_changed"));
|
||||
let samples = vec![
|
||||
SystemEvent::StreamModeSwitching {
|
||||
transition_id: String::new(),
|
||||
to_mode: String::new(),
|
||||
from_mode: String::new(),
|
||||
},
|
||||
SystemEvent::StreamStateChanged {
|
||||
state: String::new(),
|
||||
device: None,
|
||||
reason: None,
|
||||
next_retry_ms: None,
|
||||
},
|
||||
SystemEvent::StreamConfigChanging {
|
||||
transition_id: None,
|
||||
reason: String::new(),
|
||||
},
|
||||
SystemEvent::StreamConfigApplied {
|
||||
transition_id: None,
|
||||
device: String::new(),
|
||||
resolution: (0, 0),
|
||||
format: String::new(),
|
||||
fps: 0,
|
||||
},
|
||||
SystemEvent::StreamDeviceLost {
|
||||
device: String::new(),
|
||||
reason: String::new(),
|
||||
},
|
||||
SystemEvent::StreamReconnecting {
|
||||
device: String::new(),
|
||||
attempt: 0,
|
||||
},
|
||||
SystemEvent::StreamRecovered {
|
||||
device: String::new(),
|
||||
},
|
||||
SystemEvent::WebRTCReady {
|
||||
transition_id: None,
|
||||
codec: String::new(),
|
||||
hardware: false,
|
||||
},
|
||||
SystemEvent::StreamStatsUpdate {
|
||||
clients: 0,
|
||||
clients_stat: HashMap::new(),
|
||||
},
|
||||
SystemEvent::StreamModeChanged {
|
||||
transition_id: None,
|
||||
mode: String::new(),
|
||||
previous_mode: String::new(),
|
||||
},
|
||||
SystemEvent::StreamModeReady {
|
||||
transition_id: String::new(),
|
||||
mode: String::new(),
|
||||
},
|
||||
SystemEvent::WebRTCIceCandidate {
|
||||
session_id: String::new(),
|
||||
candidate: serde_json::Value::Null,
|
||||
},
|
||||
SystemEvent::WebRTCIceComplete {
|
||||
session_id: String::new(),
|
||||
},
|
||||
SystemEvent::MsdUploadProgress {
|
||||
upload_id: String::new(),
|
||||
filename: String::new(),
|
||||
bytes_uploaded: 0,
|
||||
total_bytes: 0,
|
||||
progress_pct: 0.0,
|
||||
},
|
||||
SystemEvent::MsdDownloadProgress {
|
||||
download_id: String::new(),
|
||||
url: String::new(),
|
||||
filename: String::new(),
|
||||
bytes_downloaded: 0,
|
||||
total_bytes: None,
|
||||
progress_pct: None,
|
||||
status: String::new(),
|
||||
},
|
||||
SystemEvent::DeviceInfo {
|
||||
video: VideoDeviceInfo {
|
||||
available: false,
|
||||
device: None,
|
||||
format: None,
|
||||
resolution: None,
|
||||
fps: 0,
|
||||
online: false,
|
||||
stream_mode: String::new(),
|
||||
config_changing: false,
|
||||
error: None,
|
||||
},
|
||||
hid: HidDeviceInfo {
|
||||
available: false,
|
||||
backend: String::new(),
|
||||
initialized: false,
|
||||
online: false,
|
||||
supports_absolute_mouse: false,
|
||||
keyboard_leds_enabled: false,
|
||||
led_state: LedState::default(),
|
||||
device: None,
|
||||
error: None,
|
||||
error_code: None,
|
||||
},
|
||||
msd: None,
|
||||
atx: None,
|
||||
audio: None,
|
||||
ttyd: TtydDeviceInfo {
|
||||
available: false,
|
||||
running: false,
|
||||
},
|
||||
},
|
||||
SystemEvent::Error {
|
||||
message: String::new(),
|
||||
},
|
||||
];
|
||||
|
||||
let from_enum: HashSet<_> = samples.iter().map(|e| e.event_name()).collect();
|
||||
let from_const: HashSet<_> = super::EXACT_EVENT_TOPICS.iter().copied().collect();
|
||||
assert_eq!(from_enum, from_const);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! Extension process manager
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::Path;
|
||||
use std::process::Stdio;
|
||||
@@ -12,25 +10,18 @@ use tokio::sync::RwLock;
|
||||
use super::types::*;
|
||||
use crate::events::EventBus;
|
||||
|
||||
/// Maximum number of log lines to keep per extension
|
||||
const LOG_BUFFER_SIZE: usize = 200;
|
||||
|
||||
/// Number of log lines to buffer before flushing to shared storage
|
||||
const LOG_BATCH_SIZE: usize = 16;
|
||||
|
||||
/// Unix socket path for ttyd
|
||||
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
|
||||
|
||||
/// Extension process with log buffer
|
||||
struct ExtensionProcess {
|
||||
child: Child,
|
||||
logs: Arc<RwLock<VecDeque<String>>>,
|
||||
}
|
||||
|
||||
/// Extension manager handles lifecycle of external processes
|
||||
pub struct ExtensionManager {
|
||||
processes: RwLock<HashMap<ExtensionId, ExtensionProcess>>,
|
||||
/// Cached availability status (checked once at startup)
|
||||
availability: HashMap<ExtensionId, bool>,
|
||||
event_bus: RwLock<Option<Arc<EventBus>>>,
|
||||
}
|
||||
@@ -42,9 +33,7 @@ impl Default for ExtensionManager {
|
||||
}
|
||||
|
||||
impl ExtensionManager {
|
||||
/// Create a new extension manager with cached availability
|
||||
pub fn new() -> Self {
|
||||
// Check availability once at startup
|
||||
let availability = ExtensionId::all()
|
||||
.iter()
|
||||
.map(|id| (*id, Path::new(id.binary_path()).exists()))
|
||||
@@ -57,7 +46,6 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set event bus for ttyd status notifications.
|
||||
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
|
||||
*self.event_bus.write().await = Some(event_bus);
|
||||
}
|
||||
@@ -72,12 +60,10 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the binary for an extension is available (cached)
|
||||
pub fn check_available(&self, id: ExtensionId) -> bool {
|
||||
*self.availability.get(&id).unwrap_or(&false)
|
||||
}
|
||||
|
||||
/// Get the current status of an extension
|
||||
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
|
||||
if !self.check_available(id) {
|
||||
return ExtensionStatus::Unavailable;
|
||||
@@ -117,20 +103,13 @@ impl ExtensionManager {
|
||||
ExtensionStatus::Stopped
|
||||
}
|
||||
|
||||
/// Start an extension with the given configuration
|
||||
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
|
||||
if !self.check_available(id) {
|
||||
return Err(format!(
|
||||
"{} not found at {}",
|
||||
id.display_name(),
|
||||
id.binary_path()
|
||||
));
|
||||
return Err(format!("{} not found at {}", id, id.binary_path()));
|
||||
}
|
||||
|
||||
// Stop existing process first
|
||||
self.stop(id).await.ok();
|
||||
|
||||
// Build command arguments
|
||||
let args = self.build_args(id, config).await?;
|
||||
|
||||
tracing::info!(
|
||||
@@ -146,11 +125,10 @@ impl ExtensionManager {
|
||||
.stderr(Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?;
|
||||
.map_err(|e| format!("Failed to start {}: {}", id, e))?;
|
||||
|
||||
let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE)));
|
||||
|
||||
// Spawn log collector for stdout
|
||||
if let Some(stdout) = child.stdout.take() {
|
||||
let logs_clone = logs.clone();
|
||||
let id_clone = id;
|
||||
@@ -159,7 +137,6 @@ impl ExtensionManager {
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn log collector for stderr
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
let logs_clone = logs.clone();
|
||||
let id_clone = id;
|
||||
@@ -179,7 +156,6 @@ impl ExtensionManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop an extension
|
||||
pub async fn stop(&self, id: ExtensionId) -> Result<(), String> {
|
||||
let mut processes = self.processes.write().await;
|
||||
if let Some(mut proc) = processes.remove(&id) {
|
||||
@@ -193,7 +169,6 @@ impl ExtensionManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get recent logs for an extension
|
||||
pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec<String> {
|
||||
let processes = self.processes.read().await;
|
||||
if let Some(proc) = processes.get(&id) {
|
||||
@@ -205,7 +180,6 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect logs from a stream with batched writes to reduce lock contention
|
||||
async fn collect_logs<R: tokio::io::AsyncRead + Unpin>(
|
||||
id: ExtensionId,
|
||||
reader: R,
|
||||
@@ -221,13 +195,11 @@ impl ExtensionManager {
|
||||
tracing::debug!("[{}] {}", id, line);
|
||||
local_buffer.push(line);
|
||||
|
||||
// Flush when batch is full
|
||||
if local_buffer.len() >= LOG_BATCH_SIZE {
|
||||
Self::flush_logs(&logs, &mut local_buffer).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream ended, flush remaining logs
|
||||
if !local_buffer.is_empty() {
|
||||
Self::flush_logs(&logs, &mut local_buffer).await;
|
||||
}
|
||||
@@ -241,7 +213,6 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush buffered logs to shared storage
|
||||
async fn flush_logs(logs: &RwLock<VecDeque<String>>, buffer: &mut Vec<String>) {
|
||||
let mut logs = logs.write().await;
|
||||
for line in buffer.drain(..) {
|
||||
@@ -252,7 +223,6 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Build command arguments for an extension
|
||||
async fn build_args(
|
||||
&self,
|
||||
id: ExtensionId,
|
||||
@@ -262,18 +232,16 @@ impl ExtensionManager {
|
||||
ExtensionId::Ttyd => {
|
||||
let c = &config.ttyd;
|
||||
|
||||
// Prepare socket directory and clean up old socket (async)
|
||||
Self::prepare_ttyd_socket().await?;
|
||||
|
||||
let mut args = vec![
|
||||
"-i".to_string(),
|
||||
TTYD_SOCKET_PATH.to_string(), // Unix socket
|
||||
TTYD_SOCKET_PATH.to_string(),
|
||||
"-b".to_string(),
|
||||
"/api/terminal".to_string(), // Base path for reverse proxy
|
||||
"-W".to_string(), // Writable (allow input)
|
||||
"/api/terminal".to_string(),
|
||||
"-W".to_string(),
|
||||
];
|
||||
|
||||
// Add shell as last argument
|
||||
args.push(c.shell.clone());
|
||||
Ok(args)
|
||||
}
|
||||
@@ -289,15 +257,12 @@ impl ExtensionManager {
|
||||
|
||||
let mut args = Vec::new();
|
||||
|
||||
// Add TLS flag
|
||||
if c.tls {
|
||||
args.push("--tls=true".to_string());
|
||||
}
|
||||
|
||||
// Server address (validated non-empty above)
|
||||
args.extend(["-addr".to_string(), c.addr.trim().to_string()]);
|
||||
|
||||
// Add client key
|
||||
args.extend(["-key".to_string(), c.key.clone()]);
|
||||
|
||||
Ok(args)
|
||||
@@ -316,24 +281,19 @@ impl ExtensionManager {
|
||||
c.network_secret.clone(),
|
||||
];
|
||||
|
||||
// Add peer URLs
|
||||
for peer in &c.peer_urls {
|
||||
if !peer.is_empty() {
|
||||
args.extend(["--peers".to_string(), peer.clone()]);
|
||||
}
|
||||
}
|
||||
|
||||
// Add virtual IP: use -d for DHCP if empty, or -i for specific IP
|
||||
if let Some(ref ip) = c.virtual_ip {
|
||||
if !ip.is_empty() {
|
||||
// Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24)
|
||||
args.extend(["-i".to_string(), ip.clone()]);
|
||||
} else {
|
||||
// Empty string means use DHCP
|
||||
args.push("-d".to_string());
|
||||
}
|
||||
} else {
|
||||
// None means use DHCP
|
||||
args.push("-d".to_string());
|
||||
}
|
||||
|
||||
@@ -342,11 +302,9 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepare ttyd socket directory and clean up old socket file
|
||||
async fn prepare_ttyd_socket() -> Result<(), String> {
|
||||
let socket_path = Path::new(TTYD_SOCKET_PATH);
|
||||
|
||||
// Ensure socket directory exists
|
||||
if let Some(socket_dir) = socket_path.parent() {
|
||||
if !socket_dir.exists() {
|
||||
tokio::fs::create_dir_all(socket_dir)
|
||||
@@ -355,7 +313,6 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old socket file if exists
|
||||
if tokio::fs::try_exists(TTYD_SOCKET_PATH)
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
@@ -368,9 +325,7 @@ impl ExtensionManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Health check - restart crashed processes that should be running
|
||||
pub async fn health_check(&self, config: &ExtensionsConfig) {
|
||||
// Collect extensions that need restart check
|
||||
let checks: Vec<_> = ExtensionId::all()
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
@@ -393,7 +348,6 @@ impl ExtensionManager {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Check which ones need restart (single read lock)
|
||||
let needs_restart: Vec<_> = {
|
||||
let processes = self.processes.read().await;
|
||||
checks
|
||||
@@ -408,7 +362,6 @@ impl ExtensionManager {
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Restart all crashed extensions in parallel
|
||||
let restart_futures: Vec<_> = needs_restart
|
||||
.into_iter()
|
||||
.map(|id| async move {
|
||||
@@ -422,14 +375,12 @@ impl ExtensionManager {
|
||||
futures::future::join_all(restart_futures).await;
|
||||
}
|
||||
|
||||
/// Start all enabled extensions in parallel
|
||||
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
|
||||
|
||||
// Collect enabled extensions
|
||||
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
|
||||
start_futures.push(Box::pin(async {
|
||||
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
|
||||
@@ -461,11 +412,9 @@ impl ExtensionManager {
|
||||
}));
|
||||
}
|
||||
|
||||
// Start all in parallel
|
||||
futures::future::join_all(start_futures).await;
|
||||
}
|
||||
|
||||
/// Stop all running extensions in parallel
|
||||
pub async fn stop_all(&self) {
|
||||
let stop_futures: Vec<_> = ExtensionId::all().iter().map(|id| self.stop(*id)).collect();
|
||||
futures::future::join_all(stop_futures).await;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! Extensions module - manage external processes like ttyd, gostc, easytier
|
||||
|
||||
mod manager;
|
||||
mod types;
|
||||
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
//! Extension types and configurations
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
/// Extension identifier (fixed set of supported extensions)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ExtensionId {
|
||||
/// Web terminal (ttyd)
|
||||
Ttyd,
|
||||
/// NAT traversal client (gostc)
|
||||
Gostc,
|
||||
/// P2P VPN (easytier)
|
||||
Easytier,
|
||||
}
|
||||
|
||||
impl ExtensionId {
|
||||
/// Get the binary path for this extension
|
||||
pub fn binary_path(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ttyd => "/usr/bin/ttyd",
|
||||
@@ -26,16 +19,6 @@ impl ExtensionId {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the display name for this extension
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ttyd => "Web Terminal",
|
||||
Self::Gostc => "GOSTC Tunnel",
|
||||
Self::Easytier => "EasyTier VPN",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all extension IDs
|
||||
pub fn all() -> &'static [ExtensionId] {
|
||||
&[Self::Ttyd, Self::Gostc, Self::Easytier]
|
||||
}
|
||||
@@ -64,25 +47,14 @@ impl std::str::FromStr for ExtensionId {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension running status
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "state", content = "data", rename_all = "lowercase")]
|
||||
pub enum ExtensionStatus {
|
||||
/// Binary not found at expected path
|
||||
Unavailable,
|
||||
/// Extension is stopped
|
||||
Stopped,
|
||||
/// Extension is running
|
||||
Running {
|
||||
/// Process ID
|
||||
pid: u32,
|
||||
},
|
||||
/// Extension failed to start
|
||||
Failed {
|
||||
/// Error message
|
||||
error: String,
|
||||
},
|
||||
Running { pid: u32 },
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
impl ExtensionStatus {
|
||||
@@ -91,16 +63,11 @@ impl ExtensionStatus {
|
||||
}
|
||||
}
|
||||
|
||||
/// ttyd configuration (Web Terminal)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct TtydConfig {
|
||||
/// Enable auto-start
|
||||
pub enabled: bool,
|
||||
/// Port to listen on
|
||||
pub port: u16,
|
||||
/// Shell to execute
|
||||
pub shell: String,
|
||||
}
|
||||
|
||||
@@ -108,25 +75,19 @@ impl Default for TtydConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
port: 7681,
|
||||
shell: "/bin/bash".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gostc configuration (NAT traversal based on FRP)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct GostcConfig {
|
||||
/// Enable auto-start
|
||||
pub enabled: bool,
|
||||
/// Server address (hostname or IP)
|
||||
pub addr: String,
|
||||
/// Client key from GOSTC management panel
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
pub key: String,
|
||||
/// Enable TLS
|
||||
pub tls: bool,
|
||||
}
|
||||
|
||||
@@ -141,28 +102,21 @@ impl Default for GostcConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// EasyTier configuration (P2P VPN)
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
#[derive(Default)]
|
||||
pub struct EasytierConfig {
|
||||
/// Enable auto-start
|
||||
pub enabled: bool,
|
||||
/// Network name
|
||||
pub network_name: String,
|
||||
/// Network secret/password
|
||||
#[serde(skip_serializing_if = "String::is_empty")]
|
||||
pub network_secret: String,
|
||||
/// Peer node URLs
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub peer_urls: Vec<String>,
|
||||
/// Virtual IP address (optional, auto-assigned if not set)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub virtual_ip: Option<String>,
|
||||
}
|
||||
|
||||
/// Combined extensions configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(default)]
|
||||
@@ -172,53 +126,37 @@ pub struct ExtensionsConfig {
|
||||
pub easytier: EasytierConfig,
|
||||
}
|
||||
|
||||
/// Extension info with status and config
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtensionInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
}
|
||||
|
||||
/// ttyd extension info
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TtydInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
/// Configuration
|
||||
pub config: TtydConfig,
|
||||
}
|
||||
|
||||
/// gostc extension info
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GostcInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
/// Configuration
|
||||
pub config: GostcConfig,
|
||||
}
|
||||
|
||||
/// easytier extension info
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EasytierInfo {
|
||||
/// Whether binary exists
|
||||
pub available: bool,
|
||||
/// Current status
|
||||
pub status: ExtensionStatus,
|
||||
/// Configuration
|
||||
pub config: EasytierConfig,
|
||||
}
|
||||
|
||||
/// All extensions status response
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtensionsStatus {
|
||||
@@ -227,7 +165,6 @@ pub struct ExtensionsStatus {
|
||||
pub easytier: EasytierInfo,
|
||||
}
|
||||
|
||||
/// Extension logs response
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExtensionLogs {
|
||||
|
||||
@@ -1,73 +1,32 @@
|
||||
//! HID backend trait definition
|
||||
//! `HidBackend` trait plus serde `HidBackendType` (OTG | CH9329 | disabled).
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::watch;
|
||||
|
||||
use super::otg::LedState;
|
||||
use super::types::{ConsumerEvent, KeyboardEvent, MouseEvent};
|
||||
use crate::error::Result;
|
||||
use crate::events::LedState;
|
||||
|
||||
/// Default CH9329 baud rate
|
||||
fn default_ch9329_baud_rate() -> u32 {
|
||||
9600
|
||||
}
|
||||
|
||||
/// HID backend type
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
#[derive(Default)]
|
||||
pub enum HidBackendType {
|
||||
/// USB OTG gadget mode
|
||||
Otg,
|
||||
/// CH9329 serial HID controller
|
||||
Ch9329 {
|
||||
/// Serial port path
|
||||
port: String,
|
||||
/// Baud rate (default: 9600)
|
||||
#[serde(default = "default_ch9329_baud_rate")]
|
||||
baud_rate: u32,
|
||||
},
|
||||
/// No HID backend (disabled)
|
||||
#[default]
|
||||
None,
|
||||
}
|
||||
|
||||
impl HidBackendType {
|
||||
/// Check if OTG backend is available on this system
|
||||
pub fn otg_available() -> bool {
|
||||
// Check for USB gadget support
|
||||
std::path::Path::new("/sys/class/udc").exists()
|
||||
}
|
||||
|
||||
/// Detect the best available backend
|
||||
pub fn detect() -> Self {
|
||||
// Check for OTG gadget support
|
||||
if Self::otg_available() {
|
||||
return Self::Otg;
|
||||
}
|
||||
|
||||
// Check for common CH9329 serial ports
|
||||
let common_ports = [
|
||||
"/dev/ttyUSB0",
|
||||
"/dev/ttyUSB1",
|
||||
"/dev/ttyAMA0",
|
||||
"/dev/serial0",
|
||||
];
|
||||
|
||||
for port in &common_ports {
|
||||
if std::path::Path::new(port).exists() {
|
||||
return Self::Ch9329 {
|
||||
port: port.to_string(),
|
||||
baud_rate: 9600, // Use default baud rate for auto-detection
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Self::None
|
||||
}
|
||||
|
||||
/// Get backend name as string
|
||||
pub fn name_str(&self) -> &str {
|
||||
match self {
|
||||
Self::Otg => "otg",
|
||||
@@ -77,76 +36,40 @@ impl HidBackendType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Current runtime status reported by a HID backend.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct HidBackendRuntimeSnapshot {
|
||||
/// Whether the backend has been initialized and can accept requests.
|
||||
pub initialized: bool,
|
||||
/// Whether the backend is currently online and communicating successfully.
|
||||
pub online: bool,
|
||||
/// Whether absolute mouse positioning is supported.
|
||||
pub supports_absolute_mouse: bool,
|
||||
/// Whether keyboard LED/status feedback is currently enabled.
|
||||
pub keyboard_leds_enabled: bool,
|
||||
/// Last known keyboard LED state.
|
||||
pub led_state: LedState,
|
||||
/// Screen resolution for absolute mouse mode.
|
||||
pub screen_resolution: Option<(u32, u32)>,
|
||||
/// Device identifier associated with the backend, if any.
|
||||
pub device: Option<String>,
|
||||
/// Current user-facing error, if any.
|
||||
pub error: Option<String>,
|
||||
/// Current programmatic error code, if any.
|
||||
pub error_code: Option<String>,
|
||||
}
|
||||
|
||||
/// HID backend trait
|
||||
#[async_trait]
|
||||
pub trait HidBackend: Send + Sync {
|
||||
/// Initialize the backend
|
||||
async fn init(&self) -> Result<()>;
|
||||
|
||||
/// Send a keyboard event
|
||||
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>;
|
||||
|
||||
/// Send a mouse event
|
||||
async fn send_mouse(&self, event: MouseEvent) -> Result<()>;
|
||||
|
||||
/// Send a consumer control event (multimedia keys)
|
||||
/// Default implementation returns an error (not supported)
|
||||
async fn send_consumer(&self, _event: ConsumerEvent) -> Result<()> {
|
||||
Err(crate::error::AppError::BadRequest(
|
||||
"Consumer control not supported by this backend".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Reset all inputs (release all keys/buttons)
|
||||
async fn reset(&self) -> Result<()>;
|
||||
|
||||
/// Shutdown the backend
|
||||
async fn shutdown(&self) -> Result<()>;
|
||||
|
||||
/// Get the current backend runtime snapshot.
|
||||
fn runtime_snapshot(&self) -> HidBackendRuntimeSnapshot;
|
||||
|
||||
/// Subscribe to backend runtime changes.
|
||||
fn subscribe_runtime(&self) -> watch::Receiver<()>;
|
||||
|
||||
/// Set screen resolution (for absolute mouse)
|
||||
fn set_screen_resolution(&mut self, _width: u32, _height: u32) {}
|
||||
}
|
||||
|
||||
/// HID backend information
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct HidBackendInfo {
|
||||
/// Backend name
|
||||
pub name: String,
|
||||
/// Backend type
|
||||
pub backend_type: String,
|
||||
/// Is initialized
|
||||
pub initialized: bool,
|
||||
/// Supports absolute mouse
|
||||
pub absolute_mouse: bool,
|
||||
/// Screen resolution (if absolute mouse)
|
||||
pub resolution: Option<(u32, u32)>,
|
||||
fn set_screen_resolution(&self, _width: u32, _height: u32) {}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
//! CH9329 Serial HID Controller backend
|
||||
//!
|
||||
//! CH9329 is a USB HID chip controlled via UART from WCH (沁恒).
|
||||
//! It supports keyboard, mouse (absolute + relative), and custom HID device emulation.
|
||||
//!
|
||||
//! ## Protocol Format
|
||||
//! CH9329 over UART — WCH *Serial Communication Protocol V1.0*.
|
||||
//! ```text
|
||||
//! ┌──────┬──────┬──────┬────────┬──────────────┬──────────┐
|
||||
//! │Header│ ADDR │ CMD │ LEN │ DATA │ SUM │
|
||||
@@ -11,16 +6,11 @@
|
||||
//! │57 AB │ 00 │ xx │ N │ N bytes │Checksum │
|
||||
//! └──────┴──────┴──────┴────────┴──────────────┴──────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! Checksum: Sum of ALL bytes including header (modulo 256)
|
||||
//!
|
||||
//! ## Reference
|
||||
//! Based on WCH CH9329 Serial Communication Protocol V1.0
|
||||
//! Sum of all octets modulo 256 (including header).
|
||||
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::thread;
|
||||
@@ -29,82 +19,51 @@ use tokio::sync::watch;
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
use super::backend::{HidBackend, HidBackendRuntimeSnapshot};
|
||||
use super::otg::LedState;
|
||||
use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::events::LedState;
|
||||
|
||||
// ============================================================================
|
||||
// Constants and Command Codes
|
||||
// ============================================================================
|
||||
|
||||
/// CH9329 packet header
|
||||
const PACKET_HEADER: [u8; 2] = [0x57, 0xAB];
|
||||
|
||||
/// Default address (accepts any address)
|
||||
const DEFAULT_ADDR: u8 = 0x00;
|
||||
|
||||
/// Default baud rate for CH9329
|
||||
pub const DEFAULT_BAUD_RATE: u32 = 9600;
|
||||
|
||||
/// Response timeout in milliseconds
|
||||
const RESPONSE_TIMEOUT_MS: u64 = 500;
|
||||
|
||||
/// Maximum data length in a packet
|
||||
const MAX_DATA_LEN: usize = 64;
|
||||
|
||||
/// CH9329 absolute mouse resolution
|
||||
const CH9329_MOUSE_RESOLUTION: u32 = 4096;
|
||||
|
||||
/// How often the worker probes the chip when idle.
|
||||
const PROBE_INTERVAL_MS: u64 = 100;
|
||||
|
||||
/// How long the worker waits before reopening the serial port after a failure.
|
||||
const RECONNECT_DELAY_MS: u64 = 2000;
|
||||
|
||||
/// Initial startup wait for the worker to confirm CH9329 is reachable.
|
||||
const INIT_WAIT_MS: u64 = 3000;
|
||||
|
||||
/// CH9329 command codes
|
||||
pub mod cmd {
|
||||
/// Get chip version, USB status, and LED status
|
||||
pub const GET_INFO: u8 = 0x01;
|
||||
/// Send standard keyboard data (8 bytes)
|
||||
pub const SEND_KB_GENERAL_DATA: u8 = 0x02;
|
||||
/// Send multimedia keyboard data
|
||||
pub const SEND_KB_MEDIA_DATA: u8 = 0x03;
|
||||
/// Send absolute mouse data
|
||||
pub const SEND_MS_ABS_DATA: u8 = 0x04;
|
||||
/// Send relative mouse data
|
||||
pub const SEND_MS_REL_DATA: u8 = 0x05;
|
||||
/// Send custom HID data
|
||||
pub const SEND_MY_HID_DATA: u8 = 0x06;
|
||||
/// Restore factory default configuration
|
||||
pub const SET_DEFAULT_CFG: u8 = 0x0C;
|
||||
/// Software reset
|
||||
pub const RESET: u8 = 0x0F;
|
||||
}
|
||||
|
||||
/// Response command mask (success = cmd | 0x80, error = cmd | 0xC0)
|
||||
const RESPONSE_SUCCESS_MASK: u8 = 0x80;
|
||||
const RESPONSE_ERROR_MASK: u8 = 0xC0;
|
||||
|
||||
/// CH9329 error codes
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum Ch9329Error {
|
||||
/// Command executed successfully
|
||||
Success = 0x00,
|
||||
/// Serial receive timeout
|
||||
Timeout = 0xE1,
|
||||
/// Invalid packet header
|
||||
InvalidHeader = 0xE2,
|
||||
/// Invalid command code
|
||||
InvalidCommand = 0xE3,
|
||||
/// Checksum mismatch
|
||||
ChecksumError = 0xE4,
|
||||
/// Parameter error
|
||||
ParameterError = 0xE5,
|
||||
/// Execution failed
|
||||
OperationFailed = 0xE6,
|
||||
}
|
||||
|
||||
@@ -137,29 +96,17 @@ impl std::fmt::Display for Ch9329Error {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Chip Information
|
||||
// ============================================================================
|
||||
|
||||
/// CH9329 chip information
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ChipInfo {
|
||||
/// Chip version (e.g., "V1.0", "V1.1")
|
||||
pub version: String,
|
||||
/// Raw version byte
|
||||
pub version_raw: u8,
|
||||
/// USB connection status
|
||||
pub usb_connected: bool,
|
||||
/// Num Lock LED state
|
||||
pub num_lock: bool,
|
||||
/// Caps Lock LED state
|
||||
pub caps_lock: bool,
|
||||
/// Scroll Lock LED state
|
||||
pub scroll_lock: bool,
|
||||
}
|
||||
|
||||
impl ChipInfo {
|
||||
/// Parse chip info from response data (8 bytes)
|
||||
pub fn from_response(data: &[u8]) -> Option<Self> {
|
||||
if data.len() < 8 {
|
||||
return None;
|
||||
@@ -181,7 +128,6 @@ impl ChipInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyboard LED status
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct LedStatus {
|
||||
pub num_lock: bool,
|
||||
@@ -199,98 +145,21 @@ impl From<u8> for LedStatus {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// CH9329 work mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
#[derive(Default)]
|
||||
pub enum WorkMode {
|
||||
/// Mode 0: Standard USB Keyboard + Mouse (default)
|
||||
#[default]
|
||||
KeyboardMouse = 0x00,
|
||||
/// Mode 1: Standard USB Keyboard only
|
||||
KeyboardOnly = 0x01,
|
||||
/// Mode 2: Standard USB Mouse only
|
||||
MouseOnly = 0x02,
|
||||
/// Mode 3: Custom HID device
|
||||
CustomHid = 0x03,
|
||||
}
|
||||
|
||||
/// CH9329 serial communication mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[repr(u8)]
|
||||
#[derive(Default)]
|
||||
pub enum SerialMode {
|
||||
/// Mode 0: Protocol transmission mode (default)
|
||||
#[default]
|
||||
Protocol = 0x00,
|
||||
/// Mode 1: ASCII mode
|
||||
Ascii = 0x01,
|
||||
/// Mode 2: Transparent mode
|
||||
Transparent = 0x02,
|
||||
}
|
||||
|
||||
/// CH9329 configuration parameters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Ch9329Config {
|
||||
/// Work mode
|
||||
pub work_mode: WorkMode,
|
||||
/// Serial communication mode
|
||||
pub serial_mode: SerialMode,
|
||||
/// Device address (0x00-0xFE, 0xFF = broadcast)
|
||||
pub address: u8,
|
||||
/// Baud rate
|
||||
pub baud_rate: u32,
|
||||
/// USB VID
|
||||
pub vid: u16,
|
||||
/// USB PID
|
||||
pub pid: u16,
|
||||
}
|
||||
|
||||
impl Default for Ch9329Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
work_mode: WorkMode::KeyboardMouse,
|
||||
serial_mode: SerialMode::Protocol,
|
||||
address: 0x00,
|
||||
baud_rate: 9600,
|
||||
vid: 0x1A86,
|
||||
pid: 0xE129,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Parsing
|
||||
// ============================================================================
|
||||
|
||||
/// Parsed response from CH9329
|
||||
#[derive(Debug)]
|
||||
pub struct Response {
|
||||
/// Address byte
|
||||
pub address: u8,
|
||||
/// Command code (with response bits)
|
||||
pub cmd: u8,
|
||||
/// Data payload
|
||||
pub data: Vec<u8>,
|
||||
/// Whether this is an error response
|
||||
pub is_error: bool,
|
||||
/// Error code (if is_error)
|
||||
pub error_code: Option<Ch9329Error>,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
/// Parse a response from raw bytes
|
||||
pub fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
// Minimum: Header(2) + Addr(1) + Cmd(1) + Len(1) + Sum(1) = 6
|
||||
if bytes.len() < 6 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check header
|
||||
if bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] {
|
||||
return None;
|
||||
}
|
||||
@@ -299,12 +168,10 @@ impl Response {
|
||||
let cmd = bytes[3];
|
||||
let len = bytes[4] as usize;
|
||||
|
||||
// Check if we have enough bytes
|
||||
if bytes.len() < 5 + len + 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Verify checksum
|
||||
let expected_checksum = bytes[5 + len];
|
||||
let calculated_checksum = bytes[..5 + len]
|
||||
.iter()
|
||||
@@ -335,19 +202,13 @@ impl Response {
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the response indicates success
|
||||
pub fn is_success(&self) -> bool {
|
||||
!self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8)
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum packet size (header 2 + addr 1 + cmd 1 + len 1 + data 64 + checksum 1 = 70)
|
||||
const MAX_PACKET_SIZE: usize = 70;
|
||||
|
||||
// ============================================================================
|
||||
// CH9329 Backend Implementation
|
||||
// ============================================================================
|
||||
|
||||
struct Ch9329RuntimeState {
|
||||
initialized: AtomicBool,
|
||||
online: AtomicBool,
|
||||
@@ -424,47 +285,28 @@ enum WorkerCommand {
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// CH9329 HID backend
|
||||
pub struct Ch9329Backend {
|
||||
/// Serial port path
|
||||
port_path: String,
|
||||
/// Baud rate
|
||||
baud_rate: u32,
|
||||
/// Worker command sender
|
||||
worker_tx: Mutex<Option<mpsc::Sender<WorkerCommand>>>,
|
||||
/// Background worker thread
|
||||
worker_handle: Mutex<Option<thread::JoinHandle<()>>>,
|
||||
/// Current keyboard state
|
||||
keyboard_state: Mutex<KeyboardReport>,
|
||||
/// Current mouse button state
|
||||
mouse_buttons: AtomicU8,
|
||||
/// Screen width for absolute mouse coordinate conversion
|
||||
screen_width: u32,
|
||||
/// Screen height for absolute mouse coordinate conversion
|
||||
screen_height: u32,
|
||||
/// Cached chip information
|
||||
screen_resolution: RwLock<(u32, u32)>,
|
||||
chip_info: Arc<RwLock<Option<ChipInfo>>>,
|
||||
/// LED status cache
|
||||
led_status: Arc<RwLock<LedStatus>>,
|
||||
/// Device address (default 0x00)
|
||||
address: u8,
|
||||
/// Last absolute mouse X position (CH9329 coordinate: 0-4095)
|
||||
last_abs_x: AtomicU16,
|
||||
/// Last absolute mouse Y position (CH9329 coordinate: 0-4095)
|
||||
last_abs_y: AtomicU16,
|
||||
/// Whether relative mouse mode is active (set by incoming events)
|
||||
relative_mouse_active: AtomicBool,
|
||||
/// Shared runtime status updated only by the worker.
|
||||
runtime: Arc<Ch9329RuntimeState>,
|
||||
}
|
||||
|
||||
impl Ch9329Backend {
|
||||
/// Create a new CH9329 backend with default baud rate (9600)
|
||||
pub fn new(port_path: &str) -> Result<Self> {
|
||||
Self::with_baud_rate(port_path, DEFAULT_BAUD_RATE)
|
||||
}
|
||||
|
||||
/// Create a new CH9329 backend with custom baud rate
|
||||
pub fn with_baud_rate(port_path: &str, baud_rate: u32) -> Result<Self> {
|
||||
Ok(Self {
|
||||
port_path: port_path.to_string(),
|
||||
@@ -473,8 +315,7 @@ impl Ch9329Backend {
|
||||
worker_handle: Mutex::new(None),
|
||||
keyboard_state: Mutex::new(KeyboardReport::default()),
|
||||
mouse_buttons: AtomicU8::new(0),
|
||||
screen_width: 1920,
|
||||
screen_height: 1080,
|
||||
screen_resolution: RwLock::new((1920, 1080)),
|
||||
chip_info: Arc::new(RwLock::new(None)),
|
||||
led_status: Arc::new(RwLock::new(LedStatus::default())),
|
||||
address: DEFAULT_ADDR,
|
||||
@@ -489,12 +330,10 @@ impl Ch9329Backend {
|
||||
self.runtime.set_error(reason, error_code);
|
||||
}
|
||||
|
||||
/// Check if the serial port device file exists
|
||||
pub fn check_port_exists(&self) -> bool {
|
||||
std::path::Path::new(&self.port_path).exists()
|
||||
}
|
||||
|
||||
/// Convert serialport error to HidError
|
||||
fn serial_error_to_hid_error(e: serialport::Error, operation: &str) -> AppError {
|
||||
let error_code = match e.kind() {
|
||||
serialport::ErrorKind::NoDevice => "port_not_found",
|
||||
@@ -518,16 +357,11 @@ impl Ch9329Backend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate checksum for CH9329 packet (sum of ALL bytes including header)
|
||||
#[inline]
|
||||
fn calculate_checksum(data: &[u8]) -> u8 {
|
||||
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
|
||||
}
|
||||
|
||||
/// Build a CH9329 packet into a stack-allocated buffer
|
||||
///
|
||||
/// Packet format: `[Header 0x57 0xAB] [Address] [Command] [Length] [Data] [Checksum]`
|
||||
/// Returns the packet buffer and the actual length
|
||||
#[inline]
|
||||
fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) {
|
||||
debug_assert!(
|
||||
@@ -539,25 +373,18 @@ impl Ch9329Backend {
|
||||
let packet_len = 6 + data.len();
|
||||
let mut packet = [0u8; MAX_PACKET_SIZE];
|
||||
|
||||
// Header (2 bytes)
|
||||
packet[0] = PACKET_HEADER[0];
|
||||
packet[1] = PACKET_HEADER[1];
|
||||
// Address (1 byte)
|
||||
packet[2] = address;
|
||||
// Command (1 byte)
|
||||
packet[3] = cmd;
|
||||
// Length (1 byte) - data length only
|
||||
packet[4] = len;
|
||||
// Data (N bytes)
|
||||
packet[5..5 + data.len()].copy_from_slice(data);
|
||||
// Checksum (1 byte) - sum of ALL bytes including header
|
||||
let checksum = Self::calculate_checksum(&packet[..5 + data.len()]);
|
||||
packet[5 + data.len()] = checksum;
|
||||
|
||||
(packet, packet_len)
|
||||
}
|
||||
|
||||
/// Build a CH9329 packet (legacy Vec version for compatibility)
|
||||
fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec<u8> {
|
||||
let (buf, len) = Self::build_packet_buf(address, cmd, data);
|
||||
buf[..len].to_vec()
|
||||
@@ -984,10 +811,6 @@ impl Ch9329Backend {
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// HidBackend Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[async_trait]
|
||||
impl HidBackend for Ch9329Backend {
|
||||
async fn init(&self) -> Result<()> {
|
||||
@@ -1060,7 +883,6 @@ impl HidBackend for Ch9329Backend {
|
||||
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
|
||||
let usb_key = event.key.to_hid_usage();
|
||||
|
||||
// Handle modifier keys separately
|
||||
if event.key.is_modifier() {
|
||||
let mut state = self.keyboard_state.lock();
|
||||
|
||||
@@ -1078,7 +900,6 @@ impl HidBackend for Ch9329Backend {
|
||||
} else {
|
||||
let mut state = self.keyboard_state.lock();
|
||||
|
||||
// Update modifiers from event
|
||||
state.modifiers = event.modifiers.to_hid_byte();
|
||||
|
||||
match event.event_type {
|
||||
@@ -1104,19 +925,15 @@ impl HidBackend for Ch9329Backend {
|
||||
|
||||
match event.event_type {
|
||||
MouseEventType::Move => {
|
||||
// Relative movement - send delta directly without inversion
|
||||
self.relative_mouse_active.store(true, Ordering::Relaxed);
|
||||
let dx = event.x.clamp(-127, 127) as i8;
|
||||
let dy = event.y.clamp(-127, 127) as i8;
|
||||
self.send_mouse_relative(buttons, dx, dy, 0)?;
|
||||
}
|
||||
MouseEventType::MoveAbs => {
|
||||
// Absolute movement
|
||||
self.relative_mouse_active.store(false, Ordering::Relaxed);
|
||||
// Frontend sends 0-32767 (HID standard), CH9329 expects 0-4095
|
||||
let x = ((event.x.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16;
|
||||
let y = ((event.y.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16;
|
||||
// Store last absolute position for click events
|
||||
self.last_abs_x.store(x, Ordering::Relaxed);
|
||||
self.last_abs_y.store(y, Ordering::Relaxed);
|
||||
self.send_mouse_absolute(buttons, x, y, 0)?;
|
||||
@@ -1153,7 +970,6 @@ impl HidBackend for Ch9329Backend {
|
||||
if self.relative_mouse_active.load(Ordering::Relaxed) {
|
||||
self.send_mouse_relative(buttons, 0, 0, event.scroll)?;
|
||||
} else {
|
||||
// Use absolute mouse for scroll with last position
|
||||
let x = self.last_abs_x.load(Ordering::Relaxed);
|
||||
let y = self.last_abs_y.load(Ordering::Relaxed);
|
||||
self.send_mouse_absolute(buttons, x, y, event.scroll)?;
|
||||
@@ -1165,7 +981,6 @@ impl HidBackend for Ch9329Backend {
|
||||
}
|
||||
|
||||
async fn reset(&self) -> Result<()> {
|
||||
// Reset keyboard
|
||||
{
|
||||
let mut state = self.keyboard_state.lock();
|
||||
state.clear();
|
||||
@@ -1174,14 +989,12 @@ impl HidBackend for Ch9329Backend {
|
||||
self.send_keyboard_report(&report)?;
|
||||
}
|
||||
|
||||
// Reset mouse
|
||||
self.mouse_buttons.store(0, Ordering::Relaxed);
|
||||
self.last_abs_x.store(0, Ordering::Relaxed);
|
||||
self.last_abs_y.store(0, Ordering::Relaxed);
|
||||
self.relative_mouse_active.store(false, Ordering::Relaxed);
|
||||
self.send_mouse_absolute(0, 0, 0, 0)?;
|
||||
|
||||
// Reset media keys
|
||||
let _ = self.release_media_keys();
|
||||
|
||||
info!("CH9329 HID state reset");
|
||||
@@ -1233,7 +1046,7 @@ impl HidBackend for Ch9329Backend {
|
||||
kana: false,
|
||||
}
|
||||
},
|
||||
screen_resolution: Some((self.screen_width, self.screen_height)),
|
||||
screen_resolution: Some(*self.screen_resolution.read()),
|
||||
device: Some(self.port_path.clone()),
|
||||
error: error.as_ref().map(|(reason, _)| reason.clone()),
|
||||
error_code: error.as_ref().map(|(_, code)| code.clone()),
|
||||
@@ -1244,125 +1057,21 @@ impl HidBackend for Ch9329Backend {
|
||||
self.runtime.subscribe()
|
||||
}
|
||||
|
||||
fn set_screen_resolution(&mut self, width: u32, height: u32) {
|
||||
self.screen_width = width;
|
||||
self.screen_height = height;
|
||||
fn set_screen_resolution(&self, width: u32, height: u32) {
|
||||
*self.screen_resolution.write() = (width, height);
|
||||
self.runtime.notify();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Detection and Helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Detect CH9329 on common serial ports
|
||||
pub fn detect_ch9329() -> Option<String> {
|
||||
let common_ports = [
|
||||
"/dev/ttyUSB0",
|
||||
"/dev/ttyUSB1",
|
||||
"/dev/ttyAMA0",
|
||||
"/dev/serial0",
|
||||
"/dev/ttyS0",
|
||||
];
|
||||
|
||||
// Try multiple baud rates
|
||||
let baud_rates = [9600, 115200];
|
||||
|
||||
for port_path in &common_ports {
|
||||
if !std::path::Path::new(port_path).exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
for &baud_rate in &baud_rates {
|
||||
if let Ok(mut port) = serialport::new(*port_path, baud_rate)
|
||||
.timeout(Duration::from_millis(200))
|
||||
.open()
|
||||
{
|
||||
// Build GET_INFO packet manually (address = 0x00)
|
||||
let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03];
|
||||
|
||||
if port.write_all(&packet).is_ok() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
|
||||
let mut response = [0u8; 16];
|
||||
if let Ok(n) = port.read(&mut response) {
|
||||
// Check for valid CH9329 response header
|
||||
if n >= 6
|
||||
&& response[0] == PACKET_HEADER[0]
|
||||
&& response[1] == PACKET_HEADER[1]
|
||||
{
|
||||
info!("CH9329 detected on {} @ {} baud", port_path, baud_rate);
|
||||
return Some(port_path.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Detect CH9329 and return both path and working baud rate
|
||||
pub fn detect_ch9329_with_baud() -> Option<(String, u32)> {
|
||||
let common_ports = [
|
||||
"/dev/ttyUSB0",
|
||||
"/dev/ttyUSB1",
|
||||
"/dev/ttyAMA0",
|
||||
"/dev/serial0",
|
||||
"/dev/ttyS0",
|
||||
];
|
||||
|
||||
let baud_rates = [9600, 115200, 57600, 38400, 19200];
|
||||
|
||||
for port_path in &common_ports {
|
||||
if !std::path::Path::new(port_path).exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
for &baud_rate in &baud_rates {
|
||||
if let Ok(mut port) = serialport::new(*port_path, baud_rate)
|
||||
.timeout(Duration::from_millis(200))
|
||||
.open()
|
||||
{
|
||||
let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03];
|
||||
|
||||
if port.write_all(&packet).is_ok() {
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
|
||||
let mut response = [0u8; 16];
|
||||
if let Ok(n) = port.read(&mut response) {
|
||||
if n >= 6
|
||||
&& response[0] == PACKET_HEADER[0]
|
||||
&& response[1] == PACKET_HEADER[1]
|
||||
{
|
||||
info!("CH9329 detected on {} @ {} baud", port_path, baud_rate);
|
||||
return Some((port_path.to_string(), baud_rate));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_packet_building() {
|
||||
// Test GET_INFO packet (no data)
|
||||
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]);
|
||||
assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]);
|
||||
|
||||
// Test keyboard packet (8 bytes data)
|
||||
let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key
|
||||
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data);
|
||||
|
||||
@@ -1372,7 +1081,6 @@ mod tests {
|
||||
assert_eq!(packet[3], cmd::SEND_KB_GENERAL_DATA); // Command
|
||||
assert_eq!(packet[4], 8); // Length (8 data bytes)
|
||||
assert_eq!(&packet[5..13], &data); // Data
|
||||
// Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10
|
||||
let expected_checksum: u8 = packet[..13]
|
||||
.iter()
|
||||
.fold(0u8, |acc: u8, &x| acc.wrapping_add(x));
|
||||
@@ -1381,7 +1089,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_relative_mouse_packet() {
|
||||
// Test relative mouse: move right 50 pixels
|
||||
let data = [0x01, 0x00, 50u8, 0x00, 0x00];
|
||||
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data);
|
||||
|
||||
@@ -1397,12 +1104,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_checksum_calculation() {
|
||||
// Known packet: GET_INFO
|
||||
let packet = [0x57u8, 0xAB, 0x00, 0x01, 0x00];
|
||||
let checksum = Ch9329Backend::calculate_checksum(&packet);
|
||||
assert_eq!(checksum, 0x03);
|
||||
|
||||
// Known packet: Keyboard 'A' press
|
||||
let packet = [
|
||||
0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
];
|
||||
@@ -1412,7 +1117,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_response_parsing() {
|
||||
// Valid GET_INFO response
|
||||
let response_bytes = [
|
||||
0x57, 0xAB, // Header
|
||||
0x00, // Address
|
||||
@@ -1422,9 +1126,7 @@ mod tests {
|
||||
0xE0, // Checksum (calculated)
|
||||
];
|
||||
|
||||
// Note: checksum in test is just placeholder, parse will validate
|
||||
let _result = Response::parse(&response_bytes);
|
||||
// This will fail because checksum doesn't match, but structure is tested
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,21 +2,17 @@
|
||||
//!
|
||||
//! Reference: USB HID Usage Tables 1.12, Section 15 (Consumer Page 0x0C)
|
||||
|
||||
/// Consumer Control Usage codes for multimedia keys
|
||||
pub mod usage {
|
||||
// Transport Controls
|
||||
pub const PLAY_PAUSE: u16 = 0x00CD;
|
||||
pub const STOP: u16 = 0x00B7;
|
||||
pub const NEXT_TRACK: u16 = 0x00B5;
|
||||
pub const PREV_TRACK: u16 = 0x00B6;
|
||||
|
||||
// Volume Controls
|
||||
pub const MUTE: u16 = 0x00E2;
|
||||
pub const VOLUME_UP: u16 = 0x00E9;
|
||||
pub const VOLUME_DOWN: u16 = 0x00EA;
|
||||
}
|
||||
|
||||
/// Check if a usage code is valid
|
||||
pub fn is_valid_usage(usage: u16) -> bool {
|
||||
matches!(
|
||||
usage,
|
||||
|
||||
@@ -42,23 +42,19 @@ use super::{
|
||||
MouseEventType,
|
||||
};
|
||||
|
||||
/// Message types
|
||||
pub const MSG_KEYBOARD: u8 = 0x01;
|
||||
pub const MSG_MOUSE: u8 = 0x02;
|
||||
pub const MSG_CONSUMER: u8 = 0x03;
|
||||
|
||||
/// Keyboard event types
|
||||
pub const KB_EVENT_DOWN: u8 = 0x00;
|
||||
pub const KB_EVENT_UP: u8 = 0x01;
|
||||
|
||||
/// Mouse event types
|
||||
pub const MS_EVENT_MOVE: u8 = 0x00;
|
||||
pub const MS_EVENT_MOVE_ABS: u8 = 0x01;
|
||||
pub const MS_EVENT_DOWN: u8 = 0x02;
|
||||
pub const MS_EVENT_UP: u8 = 0x03;
|
||||
pub const MS_EVENT_SCROLL: u8 = 0x04;
|
||||
|
||||
/// Parsed HID event from DataChannel
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HidChannelEvent {
|
||||
Keyboard(KeyboardEvent),
|
||||
@@ -66,7 +62,6 @@ pub enum HidChannelEvent {
|
||||
Consumer(ConsumerEvent),
|
||||
}
|
||||
|
||||
/// Parse a binary HID message from DataChannel
|
||||
pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.is_empty() {
|
||||
warn!("Empty HID message");
|
||||
@@ -86,7 +81,6 @@ pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse keyboard message payload
|
||||
fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.len() < 3 {
|
||||
warn!("Keyboard message too short: {} bytes", data.len());
|
||||
@@ -129,7 +123,6 @@ fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
}))
|
||||
}
|
||||
|
||||
/// Parse mouse message payload
|
||||
fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.len() < 6 {
|
||||
warn!("Mouse message too short: {} bytes", data.len());
|
||||
@@ -148,11 +141,9 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
}
|
||||
};
|
||||
|
||||
// Parse coordinates as i16 LE (works for both relative and absolute)
|
||||
let x = i16::from_le_bytes([data[1], data[2]]) as i32;
|
||||
let y = i16::from_le_bytes([data[3], data[4]]) as i32;
|
||||
|
||||
// Button or scroll delta
|
||||
let (button, scroll) = match event_type {
|
||||
MouseEventType::Down | MouseEventType::Up => {
|
||||
let btn = match data[5] {
|
||||
@@ -178,7 +169,6 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
}))
|
||||
}
|
||||
|
||||
/// Parse consumer control message payload
|
||||
fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
if data.len() < 2 {
|
||||
warn!("Consumer message too short: {} bytes", data.len());
|
||||
@@ -190,7 +180,6 @@ fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
|
||||
Some(HidChannelEvent::Consumer(ConsumerEvent { usage }))
|
||||
}
|
||||
|
||||
/// Encode a keyboard event to binary format (for sending to client if needed)
|
||||
pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
|
||||
let event_type = match event.event_type {
|
||||
KeyEventType::Down => KB_EVENT_DOWN,
|
||||
@@ -207,40 +196,6 @@ pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
|
||||
]
|
||||
}
|
||||
|
||||
/// Encode a mouse event to binary format (for sending to client if needed)
|
||||
pub fn encode_mouse_event(event: &MouseEvent) -> Vec<u8> {
|
||||
let event_type = match event.event_type {
|
||||
MouseEventType::Move => MS_EVENT_MOVE,
|
||||
MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS,
|
||||
MouseEventType::Down => MS_EVENT_DOWN,
|
||||
MouseEventType::Up => MS_EVENT_UP,
|
||||
MouseEventType::Scroll => MS_EVENT_SCROLL,
|
||||
};
|
||||
|
||||
let x_bytes = (event.x as i16).to_le_bytes();
|
||||
let y_bytes = (event.y as i16).to_le_bytes();
|
||||
|
||||
let extra = match event.event_type {
|
||||
MouseEventType::Down | MouseEventType::Up => event
|
||||
.button
|
||||
.as_ref()
|
||||
.map(|b| match b {
|
||||
MouseButton::Left => 0u8,
|
||||
MouseButton::Middle => 1u8,
|
||||
MouseButton::Right => 2u8,
|
||||
MouseButton::Back => 3u8,
|
||||
MouseButton::Forward => 4u8,
|
||||
})
|
||||
.unwrap_or(0),
|
||||
MouseEventType::Scroll => event.scroll as u8,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
vec![
|
||||
MSG_MOUSE, event_type, x_bytes[0], x_bytes[1], y_bytes[0], y_bytes[1], extra,
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
/// Shared canonical keyboard key identifiers used across frontend and backend.
|
||||
///
|
||||
/// The enum names intentionally mirror `KeyboardEvent.code` style values so the
|
||||
/// browser, virtual keyboard, and HID backend can all speak the same language.
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum CanonicalKey {
|
||||
@@ -128,11 +124,6 @@ pub enum CanonicalKey {
|
||||
}
|
||||
|
||||
impl CanonicalKey {
|
||||
/// Convert the canonical key to a stable wire code.
|
||||
///
|
||||
/// The wire code intentionally matches the USB HID usage for keyboard page
|
||||
/// keys so existing low-level behavior stays intact while the semantic type
|
||||
/// becomes explicit.
|
||||
pub const fn to_hid_usage(self) -> u8 {
|
||||
match self {
|
||||
Self::KeyA => 0x04,
|
||||
@@ -255,7 +246,6 @@ impl CanonicalKey {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a wire code / USB HID usage to its canonical key.
|
||||
pub const fn from_hid_usage(usage: u8) -> Option<Self> {
|
||||
match usage {
|
||||
0x04 => Some(Self::KeyA),
|
||||
|
||||
167
src/hid/mod.rs
167
src/hid/mod.rs
@@ -1,15 +1,4 @@
|
||||
//! HID (Human Interface Device) control module
|
||||
//!
|
||||
//! This module provides keyboard and mouse control for remote KVM:
|
||||
//! - USB OTG gadget mode (native Linux USB gadget)
|
||||
//! - CH9329 serial HID controller
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC
|
||||
//! |
|
||||
//! [OTG | CH9329]
|
||||
//! ```
|
||||
//! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329.
|
||||
|
||||
pub mod backend;
|
||||
pub mod ch9329;
|
||||
@@ -20,51 +9,26 @@ pub mod otg;
|
||||
pub mod types;
|
||||
pub mod websocket;
|
||||
|
||||
pub use crate::events::LedState;
|
||||
pub use backend::{HidBackend, HidBackendRuntimeSnapshot, HidBackendType};
|
||||
pub use keyboard::CanonicalKey;
|
||||
pub use otg::LedState;
|
||||
pub use types::{
|
||||
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent,
|
||||
MouseEventType,
|
||||
};
|
||||
|
||||
/// HID backend information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HidInfo {
|
||||
/// Backend name
|
||||
pub name: String,
|
||||
/// Whether backend is initialized
|
||||
pub initialized: bool,
|
||||
/// Whether absolute mouse positioning is supported
|
||||
pub supports_absolute_mouse: bool,
|
||||
/// Screen resolution for absolute mouse
|
||||
pub screen_resolution: Option<(u32, u32)>,
|
||||
}
|
||||
|
||||
/// Unified HID runtime state used by snapshots and events.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HidRuntimeState {
|
||||
/// Whether a backend is configured and expected to exist.
|
||||
pub available: bool,
|
||||
/// Stable backend key: "otg", "ch9329", "none".
|
||||
pub backend: String,
|
||||
/// Whether the backend is currently initialized and operational.
|
||||
pub initialized: bool,
|
||||
/// Whether the backend is currently online.
|
||||
pub online: bool,
|
||||
/// Whether absolute mouse positioning is supported.
|
||||
pub supports_absolute_mouse: bool,
|
||||
/// Whether keyboard LED/status feedback is enabled.
|
||||
pub keyboard_leds_enabled: bool,
|
||||
/// Last known keyboard LED state.
|
||||
pub led_state: LedState,
|
||||
/// Screen resolution for absolute mouse mode.
|
||||
pub screen_resolution: Option<(u32, u32)>,
|
||||
/// Device path associated with the backend, if any.
|
||||
pub device: Option<String>,
|
||||
/// Current user-facing error, if any.
|
||||
pub error: Option<String>,
|
||||
/// Current programmatic error code, if any.
|
||||
pub error_code: Option<String>,
|
||||
}
|
||||
|
||||
@@ -140,45 +104,29 @@ const HID_EVENT_QUEUE_CAPACITY: usize = 64;
|
||||
const HID_EVENT_SEND_TIMEOUT_MS: u64 = 30;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum HidEvent {
|
||||
enum QueuedHidEvent {
|
||||
Keyboard(KeyboardEvent),
|
||||
Mouse(MouseEvent),
|
||||
Consumer(ConsumerEvent),
|
||||
Reset,
|
||||
}
|
||||
|
||||
/// HID controller managing keyboard and mouse input
|
||||
pub struct HidController {
|
||||
/// OTG Service reference (only used when backend is OTG)
|
||||
otg_service: Option<Arc<OtgService>>,
|
||||
/// Active backend
|
||||
backend: Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
|
||||
/// Backend type (mutable for reload)
|
||||
backend_type: Arc<RwLock<HidBackendType>>,
|
||||
/// Event bus for broadcasting state changes (optional)
|
||||
events: Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
|
||||
/// Unified HID runtime state.
|
||||
runtime_state: Arc<RwLock<HidRuntimeState>>,
|
||||
/// HID event queue sender (non-blocking)
|
||||
hid_tx: mpsc::Sender<HidEvent>,
|
||||
/// HID event queue receiver (moved into worker on first start)
|
||||
hid_rx: Mutex<Option<mpsc::Receiver<HidEvent>>>,
|
||||
/// Coalesced mouse move (latest)
|
||||
hid_tx: mpsc::Sender<QueuedHidEvent>,
|
||||
hid_rx: Mutex<Option<mpsc::Receiver<QueuedHidEvent>>>,
|
||||
pending_move: Arc<parking_lot::Mutex<Option<MouseEvent>>>,
|
||||
/// Pending move flag (fast path)
|
||||
pending_move_flag: Arc<AtomicBool>,
|
||||
/// Worker task handle
|
||||
hid_worker: Mutex<Option<JoinHandle<()>>>,
|
||||
/// Backend runtime subscription task handle
|
||||
runtime_worker: Mutex<Option<JoinHandle<()>>>,
|
||||
/// Backend initialization fast flag
|
||||
backend_available: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl HidController {
|
||||
/// Create a new HID controller with specified backend
|
||||
///
|
||||
/// For OTG backend, otg_service should be provided to support hot-reload
|
||||
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
|
||||
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
|
||||
Self {
|
||||
@@ -199,12 +147,10 @@ impl HidController {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Initialize the HID backend
|
||||
pub async fn init(&self) -> Result<()> {
|
||||
let backend_type = self.backend_type.read().await.clone();
|
||||
let backend: Arc<dyn HidBackend> = match backend_type {
|
||||
@@ -256,7 +202,6 @@ impl HidController {
|
||||
*self.backend.write().await = Some(backend);
|
||||
self.sync_runtime_state_from_backend().await;
|
||||
|
||||
// Start HID event worker (once)
|
||||
self.start_event_worker().await;
|
||||
self.restart_runtime_worker().await;
|
||||
|
||||
@@ -264,12 +209,10 @@ impl HidController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the HID backend and release resources
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down HID controller");
|
||||
self.stop_runtime_worker().await;
|
||||
|
||||
// Close the backend
|
||||
if let Some(backend) = self.backend.write().await.take() {
|
||||
if let Err(e) = backend.shutdown().await {
|
||||
warn!("Error shutting down HID backend: {}", e);
|
||||
@@ -290,17 +233,15 @@ impl HidController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send keyboard event
|
||||
pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
|
||||
if !self.backend_available.load(Ordering::Acquire) {
|
||||
return Err(AppError::BadRequest(
|
||||
"HID backend not available".to_string(),
|
||||
));
|
||||
}
|
||||
self.enqueue_event(HidEvent::Keyboard(event)).await
|
||||
self.enqueue_event(QueuedHidEvent::Keyboard(event)).await
|
||||
}
|
||||
|
||||
/// Send mouse event
|
||||
pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
|
||||
if !self.backend_available.load(Ordering::Acquire) {
|
||||
return Err(AppError::BadRequest(
|
||||
@@ -312,81 +253,55 @@ impl HidController {
|
||||
event.event_type,
|
||||
MouseEventType::Move | MouseEventType::MoveAbs
|
||||
) {
|
||||
// Best-effort: drop/merge move events if queue is full
|
||||
self.enqueue_mouse_move(event)
|
||||
} else {
|
||||
self.enqueue_event(HidEvent::Mouse(event)).await
|
||||
self.enqueue_event(QueuedHidEvent::Mouse(event)).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send consumer control event (multimedia keys)
|
||||
pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
|
||||
if !self.backend_available.load(Ordering::Acquire) {
|
||||
return Err(AppError::BadRequest(
|
||||
"HID backend not available".to_string(),
|
||||
));
|
||||
}
|
||||
self.enqueue_event(HidEvent::Consumer(event)).await
|
||||
self.enqueue_event(QueuedHidEvent::Consumer(event)).await
|
||||
}
|
||||
|
||||
/// Reset all keys (release all pressed keys)
|
||||
pub async fn reset(&self) -> Result<()> {
|
||||
if !self.backend_available.load(Ordering::Acquire) {
|
||||
return Ok(());
|
||||
}
|
||||
// Reset is important but best-effort; enqueue to avoid blocking
|
||||
self.enqueue_event(HidEvent::Reset).await
|
||||
self.enqueue_event(QueuedHidEvent::Reset).await
|
||||
}
|
||||
|
||||
/// Check if backend is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
self.backend_available.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
/// Get backend type
|
||||
pub async fn backend_type(&self) -> HidBackendType {
|
||||
self.backend_type.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get backend info
|
||||
pub async fn info(&self) -> Option<HidInfo> {
|
||||
let state = self.runtime_state.read().await.clone();
|
||||
if !state.available {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(HidInfo {
|
||||
name: state.backend,
|
||||
initialized: state.initialized,
|
||||
supports_absolute_mouse: state.supports_absolute_mouse,
|
||||
screen_resolution: state.screen_resolution,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current HID runtime state snapshot.
|
||||
pub async fn snapshot(&self) -> HidRuntimeState {
|
||||
self.runtime_state.read().await.clone()
|
||||
}
|
||||
|
||||
/// Reload the HID backend with new type
|
||||
pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> {
|
||||
info!("Reloading HID backend: {:?}", new_backend_type);
|
||||
self.backend_available.store(false, Ordering::Release);
|
||||
self.stop_runtime_worker().await;
|
||||
|
||||
// Shutdown existing backend first
|
||||
if let Some(backend) = self.backend.write().await.take() {
|
||||
if let Err(e) = backend.shutdown().await {
|
||||
warn!("Error shutting down old HID backend: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create and initialize new backend
|
||||
let new_backend: Option<Arc<dyn HidBackend>> = match new_backend_type {
|
||||
HidBackendType::Otg => {
|
||||
info!("Initializing OTG HID backend");
|
||||
|
||||
// Get OtgService reference
|
||||
let otg_service = match self.otg_service.as_ref() {
|
||||
Some(svc) => svc,
|
||||
None => {
|
||||
@@ -398,28 +313,25 @@ impl HidController {
|
||||
};
|
||||
|
||||
match otg_service.hid_device_paths().await {
|
||||
Some(handles) => {
|
||||
// Create OtgBackend from handles
|
||||
match otg::OtgBackend::from_handles(handles) {
|
||||
Ok(backend) => {
|
||||
let backend = Arc::new(backend);
|
||||
match backend.init().await {
|
||||
Ok(_) => {
|
||||
info!("OTG backend initialized successfully");
|
||||
Some(backend)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to initialize OTG backend: {}", e);
|
||||
None
|
||||
}
|
||||
Some(handles) => match otg::OtgBackend::from_handles(handles) {
|
||||
Ok(backend) => {
|
||||
let backend = Arc::new(backend);
|
||||
match backend.init().await {
|
||||
Ok(_) => {
|
||||
info!("OTG backend initialized successfully");
|
||||
Some(backend)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to initialize OTG backend: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create OTG backend: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create OTG backend: {}", e);
|
||||
None
|
||||
}
|
||||
},
|
||||
None => {
|
||||
warn!("OTG HID paths are not available");
|
||||
None
|
||||
@@ -470,7 +382,6 @@ impl HidController {
|
||||
info!("HID backend reloaded successfully: {:?}", new_backend_type);
|
||||
self.start_event_worker().await;
|
||||
|
||||
// Update backend_type on success
|
||||
*self.backend_type.write().await = new_backend_type.clone();
|
||||
|
||||
self.sync_runtime_state_from_backend().await;
|
||||
@@ -481,7 +392,6 @@ impl HidController {
|
||||
warn!("HID backend reload resulted in no active backend");
|
||||
self.backend_available.store(false, Ordering::Release);
|
||||
|
||||
// Update backend_type even on failure (to reflect the attempted change)
|
||||
*self.backend_type.write().await = new_backend_type.clone();
|
||||
|
||||
let current = self.runtime_state.read().await.clone();
|
||||
@@ -541,11 +451,10 @@ impl HidController {
|
||||
|
||||
process_hid_event(event, &backend).await;
|
||||
|
||||
// After each event, flush latest move if pending
|
||||
if pending_move_flag.swap(false, Ordering::AcqRel) {
|
||||
let move_event = { pending_move.lock().take() };
|
||||
if let Some(move_event) = move_event {
|
||||
process_hid_event(HidEvent::Mouse(move_event), &backend).await;
|
||||
process_hid_event(QueuedHidEvent::Mouse(move_event), &backend).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -595,7 +504,7 @@ impl HidController {
|
||||
}
|
||||
|
||||
fn enqueue_mouse_move(&self, event: MouseEvent) -> Result<()> {
|
||||
match self.hid_tx.try_send(HidEvent::Mouse(event.clone())) {
|
||||
match self.hid_tx.try_send(QueuedHidEvent::Mouse(event.clone())) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
*self.pending_move.lock() = Some(event);
|
||||
@@ -608,11 +517,10 @@ impl HidController {
|
||||
}
|
||||
}
|
||||
|
||||
async fn enqueue_event(&self, event: HidEvent) -> Result<()> {
|
||||
async fn enqueue_event(&self, event: QueuedHidEvent) -> Result<()> {
|
||||
match self.hid_tx.try_send(event) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(mpsc::error::TrySendError::Full(ev)) => {
|
||||
// For non-move events, wait briefly to avoid dropping critical input
|
||||
let tx = self.hid_tx.clone();
|
||||
let send_result = tokio::time::timeout(
|
||||
Duration::from_millis(HID_EVENT_SEND_TIMEOUT_MS),
|
||||
@@ -649,7 +557,10 @@ async fn apply_backend_runtime_state(
|
||||
apply_runtime_state(runtime_state, events, next).await;
|
||||
}
|
||||
|
||||
async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>) {
|
||||
async fn process_hid_event(
|
||||
event: QueuedHidEvent,
|
||||
backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
|
||||
) {
|
||||
let backend_opt = backend.read().await.clone();
|
||||
let backend = match backend_opt {
|
||||
Some(b) => b,
|
||||
@@ -660,10 +571,10 @@ async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
futures::executor::block_on(async move {
|
||||
match event {
|
||||
HidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
|
||||
HidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
|
||||
HidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
|
||||
HidEvent::Reset => backend_for_send.reset().await,
|
||||
QueuedHidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
|
||||
QueuedHidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
|
||||
QueuedHidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
|
||||
QueuedHidEvent::Reset => backend_for_send.reset().await,
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -682,12 +593,6 @@ async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HidController {
|
||||
fn default() -> Self {
|
||||
Self::new(HidBackendType::None, None)
|
||||
}
|
||||
}
|
||||
|
||||
fn device_for_backend_type(backend_type: &HidBackendType) -> Option<String> {
|
||||
match backend_type {
|
||||
HidBackendType::Ch9329 { port, .. } => Some(port.clone()),
|
||||
|
||||
234
src/hid/otg.rs
234
src/hid/otg.rs
@@ -1,28 +1,11 @@
|
||||
//! OTG USB Gadget HID backend
|
||||
//! Linux gadget HID: `/dev/hidg*` opened from [`crate::otg::OtgService`].
|
||||
//! Typical nodes: hidg0 keyboard, hidg1 relative mouse, hidg2 absolute, hidg3 consumer control.
|
||||
//!
|
||||
//! This backend uses Linux USB Gadget API to emulate USB HID devices.
|
||||
//! It opens the HID gadget device nodes created by `OtgService`.
|
||||
//! Depending on the configured OTG profile, this may include:
|
||||
//! - hidg0: Keyboard
|
||||
//! - hidg1: Relative Mouse
|
||||
//! - hidg2: Absolute Mouse
|
||||
//! - hidg3: Consumer Control Keyboard
|
||||
//!
|
||||
//! Requirements:
|
||||
//! - USB OTG/Device controller (UDC)
|
||||
//! - ConfigFS with USB gadget support
|
||||
//! - Root privileges for gadget setup
|
||||
//!
|
||||
//! Error Recovery:
|
||||
//! This module implements automatic device reconnection based on PiKVM's approach.
|
||||
//! When ESHUTDOWN or EAGAIN errors occur (common during MSD operations), the device
|
||||
//! file handles are closed and reopened on the next operation.
|
||||
//! See: https://github.com/raspberrypi/linux/issues/4373
|
||||
//! Polled timed writes (JetKVM-style). Treat `ESHUTDOWN` (108) by closing handles and reopening; keep fd on `EAGAIN` (11). Host/gadget teardown during MSD resembles PiKVM. <https://github.com/raspberrypi/linux/issues/4373>
|
||||
|
||||
use async_trait::async_trait;
|
||||
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write};
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
@@ -40,9 +23,9 @@ use super::types::{
|
||||
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType,
|
||||
};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::events::LedState;
|
||||
use crate::otg::{wait_for_hid_devices, HidDevicePaths};
|
||||
|
||||
/// Device type for ensure_device operations
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum DeviceType {
|
||||
Keyboard,
|
||||
@@ -51,23 +34,7 @@ enum DeviceType {
|
||||
ConsumerControl,
|
||||
}
|
||||
|
||||
/// Keyboard LED state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct LedState {
|
||||
/// Num Lock LED
|
||||
pub num_lock: bool,
|
||||
/// Caps Lock LED
|
||||
pub caps_lock: bool,
|
||||
/// Scroll Lock LED
|
||||
pub scroll_lock: bool,
|
||||
/// Compose LED
|
||||
pub compose: bool,
|
||||
/// Kana LED
|
||||
pub kana: bool,
|
||||
}
|
||||
|
||||
impl LedState {
|
||||
/// Create from raw byte
|
||||
pub fn from_byte(b: u8) -> Self {
|
||||
Self {
|
||||
num_lock: b & 0x01 != 0,
|
||||
@@ -78,7 +45,6 @@ impl LedState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to raw byte
|
||||
pub fn to_byte(&self) -> u8 {
|
||||
let mut b = 0u8;
|
||||
if self.num_lock {
|
||||
@@ -100,76 +66,37 @@ impl LedState {
|
||||
}
|
||||
}
|
||||
|
||||
/// OTG HID backend with 4 devices
|
||||
///
|
||||
/// This backend opens HID device files created by OtgService.
|
||||
/// It does NOT manage the USB gadget itself - that's handled by OtgService.
|
||||
///
|
||||
/// ## Error Recovery
|
||||
///
|
||||
/// Based on PiKVM's implementation, this backend automatically handles:
|
||||
/// - EAGAIN (errno 11): Resource temporarily unavailable - just retry later, don't close device
|
||||
/// - ESHUTDOWN (errno 108): Transport endpoint shutdown - close and reopen device
|
||||
///
|
||||
/// When ESHUTDOWN occurs, the device file handle is closed and will be
|
||||
/// reopened on the next operation attempt.
|
||||
/// Opens `/dev/hidg*` nodes provisioned by `OtgService`; gadget lifecycle is not handled here.
|
||||
pub struct OtgBackend {
|
||||
/// Keyboard device path (/dev/hidg0)
|
||||
keyboard_path: Option<PathBuf>,
|
||||
/// Relative mouse device path (/dev/hidg1)
|
||||
mouse_rel_path: Option<PathBuf>,
|
||||
/// Absolute mouse device path (/dev/hidg2)
|
||||
mouse_abs_path: Option<PathBuf>,
|
||||
/// Consumer control device path (/dev/hidg3)
|
||||
consumer_path: Option<PathBuf>,
|
||||
/// Keyboard device file
|
||||
keyboard_dev: Mutex<Option<File>>,
|
||||
/// Relative mouse device file
|
||||
mouse_rel_dev: Mutex<Option<File>>,
|
||||
/// Absolute mouse device file
|
||||
mouse_abs_dev: Mutex<Option<File>>,
|
||||
/// Consumer control device file
|
||||
consumer_dev: Mutex<Option<File>>,
|
||||
/// Whether keyboard LED/status feedback is enabled.
|
||||
keyboard_leds_enabled: bool,
|
||||
/// Current keyboard state
|
||||
keyboard_state: Mutex<KeyboardReport>,
|
||||
/// Current mouse button state
|
||||
mouse_buttons: AtomicU8,
|
||||
/// Last known LED state (using parking_lot::RwLock for sync access)
|
||||
led_state: Arc<parking_lot::RwLock<LedState>>,
|
||||
/// Screen resolution for absolute mouse (using parking_lot::RwLock for sync access)
|
||||
screen_resolution: parking_lot::RwLock<Option<(u32, u32)>>,
|
||||
/// UDC name for state checking (e.g., "fcc00000.usb")
|
||||
udc_name: Arc<parking_lot::RwLock<Option<String>>>,
|
||||
/// Whether the backend has been initialized.
|
||||
initialized: AtomicBool,
|
||||
/// Whether the device is currently online (UDC configured and devices accessible)
|
||||
online: AtomicBool,
|
||||
/// Last backend error state.
|
||||
last_error: parking_lot::RwLock<Option<(String, String)>>,
|
||||
/// Last error log time for throttling (using parking_lot for sync)
|
||||
last_error_log: parking_lot::Mutex<std::time::Instant>,
|
||||
/// Error count since last successful operation (for log throttling)
|
||||
error_count: AtomicU8,
|
||||
/// Consecutive EAGAIN count (for offline threshold detection)
|
||||
eagain_count: AtomicU8,
|
||||
/// Runtime change notifier.
|
||||
runtime_notify_tx: watch::Sender<()>,
|
||||
/// Runtime monitor stop flag.
|
||||
runtime_worker_stop: Arc<AtomicBool>,
|
||||
/// Runtime monitor thread.
|
||||
runtime_worker: Mutex<Option<thread::JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
/// Write timeout in milliseconds (same as JetKVM's hidWriteTimeout)
|
||||
const HID_WRITE_TIMEOUT_MS: i32 = 20;
|
||||
|
||||
impl OtgBackend {
|
||||
/// Create OTG backend from device paths provided by OtgService
|
||||
///
|
||||
/// This is the ONLY way to create an OtgBackend - it no longer manages
|
||||
/// the USB gadget itself. The gadget must already be set up by OtgService.
|
||||
/// Gadget must already exist; paths come from `OtgService`.
|
||||
pub fn from_handles(paths: HidDevicePaths) -> Result<Self> {
|
||||
let (runtime_notify_tx, _runtime_notify_rx) = watch::channel(());
|
||||
Ok(Self {
|
||||
@@ -234,7 +161,6 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Log throttled error message (max once per second)
|
||||
fn log_throttled_error(&self, msg: &str) {
|
||||
let mut last_log = self.last_error_log.lock();
|
||||
let now = std::time::Instant::now();
|
||||
@@ -251,24 +177,17 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset error count on successful operation
|
||||
fn reset_error_count(&self) {
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
// Also reset EAGAIN count - successful operation means device is working
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Write data to HID device with timeout (JetKVM style)
|
||||
///
|
||||
/// Uses poll() to wait for device to be ready for writing.
|
||||
/// If timeout expires, silently drops the data (acceptable for mouse movement).
|
||||
/// Returns Ok(true) if write succeeded, Ok(false) if timed out (silently dropped).
|
||||
/// Poll-based write with `HID_WRITE_TIMEOUT_MS`; timeout → drop (JetKVM-style).
|
||||
fn write_with_timeout(&self, file: &mut File, data: &[u8]) -> std::io::Result<bool> {
|
||||
let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)];
|
||||
|
||||
match poll(&mut pollfd, PollTimeout::from(HID_WRITE_TIMEOUT_MS as u16)) {
|
||||
Ok(1) => {
|
||||
// Device ready, check for errors
|
||||
if let Some(revents) = pollfd[0].revents() {
|
||||
if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP)
|
||||
{
|
||||
@@ -278,12 +197,10 @@ impl OtgBackend {
|
||||
));
|
||||
}
|
||||
}
|
||||
// Write the data
|
||||
file.write_all(data)?;
|
||||
Ok(true)
|
||||
}
|
||||
Ok(0) => {
|
||||
// Timeout - silently drop (JetKVM behavior)
|
||||
trace!("HID write timeout, dropping data");
|
||||
Ok(false)
|
||||
}
|
||||
@@ -292,7 +209,6 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the UDC name for state checking
|
||||
pub fn set_udc_name(&self, udc: &str) {
|
||||
*self.udc_name.write() = Some(udc.to_string());
|
||||
}
|
||||
@@ -324,15 +240,11 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the UDC is in "configured" state
|
||||
///
|
||||
/// This is based on PiKVM's `__is_udc_configured()` method.
|
||||
/// The UDC state file indicates whether the USB host has enumerated and configured the gadget.
|
||||
/// `true` when `/sys/class/udc/<name>/state` reads `configured` (PiKVM-style).
|
||||
pub fn is_udc_configured(&self) -> bool {
|
||||
Self::read_udc_configured(&self.udc_name)
|
||||
}
|
||||
|
||||
/// Find the first available UDC
|
||||
fn find_udc() -> Option<String> {
|
||||
let udc_path = PathBuf::from("/sys/class/udc");
|
||||
if let Ok(entries) = fs::read_dir(&udc_path) {
|
||||
@@ -345,12 +257,7 @@ impl OtgBackend {
|
||||
None
|
||||
}
|
||||
|
||||
/// Ensure a device is open and ready for I/O
|
||||
///
|
||||
/// This method is based on PiKVM's `__ensure_device()` pattern:
|
||||
/// 1. Check if device path exists, close handle if not
|
||||
/// 2. If handle is None but path exists, reopen the device
|
||||
/// 3. Return whether the device is ready for I/O
|
||||
/// PiKVM-style: drop handle if node missing; reopen when path reappears.
|
||||
fn ensure_device(&self, device_type: DeviceType) -> Result<()> {
|
||||
let (path_opt, dev_mutex) = match device_type {
|
||||
DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev),
|
||||
@@ -372,9 +279,7 @@ impl OtgBackend {
|
||||
}
|
||||
};
|
||||
|
||||
// Check if device path exists
|
||||
if !path.exists() {
|
||||
// Close the device if open (device was removed)
|
||||
let mut dev = dev_mutex.lock();
|
||||
if dev.is_some() {
|
||||
debug!(
|
||||
@@ -392,7 +297,6 @@ impl OtgBackend {
|
||||
});
|
||||
}
|
||||
|
||||
// If device is not open, try to open it
|
||||
let mut dev = dev_mutex.lock();
|
||||
if dev.is_none() {
|
||||
match Self::open_device(path) {
|
||||
@@ -415,7 +319,6 @@ impl OtgBackend {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Open a HID device file with read/write access
|
||||
fn open_device(path: &PathBuf) -> Result<File> {
|
||||
OpenOptions::new()
|
||||
.read(true)
|
||||
@@ -431,16 +334,15 @@ impl OtgBackend {
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert I/O error to HidError with appropriate error code
|
||||
fn io_error_code(e: &std::io::Error) -> &'static str {
|
||||
match e.raw_os_error() {
|
||||
Some(32) => "epipe", // EPIPE - broken pipe
|
||||
Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown
|
||||
Some(11) => "eagain", // EAGAIN - resource temporarily unavailable
|
||||
Some(6) => "enxio", // ENXIO - no such device or address
|
||||
Some(19) => "enodev", // ENODEV - no such device
|
||||
Some(5) => "eio", // EIO - I/O error
|
||||
Some(2) => "enoent", // ENOENT - no such file or directory
|
||||
Some(32) => "epipe",
|
||||
Some(108) => "eshutdown",
|
||||
Some(11) => "eagain",
|
||||
Some(6) => "enxio",
|
||||
Some(19) => "enodev",
|
||||
Some(5) => "eio",
|
||||
Some(2) => "enoent",
|
||||
_ => "io_error",
|
||||
}
|
||||
}
|
||||
@@ -455,7 +357,6 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if all HID device files exist
|
||||
pub fn check_devices_exist(&self) -> bool {
|
||||
self.keyboard_path.as_ref().is_none_or(|p| p.exists())
|
||||
&& self.mouse_rel_path.as_ref().is_none_or(|p| p.exists())
|
||||
@@ -463,7 +364,6 @@ impl OtgBackend {
|
||||
&& self.consumer_path.as_ref().is_none_or(|p| p.exists())
|
||||
}
|
||||
|
||||
/// Get list of missing device paths
|
||||
pub fn get_missing_devices(&self) -> Vec<String> {
|
||||
let mut missing = Vec::new();
|
||||
if let Some(ref path) = self.keyboard_path {
|
||||
@@ -484,17 +384,11 @@ impl OtgBackend {
|
||||
missing
|
||||
}
|
||||
|
||||
/// Send keyboard report (8 bytes)
|
||||
///
|
||||
/// This method ensures the device is open before writing, and handles
|
||||
/// ESHUTDOWN errors by closing the device handle for later reconnection.
|
||||
/// Uses write_with_timeout to avoid blocking on busy devices.
|
||||
fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> {
|
||||
if self.keyboard_path.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::Keyboard)?;
|
||||
|
||||
let mut dev = self.keyboard_dev.lock();
|
||||
@@ -508,7 +402,6 @@ impl OtgBackend {
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => {
|
||||
// Timeout - silently dropped (JetKVM behavior)
|
||||
self.log_throttled_error("HID keyboard write timeout, dropped");
|
||||
Ok(())
|
||||
}
|
||||
@@ -517,7 +410,6 @@ impl OtgBackend {
|
||||
|
||||
match error_code {
|
||||
Some(108) => {
|
||||
// ESHUTDOWN - endpoint closed, need to reopen device
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
debug!("Keyboard ESHUTDOWN, closing for recovery");
|
||||
*dev = None;
|
||||
@@ -531,7 +423,6 @@ impl OtgBackend {
|
||||
))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN after poll - should be rare, silently drop
|
||||
trace!("Keyboard EAGAIN after poll, dropping");
|
||||
Ok(())
|
||||
}
|
||||
@@ -559,17 +450,11 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send relative mouse report (4 bytes: buttons, dx, dy, wheel)
|
||||
///
|
||||
/// This method ensures the device is open before writing, and handles
|
||||
/// ESHUTDOWN errors by closing the device handle for later reconnection.
|
||||
/// Uses write_with_timeout to avoid blocking on busy devices.
|
||||
fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> {
|
||||
if self.mouse_rel_path.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::MouseRelative)?;
|
||||
|
||||
let mut dev = self.mouse_rel_dev.lock();
|
||||
@@ -582,10 +467,7 @@ impl OtgBackend {
|
||||
trace!("Sent relative mouse report: {:02X?}", data);
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => {
|
||||
// Timeout - silently dropped (JetKVM behavior)
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => Ok(()),
|
||||
Err(e) => {
|
||||
let error_code = e.raw_os_error();
|
||||
|
||||
@@ -603,10 +485,7 @@ impl OtgBackend {
|
||||
"Failed to write mouse report",
|
||||
))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN after poll - should be rare, silently drop
|
||||
Ok(())
|
||||
}
|
||||
Some(11) => Ok(()),
|
||||
_ => {
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
warn!("Relative mouse write error: {}", e);
|
||||
@@ -631,17 +510,11 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send absolute mouse report (6 bytes: buttons, x_lo, x_hi, y_lo, y_hi, wheel)
|
||||
///
|
||||
/// This method ensures the device is open before writing, and handles
|
||||
/// ESHUTDOWN errors by closing the device handle for later reconnection.
|
||||
/// Uses write_with_timeout to avoid blocking on busy devices.
|
||||
fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> {
|
||||
if self.mouse_abs_path.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::MouseAbsolute)?;
|
||||
|
||||
let mut dev = self.mouse_abs_dev.lock();
|
||||
@@ -660,10 +533,7 @@ impl OtgBackend {
|
||||
self.reset_error_count();
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => {
|
||||
// Timeout - silently dropped (JetKVM behavior)
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => Ok(()),
|
||||
Err(e) => {
|
||||
let error_code = e.raw_os_error();
|
||||
|
||||
@@ -681,10 +551,7 @@ impl OtgBackend {
|
||||
"Failed to write mouse report",
|
||||
))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN after poll - should be rare, silently drop
|
||||
Ok(())
|
||||
}
|
||||
Some(11) => Ok(()),
|
||||
_ => {
|
||||
self.eagain_count.store(0, Ordering::Relaxed);
|
||||
warn!("Absolute mouse write error: {}", e);
|
||||
@@ -709,35 +576,27 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send consumer control report (2 bytes: usage_lo, usage_hi)
|
||||
///
|
||||
/// Sends a consumer control usage code and then releases it (sends 0x0000).
|
||||
/// Press (`usage`) then release (`0x0000`).
|
||||
fn send_consumer_report(&self, usage: u16) -> Result<()> {
|
||||
if self.consumer_path.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure device is ready
|
||||
self.ensure_device(DeviceType::ConsumerControl)?;
|
||||
|
||||
let mut dev = self.consumer_dev.lock();
|
||||
if let Some(ref mut file) = *dev {
|
||||
// Send the usage code
|
||||
let data = [(usage & 0xFF) as u8, (usage >> 8) as u8];
|
||||
match self.write_with_timeout(file, &data) {
|
||||
Ok(true) => {
|
||||
trace!("Sent consumer report: {:02X?}", data);
|
||||
// Send release (0x0000)
|
||||
let release = [0u8, 0u8];
|
||||
let _ = self.write_with_timeout(file, &release);
|
||||
self.mark_online();
|
||||
self.reset_error_count();
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => {
|
||||
// Timeout - silently dropped
|
||||
Ok(())
|
||||
}
|
||||
Ok(false) => Ok(()),
|
||||
Err(e) => {
|
||||
let error_code = e.raw_os_error();
|
||||
match error_code {
|
||||
@@ -753,10 +612,7 @@ impl OtgBackend {
|
||||
"Failed to write consumer report",
|
||||
))
|
||||
}
|
||||
Some(11) => {
|
||||
// EAGAIN after poll - silently drop
|
||||
Ok(())
|
||||
}
|
||||
Some(11) => Ok(()),
|
||||
_ => {
|
||||
warn!("Consumer control write error: {}", e);
|
||||
self.record_error(
|
||||
@@ -780,12 +636,10 @@ impl OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
/// Send consumer control event
|
||||
pub fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
|
||||
self.send_consumer_report(event.usage)
|
||||
}
|
||||
|
||||
/// Get last known LED state
|
||||
pub fn led_state(&self) -> LedState {
|
||||
*self.led_state.read()
|
||||
}
|
||||
@@ -975,7 +829,6 @@ impl HidBackend for OtgBackend {
|
||||
async fn init(&self) -> Result<()> {
|
||||
info!("Initializing OTG HID backend");
|
||||
|
||||
// Auto-detect UDC name for state checking only if OtgService did not provide one
|
||||
if self.udc_name.read().is_none() {
|
||||
if let Some(udc) = Self::find_udc() {
|
||||
info!("Auto-detected UDC: {}", udc);
|
||||
@@ -985,7 +838,6 @@ impl HidBackend for OtgBackend {
|
||||
info!("Using configured UDC: {}", udc);
|
||||
}
|
||||
|
||||
// Wait for devices to appear (they should already exist from OtgService)
|
||||
let mut device_paths = Vec::new();
|
||||
if let Some(ref path) = self.keyboard_path {
|
||||
device_paths.push(path.clone());
|
||||
@@ -1010,7 +862,6 @@ impl HidBackend for OtgBackend {
|
||||
return Err(AppError::Internal("HID devices did not appear".into()));
|
||||
}
|
||||
|
||||
// Open keyboard device
|
||||
if let Some(ref path) = self.keyboard_path {
|
||||
if path.exists() {
|
||||
let file = Self::open_device(path)?;
|
||||
@@ -1021,7 +872,6 @@ impl HidBackend for OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
// Open relative mouse device
|
||||
if let Some(ref path) = self.mouse_rel_path {
|
||||
if path.exists() {
|
||||
let file = Self::open_device(path)?;
|
||||
@@ -1032,7 +882,6 @@ impl HidBackend for OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
// Open absolute mouse device
|
||||
if let Some(ref path) = self.mouse_abs_path {
|
||||
if path.exists() {
|
||||
let file = Self::open_device(path)?;
|
||||
@@ -1043,7 +892,6 @@ impl HidBackend for OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
// Open consumer control device (optional, may not exist on older setups)
|
||||
if let Some(ref path) = self.consumer_path {
|
||||
if path.exists() {
|
||||
let file = Self::open_device(path)?;
|
||||
@@ -1054,7 +902,6 @@ impl HidBackend for OtgBackend {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark as online if all devices opened successfully
|
||||
self.initialized.store(true, Ordering::Relaxed);
|
||||
self.notify_runtime_changed();
|
||||
self.start_runtime_worker();
|
||||
@@ -1066,7 +913,6 @@ impl HidBackend for OtgBackend {
|
||||
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
|
||||
let usb_key = event.key.to_hid_usage();
|
||||
|
||||
// Handle modifier keys separately
|
||||
if event.key.is_modifier() {
|
||||
let mut state = self.keyboard_state.lock();
|
||||
|
||||
@@ -1084,7 +930,6 @@ impl HidBackend for OtgBackend {
|
||||
} else {
|
||||
let mut state = self.keyboard_state.lock();
|
||||
|
||||
// Update modifiers from event
|
||||
state.modifiers = event.modifiers.to_hid_byte();
|
||||
|
||||
match event.event_type {
|
||||
@@ -1110,15 +955,12 @@ impl HidBackend for OtgBackend {
|
||||
|
||||
match event.event_type {
|
||||
MouseEventType::Move => {
|
||||
// Relative movement - use hidg1
|
||||
let dx = event.x.clamp(-127, 127) as i8;
|
||||
let dy = event.y.clamp(-127, 127) as i8;
|
||||
self.send_mouse_report_relative(buttons, dx, dy, 0)?;
|
||||
}
|
||||
MouseEventType::MoveAbs => {
|
||||
// Absolute movement - use hidg2
|
||||
// Frontend sends 0-32767 range directly (standard HID absolute mouse range)
|
||||
// Don't send button state with move - buttons are handled separately on relative device
|
||||
// Coordinates 0–32767; buttons are sent only on the relative endpoint.
|
||||
let x = event.x.clamp(0, 32767) as u16;
|
||||
let y = event.y.clamp(0, 32767) as u16;
|
||||
self.send_mouse_report_absolute(0, x, y, 0)?;
|
||||
@@ -1127,7 +969,6 @@ impl HidBackend for OtgBackend {
|
||||
if let Some(button) = event.button {
|
||||
let bit = button.to_hid_bit();
|
||||
let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit;
|
||||
// Send on relative device for button clicks
|
||||
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
|
||||
}
|
||||
}
|
||||
@@ -1147,7 +988,6 @@ impl HidBackend for OtgBackend {
|
||||
}
|
||||
|
||||
async fn reset(&self) -> Result<()> {
|
||||
// Reset keyboard
|
||||
{
|
||||
let mut state = self.keyboard_state.lock();
|
||||
state.clear();
|
||||
@@ -1156,7 +996,6 @@ impl HidBackend for OtgBackend {
|
||||
self.send_keyboard_report(&report)?;
|
||||
}
|
||||
|
||||
// Reset mouse
|
||||
self.mouse_buttons.store(0, Ordering::Relaxed);
|
||||
self.send_mouse_report_relative(0, 0, 0, 0)?;
|
||||
self.send_mouse_report_absolute(0, 0, 0, 0)?;
|
||||
@@ -1168,16 +1007,13 @@ impl HidBackend for OtgBackend {
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
self.stop_runtime_worker();
|
||||
|
||||
// Reset before closing
|
||||
self.reset().await?;
|
||||
|
||||
// Close devices
|
||||
*self.keyboard_dev.lock() = None;
|
||||
*self.mouse_rel_dev.lock() = None;
|
||||
*self.mouse_abs_dev.lock() = None;
|
||||
*self.consumer_dev.lock() = None;
|
||||
|
||||
// Gadget cleanup is handled by OtgService, not here
|
||||
self.initialized.store(false, Ordering::Relaxed);
|
||||
self.online.store(false, Ordering::Relaxed);
|
||||
self.clear_error();
|
||||
@@ -1199,31 +1035,18 @@ impl HidBackend for OtgBackend {
|
||||
self.send_consumer_report(event.usage)
|
||||
}
|
||||
|
||||
fn set_screen_resolution(&mut self, width: u32, height: u32) {
|
||||
fn set_screen_resolution(&self, width: u32, height: u32) {
|
||||
*self.screen_resolution.write() = Some((width, height));
|
||||
self.notify_runtime_changed();
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if OTG HID gadget is available
|
||||
pub fn is_otg_available() -> bool {
|
||||
// Check for existing HID devices (they should be created by OtgService)
|
||||
let kb = PathBuf::from("/dev/hidg0");
|
||||
let mouse_rel = PathBuf::from("/dev/hidg1");
|
||||
let mouse_abs = PathBuf::from("/dev/hidg2");
|
||||
|
||||
kb.exists() || mouse_rel.exists() || mouse_abs.exists()
|
||||
}
|
||||
|
||||
/// Implement Drop for OtgBackend to close device files
|
||||
impl Drop for OtgBackend {
|
||||
fn drop(&mut self) {
|
||||
self.runtime_worker_stop.store(true, Ordering::Relaxed);
|
||||
if let Some(handle) = self.runtime_worker.get_mut().take() {
|
||||
let _ = handle.join();
|
||||
}
|
||||
// Close device files
|
||||
// Note: Gadget cleanup is handled by OtgService, not here
|
||||
*self.keyboard_dev.lock() = None;
|
||||
*self.mouse_rel_dev.lock() = None;
|
||||
*self.mouse_abs_dev.lock() = None;
|
||||
@@ -1236,12 +1059,6 @@ impl Drop for OtgBackend {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_otg_availability_check() {
|
||||
// This just tests the function runs without panicking
|
||||
let _available = is_otg_available();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_led_state() {
|
||||
let state = LedState::from_byte(0b00000011);
|
||||
@@ -1254,7 +1071,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_report_sizes() {
|
||||
// Keyboard report is 8 bytes
|
||||
let kb_report = KeyboardReport::default();
|
||||
assert_eq!(kb_report.to_bytes().len(), 8);
|
||||
}
|
||||
|
||||
@@ -1,50 +1,37 @@
|
||||
//! HID event types for keyboard and mouse
|
||||
//! Keyboard/mouse/consumer structs (`KeyboardEvent`, `MouseEvent`, …).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::keyboard::CanonicalKey;
|
||||
|
||||
/// Keyboard event type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum KeyEventType {
|
||||
/// Key pressed down
|
||||
Down,
|
||||
/// Key released
|
||||
Up,
|
||||
}
|
||||
|
||||
/// Keyboard modifier flags
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct KeyboardModifiers {
|
||||
/// Left Control
|
||||
#[serde(default)]
|
||||
pub left_ctrl: bool,
|
||||
/// Left Shift
|
||||
#[serde(default)]
|
||||
pub left_shift: bool,
|
||||
/// Left Alt
|
||||
#[serde(default)]
|
||||
pub left_alt: bool,
|
||||
/// Left Meta (Windows/Super key)
|
||||
#[serde(default)]
|
||||
pub left_meta: bool,
|
||||
/// Right Control
|
||||
#[serde(default)]
|
||||
pub right_ctrl: bool,
|
||||
/// Right Shift
|
||||
#[serde(default)]
|
||||
pub right_shift: bool,
|
||||
/// Right Alt (AltGr)
|
||||
#[serde(default)]
|
||||
pub right_alt: bool,
|
||||
/// Right Meta
|
||||
#[serde(default)]
|
||||
pub right_meta: bool,
|
||||
}
|
||||
|
||||
impl KeyboardModifiers {
|
||||
/// Convert to USB HID modifier byte
|
||||
pub fn to_hid_byte(&self) -> u8 {
|
||||
let mut byte = 0u8;
|
||||
if self.left_ctrl {
|
||||
@@ -74,7 +61,6 @@ impl KeyboardModifiers {
|
||||
byte
|
||||
}
|
||||
|
||||
/// Create from USB HID modifier byte
|
||||
pub fn from_hid_byte(byte: u8) -> Self {
|
||||
Self {
|
||||
left_ctrl: byte & 0x01 != 0,
|
||||
@@ -88,7 +74,6 @@ impl KeyboardModifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if any modifier is active
|
||||
pub fn any(&self) -> bool {
|
||||
self.left_ctrl
|
||||
|| self.left_shift
|
||||
@@ -101,21 +86,16 @@ impl KeyboardModifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// Keyboard event
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyboardEvent {
|
||||
/// Event type (down/up)
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: KeyEventType,
|
||||
/// Canonical keyboard key identifier shared across frontend and backend
|
||||
pub key: CanonicalKey,
|
||||
/// Modifier keys state
|
||||
#[serde(default)]
|
||||
pub modifiers: KeyboardModifiers,
|
||||
}
|
||||
|
||||
impl KeyboardEvent {
|
||||
/// Create a key down event
|
||||
pub fn key_down(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
|
||||
Self {
|
||||
event_type: KeyEventType::Down,
|
||||
@@ -124,7 +104,6 @@ impl KeyboardEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a key up event
|
||||
pub fn key_up(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
|
||||
Self {
|
||||
event_type: KeyEventType::Up,
|
||||
@@ -134,7 +113,6 @@ impl KeyboardEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Mouse button
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MouseButton {
|
||||
@@ -146,7 +124,6 @@ pub enum MouseButton {
|
||||
}
|
||||
|
||||
impl MouseButton {
|
||||
/// Convert to USB HID button bit
|
||||
pub fn to_hid_bit(&self) -> u8 {
|
||||
match self {
|
||||
MouseButton::Left => 0x01,
|
||||
@@ -158,44 +135,31 @@ impl MouseButton {
|
||||
}
|
||||
}
|
||||
|
||||
/// Mouse event type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum MouseEventType {
|
||||
/// Mouse moved (relative movement)
|
||||
Move,
|
||||
/// Mouse moved (absolute position)
|
||||
MoveAbs,
|
||||
/// Button pressed
|
||||
Down,
|
||||
/// Button released
|
||||
Up,
|
||||
/// Mouse wheel scroll
|
||||
Scroll,
|
||||
}
|
||||
|
||||
/// Mouse event
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MouseEvent {
|
||||
/// Event type
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: MouseEventType,
|
||||
/// X coordinate or delta
|
||||
#[serde(default)]
|
||||
pub x: i32,
|
||||
/// Y coordinate or delta
|
||||
#[serde(default)]
|
||||
pub y: i32,
|
||||
/// Button (for down/up events)
|
||||
#[serde(default)]
|
||||
pub button: Option<MouseButton>,
|
||||
/// Scroll delta (for scroll events)
|
||||
#[serde(default)]
|
||||
pub scroll: i8,
|
||||
}
|
||||
|
||||
impl MouseEvent {
|
||||
/// Create a relative move event
|
||||
pub fn move_rel(dx: i32, dy: i32) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Move,
|
||||
@@ -206,7 +170,6 @@ impl MouseEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an absolute move event
|
||||
pub fn move_abs(x: i32, y: i32) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::MoveAbs,
|
||||
@@ -217,7 +180,6 @@ impl MouseEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a button down event
|
||||
pub fn button_down(button: MouseButton) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Down,
|
||||
@@ -228,7 +190,6 @@ impl MouseEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a button up event
|
||||
pub fn button_up(button: MouseButton) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Up,
|
||||
@@ -239,7 +200,6 @@ impl MouseEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a scroll event
|
||||
pub fn scroll(delta: i8) -> Self {
|
||||
Self {
|
||||
event_type: MouseEventType::Scroll,
|
||||
@@ -251,35 +211,19 @@ impl MouseEvent {
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined HID event (keyboard or mouse)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "device", rename_all = "lowercase")]
|
||||
pub enum HidEvent {
|
||||
Keyboard(KeyboardEvent),
|
||||
Mouse(MouseEvent),
|
||||
Consumer(ConsumerEvent),
|
||||
}
|
||||
|
||||
/// Consumer control event (multimedia keys)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConsumerEvent {
|
||||
/// Consumer control usage code (e.g., 0x00CD for Play/Pause)
|
||||
pub usage: u16,
|
||||
}
|
||||
|
||||
/// USB HID keyboard report (8 bytes)
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct KeyboardReport {
|
||||
/// Modifier byte
|
||||
pub modifiers: u8,
|
||||
/// Reserved byte
|
||||
pub reserved: u8,
|
||||
/// Key codes (up to 6 simultaneous keys)
|
||||
pub keys: [u8; 6],
|
||||
}
|
||||
|
||||
impl KeyboardReport {
|
||||
/// Convert to bytes for USB HID
|
||||
pub fn to_bytes(&self) -> [u8; 8] {
|
||||
[
|
||||
self.modifiers,
|
||||
@@ -293,7 +237,6 @@ impl KeyboardReport {
|
||||
]
|
||||
}
|
||||
|
||||
/// Add a key to the report
|
||||
pub fn add_key(&mut self, key: u8) -> bool {
|
||||
for slot in &mut self.keys {
|
||||
if *slot == 0 {
|
||||
@@ -304,56 +247,21 @@ impl KeyboardReport {
|
||||
false // All slots full
|
||||
}
|
||||
|
||||
/// Remove a key from the report
|
||||
pub fn remove_key(&mut self, key: u8) {
|
||||
for slot in &mut self.keys {
|
||||
if *slot == key {
|
||||
*slot = 0;
|
||||
}
|
||||
}
|
||||
// Compact the array
|
||||
self.keys.sort_by(|a, b| b.cmp(a));
|
||||
}
|
||||
|
||||
/// Clear all keys
|
||||
pub fn clear(&mut self) {
|
||||
self.modifiers = 0;
|
||||
self.keys = [0; 6];
|
||||
}
|
||||
}
|
||||
|
||||
/// USB HID mouse report
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MouseReport {
|
||||
/// Button state
|
||||
pub buttons: u8,
|
||||
/// X movement (-127 to 127)
|
||||
pub x: i8,
|
||||
/// Y movement (-127 to 127)
|
||||
pub y: i8,
|
||||
/// Wheel movement (-127 to 127)
|
||||
pub wheel: i8,
|
||||
}
|
||||
|
||||
impl MouseReport {
|
||||
/// Convert to bytes for USB HID (relative mouse)
|
||||
pub fn to_bytes_relative(&self) -> [u8; 4] {
|
||||
[self.buttons, self.x as u8, self.y as u8, self.wheel as u8]
|
||||
}
|
||||
|
||||
/// Convert to bytes for USB HID (absolute mouse)
|
||||
pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] {
|
||||
[
|
||||
self.buttons,
|
||||
(x & 0xFF) as u8,
|
||||
(x >> 8) as u8,
|
||||
(y & 0xFF) as u8,
|
||||
(y >> 8) as u8,
|
||||
self.wheel as u8,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,13 +1,4 @@
|
||||
//! WebSocket HID channel for HTTP/MJPEG mode
|
||||
//!
|
||||
//! This provides an alternative to WebRTC DataChannel for HID input
|
||||
//! when using MJPEG streaming mode.
|
||||
//!
|
||||
//! Uses binary protocol only (same format as DataChannel):
|
||||
//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes)
|
||||
//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes)
|
||||
//!
|
||||
//! See datachannel.rs for detailed protocol specification.
|
||||
//! MJPEG mode: HID over WebSocket — same binary framing as [`super::datachannel`] (`0x01`/`0x02`/`0x03`; layout detailed there).
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
@@ -24,25 +15,20 @@ use super::datachannel::{parse_hid_message, HidChannelEvent};
|
||||
use crate::state::AppState;
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// Binary response codes
|
||||
const RESP_OK: u8 = 0x00;
|
||||
const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01;
|
||||
const RESP_ERR_INVALID_MESSAGE: u8 = 0x02;
|
||||
|
||||
/// WebSocket HID upgrade handler
|
||||
pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_hid_socket(socket, state))
|
||||
}
|
||||
|
||||
/// Handle HID WebSocket connection
|
||||
async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
// Log throttler for error messages (5 second interval)
|
||||
let log_throttler = LogThrottler::with_secs(5);
|
||||
|
||||
info!("WebSocket HID connection established (binary protocol)");
|
||||
|
||||
// Check if HID controller is available and send initial status
|
||||
let hid_available = state.hid.is_available().await;
|
||||
let initial_response = if hid_available {
|
||||
vec![RESP_OK]
|
||||
@@ -59,17 +45,14 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Process incoming messages (binary only)
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Binary(data)) => {
|
||||
// Check HID availability before processing each message
|
||||
let hid_available = state.hid.is_available().await;
|
||||
if !hid_available {
|
||||
if log_throttler.should_log("hid_unavailable") {
|
||||
warn!("HID controller not available, ignoring message");
|
||||
}
|
||||
// Send error response (optional, for client awareness)
|
||||
let _ = sender
|
||||
.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into()))
|
||||
.await;
|
||||
@@ -77,15 +60,12 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
}
|
||||
|
||||
if let Err(e) = handle_binary_message(&data, &state).await {
|
||||
// Log with throttling to avoid spam
|
||||
if log_throttler.should_log("binary_hid_error") {
|
||||
warn!("Binary HID message error: {}", e);
|
||||
}
|
||||
// Don't send error response for every failed message to reduce overhead
|
||||
}
|
||||
}
|
||||
Ok(Message::Text(text)) => {
|
||||
// Text messages are no longer supported
|
||||
if log_throttler.should_log("text_message_rejected") {
|
||||
debug!(
|
||||
"Received text message (not supported): {} bytes",
|
||||
@@ -111,7 +91,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset HID state to release any held keys/buttons
|
||||
if let Err(e) = state.hid.reset().await {
|
||||
warn!("Failed to reset HID on WebSocket disconnect: {}", e);
|
||||
}
|
||||
@@ -119,7 +98,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
info!("WebSocket HID connection ended");
|
||||
}
|
||||
|
||||
/// Handle binary HID message (same format as DataChannel)
|
||||
async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> {
|
||||
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
|
||||
|
||||
@@ -160,12 +138,10 @@ mod tests {
|
||||
assert_eq!(RESP_OK, 0x00);
|
||||
assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01);
|
||||
assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02);
|
||||
// assert_eq!(RESP_ERR_SEND_FAILED, 0x03); // TODO: fix test
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keyboard_message_format() {
|
||||
// Keyboard message: [0x01, event_type, key, modifiers]
|
||||
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl
|
||||
let event = parse_hid_message(&data);
|
||||
assert!(event.is_some());
|
||||
@@ -173,7 +149,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mouse_message_format() {
|
||||
// Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra]
|
||||
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
|
||||
let event = parse_hid_message(&data);
|
||||
assert!(event.is_some());
|
||||
|
||||
@@ -1,30 +1,27 @@
|
||||
//! One-KVM - Lightweight IP-KVM solution
|
||||
//!
|
||||
//! This crate provides the core functionality for One-KVM,
|
||||
//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust.
|
||||
//! Core library for One-KVM (IP‑KVM: capture, HID, OTG, streaming, Web UI glue).
|
||||
|
||||
pub mod atx;
|
||||
pub mod audio;
|
||||
pub mod auth;
|
||||
pub mod config;
|
||||
pub mod db;
|
||||
pub mod error;
|
||||
pub mod events;
|
||||
pub mod extensions;
|
||||
pub mod hid;
|
||||
pub mod modules;
|
||||
pub mod msd;
|
||||
pub mod otg;
|
||||
pub mod rtsp;
|
||||
pub mod rustdesk;
|
||||
pub mod state;
|
||||
pub mod stream;
|
||||
pub mod stream_encoder;
|
||||
pub mod update;
|
||||
pub mod utils;
|
||||
pub mod video;
|
||||
pub mod web;
|
||||
pub mod webrtc;
|
||||
|
||||
/// Auto-generated secrets module (from secrets.toml at compile time)
|
||||
pub mod secrets {
|
||||
include!(concat!(env!("OUT_DIR"), "/secrets_generated.rs"));
|
||||
}
|
||||
|
||||
207
src/main.rs
207
src/main.rs
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
use std::io::Write;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
@@ -15,6 +16,7 @@ use one_kvm::atx::AtxController;
|
||||
use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality};
|
||||
use one_kvm::auth::{SessionStore, UserStore};
|
||||
use one_kvm::config::{self, AppConfig, ConfigStore};
|
||||
use one_kvm::db::DatabasePool;
|
||||
use one_kvm::events::EventBus;
|
||||
use one_kvm::extensions::ExtensionManager;
|
||||
use one_kvm::hid::{HidBackendType, HidController};
|
||||
@@ -33,7 +35,6 @@ use one_kvm::video::{Streamer, VideoStreamManager};
|
||||
use one_kvm::web;
|
||||
use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig};
|
||||
|
||||
/// Log level for the application
|
||||
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
|
||||
enum LogLevel {
|
||||
Error,
|
||||
@@ -45,7 +46,6 @@ enum LogLevel {
|
||||
Trace,
|
||||
}
|
||||
|
||||
/// One-KVM command line arguments
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "one-kvm")]
|
||||
#[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)]
|
||||
@@ -111,37 +111,30 @@ enum UserAction {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Parse command line arguments
|
||||
let args = CliArgs::parse();
|
||||
|
||||
// Initialize logging with CLI arguments
|
||||
init_logging(args.log_level, args.verbose);
|
||||
|
||||
// Install default crypto provider (required by rustls 0.23+)
|
||||
CryptoProvider::install_default(ring::default_provider())
|
||||
.expect("Failed to install rustls crypto provider");
|
||||
|
||||
tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
// Determine data directory (CLI arg takes precedence)
|
||||
let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir);
|
||||
tracing::info!("Data directory: {}", data_dir.display());
|
||||
|
||||
// Run one-off CLI command and exit.
|
||||
if let Some(command) = args.command {
|
||||
run_cli_command(command, data_dir).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure data directory exists
|
||||
tokio::fs::create_dir_all(&data_dir).await?;
|
||||
|
||||
// Initialize configuration store
|
||||
let db_path = data_dir.join("one-kvm.db");
|
||||
let config_store = ConfigStore::new(&db_path).await?;
|
||||
let db = open_database_pool(&data_dir).await?;
|
||||
let config_store = ConfigStore::new(db.clone_pool())?;
|
||||
config_store.load().await?;
|
||||
let mut config = (*config_store.get()).clone();
|
||||
|
||||
// Normalize MSD directory (absolute path under data dir if empty/relative)
|
||||
let mut msd_dir_updated = false;
|
||||
if config.msd.msd_dir.trim().is_empty() {
|
||||
let msd_dir = data_dir.join("msd");
|
||||
@@ -159,8 +152,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
if msd_dir_updated {
|
||||
config_store.set(config.clone()).await?;
|
||||
}
|
||||
|
||||
// Ensure MSD directories exist (msd/images, msd/ventoy)
|
||||
let msd_dir = PathBuf::from(&config.msd.msd_dir);
|
||||
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
|
||||
tracing::warn!("Failed to create MSD images directory: {}", e);
|
||||
@@ -169,7 +160,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
|
||||
}
|
||||
|
||||
// Apply CLI argument overrides to config (only if explicitly specified)
|
||||
if let Some(addr) = args.address {
|
||||
config.web.bind_address = addr.clone();
|
||||
config.web.bind_addresses = vec![addr];
|
||||
@@ -203,29 +193,21 @@ async fn main() -> anyhow::Result<()> {
|
||||
config.web.http_port
|
||||
};
|
||||
|
||||
// Log final configuration
|
||||
|
||||
for ip in &bind_ips {
|
||||
let addr = SocketAddr::new(*ip, bind_port);
|
||||
tracing::info!("Server will listen on: {}://{}", scheme, addr);
|
||||
}
|
||||
|
||||
// Initialize session store
|
||||
let session_store = SessionStore::new(
|
||||
config_store.pool().clone(),
|
||||
config.auth.session_timeout_secs as i64,
|
||||
);
|
||||
let session_store = SessionStore::new(config.auth.session_timeout_secs as i64);
|
||||
|
||||
// Initialize user store
|
||||
let user_store = UserStore::new(config_store.pool().clone());
|
||||
let user_store = UserStore::new(db.clone_pool());
|
||||
|
||||
// Create shutdown channel
|
||||
let (shutdown_tx, _) = broadcast::channel::<()>(1);
|
||||
|
||||
// Create event bus for real-time notifications
|
||||
let events = Arc::new(EventBus::new());
|
||||
tracing::info!("Event bus initialized");
|
||||
|
||||
// Parse video configuration once (avoid duplication)
|
||||
let (video_format, video_resolution) = parse_video_config(&config);
|
||||
tracing::debug!(
|
||||
"Parsed video config: {} @ {}x{}",
|
||||
@@ -234,7 +216,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
video_resolution.height
|
||||
);
|
||||
|
||||
// Create video streamer and initialize with config if device is set
|
||||
let streamer = Streamer::new();
|
||||
streamer.set_event_bus(events.clone()).await;
|
||||
if let Some(ref device_path) = config.video.device {
|
||||
@@ -262,19 +243,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Create WebRTC streamer
|
||||
let webrtc_streamer = {
|
||||
let webrtc_config = WebRtcStreamerConfig {
|
||||
resolution: video_resolution,
|
||||
input_format: video_format,
|
||||
fps: config.video.fps,
|
||||
bitrate_preset: config.stream.bitrate_preset,
|
||||
encoder_backend: config.stream.encoder.to_backend(),
|
||||
encoder_backend: one_kvm::stream_encoder::encoder_type_to_backend(
|
||||
config.stream.encoder.clone(),
|
||||
),
|
||||
webrtc: {
|
||||
let mut stun_servers = vec![];
|
||||
let mut turn_servers = vec![];
|
||||
|
||||
// Check if user configured custom servers
|
||||
let has_custom_stun = config
|
||||
.stream
|
||||
.stun_server
|
||||
@@ -288,14 +269,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
.map(|s| !s.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
// If no custom servers, use baked-in public STUN
|
||||
if !has_custom_stun && !has_custom_turn {
|
||||
use one_kvm::webrtc::config::public_ice;
|
||||
let stun = public_ice::stun_server().to_string();
|
||||
tracing::info!("Using public STUN server: {}", stun);
|
||||
stun_servers.push(stun);
|
||||
} else {
|
||||
// Use custom servers
|
||||
if let Some(ref stun) = config.stream.stun_server {
|
||||
if !stun.is_empty() {
|
||||
stun_servers.push(stun.clone());
|
||||
@@ -333,16 +312,13 @@ async fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
tracing::info!("WebRTC streamer created");
|
||||
|
||||
// Create OTG Service (single instance for centralized USB gadget management)
|
||||
let otg_service = Arc::new(OtgService::new());
|
||||
tracing::info!("OTG Service created");
|
||||
|
||||
// Reconcile OTG once from the persisted config so controllers only consume its result.
|
||||
if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await {
|
||||
tracing::warn!("Failed to apply OTG config: {}", e);
|
||||
}
|
||||
|
||||
// Create HID controller based on config
|
||||
let hid_backend = match config.hid.backend {
|
||||
config::HidBackend::Otg => HidBackendType::Otg,
|
||||
config::HidBackend::Ch9329 => HidBackendType::Ch9329 {
|
||||
@@ -353,16 +329,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
let hid = Arc::new(HidController::new(
|
||||
hid_backend,
|
||||
Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG
|
||||
Some(otg_service.clone()),
|
||||
));
|
||||
hid.set_event_bus(events.clone()).await;
|
||||
if let Err(e) = hid.init().await {
|
||||
tracing::warn!("Failed to initialize HID backend: {}", e);
|
||||
}
|
||||
|
||||
// Create MSD controller (optional, based on config)
|
||||
let msd = if config.msd.enabled {
|
||||
// `{data_dir}/ventoy`: boot.img, core.img, ventoy.disk.img for ventoy_img
|
||||
let ventoy_resource_dir = data_dir.join("ventoy");
|
||||
if ventoy_resource_dir.exists() {
|
||||
if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) {
|
||||
@@ -393,7 +367,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// Create ATX controller (optional, based on config)
|
||||
let atx = if config.atx.enabled {
|
||||
let controller_config = config.atx.to_controller_config();
|
||||
let controller = AtxController::new(controller_config);
|
||||
@@ -409,12 +382,21 @@ async fn main() -> anyhow::Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// Create Audio controller
|
||||
let audio = {
|
||||
let audio_config = AudioControllerConfig {
|
||||
enabled: config.audio.enabled,
|
||||
device: config.audio.device.clone(),
|
||||
quality: AudioQuality::from_str(&config.audio.quality),
|
||||
quality: match config.audio.quality.parse::<AudioQuality>() {
|
||||
Ok(q) => q,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Invalid audio quality in config (value={:?}): {}, using balanced",
|
||||
config.audio.quality,
|
||||
e
|
||||
);
|
||||
AudioQuality::Balanced
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let controller = AudioController::new(audio_config);
|
||||
@@ -426,7 +408,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
config.audio.device,
|
||||
config.audio.quality
|
||||
);
|
||||
// Start audio streaming so WebRTC can subscribe to Opus frames
|
||||
if let Err(e) = controller.start_streaming().await {
|
||||
tracing::warn!("Failed to start audio streaming: {}", e);
|
||||
}
|
||||
@@ -437,16 +418,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
Arc::new(controller)
|
||||
};
|
||||
|
||||
// Create Extension manager (ttyd, gostc, easytier)
|
||||
let extensions = Arc::new(ExtensionManager::new());
|
||||
tracing::info!("Extension manager initialized");
|
||||
|
||||
// Wire up WebRTC streamer with HID controller
|
||||
// This enables WebRTC DataChannel to process HID events
|
||||
webrtc_streamer.set_hid_controller(hid.clone()).await;
|
||||
|
||||
// Wire up WebRTC streamer with Audio controller
|
||||
// This enables WebRTC audio track to receive Opus frames
|
||||
webrtc_streamer.set_audio_controller(audio.clone()).await;
|
||||
if config.audio.enabled {
|
||||
if let Err(e) = webrtc_streamer.set_audio_enabled(true).await {
|
||||
@@ -456,7 +432,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Configure direct capture for WebRTC encoder pipeline
|
||||
let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) =
|
||||
streamer.current_capture_config().await;
|
||||
tracing::info!(
|
||||
@@ -495,14 +470,13 @@ async fn main() -> anyhow::Result<()> {
|
||||
tracing::warn!("No capture device configured for WebRTC");
|
||||
}
|
||||
|
||||
// Create video stream manager (unified MJPEG/WebRTC management)
|
||||
// Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance
|
||||
let stream_manager =
|
||||
VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone());
|
||||
let stream_manager = VideoStreamManager::with_webrtc_streamer(
|
||||
streamer.clone(),
|
||||
webrtc_streamer.clone() as std::sync::Arc<dyn one_kvm::video::traits::VideoOutput>,
|
||||
);
|
||||
stream_manager.set_event_bus(events.clone()).await;
|
||||
stream_manager.set_config_store(config_store.clone()).await;
|
||||
|
||||
// Initialize stream manager with configured mode
|
||||
let initial_mode = config.stream.mode.clone();
|
||||
if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await {
|
||||
tracing::warn!(
|
||||
@@ -517,7 +491,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
}
|
||||
|
||||
// Create RustDesk service (optional, based on config)
|
||||
let rustdesk = if config.rustdesk.is_valid() {
|
||||
tracing::info!(
|
||||
"Initializing RustDesk service: ID={} -> {}",
|
||||
@@ -542,7 +515,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// Create RTSP service (optional, based on config)
|
||||
let rtsp = if config.rtsp.enabled {
|
||||
tracing::info!(
|
||||
"Initializing RTSP service: rtsp://{}:{}/{}",
|
||||
@@ -557,15 +529,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// Create application state
|
||||
let update_service = Arc::new(UpdateService::new(data_dir.join("updates")));
|
||||
|
||||
let state = AppState::new(
|
||||
db.clone(),
|
||||
config_store.clone(),
|
||||
session_store,
|
||||
user_store,
|
||||
otg_service,
|
||||
stream_manager,
|
||||
webrtc_streamer.clone(),
|
||||
hid,
|
||||
msd,
|
||||
atx,
|
||||
@@ -581,12 +554,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
extensions.set_event_bus(events.clone()).await;
|
||||
|
||||
// Start RustDesk service if enabled
|
||||
if let Some(ref service) = rustdesk {
|
||||
if let Err(e) = service.start().await {
|
||||
tracing::error!("Failed to start RustDesk service: {}", e);
|
||||
} else {
|
||||
// Save generated keypair and UUID to config
|
||||
if let Some(updated_config) = service.save_credentials() {
|
||||
if let Err(e) = config_store
|
||||
.update(|cfg| {
|
||||
@@ -606,7 +577,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Start RTSP service if enabled
|
||||
if let Some(ref service) = rtsp {
|
||||
if let Err(e) = service.start().await {
|
||||
tracing::error!("Failed to start RTSP service: {}", e);
|
||||
@@ -615,7 +585,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce startup codec constraints (e.g. RTSP/RustDesk locks)
|
||||
{
|
||||
let runtime_config = state.config.get();
|
||||
let constraints = StreamCodecConstraints::from_config(&runtime_config);
|
||||
@@ -630,13 +599,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Start enabled extensions
|
||||
{
|
||||
let ext_config = config_store.get();
|
||||
extensions.start_enabled(&ext_config.extensions).await;
|
||||
}
|
||||
|
||||
// Start extension health check task (every 30 seconds)
|
||||
{
|
||||
let extensions_clone = extensions.clone();
|
||||
let config_store_clone = config_store.clone();
|
||||
@@ -653,17 +620,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
state.publish_device_info().await;
|
||||
|
||||
// Start device info broadcast task
|
||||
// This monitors state change events and broadcasts DeviceInfo to all clients
|
||||
spawn_device_info_broadcaster(state.clone(), events);
|
||||
|
||||
// Create router
|
||||
let app = web::create_router(state.clone());
|
||||
|
||||
// Bind sockets for configured addresses
|
||||
let listeners = bind_tcp_listeners(&bind_ips, bind_port)?;
|
||||
|
||||
// Setup graceful shutdown
|
||||
let shutdown_signal = async move {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
@@ -672,9 +634,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let _ = shutdown_tx.send(());
|
||||
};
|
||||
|
||||
// Start server
|
||||
if config.web.https_enabled {
|
||||
// Generate self-signed certificate if no custom cert provided
|
||||
let tls_config = if let (Some(cert_path), Some(key_path)) =
|
||||
(&config.web.ssl_cert_path, &config.web.ssl_key_path)
|
||||
{
|
||||
@@ -684,7 +644,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cert_path = cert_dir.join("server.crt");
|
||||
let key_path = cert_dir.join("server.key");
|
||||
|
||||
// Check if certificate already exists, only generate if missing
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
tracing::info!("Generating new self-signed TLS certificate");
|
||||
let cert = generate_self_signed_cert()?;
|
||||
@@ -698,7 +657,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
|
||||
};
|
||||
|
||||
let mut servers = FuturesUnordered::new();
|
||||
let servers = FuturesUnordered::new();
|
||||
for listener in listeners {
|
||||
let local_addr = listener.local_addr()?;
|
||||
tracing::info!("Starting HTTPS server on {}", local_addr);
|
||||
@@ -708,19 +667,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
servers.push(server);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_signal => {
|
||||
cleanup(&state).await;
|
||||
}
|
||||
result = servers.next() => {
|
||||
if let Some(Err(e)) = result {
|
||||
tracing::error!("HTTPS server error: {}", e);
|
||||
}
|
||||
cleanup(&state).await;
|
||||
}
|
||||
}
|
||||
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await;
|
||||
} else {
|
||||
let mut servers = FuturesUnordered::new();
|
||||
let servers = FuturesUnordered::new();
|
||||
for listener in listeners {
|
||||
let local_addr = listener.local_addr()?;
|
||||
tracing::info!("Starting HTTP server on {}", local_addr);
|
||||
@@ -730,26 +679,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
servers.push(async move { server.await });
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_signal => {
|
||||
cleanup(&state).await;
|
||||
}
|
||||
result = servers.next() => {
|
||||
if let Some(Err(e)) = result {
|
||||
tracing::error!("HTTP server error: {}", e);
|
||||
}
|
||||
cleanup(&state).await;
|
||||
}
|
||||
}
|
||||
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await;
|
||||
}
|
||||
|
||||
tracing::info!("Server shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize logging with tracing
|
||||
fn init_logging(level: LogLevel, verbose_count: u8) {
|
||||
// Verbose count overrides log level
|
||||
let effective_level = match verbose_count {
|
||||
0 => level,
|
||||
1 => LogLevel::Verbose,
|
||||
@@ -757,7 +694,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
|
||||
_ => LogLevel::Trace,
|
||||
};
|
||||
|
||||
// Build filter string based on effective level
|
||||
let filter = match effective_level {
|
||||
LogLevel::Error => "one_kvm=error,tower_http=error",
|
||||
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
|
||||
@@ -767,7 +703,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
|
||||
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
|
||||
};
|
||||
|
||||
// Environment variable takes highest priority
|
||||
let env_filter =
|
||||
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| filter.into());
|
||||
|
||||
@@ -780,23 +715,48 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the application data directory
|
||||
fn get_data_dir() -> PathBuf {
|
||||
// Check environment variable first
|
||||
if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") {
|
||||
return PathBuf::from(path);
|
||||
}
|
||||
|
||||
// Default to system configuration directory
|
||||
PathBuf::from("/etc/one-kvm")
|
||||
}
|
||||
|
||||
async fn open_database_pool(data_dir: &Path) -> anyhow::Result<DatabasePool> {
|
||||
let db_path = data_dir.join("one-kvm.db");
|
||||
let db = DatabasePool::new(&db_path).await?;
|
||||
db.init_schema().await?;
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
async fn run_servers_until_shutdown<F, E>(
|
||||
mut servers: FuturesUnordered<F>,
|
||||
shutdown_signal: impl Future<Output = ()>,
|
||||
state: &Arc<AppState>,
|
||||
protocol: &'static str,
|
||||
) where
|
||||
F: Future<Output = Result<(), E>> + Send,
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
tokio::select! {
|
||||
_ = shutdown_signal => {
|
||||
cleanup(state).await;
|
||||
}
|
||||
result = servers.next() => {
|
||||
if let Some(Err(e)) = result {
|
||||
tracing::error!("{} server error: {}", protocol, e);
|
||||
}
|
||||
cleanup(state).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Result<()> {
|
||||
tokio::fs::create_dir_all(&data_dir).await?;
|
||||
let db_path = data_dir.join("one-kvm.db");
|
||||
let config_store = ConfigStore::new(&db_path).await?;
|
||||
let users = UserStore::new(config_store.pool().clone());
|
||||
let sessions = SessionStore::new(config_store.pool().clone(), 0);
|
||||
let db = open_database_pool(&data_dir).await?;
|
||||
let users = UserStore::new(db.clone_pool());
|
||||
let sessions = SessionStore::new(0);
|
||||
|
||||
match command {
|
||||
CliCommand::User(user) => run_user_action(user.action, &users, &sessions).await,
|
||||
@@ -814,15 +774,10 @@ async fn run_user_action(
|
||||
}
|
||||
|
||||
async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow::Result<()> {
|
||||
let all = users.list().await?;
|
||||
let user = match all.len() {
|
||||
0 => anyhow::bail!("No local user exists yet; complete setup in the web UI first."),
|
||||
1 => &all[0],
|
||||
_ => anyhow::bail!(
|
||||
"Expected exactly one local user (single-user design), found {}. Remove extra users from the database or contact support.",
|
||||
all.len()
|
||||
),
|
||||
};
|
||||
let user = users
|
||||
.single_user()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("No local user exists yet; complete setup in the web UI first."))?;
|
||||
|
||||
let new_password = read_new_password_interactive()?;
|
||||
if new_password.len() < 4 {
|
||||
@@ -830,7 +785,7 @@ async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow
|
||||
}
|
||||
|
||||
users.update_password(&user.id, &new_password).await?;
|
||||
let revoked = sessions.delete_by_user_id(&user.id).await?;
|
||||
let revoked = sessions.delete_all().await?;
|
||||
|
||||
tracing::info!(
|
||||
"Password updated for user '{}' and {} sessions revoked",
|
||||
@@ -866,7 +821,6 @@ fn read_new_password_interactive() -> anyhow::Result<String> {
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
/// Resolve bind IPs from config, preferring bind_addresses when set.
|
||||
fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result<Vec<IpAddr>> {
|
||||
let raw_addrs = if !web.bind_addresses.is_empty() {
|
||||
web.bind_addresses.as_slice()
|
||||
@@ -907,7 +861,6 @@ fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result<Vec<std::ne
|
||||
Ok(listeners)
|
||||
}
|
||||
|
||||
/// Parse video format and resolution from config (avoids code duplication)
|
||||
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
|
||||
let format = config
|
||||
.video
|
||||
@@ -919,7 +872,6 @@ fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
|
||||
(format, resolution)
|
||||
}
|
||||
|
||||
/// Generate a self-signed TLS certificate
|
||||
fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyPair>> {
|
||||
use rcgen::generate_simple_self_signed;
|
||||
|
||||
@@ -933,8 +885,6 @@ fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyP
|
||||
Ok(certified_key)
|
||||
}
|
||||
|
||||
/// Spawn a background task that monitors state change events
|
||||
/// and broadcasts DeviceInfo to all WebSocket clients with debouncing
|
||||
fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
@@ -1021,7 +971,6 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
|
||||
let mut pending_broadcast = false;
|
||||
|
||||
loop {
|
||||
// Use timeout to handle pending broadcasts
|
||||
let recv_result = if pending_broadcast {
|
||||
let remaining =
|
||||
DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64);
|
||||
@@ -1046,12 +995,9 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
|
||||
tracing::info!("Event bus closed, stopping DeviceInfo broadcaster");
|
||||
break;
|
||||
}
|
||||
Err(_timeout) => {
|
||||
// Debounce timeout reached, broadcast now
|
||||
}
|
||||
Err(_timeout) => {}
|
||||
}
|
||||
|
||||
// Broadcast if pending and debounce time has passed
|
||||
if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) {
|
||||
state.publish_device_info().await;
|
||||
tracing::trace!("Broadcasted DeviceInfo (debounced)");
|
||||
@@ -1067,13 +1013,10 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
|
||||
);
|
||||
}
|
||||
|
||||
/// Clean up subsystems on shutdown
|
||||
async fn cleanup(state: &Arc<AppState>) {
|
||||
// Stop all extensions
|
||||
state.extensions.stop_all().await;
|
||||
tracing::info!("Extensions stopped");
|
||||
|
||||
// Stop RustDesk service
|
||||
if let Some(ref service) = *state.rustdesk.read().await {
|
||||
if let Err(e) = service.stop().await {
|
||||
tracing::warn!("Failed to stop RustDesk service: {}", e);
|
||||
@@ -1082,7 +1025,6 @@ async fn cleanup(state: &Arc<AppState>) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop RTSP service
|
||||
if let Some(ref service) = *state.rtsp.read().await {
|
||||
if let Err(e) = service.stop().await {
|
||||
tracing::warn!("Failed to stop RTSP service: {}", e);
|
||||
@@ -1091,31 +1033,26 @@ async fn cleanup(state: &Arc<AppState>) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop video
|
||||
if let Err(e) = state.stream_manager.stop().await {
|
||||
tracing::warn!("Failed to stop streamer: {}", e);
|
||||
}
|
||||
|
||||
// Shutdown HID
|
||||
if let Err(e) = state.hid.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown HID: {}", e);
|
||||
}
|
||||
|
||||
// Shutdown MSD
|
||||
if let Some(msd) = state.msd.write().await.as_mut() {
|
||||
if let Err(e) = msd.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown MSD: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown ATX
|
||||
if let Some(atx) = state.atx.write().await.as_mut() {
|
||||
if let Err(e) = atx.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown ATX: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown Audio
|
||||
if let Err(e) = state.audio.shutdown().await {
|
||||
tracing::warn!("Failed to shutdown audio: {}", e);
|
||||
}
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
//! Module management for One-KVM
|
||||
//!
|
||||
//! This module provides infrastructure for managing feature modules
|
||||
//! (video streaming, HID control, MSD, ATX) as independent async tasks.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Module status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ModuleStatus {
|
||||
Stopped,
|
||||
Starting,
|
||||
Running,
|
||||
Stopping,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Trait for feature modules
|
||||
pub trait Module: Send + Sync {
|
||||
/// Module name
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Current status
|
||||
fn status(&self) -> ModuleStatus;
|
||||
|
||||
/// Start the module
|
||||
fn start(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
|
||||
|
||||
/// Stop the module
|
||||
fn stop(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
|
||||
}
|
||||
|
||||
/// Module manager for coordinating feature modules
|
||||
pub struct ModuleManager {
|
||||
shutdown_rx: broadcast::Receiver<()>,
|
||||
}
|
||||
|
||||
impl ModuleManager {
|
||||
pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self {
|
||||
Self { shutdown_rx }
|
||||
}
|
||||
|
||||
/// Wait for shutdown signal
|
||||
pub async fn wait_for_shutdown(&mut self) {
|
||||
let _ = self.shutdown_rx.recv().await;
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,3 @@
|
||||
//! MSD Controller
|
||||
//!
|
||||
//! Manages the mass storage device lifecycle including:
|
||||
//! - Image mounting and unmounting
|
||||
//! - Virtual drive management
|
||||
//! - State tracking
|
||||
//! - Image downloads from URL
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -14,41 +6,25 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::image::ImageManager;
|
||||
use super::monitor::{MsdHealthMonitor, MsdHealthStatus};
|
||||
use super::monitor::MsdHealthMonitor;
|
||||
use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState};
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::otg::{MsdFunction, MsdLunConfig, OtgService};
|
||||
|
||||
/// MSD Controller
|
||||
pub struct MsdController {
|
||||
/// OTG Service reference
|
||||
otg_service: Arc<OtgService>,
|
||||
/// MSD function manager (provided by OtgService)
|
||||
msd_function: RwLock<Option<MsdFunction>>,
|
||||
/// Current state
|
||||
state: RwLock<MsdState>,
|
||||
/// Images storage path
|
||||
images_path: PathBuf,
|
||||
/// Ventoy directory path
|
||||
ventoy_dir: PathBuf,
|
||||
/// Virtual drive path
|
||||
drive_path: PathBuf,
|
||||
/// Event bus for broadcasting state changes (optional)
|
||||
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
|
||||
/// Active downloads (download_id -> CancellationToken)
|
||||
downloads: Arc<RwLock<HashMap<String, CancellationToken>>>,
|
||||
/// Operation mutex lock (prevents concurrent operations)
|
||||
operation_lock: Arc<RwLock<()>>,
|
||||
/// Health monitor for error tracking and recovery
|
||||
monitor: Arc<MsdHealthMonitor>,
|
||||
}
|
||||
|
||||
impl MsdController {
|
||||
/// Create new MSD controller
|
||||
///
|
||||
/// # Parameters
|
||||
/// * `otg_service` - OTG service for gadget management
|
||||
/// * `msd_dir` - Base directory for MSD storage
|
||||
pub fn new(otg_service: Arc<OtgService>, msd_dir: impl Into<PathBuf>) -> Self {
|
||||
let msd_dir = msd_dir.into();
|
||||
let images_path = msd_dir.join("images");
|
||||
@@ -68,11 +44,9 @@ impl MsdController {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the MSD controller
|
||||
pub async fn init(&self) -> Result<()> {
|
||||
info!("Initializing MSD controller");
|
||||
|
||||
// 1. Ensure images directory exists
|
||||
if let Err(e) = std::fs::create_dir_all(&self.images_path) {
|
||||
warn!("Failed to create images directory: {}", e);
|
||||
}
|
||||
@@ -80,20 +54,16 @@ impl MsdController {
|
||||
warn!("Failed to create ventoy directory: {}", e);
|
||||
}
|
||||
|
||||
// 2. Get active MSD function from OtgService
|
||||
info!("Fetching MSD function from OtgService");
|
||||
let msd_func = self.otg_service.msd_function().await.ok_or_else(|| {
|
||||
AppError::Internal("MSD function is not active in OtgService".to_string())
|
||||
})?;
|
||||
|
||||
// 3. Store function handle
|
||||
*self.msd_function.write().await = Some(msd_func);
|
||||
|
||||
// 4. Update state
|
||||
let mut state = self.state.write().await;
|
||||
state.available = true;
|
||||
|
||||
// 5. Check for existing virtual drive
|
||||
if self.drive_path.exists() {
|
||||
if let Ok(metadata) = std::fs::metadata(&self.drive_path) {
|
||||
state.drive_info = Some(DriveInfo {
|
||||
@@ -114,17 +84,14 @@ impl MsdController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current MSD state
|
||||
pub async fn state(&self) -> MsdState {
|
||||
self.state.read().await.clone()
|
||||
}
|
||||
|
||||
/// Set event bus for broadcasting state changes
|
||||
pub async fn set_event_bus(&self, events: std::sync::Arc<crate::events::EventBus>) {
|
||||
*self.events.write().await = Some(events);
|
||||
}
|
||||
|
||||
/// Publish an event to the event bus
|
||||
async fn publish_event(&self, event: crate::events::SystemEvent) {
|
||||
if let Some(ref bus) = *self.events.read().await {
|
||||
bus.publish(event);
|
||||
@@ -137,43 +104,21 @@ impl MsdController {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if MSD is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
self.state.read().await.available
|
||||
}
|
||||
|
||||
/// Connect an image file
|
||||
///
|
||||
/// # Parameters
|
||||
/// * `image` - Image info to mount
|
||||
/// * `cdrom` - Mount as CD-ROM (read-only, removable)
|
||||
/// * `read_only` - Mount as read-only
|
||||
pub async fn connect_image(
|
||||
&self,
|
||||
image: &ImageInfo,
|
||||
cdrom: bool,
|
||||
read_only: bool,
|
||||
) -> Result<()> {
|
||||
// Acquire operation lock to prevent concurrent operations
|
||||
let _op_guard = self.operation_lock.write().await;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
if !state.available {
|
||||
let err = AppError::Internal("MSD not available".to_string());
|
||||
self.monitor
|
||||
.report_error("MSD not available", "not_available")
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
self.assert_can_connect(&state).await?;
|
||||
|
||||
if state.connected {
|
||||
return Err(AppError::Internal(
|
||||
"Already connected. Disconnect first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Verify image exists
|
||||
if !image.path.exists() {
|
||||
let error_msg = format!("Image file not found: {}", image.path.display());
|
||||
self.monitor
|
||||
@@ -182,29 +127,12 @@ impl MsdController {
|
||||
return Err(AppError::Internal(error_msg));
|
||||
}
|
||||
|
||||
// Configure LUN
|
||||
let config = if cdrom {
|
||||
MsdLunConfig::cdrom(image.path.clone())
|
||||
} else {
|
||||
MsdLunConfig::disk(image.path.clone(), read_only)
|
||||
};
|
||||
|
||||
let gadget_path = self.active_gadget_path().await?;
|
||||
if let Some(ref msd) = *self.msd_function.read().await {
|
||||
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
|
||||
let error_msg = format!("Failed to configure LUN: {}", e);
|
||||
self.monitor
|
||||
.report_error(&error_msg, "configfs_error")
|
||||
.await;
|
||||
return Err(e);
|
||||
}
|
||||
} else {
|
||||
let err = AppError::Internal("MSD function not initialized".to_string());
|
||||
self.monitor
|
||||
.report_error("MSD function not initialized", "not_initialized")
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
self.configure_lun_now(&config).await?;
|
||||
|
||||
state.connected = true;
|
||||
state.mode = MsdMode::Image;
|
||||
@@ -215,42 +143,19 @@ impl MsdController {
|
||||
image.name, cdrom, read_only
|
||||
);
|
||||
|
||||
// Release the lock before publishing events
|
||||
drop(state);
|
||||
drop(_op_guard);
|
||||
|
||||
// Report recovery if we were in an error state
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered().await;
|
||||
}
|
||||
|
||||
self.mark_device_info_dirty().await;
|
||||
|
||||
self.finish_connect_success().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect the virtual drive
|
||||
pub async fn connect_drive(&self) -> Result<()> {
|
||||
// Acquire operation lock to prevent concurrent operations
|
||||
let _op_guard = self.operation_lock.write().await;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
if !state.available {
|
||||
let err = AppError::Internal("MSD not available".to_string());
|
||||
self.monitor
|
||||
.report_error("MSD not available", "not_available")
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
self.assert_can_connect(&state).await?;
|
||||
|
||||
if state.connected {
|
||||
return Err(AppError::Internal(
|
||||
"Already connected. Disconnect first.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check drive exists
|
||||
if !self.drive_path.exists() {
|
||||
let err =
|
||||
AppError::Internal("Virtual drive not initialized. Call init first.".to_string());
|
||||
@@ -260,25 +165,8 @@ impl MsdController {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Configure LUN as read-write disk
|
||||
let config = MsdLunConfig::disk(self.drive_path.clone(), false);
|
||||
|
||||
let gadget_path = self.active_gadget_path().await?;
|
||||
if let Some(ref msd) = *self.msd_function.read().await {
|
||||
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
|
||||
let error_msg = format!("Failed to configure LUN: {}", e);
|
||||
self.monitor
|
||||
.report_error(&error_msg, "configfs_error")
|
||||
.await;
|
||||
return Err(e);
|
||||
}
|
||||
} else {
|
||||
let err = AppError::Internal("MSD function not initialized".to_string());
|
||||
self.monitor
|
||||
.report_error("MSD function not initialized", "not_initialized")
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
self.configure_lun_now(&config).await?;
|
||||
|
||||
state.connected = true;
|
||||
state.mode = MsdMode::Drive;
|
||||
@@ -286,23 +174,57 @@ impl MsdController {
|
||||
|
||||
info!("Connected virtual drive: {}", self.drive_path.display());
|
||||
|
||||
// Release the lock before publishing event
|
||||
drop(state);
|
||||
drop(_op_guard);
|
||||
|
||||
// Report recovery if we were in an error state
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered().await;
|
||||
}
|
||||
|
||||
self.mark_device_info_dirty().await;
|
||||
|
||||
self.finish_connect_success().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect current storage
|
||||
async fn assert_can_connect(&self, state: &MsdState) -> Result<()> {
|
||||
if !state.available {
|
||||
self.monitor
|
||||
.report_error("MSD not available", "not_available")
|
||||
.await;
|
||||
return Err(AppError::Internal("MSD not available".to_string()));
|
||||
}
|
||||
if state.connected {
|
||||
return Err(AppError::Internal(
|
||||
"Already connected. Disconnect first.".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn configure_lun_now(&self, config: &MsdLunConfig) -> Result<()> {
|
||||
let gadget_path = self.active_gadget_path().await?;
|
||||
let msd_hold = self.msd_function.read().await;
|
||||
let Some(ref msd) = *msd_hold else {
|
||||
self.monitor
|
||||
.report_error("MSD function not initialized", "not_initialized")
|
||||
.await;
|
||||
return Err(AppError::Internal(
|
||||
"MSD function not initialized".to_string(),
|
||||
));
|
||||
};
|
||||
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, config).await {
|
||||
let error_msg = format!("Failed to configure LUN: {}", e);
|
||||
self.monitor
|
||||
.report_error(&error_msg, "configfs_error")
|
||||
.await;
|
||||
return Err(e);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn finish_connect_success(&self) {
|
||||
if self.monitor.is_error().await {
|
||||
self.monitor.report_recovered().await;
|
||||
}
|
||||
self.mark_device_info_dirty().await;
|
||||
}
|
||||
|
||||
pub async fn disconnect(&self) -> Result<()> {
|
||||
// Acquire operation lock to prevent concurrent operations
|
||||
let _op_guard = self.operation_lock.write().await;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
@@ -323,7 +245,6 @@ impl MsdController {
|
||||
|
||||
info!("Disconnected storage");
|
||||
|
||||
// Release the lock before publishing events
|
||||
drop(state);
|
||||
drop(_op_guard);
|
||||
|
||||
@@ -332,41 +253,31 @@ impl MsdController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get images storage path
|
||||
pub fn images_path(&self) -> &PathBuf {
|
||||
&self.images_path
|
||||
}
|
||||
|
||||
/// Get ventoy directory path
|
||||
pub fn ventoy_dir(&self) -> &PathBuf {
|
||||
&self.ventoy_dir
|
||||
}
|
||||
|
||||
/// Get virtual drive path
|
||||
pub fn drive_path(&self) -> &PathBuf {
|
||||
&self.drive_path
|
||||
}
|
||||
|
||||
/// Check if currently connected
|
||||
pub async fn is_connected(&self) -> bool {
|
||||
self.state.read().await.connected
|
||||
}
|
||||
|
||||
/// Get current mode
|
||||
pub async fn mode(&self) -> MsdMode {
|
||||
self.state.read().await.mode.clone()
|
||||
}
|
||||
|
||||
/// Update drive info
|
||||
pub async fn update_drive_info(&self, info: DriveInfo) {
|
||||
let mut state = self.state.write().await;
|
||||
state.drive_info = Some(info);
|
||||
}
|
||||
|
||||
/// Start downloading an image from URL
|
||||
///
|
||||
/// Returns the download_id that can be used to track or cancel the download.
|
||||
/// Progress is reported via MsdDownloadProgress events.
|
||||
pub async fn download_image(
|
||||
&self,
|
||||
url: String,
|
||||
@@ -375,18 +286,15 @@ impl MsdController {
|
||||
let download_id = uuid::Uuid::new_v4().to_string();
|
||||
let cancel_token = CancellationToken::new();
|
||||
|
||||
// Register download
|
||||
{
|
||||
let mut downloads = self.downloads.write().await;
|
||||
downloads.insert(download_id.clone(), cancel_token.clone());
|
||||
}
|
||||
|
||||
// Extract filename for initial response
|
||||
let display_filename = filename
|
||||
.clone()
|
||||
.unwrap_or_else(|| url.rsplit('/').next().unwrap_or("download").to_string());
|
||||
|
||||
// Create initial progress
|
||||
let initial_progress = DownloadProgress {
|
||||
download_id: download_id.clone(),
|
||||
url: url.clone(),
|
||||
@@ -398,7 +306,6 @@ impl MsdController {
|
||||
error: None,
|
||||
};
|
||||
|
||||
// Publish started event
|
||||
self.publish_event(crate::events::SystemEvent::MsdDownloadProgress {
|
||||
download_id: download_id.clone(),
|
||||
url: url.clone(),
|
||||
@@ -410,18 +317,15 @@ impl MsdController {
|
||||
})
|
||||
.await;
|
||||
|
||||
// Clone what we need for the spawned task
|
||||
let images_path = self.images_path.clone();
|
||||
let events = self.events.read().await.clone();
|
||||
let downloads = self.downloads.clone();
|
||||
let download_id_clone = download_id.clone();
|
||||
let url_clone = url.clone();
|
||||
|
||||
// Spawn download task
|
||||
tokio::spawn(async move {
|
||||
let manager = ImageManager::new(images_path);
|
||||
|
||||
// Create progress callback
|
||||
let events_for_callback = events.clone();
|
||||
let download_id_for_callback = download_id_clone.clone();
|
||||
let url_for_callback = url_clone.clone();
|
||||
@@ -443,18 +347,15 @@ impl MsdController {
|
||||
}
|
||||
};
|
||||
|
||||
// Run download
|
||||
let result = manager
|
||||
.download_from_url(&url_clone, filename, progress_callback)
|
||||
.await;
|
||||
|
||||
// Remove from active downloads
|
||||
{
|
||||
let mut downloads_guard = downloads.write().await;
|
||||
downloads_guard.remove(&download_id_clone);
|
||||
}
|
||||
|
||||
// Publish completion event
|
||||
match result {
|
||||
Ok(image_info) => {
|
||||
if let Some(ref bus) = events {
|
||||
@@ -489,7 +390,6 @@ impl MsdController {
|
||||
Ok(initial_progress)
|
||||
}
|
||||
|
||||
/// Cancel an active download
|
||||
pub async fn cancel_download(&self, download_id: &str) -> Result<()> {
|
||||
let mut downloads = self.downloads.write().await;
|
||||
|
||||
@@ -505,12 +405,6 @@ impl MsdController {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of active download IDs
|
||||
pub async fn active_downloads(&self) -> Vec<String> {
|
||||
let downloads = self.downloads.read().await;
|
||||
downloads.keys().cloned().collect()
|
||||
}
|
||||
|
||||
async fn active_gadget_path(&self) -> Result<PathBuf> {
|
||||
self.otg_service
|
||||
.gadget_path()
|
||||
@@ -518,16 +412,13 @@ impl MsdController {
|
||||
.ok_or_else(|| AppError::Internal("OTG gadget path is not available".to_string()))
|
||||
}
|
||||
|
||||
/// Shutdown the controller
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down MSD controller");
|
||||
|
||||
// 1. Disconnect if connected
|
||||
if let Err(e) = self.disconnect().await {
|
||||
warn!("Error disconnecting during shutdown: {}", e);
|
||||
}
|
||||
|
||||
// 2. Clear local state
|
||||
*self.msd_function.write().await = None;
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
@@ -537,27 +428,9 @@ impl MsdController {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the health monitor reference
|
||||
pub fn monitor(&self) -> &Arc<MsdHealthMonitor> {
|
||||
&self.monitor
|
||||
}
|
||||
|
||||
/// Get current health status
|
||||
pub async fn health_status(&self) -> MsdHealthStatus {
|
||||
self.monitor.status().await
|
||||
}
|
||||
|
||||
/// Check if the MSD is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.monitor.is_healthy().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MsdController {
|
||||
fn drop(&mut self) {
|
||||
// Cleanup is handled by OtgGadgetManager when the gadget is torn down
|
||||
// Individual controllers don't need to cleanup the ConfigFS
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -573,7 +446,6 @@ mod tests {
|
||||
|
||||
let controller = MsdController::new(otg_service, &msd_dir);
|
||||
|
||||
// Check that MSD is not initialized (msd_function is None)
|
||||
let state = controller.state().await;
|
||||
assert!(!state.available);
|
||||
assert!(controller.images_path.ends_with("images"));
|
||||
|
||||
183
src/msd/image.rs
183
src/msd/image.rs
@@ -1,53 +1,37 @@
|
||||
//! Image file manager
|
||||
//!
|
||||
//! Handles ISO/IMG image file operations:
|
||||
//! - List available images
|
||||
//! - Upload new images
|
||||
//! - Delete images
|
||||
//! - Metadata management
|
||||
//! - Download from URL
|
||||
|
||||
use futures::StreamExt;
|
||||
use time::OffsetDateTime;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, Read, Write};
|
||||
use std::fs;
|
||||
#[cfg(test)]
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, Instant};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::info;
|
||||
|
||||
use super::types::ImageInfo;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Maximum image size (32 GB)
|
||||
const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024;
|
||||
|
||||
/// Progress report throttle interval (milliseconds)
|
||||
const PROGRESS_THROTTLE_MS: u64 = 200;
|
||||
|
||||
/// Progress report throttle bytes threshold (512 KB)
|
||||
const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024;
|
||||
|
||||
/// Image Manager
|
||||
pub struct ImageManager {
|
||||
/// Images storage directory
|
||||
images_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ImageManager {
|
||||
/// Create a new image manager
|
||||
pub fn new(images_path: PathBuf) -> Self {
|
||||
Self { images_path }
|
||||
}
|
||||
|
||||
/// Ensure images directory exists
|
||||
pub fn ensure_dir(&self) -> Result<()> {
|
||||
fs::create_dir_all(&self.images_path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create images directory: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all available images
|
||||
pub fn list(&self) -> Result<Vec<ImageInfo>> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
@@ -68,19 +52,16 @@ impl ImageManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by creation time (newest first)
|
||||
images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
/// Get image info from path
|
||||
fn get_image_info(&self, path: &Path) -> Option<ImageInfo> {
|
||||
let metadata = fs::metadata(path).ok()?;
|
||||
let name = path.file_name()?.to_string_lossy().to_string();
|
||||
|
||||
// Use filename hash as ID (stable across restarts)
|
||||
let id = format!("{:x}", md5_hash(&name));
|
||||
let id = stable_image_id_from_filename(&name);
|
||||
|
||||
let created_at = metadata
|
||||
.created()
|
||||
@@ -101,7 +82,6 @@ impl ImageManager {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get image by ID
|
||||
pub fn get(&self, id: &str) -> Result<ImageInfo> {
|
||||
for image in self.list()? {
|
||||
if image.id == id {
|
||||
@@ -111,24 +91,21 @@ impl ImageManager {
|
||||
Err(AppError::NotFound(format!("Image not found: {}", id)))
|
||||
}
|
||||
|
||||
/// Get image by name
|
||||
pub fn get_by_name(&self, name: &str) -> Result<ImageInfo> {
|
||||
let path = self.images_path.join(name);
|
||||
self.get_image_info(&path)
|
||||
.ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name)))
|
||||
}
|
||||
|
||||
/// Create a new image from bytes
|
||||
pub fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
|
||||
#[cfg(test)]
|
||||
fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
// Validate name
|
||||
let name = sanitize_filename(name);
|
||||
if name.is_empty() {
|
||||
return Err(AppError::Internal("Invalid filename".to_string()));
|
||||
}
|
||||
|
||||
// Check size
|
||||
if data.len() as u64 > MAX_IMAGE_SIZE {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image too large. Maximum size: {} GB",
|
||||
@@ -136,7 +113,6 @@ impl ImageManager {
|
||||
)));
|
||||
}
|
||||
|
||||
// Write file
|
||||
let path = self.images_path.join(&name);
|
||||
if path.exists() {
|
||||
return Err(AppError::Internal(format!(
|
||||
@@ -145,11 +121,10 @@ impl ImageManager {
|
||||
)));
|
||||
}
|
||||
|
||||
let mut file = File::create(&path)
|
||||
let mut file = fs::File::create(&path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
|
||||
|
||||
file.write_all(data).map_err(|e| {
|
||||
// Try to clean up on error
|
||||
let _ = fs::remove_file(&path);
|
||||
AppError::Internal(format!("Failed to write image data: {}", e))
|
||||
})?;
|
||||
@@ -159,55 +134,6 @@ impl ImageManager {
|
||||
self.get_by_name(&name)
|
||||
}
|
||||
|
||||
/// Create a new image from a file stream (for chunked uploads)
|
||||
pub fn create_from_stream<R: Read>(
|
||||
&self,
|
||||
name: &str,
|
||||
reader: &mut R,
|
||||
expected_size: Option<u64>,
|
||||
) -> Result<ImageInfo> {
|
||||
self.ensure_dir()?;
|
||||
|
||||
let name = sanitize_filename(name);
|
||||
if name.is_empty() {
|
||||
return Err(AppError::Internal("Invalid filename".to_string()));
|
||||
}
|
||||
|
||||
if let Some(size) = expected_size {
|
||||
if size > MAX_IMAGE_SIZE {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image too large. Maximum size: {} GB",
|
||||
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let path = self.images_path.join(&name);
|
||||
if path.exists() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image already exists: {}",
|
||||
name
|
||||
)));
|
||||
}
|
||||
|
||||
// Create file and copy data
|
||||
let mut file = File::create(&path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
|
||||
|
||||
let bytes_written = io::copy(reader, &mut file).map_err(|e| {
|
||||
let _ = fs::remove_file(&path);
|
||||
AppError::Internal(format!("Failed to write image data: {}", e))
|
||||
})?;
|
||||
|
||||
info!("Created image: {} ({} bytes)", name, bytes_written);
|
||||
|
||||
self.get_by_name(&name)
|
||||
}
|
||||
|
||||
/// Create a new image from an async multipart field (streaming, memory-efficient)
|
||||
///
|
||||
/// This method streams data directly to disk without buffering the entire file in memory,
|
||||
/// making it suitable for large files (multi-GB ISOs).
|
||||
pub async fn create_from_multipart_field(
|
||||
&self,
|
||||
name: &str,
|
||||
@@ -220,12 +146,10 @@ impl ImageManager {
|
||||
return Err(AppError::Internal("Invalid filename".to_string()));
|
||||
}
|
||||
|
||||
// Use a temporary file during upload
|
||||
let temp_name = format!(".upload_{}", uuid::Uuid::new_v4());
|
||||
let temp_path = self.images_path.join(&temp_name);
|
||||
let final_path = self.images_path.join(&name);
|
||||
|
||||
// Check if final file already exists
|
||||
if final_path.exists() {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Image already exists: {}",
|
||||
@@ -233,23 +157,19 @@ impl ImageManager {
|
||||
)));
|
||||
}
|
||||
|
||||
// Create temp file
|
||||
let mut file = tokio::fs::File::create(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
|
||||
|
||||
let mut bytes_written: u64 = 0;
|
||||
|
||||
// Stream chunks directly to disk
|
||||
while let Some(chunk) = field
|
||||
.chunk()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))?
|
||||
{
|
||||
// Check size limit
|
||||
bytes_written += chunk.len() as u64;
|
||||
if bytes_written > MAX_IMAGE_SIZE {
|
||||
// Cleanup and return error
|
||||
drop(file);
|
||||
let _ = tokio::fs::remove_file(&temp_path).await;
|
||||
return Err(AppError::Internal(format!(
|
||||
@@ -258,19 +178,16 @@ impl ImageManager {
|
||||
)));
|
||||
}
|
||||
|
||||
// Write chunk to file
|
||||
file.write_all(&chunk)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?;
|
||||
}
|
||||
|
||||
// Flush and close file
|
||||
file.flush()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
|
||||
drop(file);
|
||||
|
||||
// Move temp file to final location
|
||||
tokio::fs::rename(&temp_path, &final_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -286,7 +203,6 @@ impl ImageManager {
|
||||
self.get_by_name(&name)
|
||||
}
|
||||
|
||||
/// Delete an image by ID
|
||||
pub fn delete(&self, id: &str) -> Result<()> {
|
||||
let image = self.get(id)?;
|
||||
|
||||
@@ -297,45 +213,6 @@ impl ImageManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete an image by name
|
||||
pub fn delete_by_name(&self, name: &str) -> Result<()> {
|
||||
let path = self.images_path.join(name);
|
||||
|
||||
if !path.exists() {
|
||||
return Err(AppError::NotFound(format!("Image not found: {}", name)));
|
||||
}
|
||||
|
||||
fs::remove_file(&path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?;
|
||||
|
||||
info!("Deleted image: {}", name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get total storage used
|
||||
pub fn used_space(&self) -> u64 {
|
||||
self.list()
|
||||
.map(|images| images.iter().map(|i| i.size).sum())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Check if storage has space for new image
|
||||
pub fn has_space(&self, size: u64) -> bool {
|
||||
// For now, just check against max size
|
||||
// In the future, could check disk space
|
||||
size <= MAX_IMAGE_SIZE
|
||||
}
|
||||
|
||||
/// Download image from URL with progress callback
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `url` - The URL to download from
|
||||
/// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided)
|
||||
/// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(ImageInfo)` - The downloaded image info
|
||||
/// * `Err(AppError)` - If download fails
|
||||
pub async fn download_from_url<F>(
|
||||
&self,
|
||||
url: &str,
|
||||
@@ -347,20 +224,17 @@ impl ImageManager {
|
||||
{
|
||||
self.ensure_dir()?;
|
||||
|
||||
// Validate URL
|
||||
let parsed_url = reqwest::Url::parse(url)
|
||||
.map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?;
|
||||
|
||||
info!("Starting download from: {}", url);
|
||||
|
||||
// Create HTTP client with timeout
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files
|
||||
.timeout(std::time::Duration::from_secs(3600))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?;
|
||||
|
||||
// Send HEAD request first to get content info
|
||||
let head_response = client
|
||||
.head(url)
|
||||
.send()
|
||||
@@ -380,7 +254,6 @@ impl ImageManager {
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
// Check file size
|
||||
if let Some(size) = total_size {
|
||||
if size > MAX_IMAGE_SIZE {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
@@ -391,11 +264,9 @@ impl ImageManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Determine filename
|
||||
let final_filename = if let Some(name) = filename {
|
||||
sanitize_filename(&name)
|
||||
} else {
|
||||
// Try Content-Disposition header first
|
||||
let from_header = head_response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_DISPOSITION)
|
||||
@@ -405,7 +276,6 @@ impl ImageManager {
|
||||
if let Some(name) = from_header {
|
||||
sanitize_filename(&name)
|
||||
} else {
|
||||
// Fall back to URL path
|
||||
let path = parsed_url.path();
|
||||
let name = path.rsplit('/').next().unwrap_or("download");
|
||||
let name = urlencoding::decode(name).unwrap_or_else(|_| name.into());
|
||||
@@ -419,7 +289,6 @@ impl ImageManager {
|
||||
));
|
||||
}
|
||||
|
||||
// Check if file already exists
|
||||
let final_path = self.images_path.join(&final_filename);
|
||||
if final_path.exists() {
|
||||
return Err(AppError::BadRequest(format!(
|
||||
@@ -428,11 +297,9 @@ impl ImageManager {
|
||||
)));
|
||||
}
|
||||
|
||||
// Create temporary file for download
|
||||
let temp_filename = format!(".download_{}", uuid::Uuid::new_v4());
|
||||
let temp_path = self.images_path.join(&temp_filename);
|
||||
|
||||
// Start actual download
|
||||
let response = client
|
||||
.get(url)
|
||||
.send()
|
||||
@@ -446,7 +313,6 @@ impl ImageManager {
|
||||
)));
|
||||
}
|
||||
|
||||
// Get actual content length from response (may differ from HEAD)
|
||||
let content_length = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_LENGTH)
|
||||
@@ -454,19 +320,16 @@ impl ImageManager {
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.or(total_size);
|
||||
|
||||
// Create temp file
|
||||
let mut file = tokio::fs::File::create(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
|
||||
|
||||
// Stream download with progress (throttled)
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut last_report_time = Instant::now();
|
||||
let mut last_reported_bytes: u64 = 0;
|
||||
let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS);
|
||||
|
||||
// Report initial progress
|
||||
progress_callback(0, content_length);
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
@@ -474,14 +337,12 @@ impl ImageManager {
|
||||
chunk_result.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?;
|
||||
|
||||
file.write_all(&chunk).await.map_err(|e| {
|
||||
// Cleanup on error
|
||||
let _ = std::fs::remove_file(&temp_path);
|
||||
AppError::Internal(format!("Failed to write data: {}", e))
|
||||
})?;
|
||||
|
||||
downloaded += chunk.len() as u64;
|
||||
|
||||
// Throttled progress reporting: report if enough time or bytes have passed
|
||||
let now = Instant::now();
|
||||
let time_elapsed = now.duration_since(last_report_time) >= throttle_interval;
|
||||
let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES;
|
||||
@@ -493,18 +354,15 @@ impl ImageManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Always report final progress
|
||||
if downloaded != last_reported_bytes {
|
||||
progress_callback(downloaded, content_length);
|
||||
}
|
||||
|
||||
// Ensure all data is flushed
|
||||
file.flush()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
|
||||
drop(file);
|
||||
|
||||
// Verify downloaded size
|
||||
let metadata = tokio::fs::metadata(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?;
|
||||
@@ -520,7 +378,6 @@ impl ImageManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Move temp file to final location
|
||||
tokio::fs::rename(&temp_path, &final_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -534,35 +391,29 @@ impl ImageManager {
|
||||
metadata.len()
|
||||
);
|
||||
|
||||
// Return image info
|
||||
self.get_by_name(&final_filename)
|
||||
}
|
||||
|
||||
/// Get images storage path
|
||||
pub fn images_path(&self) -> &PathBuf {
|
||||
&self.images_path
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple hash function for generating stable IDs
|
||||
fn md5_hash(s: &str) -> u64 {
|
||||
fn stable_image_id_from_filename(name: &str) -> String {
|
||||
let mut hash: u64 = 0;
|
||||
for (i, byte) in s.bytes().enumerate() {
|
||||
for (i, byte) in name.bytes().enumerate() {
|
||||
hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1)));
|
||||
hash = hash.wrapping_mul(31);
|
||||
}
|
||||
hash
|
||||
format!("{:x}", hash)
|
||||
}
|
||||
|
||||
/// Sanitize filename to prevent path traversal
|
||||
fn sanitize_filename(name: &str) -> String {
|
||||
let name = name.trim();
|
||||
let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_");
|
||||
|
||||
// Remove leading dots (hidden files)
|
||||
let name = name.trim_start_matches('.');
|
||||
|
||||
// Limit length
|
||||
if name.len() > 255 {
|
||||
name[..255].to_string()
|
||||
} else {
|
||||
@@ -570,17 +421,10 @@ fn sanitize_filename(name: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract filename from Content-Disposition header
|
||||
fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
|
||||
// Handle both:
|
||||
// Content-Disposition: attachment; filename="example.iso"
|
||||
// Content-Disposition: attachment; filename*=UTF-8''example.iso
|
||||
|
||||
// Try filename* first (RFC 5987)
|
||||
if let Some(pos) = header.find("filename*=") {
|
||||
let start = pos + 10;
|
||||
let value = &header[start..];
|
||||
// Format: charset'language'value
|
||||
if let Some(quote_start) = value.find("''") {
|
||||
let encoded = value[quote_start + 2..].split(';').next()?;
|
||||
let decoded = urlencoding::decode(encoded.trim()).ok()?;
|
||||
@@ -591,7 +435,6 @@ fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
// Try filename next
|
||||
if let Some(pos) = header.find("filename=") {
|
||||
let start = pos + 9;
|
||||
let value = &header[start..];
|
||||
@@ -613,7 +456,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_sanitize_filename() {
|
||||
assert_eq!(sanitize_filename("test.iso"), "test.iso");
|
||||
assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.')
|
||||
assert_eq!(sanitize_filename("../test.iso"), "_test.iso");
|
||||
assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso");
|
||||
assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso");
|
||||
}
|
||||
|
||||
@@ -1,19 +1,3 @@
|
||||
//! MSD (Mass Storage Device) module
|
||||
//!
|
||||
//! Provides virtual USB storage functionality with two modes:
|
||||
//! - Image mounting: Mount ISO/IMG files for system installation
|
||||
//! - Ventoy drive: Bootable exFAT drive for multiple ISO files
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC
|
||||
//! |
|
||||
//! ┌──────┴──────┐
|
||||
//! │ │
|
||||
//! Image Manager Ventoy Drive
|
||||
//! (ISO/IMG) (Bootable exFAT)
|
||||
//! ```
|
||||
|
||||
pub mod controller;
|
||||
pub mod image;
|
||||
pub mod monitor;
|
||||
@@ -22,12 +6,11 @@ pub mod ventoy_drive;
|
||||
|
||||
pub use controller::MsdController;
|
||||
pub use image::ImageManager;
|
||||
pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig};
|
||||
pub use monitor::MsdHealthMonitor;
|
||||
pub use types::{
|
||||
DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest,
|
||||
ImageInfo, MsdConnectRequest, MsdMode, MsdState,
|
||||
};
|
||||
pub use ventoy_drive::VentoyDrive;
|
||||
|
||||
// Re-export from otg module for backward compatibility
|
||||
pub use crate::otg::{MsdFunction, MsdLunConfig};
|
||||
|
||||
@@ -1,99 +1,46 @@
|
||||
//! MSD (Mass Storage Device) health monitoring
|
||||
//!
|
||||
//! This module provides health monitoring for MSD operations, including:
|
||||
//! - ConfigFS operation error tracking
|
||||
//! - Image mount/unmount error tracking
|
||||
//! - Error state tracking
|
||||
//! - Log throttling to prevent log flooding
|
||||
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::utils::LogThrottler;
|
||||
|
||||
/// MSD health status
|
||||
const LOG_THROTTLE_SECS: u64 = 5;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub enum MsdHealthStatus {
|
||||
/// Device is healthy and operational
|
||||
pub(crate) enum MsdHealthStatus {
|
||||
#[default]
|
||||
Healthy,
|
||||
/// Device has an error
|
||||
Error {
|
||||
/// Human-readable error reason
|
||||
reason: String,
|
||||
/// Error code for programmatic handling
|
||||
error_code: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// MSD health monitor configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MsdMonitorConfig {
|
||||
/// Log throttle interval in seconds
|
||||
pub log_throttle_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for MsdMonitorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
log_throttle_secs: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD health monitor
|
||||
///
|
||||
/// Monitors MSD operation health and manages error state.
|
||||
pub struct MsdHealthMonitor {
|
||||
/// Current health status
|
||||
status: RwLock<MsdHealthStatus>,
|
||||
/// Log throttler to prevent log flooding
|
||||
throttler: LogThrottler,
|
||||
/// Error count (for tracking)
|
||||
error_count: AtomicU32,
|
||||
/// Last error code (for change detection)
|
||||
last_error_code: RwLock<Option<String>>,
|
||||
}
|
||||
|
||||
impl MsdHealthMonitor {
|
||||
/// Create a new MSD health monitor with the specified configuration
|
||||
pub fn new(config: MsdMonitorConfig) -> Self {
|
||||
let throttle_secs = config.log_throttle_secs;
|
||||
pub fn with_defaults() -> Self {
|
||||
Self {
|
||||
status: RwLock::new(MsdHealthStatus::Healthy),
|
||||
throttler: LogThrottler::with_secs(throttle_secs),
|
||||
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
|
||||
error_count: AtomicU32::new(0),
|
||||
last_error_code: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new MSD health monitor with default configuration
|
||||
pub fn with_defaults() -> Self {
|
||||
Self::new(MsdMonitorConfig::default())
|
||||
}
|
||||
|
||||
/// Report an error from MSD operations
|
||||
///
|
||||
/// This method is called when an MSD operation fails. It:
|
||||
/// 1. Updates the health status
|
||||
/// 2. Logs the error (with throttling)
|
||||
/// 3. Updates in-memory error state
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `reason` - Human-readable error description
|
||||
/// * `error_code` - Error code for programmatic handling
|
||||
pub async fn report_error(&self, reason: &str, error_code: &str) {
|
||||
let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if error code changed
|
||||
let error_changed = {
|
||||
let last = self.last_error_code.read().await;
|
||||
last.as_ref().map(|s| s.as_str()) != Some(error_code)
|
||||
};
|
||||
|
||||
// Log with throttling (always log if error type changed)
|
||||
let throttle_key = format!("msd_{}", error_code);
|
||||
if error_changed || self.throttler.should_log(&throttle_key) {
|
||||
warn!(
|
||||
@@ -102,29 +49,21 @@ impl MsdHealthMonitor {
|
||||
);
|
||||
}
|
||||
|
||||
// Update last error code
|
||||
*self.last_error_code.write().await = Some(error_code.to_string());
|
||||
|
||||
// Update status
|
||||
*self.status.write().await = MsdHealthStatus::Error {
|
||||
reason: reason.to_string(),
|
||||
error_code: error_code.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
/// Report that the MSD has recovered from error
|
||||
///
|
||||
/// This method is called when an MSD operation succeeds after errors.
|
||||
/// It resets the error state.
|
||||
pub async fn report_recovered(&self) {
|
||||
let prev_status = self.status.read().await.clone();
|
||||
|
||||
// Only report recovery if we were in an error state
|
||||
if prev_status != MsdHealthStatus::Healthy {
|
||||
let error_count = self.error_count.load(Ordering::Relaxed);
|
||||
info!("MSD recovered after {} errors", error_count);
|
||||
|
||||
// Reset state
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
self.throttler.clear_all();
|
||||
*self.last_error_code.write().await = None;
|
||||
@@ -132,29 +71,25 @@ impl MsdHealthMonitor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current health status
|
||||
pub async fn status(&self) -> MsdHealthStatus {
|
||||
#[cfg(test)]
|
||||
pub(crate) async fn status(&self) -> MsdHealthStatus {
|
||||
self.status.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get the current error count
|
||||
pub fn error_count(&self) -> u32 {
|
||||
#[cfg(test)]
|
||||
pub(crate) fn error_count(&self) -> u32 {
|
||||
self.error_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check if the monitor is in an error state
|
||||
pub async fn is_error(&self) -> bool {
|
||||
matches!(*self.status.read().await, MsdHealthStatus::Error { .. })
|
||||
}
|
||||
|
||||
/// Check if the monitor is healthy
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
#[cfg(test)]
|
||||
pub(crate) async fn is_healthy(&self) -> bool {
|
||||
matches!(*self.status.read().await, MsdHealthStatus::Healthy)
|
||||
}
|
||||
|
||||
/// Reset the monitor to healthy state without publishing events
|
||||
///
|
||||
/// This is useful during initialization.
|
||||
pub async fn reset(&self) {
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
*self.last_error_code.write().await = None;
|
||||
@@ -162,7 +97,6 @@ impl MsdHealthMonitor {
|
||||
self.throttler.clear_all();
|
||||
}
|
||||
|
||||
/// Get the current error message if in error state
|
||||
pub async fn error_message(&self) -> Option<String> {
|
||||
match &*self.status.read().await {
|
||||
MsdHealthStatus::Error { reason, .. } => Some(reason.clone()),
|
||||
@@ -212,13 +146,11 @@ mod tests {
|
||||
async fn test_report_recovered() {
|
||||
let monitor = MsdHealthMonitor::with_defaults();
|
||||
|
||||
// First report an error
|
||||
monitor
|
||||
.report_error("Image not found", "image_not_found")
|
||||
.await;
|
||||
assert!(monitor.is_error().await);
|
||||
|
||||
// Then report recovery
|
||||
monitor.report_recovered().await;
|
||||
assert!(monitor.is_healthy().await);
|
||||
assert_eq!(monitor.error_count(), 0);
|
||||
|
||||
@@ -1,42 +1,28 @@
|
||||
//! MSD data types and structures
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
/// MSD operating mode
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[derive(Default)]
|
||||
pub enum MsdMode {
|
||||
/// No storage connected
|
||||
#[default]
|
||||
None,
|
||||
/// Image file mounted (ISO/IMG)
|
||||
Image,
|
||||
/// Virtual drive (FAT32) connected
|
||||
Drive,
|
||||
}
|
||||
|
||||
/// Image file metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageInfo {
|
||||
/// Unique image ID
|
||||
pub id: String,
|
||||
/// Display name
|
||||
pub name: String,
|
||||
/// File path on disk
|
||||
#[serde(skip_serializing)]
|
||||
pub path: PathBuf,
|
||||
/// File size in bytes
|
||||
pub size: u64,
|
||||
/// Creation timestamp
|
||||
#[serde(with = "time::serde::rfc3339")]
|
||||
pub created_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
impl ImageInfo {
|
||||
/// Create new image info
|
||||
pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self {
|
||||
Self {
|
||||
id,
|
||||
@@ -47,7 +33,6 @@ impl ImageInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// Format size for display
|
||||
pub fn size_display(&self) -> String {
|
||||
const KB: u64 = 1024;
|
||||
const MB: u64 = KB * 1024;
|
||||
@@ -65,18 +50,12 @@ impl ImageInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD state information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MsdState {
|
||||
/// Whether MSD feature is available
|
||||
pub available: bool,
|
||||
/// Current mode
|
||||
pub mode: MsdMode,
|
||||
/// Whether storage is connected to target
|
||||
pub connected: bool,
|
||||
/// Currently mounted image (if mode is Image)
|
||||
pub current_image: Option<ImageInfo>,
|
||||
/// Virtual drive info (if mode is Drive)
|
||||
pub drive_info: Option<DriveInfo>,
|
||||
}
|
||||
|
||||
@@ -92,24 +71,17 @@ impl Default for MsdState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Virtual drive information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DriveInfo {
|
||||
/// Drive size in bytes
|
||||
pub size: u64,
|
||||
/// Used space in bytes
|
||||
pub used: u64,
|
||||
/// Free space in bytes
|
||||
pub free: u64,
|
||||
/// Whether drive is initialized
|
||||
pub initialized: bool,
|
||||
/// Drive file path
|
||||
#[serde(skip_serializing)]
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
impl DriveInfo {
|
||||
/// Create new drive info
|
||||
pub fn new(path: PathBuf, size: u64) -> Self {
|
||||
Self {
|
||||
size,
|
||||
@@ -121,92 +93,60 @@ impl DriveInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// File entry in virtual drive
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DriveFile {
|
||||
/// File name
|
||||
pub name: String,
|
||||
/// Relative path from drive root
|
||||
pub path: String,
|
||||
/// File size in bytes (0 for directories)
|
||||
pub size: u64,
|
||||
/// Whether this is a directory
|
||||
pub is_dir: bool,
|
||||
/// Last modified timestamp
|
||||
#[serde(with = "time::serde::rfc3339::option")]
|
||||
pub modified: Option<OffsetDateTime>,
|
||||
}
|
||||
|
||||
/// MSD connect request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MsdConnectRequest {
|
||||
/// Connection mode: "image" or "drive"
|
||||
pub mode: MsdMode,
|
||||
/// Image ID to mount (required for image mode)
|
||||
pub image_id: Option<String>,
|
||||
/// Mount as CD-ROM (optional, defaults based on image type)
|
||||
#[serde(default)]
|
||||
pub cdrom: Option<bool>,
|
||||
/// Mount as read-only
|
||||
#[serde(default)]
|
||||
pub read_only: Option<bool>,
|
||||
}
|
||||
|
||||
/// Virtual drive init request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DriveInitRequest {
|
||||
/// Drive size in megabytes (defaults to 16GB)
|
||||
#[serde(default = "default_drive_size")]
|
||||
pub size_mb: u32,
|
||||
/// Optional custom path for Ventoy installation
|
||||
pub ventoy_path: Option<String>,
|
||||
}
|
||||
|
||||
fn default_drive_size() -> u32 {
|
||||
16 * 1024 // 16GB
|
||||
16 * 1024
|
||||
}
|
||||
|
||||
/// Image download request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ImageDownloadRequest {
|
||||
/// URL to download from
|
||||
pub url: String,
|
||||
/// Optional custom filename
|
||||
pub filename: Option<String>,
|
||||
}
|
||||
|
||||
/// Download status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DownloadStatus {
|
||||
/// Download has started
|
||||
Started,
|
||||
/// Download is in progress
|
||||
InProgress,
|
||||
/// Download completed successfully
|
||||
Completed,
|
||||
/// Download failed
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Download progress information
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct DownloadProgress {
|
||||
/// Unique download ID
|
||||
pub download_id: String,
|
||||
/// Source URL
|
||||
pub url: String,
|
||||
/// Target filename
|
||||
pub filename: String,
|
||||
/// Bytes downloaded so far
|
||||
pub bytes_downloaded: u64,
|
||||
/// Total file size (None if unknown)
|
||||
pub total_bytes: Option<u64>,
|
||||
/// Progress percentage (0.0 - 100.0, None if total unknown)
|
||||
pub progress_pct: Option<f32>,
|
||||
/// Download status
|
||||
pub status: DownloadStatus,
|
||||
/// Error message if failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
@@ -220,7 +160,7 @@ mod tests {
|
||||
"test".into(),
|
||||
"test.iso".into(),
|
||||
PathBuf::from("/tmp/test.iso"),
|
||||
1024 * 1024 * 1024 * 2, // 2 GB
|
||||
1024 * 1024 * 1024 * 2,
|
||||
);
|
||||
assert!(info.size_display().contains("GB"));
|
||||
}
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
//! Ventoy Virtual Drive
|
||||
//!
|
||||
//! Replaces FAT32 VirtualDrive with a Ventoy bootable image.
|
||||
//! Provides a bootable USB with exFAT data partition for ISO files.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
@@ -13,33 +8,20 @@ use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage};
|
||||
use super::types::{DriveFile, DriveInfo};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Chunk size for streaming reads (64 KB)
|
||||
const STREAM_CHUNK_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// Minimum drive size (1 GB) - Ventoy requires space for boot partition
|
||||
const MIN_DRIVE_SIZE_MB: u32 = 1024;
|
||||
|
||||
/// Maximum drive size (128 GB)
|
||||
const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024;
|
||||
|
||||
/// Default drive label
|
||||
const DEFAULT_LABEL: &str = "ONE-KVM";
|
||||
|
||||
/// Ventoy Drive Manager
|
||||
///
|
||||
/// Thread-safe wrapper around VentoyImage providing async file operations.
|
||||
/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous.
|
||||
/// Uses RwLock to allow concurrent read operations while serializing writes.
|
||||
pub struct VentoyDrive {
|
||||
/// Drive image path
|
||||
path: PathBuf,
|
||||
/// RwLock for concurrent reads, exclusive writes
|
||||
/// (ventoy-img-rs operations are synchronous and not thread-safe)
|
||||
lock: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl VentoyDrive {
|
||||
/// Create new Ventoy drive manager
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self {
|
||||
path,
|
||||
@@ -47,40 +29,32 @@ impl VentoyDrive {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if drive image exists
|
||||
pub fn exists(&self) -> bool {
|
||||
self.path.exists()
|
||||
}
|
||||
|
||||
/// Get drive path
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Initialize a new Ventoy drive image
|
||||
///
|
||||
/// Creates a bootable Ventoy image with the specified size.
|
||||
/// The image includes boot partitions and an exFAT data partition.
|
||||
pub async fn init(&self, size_mb: u32) -> Result<DriveInfo> {
|
||||
let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB);
|
||||
let size_str = format!("{}M", size_mb);
|
||||
let path = self.path.clone();
|
||||
let _lock = self.lock.write().await; // Write lock for initialization
|
||||
let _lock = self.lock.write().await;
|
||||
|
||||
info!("Creating {} MB Ventoy drive at {}", size_mb, path.display());
|
||||
|
||||
// Run Ventoy creation in blocking task
|
||||
let info = tokio::task::spawn_blocking(move || {
|
||||
VentoyImage::create(&path, &size_str, DEFAULT_LABEL).map_err(ventoy_to_app_error)?;
|
||||
|
||||
// Get file metadata for DriveInfo
|
||||
let metadata = std::fs::metadata(&path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
|
||||
|
||||
Ok::<DriveInfo, AppError>(DriveInfo {
|
||||
size: metadata.len(),
|
||||
used: 0,
|
||||
free: metadata.len(), // Approximate - exFAT overhead not calculated
|
||||
free: metadata.len(),
|
||||
initialized: true,
|
||||
path,
|
||||
})
|
||||
@@ -92,20 +66,18 @@ impl VentoyDrive {
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
/// Get drive information
|
||||
pub async fn info(&self) -> Result<DriveInfo> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
let path = self.path.clone();
|
||||
let _lock = self.lock.read().await; // Read lock for info query
|
||||
let _lock = self.lock.read().await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let metadata = std::fs::metadata(&path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
|
||||
|
||||
// Open image to get file list and calculate used space
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
let files = image.list_files_recursive().map_err(ventoy_to_app_error)?;
|
||||
@@ -116,7 +88,6 @@ impl VentoyDrive {
|
||||
.map(|f| f.size)
|
||||
.sum();
|
||||
|
||||
// Note: This is approximate since we don't have exact exFAT overhead
|
||||
let size = metadata.len();
|
||||
let free = size.saturating_sub(used);
|
||||
|
||||
@@ -132,7 +103,6 @@ impl VentoyDrive {
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// List files at a given path (or root if empty/"/")
|
||||
pub async fn list_files(&self, dir_path: &str) -> Result<Vec<DriveFile>> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
@@ -140,7 +110,7 @@ impl VentoyDrive {
|
||||
|
||||
let path = self.path.clone();
|
||||
let dir_path = dir_path.to_string();
|
||||
let _lock = self.lock.read().await; // Read lock for listing
|
||||
let _lock = self.lock.read().await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
@@ -161,9 +131,6 @@ impl VentoyDrive {
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Write a file to the drive from multipart upload (streaming)
|
||||
///
|
||||
/// Streams the file directly into the Ventoy image's exFAT partition.
|
||||
pub async fn write_file_from_multipart_field(
|
||||
&self,
|
||||
file_path: &str,
|
||||
@@ -173,12 +140,10 @@ impl VentoyDrive {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
// First, stream to a temporary file (to get the size)
|
||||
let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp"));
|
||||
let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4());
|
||||
let temp_path = temp_dir.join(&temp_name);
|
||||
|
||||
// Stream upload to temp file
|
||||
let mut temp_file = tokio::fs::File::create(&temp_path)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
|
||||
@@ -201,23 +166,16 @@ impl VentoyDrive {
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?;
|
||||
drop(temp_file);
|
||||
|
||||
// Now copy from temp file to Ventoy image
|
||||
let path = self.path.clone();
|
||||
let file_path = file_path.to_string();
|
||||
let temp_path_clone = temp_path.clone();
|
||||
let _lock = self.lock.write().await; // Write lock for file write
|
||||
let _lock = self.lock.write().await;
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
// Use add_file_to_path which handles streaming internally
|
||||
image
|
||||
.add_file_to_path(
|
||||
&temp_path_clone,
|
||||
&file_path,
|
||||
true, // create_parents
|
||||
true, // overwrite
|
||||
)
|
||||
.add_file_to_path(&temp_path_clone, &file_path, true, true)
|
||||
.map_err(ventoy_to_app_error)?;
|
||||
|
||||
Ok::<(), AppError>(())
|
||||
@@ -225,14 +183,13 @@ impl VentoyDrive {
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?;
|
||||
|
||||
// Cleanup temp file
|
||||
let _ = tokio::fs::remove_file(&temp_path).await;
|
||||
|
||||
result?;
|
||||
Ok(bytes_written)
|
||||
}
|
||||
|
||||
/// Read a file from the drive (for download)
|
||||
#[cfg(test)]
|
||||
pub async fn read_file(&self, file_path: &str) -> Result<Vec<u8>> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
@@ -240,7 +197,7 @@ impl VentoyDrive {
|
||||
|
||||
let path = self.path.clone();
|
||||
let file_path = file_path.to_string();
|
||||
let _lock = self.lock.read().await; // Read lock for file read
|
||||
let _lock = self.lock.read().await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
@@ -251,10 +208,6 @@ impl VentoyDrive {
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Get file information without reading content
|
||||
///
|
||||
/// Returns file size, name, and other metadata.
|
||||
/// Returns None if the file doesn't exist.
|
||||
pub async fn get_file_info(&self, file_path: &str) -> Result<Option<DriveFile>> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
@@ -262,7 +215,7 @@ impl VentoyDrive {
|
||||
|
||||
let path = self.path.clone();
|
||||
let file_path_owned = file_path.to_string();
|
||||
let _lock = self.lock.read().await; // Read lock for file info
|
||||
let _lock = self.lock.read().await;
|
||||
|
||||
let info = tokio::task::spawn_blocking(move || {
|
||||
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
@@ -282,10 +235,6 @@ impl VentoyDrive {
|
||||
}))
|
||||
}
|
||||
|
||||
/// Read a file from the drive as a stream (for large file downloads)
|
||||
///
|
||||
/// Returns an async channel receiver that yields chunks of file data.
|
||||
/// This avoids loading the entire file into memory.
|
||||
pub async fn read_file_stream(
|
||||
&self,
|
||||
file_path: &str,
|
||||
@@ -297,7 +246,6 @@ impl VentoyDrive {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
}
|
||||
|
||||
// First, get the file size
|
||||
let file_info = self
|
||||
.get_file_info(file_path)
|
||||
.await?
|
||||
@@ -315,15 +263,12 @@ impl VentoyDrive {
|
||||
let file_path_owned = file_path.to_string();
|
||||
let lock = self.lock.clone();
|
||||
|
||||
// Create a channel for streaming data
|
||||
let (tx, rx) =
|
||||
tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, std::io::Error>>(8);
|
||||
|
||||
// Spawn blocking task to read and send chunks
|
||||
tokio::task::spawn_blocking(move || {
|
||||
// Hold read lock for the entire read operation
|
||||
let rt = tokio::runtime::Handle::current();
|
||||
let _lock = rt.block_on(lock.read()); // Read lock for streaming
|
||||
let _lock = rt.block_on(lock.read());
|
||||
|
||||
let image = match VentoyImage::open(&path) {
|
||||
Ok(img) => img,
|
||||
@@ -333,10 +278,8 @@ impl VentoyDrive {
|
||||
}
|
||||
};
|
||||
|
||||
// Create a channel writer that sends chunks
|
||||
let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone());
|
||||
|
||||
// Stream the file through the writer
|
||||
if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) {
|
||||
let _ = rt.block_on(tx.send(Err(std::io::Error::other(e.to_string()))));
|
||||
}
|
||||
@@ -345,7 +288,6 @@ impl VentoyDrive {
|
||||
Ok((file_size, rx))
|
||||
}
|
||||
|
||||
/// Create a directory
|
||||
pub async fn mkdir(&self, dir_path: &str) -> Result<()> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
@@ -353,7 +295,7 @@ impl VentoyDrive {
|
||||
|
||||
let path = self.path.clone();
|
||||
let dir_path = dir_path.to_string();
|
||||
let _lock = self.lock.write().await; // Write lock for mkdir
|
||||
let _lock = self.lock.write().await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
@@ -366,7 +308,6 @@ impl VentoyDrive {
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Delete a file or directory
|
||||
pub async fn delete(&self, path_to_delete: &str) -> Result<()> {
|
||||
if !self.exists() {
|
||||
return Err(AppError::Internal("Drive not initialized".to_string()));
|
||||
@@ -374,12 +315,11 @@ impl VentoyDrive {
|
||||
|
||||
let path = self.path.clone();
|
||||
let path_to_delete = path_to_delete.to_string();
|
||||
let _lock = self.lock.write().await; // Write lock for delete
|
||||
let _lock = self.lock.write().await;
|
||||
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
|
||||
|
||||
// Use recursive delete to handle directories
|
||||
image
|
||||
.remove_recursive(&path_to_delete)
|
||||
.map_err(ventoy_to_app_error)
|
||||
@@ -389,7 +329,6 @@ impl VentoyDrive {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert VentoyError to AppError
|
||||
fn ventoy_to_app_error(err: VentoyError) -> AppError {
|
||||
match err {
|
||||
VentoyError::Io(e) => AppError::Io(e),
|
||||
@@ -405,7 +344,6 @@ fn ventoy_to_app_error(err: VentoyError) -> AppError {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert VentoyFileInfo to DriveFile
|
||||
fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile {
|
||||
let full_path = if parent_path.is_empty() || parent_path == "/" {
|
||||
format!("/{}", info.name)
|
||||
@@ -418,13 +356,10 @@ fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFi
|
||||
path: full_path,
|
||||
size: info.size,
|
||||
is_dir: info.is_directory,
|
||||
modified: None, // Ventoy FileInfo doesn't include timestamps
|
||||
modified: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// A writer that sends chunks to an async channel
|
||||
///
|
||||
/// This bridges the sync Write trait with async channels for streaming.
|
||||
struct ChannelWriter {
|
||||
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
|
||||
rt: tokio::runtime::Handle,
|
||||
@@ -484,7 +419,6 @@ impl std::io::Write for ChannelWriter {
|
||||
|
||||
impl Drop for ChannelWriter {
|
||||
fn drop(&mut self) {
|
||||
// Flush any remaining data when the writer is dropped
|
||||
let _ = self.flush_buffer();
|
||||
}
|
||||
}
|
||||
@@ -496,16 +430,13 @@ mod tests {
|
||||
use std::sync::OnceLock;
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Path to ventoy resources directory
|
||||
static RESOURCE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../ventoy-img-rs/resources");
|
||||
|
||||
/// Initialize ventoy resources once
|
||||
fn init_ventoy_resources() -> bool {
|
||||
static INIT: OnceLock<bool> = OnceLock::new();
|
||||
*INIT.get_or_init(|| {
|
||||
let resource_path = std::path::Path::new(RESOURCE_DIR);
|
||||
|
||||
// Decompress xz files if needed
|
||||
let core_xz = resource_path.join("core.img.xz");
|
||||
let core_img = resource_path.join("core.img");
|
||||
if core_xz.exists() && !core_img.exists() {
|
||||
@@ -524,7 +455,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize resources
|
||||
if let Err(e) = ventoy_img::resources::init_resources(resource_path) {
|
||||
eprintln!("Failed to init ventoy resources: {}", e);
|
||||
return false;
|
||||
@@ -534,7 +464,6 @@ mod tests {
|
||||
})
|
||||
}
|
||||
|
||||
/// Decompress xz file using system command
|
||||
fn decompress_xz(src: &std::path::Path, dst: &std::path::Path) -> std::io::Result<()> {
|
||||
let output = Command::new("xz")
|
||||
.args(["-d", "-k", "-c", src.to_str().unwrap()])
|
||||
@@ -551,7 +480,6 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure resources are initialized, skip test if failed
|
||||
fn ensure_resources() -> bool {
|
||||
if !init_ventoy_resources() {
|
||||
eprintln!("Skipping test: ventoy resources not available");
|
||||
@@ -602,15 +530,12 @@ mod tests {
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Write a test file
|
||||
let test_content = b"Hello, Ventoy!";
|
||||
let test_file_path = temp_dir.path().join("test.txt");
|
||||
std::fs::write(&test_file_path, test_content).unwrap();
|
||||
|
||||
// Add file to drive using ventoy-img directly
|
||||
let path = drive.path().clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
@@ -619,7 +544,6 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Read file from drive
|
||||
let read_data = drive.read_file("/test.txt").await.unwrap();
|
||||
assert_eq!(read_data, test_content);
|
||||
}
|
||||
@@ -633,18 +557,14 @@ mod tests {
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Create a directory
|
||||
drive.mkdir("/mydir").await.unwrap();
|
||||
|
||||
// Write a test file
|
||||
let test_content = b"Test file content for info check";
|
||||
let test_file_path = temp_dir.path().join("info_test.txt");
|
||||
std::fs::write(&test_file_path, test_content).unwrap();
|
||||
|
||||
// Add file to drive
|
||||
let path = drive.path().clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
@@ -653,7 +573,6 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Test get_file_info for file
|
||||
let file_info = drive.get_file_info("/info_test.txt").await.unwrap();
|
||||
assert!(file_info.is_some());
|
||||
let file_info = file_info.unwrap();
|
||||
@@ -661,14 +580,12 @@ mod tests {
|
||||
assert_eq!(file_info.size, test_content.len() as u64);
|
||||
assert!(!file_info.is_dir);
|
||||
|
||||
// Test get_file_info for directory
|
||||
let dir_info = drive.get_file_info("/mydir").await.unwrap();
|
||||
assert!(dir_info.is_some());
|
||||
let dir_info = dir_info.unwrap();
|
||||
assert_eq!(dir_info.name, "mydir");
|
||||
assert!(dir_info.is_dir);
|
||||
|
||||
// Test get_file_info for non-existent file
|
||||
let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap();
|
||||
assert!(not_found.is_none());
|
||||
}
|
||||
@@ -682,16 +599,13 @@ mod tests {
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Create test data that spans multiple chunks (>64KB)
|
||||
let test_size = 200 * 1024; // 200 KB
|
||||
let test_size = 200 * 1024;
|
||||
let test_content: Vec<u8> = (0..test_size).map(|i| (i % 256) as u8).collect();
|
||||
let test_file_path = temp_dir.path().join("large_file.bin");
|
||||
std::fs::write(&test_file_path, &test_content).unwrap();
|
||||
|
||||
// Add file to drive
|
||||
let path = drive.path().clone();
|
||||
let file_path_clone = test_file_path.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
@@ -701,18 +615,15 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Stream read the file
|
||||
let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap();
|
||||
assert_eq!(file_size, test_size as u64);
|
||||
|
||||
// Collect all chunks
|
||||
let mut received_data = Vec::new();
|
||||
while let Some(chunk_result) = rx.recv().await {
|
||||
let chunk = chunk_result.expect("Chunk should not be an error");
|
||||
received_data.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
// Verify data matches
|
||||
assert_eq!(received_data.len(), test_content.len());
|
||||
assert_eq!(received_data, test_content);
|
||||
}
|
||||
@@ -726,15 +637,12 @@ mod tests {
|
||||
let drive_path = temp_dir.path().join("test_ventoy.img");
|
||||
let drive = VentoyDrive::new(drive_path.clone());
|
||||
|
||||
// Initialize drive
|
||||
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
|
||||
|
||||
// Create a small test file
|
||||
let test_content = b"Small file for streaming test";
|
||||
let test_file_path = temp_dir.path().join("small.txt");
|
||||
std::fs::write(&test_file_path, test_content).unwrap();
|
||||
|
||||
// Add file to drive
|
||||
let path = drive.path().clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut image = VentoyImage::open(&path).unwrap();
|
||||
@@ -743,18 +651,15 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Stream read the file
|
||||
let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap();
|
||||
assert_eq!(file_size, test_content.len() as u64);
|
||||
|
||||
// Collect all chunks
|
||||
let mut received_data = Vec::new();
|
||||
while let Some(chunk_result) = rx.recv().await {
|
||||
let chunk = chunk_result.expect("Chunk should not be an error");
|
||||
received_data.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
// Verify data matches
|
||||
assert_eq!(received_data.as_slice(), test_content);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! ConfigFS file operations for USB Gadget
|
||||
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
@@ -7,34 +5,18 @@ use std::process::Command;
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// ConfigFS base path for USB gadgets
|
||||
pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget";
|
||||
|
||||
/// Default gadget name
|
||||
pub const DEFAULT_GADGET_NAME: &str = "one-kvm";
|
||||
|
||||
/// USB Vendor ID (Linux Foundation) - default value
|
||||
pub const DEFAULT_USB_VENDOR_ID: u16 = 0x1d6b;
|
||||
|
||||
/// USB Product ID (Multifunction Composite Gadget) - default value
|
||||
pub const DEFAULT_USB_PRODUCT_ID: u16 = 0x0104;
|
||||
|
||||
/// USB device version - default value
|
||||
pub const DEFAULT_USB_BCD_DEVICE: u16 = 0x0100;
|
||||
|
||||
/// USB spec version (USB 2.0)
|
||||
pub const USB_BCD_USB: u16 = 0x0200;
|
||||
|
||||
/// Check if ConfigFS is available
|
||||
pub fn is_configfs_available() -> bool {
|
||||
Path::new(CONFIGFS_PATH).exists()
|
||||
}
|
||||
|
||||
/// Ensure libcomposite support is available for USB gadget operations.
|
||||
///
|
||||
/// This is a best-effort runtime fallback for systems where `libcomposite`
|
||||
/// is built as a module and not loaded yet. It does not try to mount configfs;
|
||||
/// mounting remains an explicit system responsibility.
|
||||
/// Loads `libcomposite` if needed; does not mount configfs.
|
||||
pub fn ensure_libcomposite_loaded() -> Result<()> {
|
||||
if is_configfs_available() {
|
||||
return Ok(());
|
||||
@@ -66,7 +48,6 @@ pub fn ensure_libcomposite_loaded() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Find available UDC (USB Device Controller)
|
||||
pub fn find_udc() -> Option<String> {
|
||||
let udc_path = Path::new("/sys/class/udc");
|
||||
if !udc_path.exists() {
|
||||
@@ -80,40 +61,17 @@ pub fn find_udc() -> Option<String> {
|
||||
.next()
|
||||
}
|
||||
|
||||
/// Check if UDC is known to have low endpoint resources
|
||||
pub fn is_low_endpoint_udc(name: &str) -> bool {
|
||||
let name = name.to_ascii_lowercase();
|
||||
name.contains("musb") || name.contains("musb-hdrc")
|
||||
}
|
||||
|
||||
/// Resolve preferred UDC name if available, otherwise auto-detect
|
||||
pub fn resolve_udc_name(preferred: Option<&str>) -> Option<String> {
|
||||
if let Some(name) = preferred {
|
||||
let path = Path::new("/sys/class/udc").join(name);
|
||||
if path.exists() {
|
||||
return Some(name.to_string());
|
||||
}
|
||||
}
|
||||
find_udc()
|
||||
}
|
||||
|
||||
/// Write string content to a file
|
||||
///
|
||||
/// For sysfs files, this function appends a newline and flushes
|
||||
/// to ensure the kernel processes the write immediately.
|
||||
///
|
||||
/// IMPORTANT: sysfs attributes require a single atomic write() syscall.
|
||||
/// The kernel processes the value on the first write(), so we must
|
||||
/// build the complete buffer (including newline) before writing.
|
||||
/// Sysfs/configfs: one write syscall with final buffer (incl. newline when needed).
|
||||
pub fn write_file(path: &Path, content: &str) -> Result<()> {
|
||||
// For sysfs files (especially write-only ones like forced_eject),
|
||||
// we need to use simple O_WRONLY without O_TRUNC
|
||||
// O_TRUNC may fail on special files or require read permission
|
||||
let mut file = OpenOptions::new()
|
||||
.write(true)
|
||||
.open(path)
|
||||
.or_else(|e| {
|
||||
// If open fails, try create (for regular files)
|
||||
if path.exists() {
|
||||
Err(e)
|
||||
} else {
|
||||
@@ -122,9 +80,6 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
|
||||
})
|
||||
.map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?;
|
||||
|
||||
// Build complete buffer with newline, then write in single syscall.
|
||||
// This is critical for sysfs - multiple write() calls may cause
|
||||
// the kernel to only process partial data or return EINVAL.
|
||||
let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') {
|
||||
content.as_bytes().into()
|
||||
} else {
|
||||
@@ -136,14 +91,12 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
|
||||
file.write_all(&data)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
|
||||
|
||||
// Explicitly flush to ensure sysfs processes the write
|
||||
file.flush()
|
||||
.map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write binary content to a file
|
||||
pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
|
||||
let mut file = File::create(path)
|
||||
.map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?;
|
||||
@@ -154,14 +107,6 @@ pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read string content from a file
|
||||
pub fn read_file(path: &Path) -> Result<String> {
|
||||
fs::read_to_string(path)
|
||||
.map(|s| s.trim().to_string())
|
||||
.map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
/// Create directory if not exists
|
||||
pub fn create_dir(path: &Path) -> Result<()> {
|
||||
fs::create_dir_all(path).map_err(|e| {
|
||||
AppError::Internal(format!(
|
||||
@@ -172,7 +117,6 @@ pub fn create_dir(path: &Path) -> Result<()> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Remove directory
|
||||
pub fn remove_dir(path: &Path) -> Result<()> {
|
||||
if path.exists() {
|
||||
fs::remove_dir(path).map_err(|e| {
|
||||
@@ -186,7 +130,6 @@ pub fn remove_dir(path: &Path) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove file
|
||||
pub fn remove_file(path: &Path) -> Result<()> {
|
||||
if path.exists() {
|
||||
fs::remove_file(path).map_err(|e| {
|
||||
@@ -196,7 +139,6 @@ pub fn remove_file(path: &Path) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create symlink
|
||||
pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> {
|
||||
std::os::unix::fs::symlink(src, dest).map_err(|e| {
|
||||
AppError::Internal(format!(
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
//! USB Endpoint allocation management
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// Default maximum endpoints for typical UDC
|
||||
pub const DEFAULT_MAX_ENDPOINTS: u8 = 16;
|
||||
|
||||
/// Endpoint allocator - manages UDC endpoint resources
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EndpointAllocator {
|
||||
max_endpoints: u8,
|
||||
@@ -13,7 +9,6 @@ pub struct EndpointAllocator {
|
||||
}
|
||||
|
||||
impl EndpointAllocator {
|
||||
/// Create a new endpoint allocator
|
||||
pub fn new(max_endpoints: u8) -> Self {
|
||||
Self {
|
||||
max_endpoints,
|
||||
@@ -21,7 +16,6 @@ impl EndpointAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate endpoints for a function
|
||||
pub fn allocate(&mut self, count: u8) -> Result<()> {
|
||||
if self.used_endpoints + count > self.max_endpoints {
|
||||
return Err(AppError::Internal(format!(
|
||||
@@ -34,27 +28,22 @@ impl EndpointAllocator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Release endpoints
|
||||
pub fn release(&mut self, count: u8) {
|
||||
self.used_endpoints = self.used_endpoints.saturating_sub(count);
|
||||
}
|
||||
|
||||
/// Get available endpoint count
|
||||
pub fn available(&self) -> u8 {
|
||||
self.max_endpoints.saturating_sub(self.used_endpoints)
|
||||
}
|
||||
|
||||
/// Get used endpoint count
|
||||
pub fn used(&self) -> u8 {
|
||||
self.used_endpoints
|
||||
}
|
||||
|
||||
/// Get maximum endpoint count
|
||||
pub fn max(&self) -> u8 {
|
||||
self.max_endpoints
|
||||
}
|
||||
|
||||
/// Check if can allocate
|
||||
pub fn can_allocate(&self, count: u8) -> bool {
|
||||
self.available() >= count
|
||||
}
|
||||
@@ -82,7 +71,6 @@ mod tests {
|
||||
alloc.allocate(4).unwrap();
|
||||
assert_eq!(alloc.available(), 2);
|
||||
|
||||
// Should fail - not enough endpoints
|
||||
assert!(alloc.allocate(3).is_err());
|
||||
|
||||
alloc.release(2);
|
||||
|
||||
@@ -1,42 +1,17 @@
|
||||
//! USB Gadget Function trait definition
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// Function metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunctionMeta {
|
||||
/// Function name (e.g., "hid.usb0")
|
||||
pub name: String,
|
||||
/// Human-readable description
|
||||
pub description: String,
|
||||
/// Number of endpoints used
|
||||
pub endpoints: u8,
|
||||
/// Whether the function is enabled
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// USB Gadget Function trait
|
||||
pub trait GadgetFunction: Send + Sync {
|
||||
/// Get function name (e.g., "hid.usb0", "mass_storage.usb0")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get number of endpoints required
|
||||
fn endpoints_required(&self) -> u8;
|
||||
|
||||
/// Get function metadata
|
||||
fn meta(&self) -> FunctionMeta;
|
||||
|
||||
/// Create function directory and configuration in ConfigFS
|
||||
fn create(&self, gadget_path: &Path) -> Result<()>;
|
||||
|
||||
/// Link function to configuration
|
||||
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()>;
|
||||
|
||||
/// Unlink function from configuration
|
||||
fn unlink(&self, config_path: &Path) -> Result<()>;
|
||||
|
||||
/// Cleanup function directory
|
||||
fn cleanup(&self, gadget_path: &Path) -> Result<()>;
|
||||
}
|
||||
|
||||
@@ -1,35 +1,24 @@
|
||||
//! HID Function implementation for USB Gadget
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::debug;
|
||||
|
||||
use super::configfs::{
|
||||
create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file,
|
||||
};
|
||||
use super::function::{FunctionMeta, GadgetFunction};
|
||||
use super::function::GadgetFunction;
|
||||
use super::report_desc::{
|
||||
CONSUMER_CONTROL, KEYBOARD, KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE,
|
||||
};
|
||||
use crate::error::Result;
|
||||
|
||||
/// HID function type
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HidFunctionType {
|
||||
/// Keyboard
|
||||
Keyboard,
|
||||
/// Relative mouse (traditional mouse movement)
|
||||
/// Uses 1 endpoint: IN
|
||||
MouseRelative,
|
||||
/// Absolute mouse (touchscreen-like positioning)
|
||||
/// Uses 1 endpoint: IN
|
||||
MouseAbsolute,
|
||||
/// Consumer control (multimedia keys)
|
||||
/// Uses 1 endpoint: IN
|
||||
ConsumerControl,
|
||||
}
|
||||
|
||||
impl HidFunctionType {
|
||||
/// Get the base endpoint cost for this function type.
|
||||
pub fn endpoints(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 1,
|
||||
@@ -39,27 +28,24 @@ impl HidFunctionType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get HID protocol
|
||||
pub fn protocol(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 1, // Keyboard
|
||||
HidFunctionType::MouseRelative => 2, // Mouse
|
||||
HidFunctionType::MouseAbsolute => 2, // Mouse
|
||||
HidFunctionType::ConsumerControl => 0, // None
|
||||
HidFunctionType::Keyboard => 1,
|
||||
HidFunctionType::MouseRelative => 2,
|
||||
HidFunctionType::MouseAbsolute => 2,
|
||||
HidFunctionType::ConsumerControl => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get HID subclass
|
||||
pub fn subclass(&self) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 1, // Boot interface
|
||||
HidFunctionType::MouseRelative => 1, // Boot interface
|
||||
HidFunctionType::MouseAbsolute => 0, // No boot interface
|
||||
HidFunctionType::ConsumerControl => 0, // No boot interface
|
||||
HidFunctionType::Keyboard => 1,
|
||||
HidFunctionType::MouseRelative => 1,
|
||||
HidFunctionType::MouseAbsolute => 0,
|
||||
HidFunctionType::ConsumerControl => 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get report length in bytes
|
||||
pub fn report_length(&self, _keyboard_leds: bool) -> u8 {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => 8,
|
||||
@@ -69,7 +55,6 @@ impl HidFunctionType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get report descriptor
|
||||
pub fn report_desc(&self, keyboard_leds: bool) -> &'static [u8] {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => {
|
||||
@@ -84,33 +69,17 @@ impl HidFunctionType {
|
||||
HidFunctionType::ConsumerControl => CONSUMER_CONTROL,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get description
|
||||
pub fn description(&self) -> &'static str {
|
||||
match self {
|
||||
HidFunctionType::Keyboard => "Keyboard",
|
||||
HidFunctionType::MouseRelative => "Relative Mouse",
|
||||
HidFunctionType::MouseAbsolute => "Absolute Mouse",
|
||||
HidFunctionType::ConsumerControl => "Consumer Control",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HID Function for USB Gadget
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HidFunction {
|
||||
/// Instance number (usb0, usb1, ...)
|
||||
instance: u8,
|
||||
/// Function type
|
||||
func_type: HidFunctionType,
|
||||
/// Cached function name (avoids repeated allocation)
|
||||
name: String,
|
||||
/// Whether keyboard LED/status feedback is enabled.
|
||||
keyboard_leds: bool,
|
||||
}
|
||||
|
||||
impl HidFunction {
|
||||
/// Create a keyboard function
|
||||
pub fn keyboard(instance: u8, keyboard_leds: bool) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
@@ -120,7 +89,6 @@ impl HidFunction {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a relative mouse function
|
||||
pub fn mouse_relative(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
@@ -130,7 +98,6 @@ impl HidFunction {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an absolute mouse function
|
||||
pub fn mouse_absolute(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
@@ -140,7 +107,6 @@ impl HidFunction {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a consumer control function
|
||||
pub fn consumer_control(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
@@ -150,12 +116,10 @@ impl HidFunction {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get function path in gadget
|
||||
fn function_path(&self, gadget_path: &Path) -> PathBuf {
|
||||
gadget_path.join("functions").join(self.name())
|
||||
}
|
||||
|
||||
/// Get expected device path (e.g., /dev/hidg0)
|
||||
pub fn device_path(&self) -> PathBuf {
|
||||
PathBuf::from(format!("/dev/hidg{}", self.instance))
|
||||
}
|
||||
@@ -170,20 +134,10 @@ impl GadgetFunction for HidFunction {
|
||||
self.func_type.endpoints()
|
||||
}
|
||||
|
||||
fn meta(&self) -> FunctionMeta {
|
||||
FunctionMeta {
|
||||
name: self.name().to_string(),
|
||||
description: self.func_type.description().to_string(),
|
||||
endpoints: self.endpoints_required(),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn create(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
create_dir(&func_path)?;
|
||||
|
||||
// Set HID parameters
|
||||
write_file(
|
||||
&func_path.join("protocol"),
|
||||
&self.func_type.protocol().to_string(),
|
||||
@@ -197,7 +151,6 @@ impl GadgetFunction for HidFunction {
|
||||
&self.func_type.report_length(self.keyboard_leds).to_string(),
|
||||
)?;
|
||||
|
||||
// Write report descriptor
|
||||
write_bytes(
|
||||
&func_path.join("report_desc"),
|
||||
self.func_type.report_desc(self.keyboard_leds),
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
//! OTG Gadget Manager - unified management for USB Gadget functions
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -11,14 +8,13 @@ use super::configfs::{
|
||||
DEFAULT_USB_VENDOR_ID, USB_BCD_USB,
|
||||
};
|
||||
use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS};
|
||||
use super::function::{FunctionMeta, GadgetFunction};
|
||||
use super::function::GadgetFunction;
|
||||
use super::hid::HidFunction;
|
||||
use super::msd::MsdFunction;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
const REBIND_DELAY_MS: u64 = 300;
|
||||
|
||||
/// USB Gadget device descriptor configuration
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct GadgetDescriptor {
|
||||
pub vendor_id: u16,
|
||||
@@ -42,44 +38,28 @@ impl Default for GadgetDescriptor {
|
||||
}
|
||||
}
|
||||
|
||||
/// OTG Gadget Manager - unified management for HID and MSD
|
||||
pub struct OtgGadgetManager {
|
||||
/// Gadget name
|
||||
gadget_name: String,
|
||||
/// Gadget path in ConfigFS
|
||||
gadget_path: PathBuf,
|
||||
/// Configuration path
|
||||
config_path: PathBuf,
|
||||
/// Device descriptor
|
||||
descriptor: GadgetDescriptor,
|
||||
/// Endpoint allocator
|
||||
endpoint_allocator: EndpointAllocator,
|
||||
/// HID instance counter
|
||||
hid_instance: u8,
|
||||
/// MSD instance counter
|
||||
msd_instance: u8,
|
||||
/// Registered functions
|
||||
functions: Vec<Box<dyn GadgetFunction>>,
|
||||
/// Function metadata
|
||||
meta: HashMap<String, FunctionMeta>,
|
||||
/// Bound UDC name
|
||||
bound_udc: Option<String>,
|
||||
/// Whether gadget was created by us
|
||||
created_by_us: bool,
|
||||
}
|
||||
|
||||
impl OtgGadgetManager {
|
||||
/// Create a new gadget manager with default settings
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(DEFAULT_GADGET_NAME, DEFAULT_MAX_ENDPOINTS)
|
||||
}
|
||||
|
||||
/// Create a new gadget manager with custom configuration
|
||||
pub fn with_config(gadget_name: &str, max_endpoints: u8) -> Self {
|
||||
Self::with_descriptor(gadget_name, max_endpoints, GadgetDescriptor::default())
|
||||
}
|
||||
|
||||
/// Create a new gadget manager with custom descriptor
|
||||
pub fn with_descriptor(
|
||||
gadget_name: &str,
|
||||
max_endpoints: u8,
|
||||
@@ -96,30 +76,24 @@ impl OtgGadgetManager {
|
||||
endpoint_allocator: EndpointAllocator::new(max_endpoints),
|
||||
hid_instance: 0,
|
||||
msd_instance: 0,
|
||||
// Pre-allocate for typical use: 3 HID (keyboard, rel mouse, abs mouse) + 1 MSD
|
||||
functions: Vec::with_capacity(4),
|
||||
meta: HashMap::with_capacity(4),
|
||||
bound_udc: None,
|
||||
created_by_us: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if ConfigFS is available
|
||||
pub fn is_available() -> bool {
|
||||
is_configfs_available()
|
||||
}
|
||||
|
||||
/// Find available UDC
|
||||
pub fn find_udc() -> Option<String> {
|
||||
find_udc()
|
||||
}
|
||||
|
||||
/// Check if gadget exists
|
||||
pub fn gadget_exists(&self) -> bool {
|
||||
self.gadget_path.exists()
|
||||
}
|
||||
|
||||
/// Check if gadget is bound to UDC
|
||||
pub fn is_bound(&self) -> bool {
|
||||
let udc_file = self.gadget_path.join("UDC");
|
||||
if let Ok(content) = fs::read_to_string(&udc_file) {
|
||||
@@ -129,8 +103,6 @@ impl OtgGadgetManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add keyboard function
|
||||
/// Returns the expected device path (e.g., /dev/hidg0)
|
||||
pub fn add_keyboard(&mut self, keyboard_leds: bool) -> Result<PathBuf> {
|
||||
let func = HidFunction::keyboard(self.hid_instance, keyboard_leds);
|
||||
let device_path = func.device_path();
|
||||
@@ -139,7 +111,6 @@ impl OtgGadgetManager {
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add relative mouse function
|
||||
pub fn add_mouse_relative(&mut self) -> Result<PathBuf> {
|
||||
let func = HidFunction::mouse_relative(self.hid_instance);
|
||||
let device_path = func.device_path();
|
||||
@@ -148,7 +119,6 @@ impl OtgGadgetManager {
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add absolute mouse function
|
||||
pub fn add_mouse_absolute(&mut self) -> Result<PathBuf> {
|
||||
let func = HidFunction::mouse_absolute(self.hid_instance);
|
||||
let device_path = func.device_path();
|
||||
@@ -157,7 +127,6 @@ impl OtgGadgetManager {
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add consumer control function (multimedia keys)
|
||||
pub fn add_consumer_control(&mut self) -> Result<PathBuf> {
|
||||
let func = HidFunction::consumer_control(self.hid_instance);
|
||||
let device_path = func.device_path();
|
||||
@@ -166,7 +135,6 @@ impl OtgGadgetManager {
|
||||
Ok(device_path)
|
||||
}
|
||||
|
||||
/// Add MSD function (returns MsdFunction handle for LUN configuration)
|
||||
pub fn add_msd(&mut self) -> Result<MsdFunction> {
|
||||
let func = MsdFunction::new(self.msd_instance);
|
||||
let func_clone = func.clone();
|
||||
@@ -175,11 +143,9 @@ impl OtgGadgetManager {
|
||||
Ok(func_clone)
|
||||
}
|
||||
|
||||
/// Add a generic function
|
||||
fn add_function(&mut self, func: Box<dyn GadgetFunction>) -> Result<()> {
|
||||
let endpoints = func.endpoints_required();
|
||||
|
||||
// Check endpoint availability
|
||||
if !self.endpoint_allocator.can_allocate(endpoints) {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Not enough endpoints for function {}: need {}, available {}",
|
||||
@@ -189,30 +155,22 @@ impl OtgGadgetManager {
|
||||
)));
|
||||
}
|
||||
|
||||
// Allocate endpoints
|
||||
self.endpoint_allocator.allocate(endpoints)?;
|
||||
|
||||
// Store metadata
|
||||
self.meta.insert(func.name().to_string(), func.meta());
|
||||
|
||||
// Store function
|
||||
self.functions.push(func);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Setup the gadget (create directories and configure)
|
||||
pub fn setup(&mut self) -> Result<()> {
|
||||
info!("Setting up OTG USB Gadget: {}", self.gadget_name);
|
||||
|
||||
// Check ConfigFS availability
|
||||
if !Self::is_available() {
|
||||
return Err(AppError::Internal(
|
||||
"ConfigFS not available. Is it mounted at /sys/kernel/config?".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check if gadget already exists and is bound
|
||||
if self.gadget_exists() {
|
||||
if self.is_bound() {
|
||||
info!("Gadget already exists and is bound, skipping setup");
|
||||
@@ -222,20 +180,15 @@ impl OtgGadgetManager {
|
||||
self.cleanup()?;
|
||||
}
|
||||
|
||||
// Create gadget directory
|
||||
create_dir(&self.gadget_path)?;
|
||||
self.created_by_us = true;
|
||||
|
||||
// Set device descriptors
|
||||
self.set_device_descriptors()?;
|
||||
|
||||
// Create strings
|
||||
self.create_strings()?;
|
||||
|
||||
// Create configuration
|
||||
self.create_configuration()?;
|
||||
|
||||
// Create and link all functions
|
||||
for func in &self.functions {
|
||||
func.create(&self.gadget_path)?;
|
||||
func.link(&self.config_path, &self.gadget_path)?;
|
||||
@@ -245,9 +198,7 @@ impl OtgGadgetManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Bind gadget to a specific UDC
|
||||
pub fn bind(&mut self, udc: &str) -> Result<()> {
|
||||
// Recreate config symlinks before binding to avoid kernel gadget issues after rebind
|
||||
if let Err(e) = self.recreate_config_links() {
|
||||
warn!("Failed to recreate gadget config links before bind: {}", e);
|
||||
}
|
||||
@@ -260,7 +211,6 @@ impl OtgGadgetManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Unbind gadget from UDC
|
||||
pub fn unbind(&mut self) -> Result<()> {
|
||||
if self.is_bound() {
|
||||
write_file(&self.gadget_path.join("UDC"), "")?;
|
||||
@@ -271,7 +221,6 @@ impl OtgGadgetManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cleanup all resources
|
||||
pub fn cleanup(&mut self) -> Result<()> {
|
||||
if !self.gadget_exists() {
|
||||
return Ok(());
|
||||
@@ -279,29 +228,23 @@ impl OtgGadgetManager {
|
||||
|
||||
info!("Cleaning up OTG USB Gadget: {}", self.gadget_name);
|
||||
|
||||
// Unbind from UDC first
|
||||
let _ = self.unbind();
|
||||
|
||||
// Unlink and cleanup functions
|
||||
for func in self.functions.iter().rev() {
|
||||
let _ = func.unlink(&self.config_path);
|
||||
}
|
||||
|
||||
// Remove config strings
|
||||
let config_strings = self.config_path.join("strings/0x409");
|
||||
let _ = remove_dir(&config_strings);
|
||||
let _ = remove_dir(&self.config_path);
|
||||
|
||||
// Cleanup functions
|
||||
for func in self.functions.iter().rev() {
|
||||
let _ = func.cleanup(&self.gadget_path);
|
||||
}
|
||||
|
||||
// Remove gadget strings
|
||||
let gadget_strings = self.gadget_path.join("strings/0x409");
|
||||
let _ = remove_dir(&gadget_strings);
|
||||
|
||||
// Remove gadget directory
|
||||
if let Err(e) = remove_dir(&self.gadget_path) {
|
||||
warn!("Could not remove gadget directory: {}", e);
|
||||
}
|
||||
@@ -311,7 +254,6 @@ impl OtgGadgetManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set USB device descriptors
|
||||
fn set_device_descriptors(&self) -> Result<()> {
|
||||
write_file(
|
||||
&self.gadget_path.join("idVendor"),
|
||||
@@ -329,14 +271,13 @@ impl OtgGadgetManager {
|
||||
&self.gadget_path.join("bcdUSB"),
|
||||
&format!("0x{:04x}", USB_BCD_USB),
|
||||
)?;
|
||||
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device
|
||||
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?;
|
||||
write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?;
|
||||
write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?;
|
||||
debug!("Set device descriptors");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create USB strings
|
||||
fn create_strings(&self) -> Result<()> {
|
||||
let strings_path = self.gadget_path.join("strings/0x409");
|
||||
create_dir(&strings_path)?;
|
||||
@@ -354,41 +295,23 @@ impl OtgGadgetManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create configuration
|
||||
fn create_configuration(&self) -> Result<()> {
|
||||
create_dir(&self.config_path)?;
|
||||
|
||||
// Create config strings
|
||||
let strings_path = self.config_path.join("strings/0x409");
|
||||
create_dir(&strings_path)?;
|
||||
write_file(&strings_path.join("configuration"), "Config 1: HID + MSD")?;
|
||||
|
||||
// Set max power (500mA)
|
||||
write_file(&self.config_path.join("MaxPower"), "500")?;
|
||||
|
||||
debug!("Created configuration c.1");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get function metadata
|
||||
pub fn get_meta(&self) -> &HashMap<String, FunctionMeta> {
|
||||
&self.meta
|
||||
}
|
||||
|
||||
/// Get endpoint usage info
|
||||
pub fn endpoint_info(&self) -> (u8, u8) {
|
||||
(
|
||||
self.endpoint_allocator.used(),
|
||||
self.endpoint_allocator.max(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Get gadget path
|
||||
pub fn gadget_path(&self) -> &PathBuf {
|
||||
&self.gadget_path
|
||||
}
|
||||
|
||||
/// Recreate config symlinks from functions directory
|
||||
fn recreate_config_links(&self) -> Result<()> {
|
||||
let functions_path = self.gadget_path.join("functions");
|
||||
if !functions_path.exists() || !self.config_path.exists() {
|
||||
@@ -450,15 +373,10 @@ impl Drop for OtgGadgetManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for HID devices to become available
|
||||
///
|
||||
/// Uses exponential backoff starting from 10ms, capped at 100ms,
|
||||
/// to reduce CPU usage while still providing fast response.
|
||||
pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> bool {
|
||||
let start = std::time::Instant::now();
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
|
||||
// Exponential backoff: start at 10ms, double each time, cap at 100ms
|
||||
let mut delay_ms = 10u64;
|
||||
const MAX_DELAY_MS: u64 = 100;
|
||||
|
||||
@@ -467,7 +385,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) ->
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculate remaining time to avoid overshooting timeout
|
||||
let remaining = timeout.saturating_sub(start.elapsed());
|
||||
let sleep_duration = std::time::Duration::from_millis(delay_ms).min(remaining);
|
||||
|
||||
@@ -477,7 +394,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) ->
|
||||
|
||||
tokio::time::sleep(sleep_duration).await;
|
||||
|
||||
// Exponential backoff with cap
|
||||
delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
|
||||
}
|
||||
|
||||
@@ -492,18 +408,16 @@ mod tests {
|
||||
fn test_manager_creation() {
|
||||
let manager = OtgGadgetManager::new();
|
||||
assert_eq!(manager.gadget_name, DEFAULT_GADGET_NAME);
|
||||
assert!(!manager.gadget_exists()); // Won't exist in test environment
|
||||
assert!(!manager.gadget_exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_endpoint_tracking() {
|
||||
let mut manager = OtgGadgetManager::with_config("test", 8);
|
||||
|
||||
// Keyboard uses 1 endpoint
|
||||
let _ = manager.add_keyboard(false);
|
||||
assert_eq!(manager.endpoint_allocator.used(), 1);
|
||||
|
||||
// Mouse uses 1 endpoint each
|
||||
let _ = manager.add_mouse_relative();
|
||||
let _ = manager.add_mouse_absolute();
|
||||
assert_eq!(manager.endpoint_allocator.used(), 3);
|
||||
|
||||
@@ -1,21 +1,4 @@
|
||||
//! OTG USB Gadget unified management module
|
||||
//!
|
||||
//! This module provides unified management for USB Gadget functions:
|
||||
//! - HID (Keyboard, Mouse)
|
||||
//! - MSD (Mass Storage Device)
|
||||
//!
|
||||
//! Architecture:
|
||||
//! ```text
|
||||
//! OtgService (high-level coordination)
|
||||
//! └── OtgGadgetManager (gadget lifecycle)
|
||||
//! ├── EndpointAllocator (manages UDC endpoints)
|
||||
//! ├── HidFunction (keyboard, mouse_rel, mouse_abs)
|
||||
//! └── MsdFunction (mass storage)
|
||||
//! ```
|
||||
//!
|
||||
//! The recommended way to use this module is through `OtgService`, which provides
|
||||
//! a high-level interface for enabling/disabling HID and MSD functions independently.
|
||||
//! Both `HidController` and `MsdController` should share the same `OtgService` instance.
|
||||
//! USB OTG composite gadget (HID + MSD).
|
||||
|
||||
pub mod configfs;
|
||||
pub mod endpoint;
|
||||
@@ -26,10 +9,6 @@ pub mod msd;
|
||||
pub mod report_desc;
|
||||
pub mod service;
|
||||
|
||||
pub use endpoint::EndpointAllocator;
|
||||
pub use function::{FunctionMeta, GadgetFunction};
|
||||
pub use hid::{HidFunction, HidFunctionType};
|
||||
pub use manager::{wait_for_hid_devices, OtgGadgetManager};
|
||||
pub use msd::{MsdFunction, MsdLunConfig};
|
||||
pub use report_desc::{KEYBOARD, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
|
||||
pub use service::{HidDevicePaths, OtgDesiredState, OtgService, OtgServiceState};
|
||||
pub use service::{HidDevicePaths, OtgService};
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
//! MSD (Mass Storage Device) Function implementation for USB Gadget
|
||||
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_file};
|
||||
use super::function::{FunctionMeta, GadgetFunction};
|
||||
use super::function::GadgetFunction;
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// MSD LUN configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MsdLunConfig {
|
||||
/// File/image path to expose
|
||||
pub file: PathBuf,
|
||||
/// Mount as CD-ROM
|
||||
pub cdrom: bool,
|
||||
/// Read-only mode
|
||||
pub ro: bool,
|
||||
/// Removable media
|
||||
pub removable: bool,
|
||||
/// Disable Force Unit Access
|
||||
pub nofua: bool,
|
||||
}
|
||||
|
||||
@@ -36,7 +28,6 @@ impl Default for MsdLunConfig {
|
||||
}
|
||||
|
||||
impl MsdLunConfig {
|
||||
/// Create CD-ROM configuration
|
||||
pub fn cdrom(file: PathBuf) -> Self {
|
||||
Self {
|
||||
file,
|
||||
@@ -47,7 +38,6 @@ impl MsdLunConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create disk configuration
|
||||
pub fn disk(file: PathBuf, read_only: bool) -> Self {
|
||||
Self {
|
||||
file,
|
||||
@@ -59,38 +49,26 @@ impl MsdLunConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// MSD Function for USB Gadget
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MsdFunction {
|
||||
/// Instance number (usb0, usb1, ...)
|
||||
instance: u8,
|
||||
/// Cached function name (avoids repeated allocation)
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl MsdFunction {
|
||||
/// Create a new MSD function
|
||||
pub fn new(instance: u8) -> Self {
|
||||
Self {
|
||||
instance,
|
||||
name: format!("mass_storage.usb{}", instance),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get function path in gadget
|
||||
fn function_path(&self, gadget_path: &Path) -> PathBuf {
|
||||
gadget_path.join("functions").join(self.name())
|
||||
}
|
||||
|
||||
/// Get LUN path
|
||||
fn lun_path(&self, gadget_path: &Path, lun: u8) -> PathBuf {
|
||||
self.function_path(gadget_path).join(format!("lun.{}", lun))
|
||||
}
|
||||
|
||||
/// Configure a LUN with specified settings (async version)
|
||||
///
|
||||
/// This is the preferred method for async contexts. It runs the blocking
|
||||
/// file I/O and USB timing operations in a separate thread pool.
|
||||
pub async fn configure_lun_async(
|
||||
&self,
|
||||
gadget_path: &Path,
|
||||
@@ -106,17 +84,6 @@ impl MsdFunction {
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Configure a LUN with specified settings
|
||||
/// Note: This should be called after the gadget is set up
|
||||
///
|
||||
/// This implementation is based on PiKVM's MSD drive configuration.
|
||||
/// Key improvements:
|
||||
/// - Uses forced_eject when available (safer than clearing file directly)
|
||||
/// - Reduced sleep times to minimize HID interference
|
||||
/// - Better retry logic for EBUSY errors
|
||||
///
|
||||
/// **Note**: This is a blocking function. In async contexts, prefer
|
||||
/// `configure_lun_async` to avoid blocking the runtime.
|
||||
pub fn configure_lun(&self, gadget_path: &Path, lun: u8, config: &MsdLunConfig) -> Result<()> {
|
||||
let lun_path = self.lun_path(gadget_path, lun);
|
||||
|
||||
@@ -124,7 +91,6 @@ impl MsdFunction {
|
||||
create_dir(&lun_path)?;
|
||||
}
|
||||
|
||||
// Batch read all current values to minimize syscalls
|
||||
let read_attr = |attr: &str| -> String {
|
||||
fs::read_to_string(lun_path.join(attr))
|
||||
.unwrap_or_default()
|
||||
@@ -137,28 +103,21 @@ impl MsdFunction {
|
||||
let current_removable = read_attr("removable");
|
||||
let current_nofua = read_attr("nofua");
|
||||
|
||||
// Prepare new values
|
||||
let new_cdrom = if config.cdrom { "1" } else { "0" };
|
||||
let new_ro = if config.ro { "1" } else { "0" };
|
||||
let new_removable = if config.removable { "1" } else { "0" };
|
||||
let new_nofua = if config.nofua { "1" } else { "0" };
|
||||
|
||||
// Disconnect current file first using forced_eject if available (PiKVM approach)
|
||||
let forced_eject_path = lun_path.join("forced_eject");
|
||||
if forced_eject_path.exists() {
|
||||
// forced_eject is safer - it forcibly detaches regardless of host state
|
||||
debug!("Using forced_eject to clear LUN {}", lun);
|
||||
let _ = write_file(&forced_eject_path, "1");
|
||||
} else {
|
||||
// Fallback to clearing file directly
|
||||
let _ = write_file(&lun_path.join("file"), "");
|
||||
}
|
||||
|
||||
// Brief yield to allow USB stack to process the disconnect
|
||||
// Reduced from 200ms to 50ms - let USB protocol handle timing
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Write only changed attributes
|
||||
let cdrom_changed = current_cdrom != new_cdrom;
|
||||
if cdrom_changed {
|
||||
debug!(
|
||||
@@ -186,13 +145,11 @@ impl MsdFunction {
|
||||
write_file(&lun_path.join("nofua"), new_nofua)?;
|
||||
}
|
||||
|
||||
// If cdrom mode changed, brief yield for USB host
|
||||
if cdrom_changed {
|
||||
debug!("CDROM mode changed, brief yield for USB host");
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
}
|
||||
|
||||
// Set file path (this triggers the actual mount) - with retry on EBUSY
|
||||
if config.file.exists() {
|
||||
let file_path = config.file.to_string_lossy();
|
||||
let mut last_error = None;
|
||||
@@ -210,7 +167,6 @@ impl MsdFunction {
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
// Check if it's EBUSY (error code 16)
|
||||
let is_busy = e.to_string().contains("Device or resource busy")
|
||||
|| e.to_string().contains("os error 16");
|
||||
|
||||
@@ -220,7 +176,6 @@ impl MsdFunction {
|
||||
lun,
|
||||
attempt + 1
|
||||
);
|
||||
// Exponential backoff: 50, 100, 200, 400ms
|
||||
std::thread::sleep(std::time::Duration::from_millis(50 << attempt));
|
||||
last_error = Some(e);
|
||||
continue;
|
||||
@@ -231,7 +186,6 @@ impl MsdFunction {
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, all retries failed
|
||||
if let Some(e) = last_error {
|
||||
return Err(e);
|
||||
}
|
||||
@@ -242,9 +196,6 @@ impl MsdFunction {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect LUN (async version)
|
||||
///
|
||||
/// Preferred for async contexts.
|
||||
pub async fn disconnect_lun_async(&self, gadget_path: &Path, lun: u8) -> Result<()> {
|
||||
let gadget_path = gadget_path.to_path_buf();
|
||||
let this = self.clone();
|
||||
@@ -254,17 +205,10 @@ impl MsdFunction {
|
||||
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
|
||||
}
|
||||
|
||||
/// Disconnect LUN (clear file)
|
||||
///
|
||||
/// This method uses forced_eject when available, which is safer than
|
||||
/// directly clearing the file path. Based on PiKVM's implementation.
|
||||
/// See: https://docs.kernel.org/usb/mass-storage.html
|
||||
pub fn disconnect_lun(&self, gadget_path: &Path, lun: u8) -> Result<()> {
|
||||
let lun_path = self.lun_path(gadget_path, lun);
|
||||
|
||||
if lun_path.exists() {
|
||||
// Prefer forced_eject if available (PiKVM approach)
|
||||
// forced_eject forcibly detaches the backing file regardless of host state
|
||||
let forced_eject_path = lun_path.join("forced_eject");
|
||||
if forced_eject_path.exists() {
|
||||
debug!(
|
||||
@@ -282,7 +226,6 @@ impl MsdFunction {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to clearing file directly
|
||||
write_file(&lun_path.join("file"), "")?;
|
||||
}
|
||||
info!("LUN {} disconnected", lun);
|
||||
@@ -291,7 +234,6 @@ impl MsdFunction {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get current LUN file path
|
||||
pub fn get_lun_file(&self, gadget_path: &Path, lun: u8) -> Option<PathBuf> {
|
||||
let lun_path = self.lun_path(gadget_path, lun);
|
||||
let file_path = lun_path.join("file");
|
||||
@@ -306,7 +248,6 @@ impl MsdFunction {
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if LUN is connected
|
||||
pub fn is_lun_connected(&self, gadget_path: &Path, lun: u8) -> bool {
|
||||
self.get_lun_file(gadget_path, lun).is_some()
|
||||
}
|
||||
@@ -318,39 +259,23 @@ impl GadgetFunction for MsdFunction {
|
||||
}
|
||||
|
||||
fn endpoints_required(&self) -> u8 {
|
||||
2 // IN + OUT for bulk transfers
|
||||
}
|
||||
|
||||
fn meta(&self) -> FunctionMeta {
|
||||
FunctionMeta {
|
||||
name: self.name().to_string(),
|
||||
description: if self.instance == 0 {
|
||||
"Mass Storage Drive".to_string()
|
||||
} else {
|
||||
format!("Extra Drive #{}", self.instance)
|
||||
},
|
||||
endpoints: self.endpoints_required(),
|
||||
enabled: true,
|
||||
}
|
||||
2
|
||||
}
|
||||
|
||||
fn create(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
create_dir(&func_path)?;
|
||||
|
||||
// Set stall to 0 (workaround for some hosts)
|
||||
let stall_path = func_path.join("stall");
|
||||
if stall_path.exists() {
|
||||
let _ = write_file(&stall_path, "0");
|
||||
}
|
||||
|
||||
// LUN 0 is created automatically, but ensure it exists
|
||||
let lun0_path = func_path.join("lun.0");
|
||||
if !lun0_path.exists() {
|
||||
create_dir(&lun0_path)?;
|
||||
}
|
||||
|
||||
// Set default LUN 0 parameters
|
||||
let _ = write_file(&lun0_path.join("cdrom"), "0");
|
||||
let _ = write_file(&lun0_path.join("ro"), "0");
|
||||
let _ = write_file(&lun0_path.join("removable"), "1");
|
||||
@@ -382,12 +307,10 @@ impl GadgetFunction for MsdFunction {
|
||||
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
|
||||
let func_path = self.function_path(gadget_path);
|
||||
|
||||
// Disconnect all LUNs first
|
||||
for lun in 0..8 {
|
||||
let _ = self.disconnect_lun(gadget_path, lun);
|
||||
}
|
||||
|
||||
// Remove function directory
|
||||
if let Err(e) = remove_dir(&func_path) {
|
||||
warn!("Could not remove MSD function directory: {}", e);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
//! HID Report Descriptors
|
||||
|
||||
/// Keyboard HID Report Descriptor (no LED output)
|
||||
/// Report format (8 bytes input):
|
||||
/// [0] Modifier keys (8 bits)
|
||||
/// [1] Reserved
|
||||
/// [2-7] Key codes (6 keys)
|
||||
pub const KEYBOARD: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x06, // Usage (Keyboard)
|
||||
@@ -34,13 +27,6 @@ pub const KEYBOARD: &[u8] = &[
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
/// Keyboard HID Report Descriptor with LED output support.
|
||||
/// Input report format (8 bytes):
|
||||
/// [0] Modifier keys (8 bits)
|
||||
/// [1] Reserved
|
||||
/// [2-7] Key codes (6 keys)
|
||||
/// Output report format (1 byte):
|
||||
/// [0] Num Lock / Caps Lock / Scroll Lock / Compose / Kana
|
||||
pub const KEYBOARD_WITH_LED: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x06, // Usage (Keyboard)
|
||||
@@ -81,12 +67,6 @@ pub const KEYBOARD_WITH_LED: &[u8] = &[
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
/// Relative Mouse HID Report Descriptor (4 bytes report)
|
||||
/// Report format:
|
||||
/// [0] Buttons (5 bits) + padding (3 bits)
|
||||
/// [1] X movement (signed 8-bit)
|
||||
/// [2] Y movement (signed 8-bit)
|
||||
/// [3] Wheel (signed 8-bit)
|
||||
pub const MOUSE_RELATIVE: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x02, // Usage (Mouse)
|
||||
@@ -126,12 +106,6 @@ pub const MOUSE_RELATIVE: &[u8] = &[
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
/// Absolute Mouse HID Report Descriptor (6 bytes report)
|
||||
/// Report format:
|
||||
/// [0] Buttons (5 bits) + padding (3 bits)
|
||||
/// [1-2] X position (16-bit, 0-32767)
|
||||
/// [3-4] Y position (16-bit, 0-32767)
|
||||
/// [5] Wheel (signed 8-bit)
|
||||
pub const MOUSE_ABSOLUTE: &[u8] = &[
|
||||
0x05, 0x01, // Usage Page (Generic Desktop)
|
||||
0x09, 0x02, // Usage (Mouse)
|
||||
@@ -177,10 +151,6 @@ pub const MOUSE_ABSOLUTE: &[u8] = &[
|
||||
0xC0, // End Collection
|
||||
];
|
||||
|
||||
/// Consumer Control HID Report Descriptor (2 bytes report)
|
||||
/// Report format:
|
||||
/// [0-1] Consumer Control Usage (16-bit little-endian)
|
||||
/// Supports: Play/Pause, Stop, Next/Prev Track, Mute, Volume Up/Down, etc.
|
||||
pub const CONSUMER_CONTROL: &[u8] = &[
|
||||
0x05, 0x0C, // Usage Page (Consumer)
|
||||
0x09, 0x01, // Usage (Consumer Control)
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
//! OTG Service - unified gadget lifecycle management
|
||||
//!
|
||||
//! This module provides centralized management for USB OTG gadget functions.
|
||||
//! It is the single owner of the USB gadget desired state and reconciles
|
||||
//! ConfigFS to match that state.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -13,7 +7,6 @@ use super::msd::MsdFunction;
|
||||
use crate::config::{HidBackend, HidConfig, MsdConfig, OtgDescriptorConfig, OtgHidFunctions};
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
/// HID device paths
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct HidDevicePaths {
|
||||
pub keyboard: Option<PathBuf>,
|
||||
@@ -26,26 +19,20 @@ pub struct HidDevicePaths {
|
||||
|
||||
impl HidDevicePaths {
|
||||
pub fn existing_paths(&self) -> Vec<PathBuf> {
|
||||
let mut paths = Vec::new();
|
||||
if let Some(ref p) = self.keyboard {
|
||||
paths.push(p.clone());
|
||||
}
|
||||
if let Some(ref p) = self.mouse_relative {
|
||||
paths.push(p.clone());
|
||||
}
|
||||
if let Some(ref p) = self.mouse_absolute {
|
||||
paths.push(p.clone());
|
||||
}
|
||||
if let Some(ref p) = self.consumer {
|
||||
paths.push(p.clone());
|
||||
}
|
||||
paths
|
||||
[
|
||||
&self.keyboard,
|
||||
&self.mouse_relative,
|
||||
&self.mouse_absolute,
|
||||
&self.consumer,
|
||||
]
|
||||
.into_iter()
|
||||
.filter_map(|p| p.as_ref().cloned())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Desired OTG gadget state derived from configuration.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OtgDesiredState {
|
||||
pub(crate) struct OtgDesiredState {
|
||||
pub udc: Option<String>,
|
||||
pub descriptor: GadgetDescriptor,
|
||||
pub hid_functions: Option<OtgHidFunctions>,
|
||||
@@ -68,7 +55,7 @@ impl Default for OtgDesiredState {
|
||||
}
|
||||
|
||||
impl OtgDesiredState {
|
||||
pub fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result<Self> {
|
||||
pub(crate) fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result<Self> {
|
||||
let hid_functions = if hid.backend == HidBackend::Otg {
|
||||
let functions = hid.constrained_otg_functions();
|
||||
Some(functions)
|
||||
@@ -96,45 +83,28 @@ impl OtgDesiredState {
|
||||
}
|
||||
}
|
||||
|
||||
/// OTG Service state
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OtgServiceState {
|
||||
/// Whether the gadget is created and bound
|
||||
struct OtgServiceState {
|
||||
pub gadget_active: bool,
|
||||
/// Whether HID functions are enabled
|
||||
pub hid_enabled: bool,
|
||||
/// Whether MSD function is enabled
|
||||
pub msd_enabled: bool,
|
||||
/// Bound UDC name
|
||||
pub configured_udc: Option<String>,
|
||||
/// HID device paths (set after gadget setup)
|
||||
pub hid_paths: Option<HidDevicePaths>,
|
||||
/// HID function selection (set after gadget setup)
|
||||
pub hid_functions: Option<OtgHidFunctions>,
|
||||
/// Whether keyboard LED/status feedback is enabled.
|
||||
pub keyboard_leds_enabled: bool,
|
||||
/// Applied endpoint budget.
|
||||
pub max_endpoints: u8,
|
||||
/// Applied descriptor configuration
|
||||
pub descriptor: Option<GadgetDescriptor>,
|
||||
/// Error message if setup failed
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// OTG Service - unified gadget lifecycle management
|
||||
pub struct OtgService {
|
||||
/// The underlying gadget manager
|
||||
manager: Mutex<Option<OtgGadgetManager>>,
|
||||
/// Current state
|
||||
state: RwLock<OtgServiceState>,
|
||||
/// MSD function handle (for runtime LUN configuration)
|
||||
msd_function: RwLock<Option<MsdFunction>>,
|
||||
/// Desired OTG state
|
||||
desired: RwLock<OtgDesiredState>,
|
||||
}
|
||||
|
||||
impl OtgService {
|
||||
/// Create a new OTG service
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
manager: Mutex::new(None),
|
||||
@@ -144,55 +114,29 @@ impl OtgService {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if OTG is available on this system
|
||||
pub fn is_available() -> bool {
|
||||
OtgGadgetManager::is_available() && OtgGadgetManager::find_udc().is_some()
|
||||
}
|
||||
|
||||
/// Get current service state
|
||||
pub async fn state(&self) -> OtgServiceState {
|
||||
self.state.read().await.clone()
|
||||
}
|
||||
|
||||
/// Check if gadget is active
|
||||
pub async fn is_gadget_active(&self) -> bool {
|
||||
self.state.read().await.gadget_active
|
||||
}
|
||||
|
||||
/// Check if HID is enabled
|
||||
pub async fn is_hid_enabled(&self) -> bool {
|
||||
self.state.read().await.hid_enabled
|
||||
}
|
||||
|
||||
/// Check if MSD is enabled
|
||||
pub async fn is_msd_enabled(&self) -> bool {
|
||||
self.state.read().await.msd_enabled
|
||||
}
|
||||
|
||||
/// Get gadget path (for MSD LUN configuration)
|
||||
pub async fn gadget_path(&self) -> Option<PathBuf> {
|
||||
let manager = self.manager.lock().await;
|
||||
manager.as_ref().map(|m| m.gadget_path().clone())
|
||||
}
|
||||
|
||||
/// Get HID device paths
|
||||
pub async fn hid_device_paths(&self) -> Option<HidDevicePaths> {
|
||||
self.state.read().await.hid_paths.clone()
|
||||
}
|
||||
|
||||
/// Get MSD function handle (for LUN configuration)
|
||||
pub async fn msd_function(&self) -> Option<MsdFunction> {
|
||||
self.msd_function.read().await.clone()
|
||||
}
|
||||
|
||||
/// Apply desired OTG state derived from the current application config.
|
||||
pub async fn apply_config(&self, hid: &HidConfig, msd: &MsdConfig) -> Result<()> {
|
||||
let desired = OtgDesiredState::from_config(hid, msd)?;
|
||||
self.apply_desired_state(desired).await
|
||||
}
|
||||
|
||||
/// Apply a fully materialized desired OTG state.
|
||||
pub async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> {
|
||||
pub(crate) async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> {
|
||||
{
|
||||
let mut current = self.desired.write().await;
|
||||
*current = desired;
|
||||
@@ -392,7 +336,6 @@ impl OtgService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown the OTG service and cleanup all resources
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
info!("Shutting down OTG service");
|
||||
|
||||
@@ -425,12 +368,6 @@ impl Default for OtgService {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OtgService {
|
||||
fn drop(&mut self) {
|
||||
debug!("OtgService dropping");
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&OtgDescriptorConfig> for GadgetDescriptor {
|
||||
fn from(config: &OtgDescriptorConfig) -> Self {
|
||||
Self {
|
||||
@@ -452,17 +389,8 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_service_creation() {
|
||||
fn service_new_and_availability_probe() {
|
||||
let _service = OtgService::new();
|
||||
let _ = OtgService::is_available();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initial_state() {
|
||||
let service = OtgService::new();
|
||||
let state = service.state().await;
|
||||
assert!(!state.gadget_active);
|
||||
assert!(!state.hid_enabled);
|
||||
assert!(!state.msd_enabled);
|
||||
}
|
||||
}
|
||||
|
||||
73
src/rtsp/auth.rs
Normal file
73
src/rtsp/auth.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use base64::Engine;
|
||||
|
||||
use crate::config::RtspConfig;
|
||||
|
||||
use super::types::RtspRequest;
|
||||
|
||||
pub(crate) fn extract_basic_auth(req: &RtspRequest) -> Option<(String, String)> {
|
||||
let value = req.headers.get("authorization")?;
|
||||
let mut parts = value.split_whitespace();
|
||||
let scheme = parts.next()?;
|
||||
if !scheme.eq_ignore_ascii_case("basic") {
|
||||
return None;
|
||||
}
|
||||
let b64 = parts.next()?;
|
||||
let decoded = base64::engine::general_purpose::STANDARD.decode(b64).ok()?;
|
||||
let raw = String::from_utf8(decoded).ok()?;
|
||||
let (user, pass) = raw.split_once(':')?;
|
||||
Some((user.to_string(), pass.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) fn rtsp_auth_credentials(config: &RtspConfig) -> Option<(String, String)> {
|
||||
let username = config.username.as_ref()?.trim();
|
||||
if username.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((
|
||||
username.to_string(),
|
||||
config.password.clone().unwrap_or_default(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rtsp_types as rtsp;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn rtsp_auth_requires_non_empty_username() {
|
||||
let mut config = RtspConfig::default();
|
||||
config.password = Some("secret".to_string());
|
||||
assert!(rtsp_auth_credentials(&config).is_none());
|
||||
|
||||
config.username = Some("".to_string());
|
||||
assert!(rtsp_auth_credentials(&config).is_none());
|
||||
|
||||
config.username = Some("user".to_string());
|
||||
let credentials = rtsp_auth_credentials(&config).expect("expected credentials");
|
||||
assert_eq!(credentials, ("user".to_string(), "secret".to_string()));
|
||||
|
||||
config.password = None;
|
||||
let credentials = rtsp_auth_credentials(&config).expect("expected credentials");
|
||||
assert_eq!(credentials, ("user".to_string(), "".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_basic_auth_roundtrip() {
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(b"alice:pwd");
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("authorization".to_string(), format!("Basic {}", encoded));
|
||||
let req = RtspRequest {
|
||||
method: rtsp::Method::Options,
|
||||
uri: "*".to_string(),
|
||||
version: rtsp::Version::V1_0,
|
||||
headers,
|
||||
};
|
||||
assert_eq!(
|
||||
extract_basic_auth(&req),
|
||||
Some(("alice".to_string(), "pwd".to_string()))
|
||||
);
|
||||
}
|
||||
}
|
||||
96
src/rtsp/bitstream.rs
Normal file
96
src/rtsp/bitstream.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::video::encoder::registry::VideoEncoderType;
|
||||
use crate::video::shared_video_pipeline::EncodedVideoFrame;
|
||||
|
||||
use super::state::ParameterSets;
|
||||
|
||||
pub(crate) fn update_parameter_sets(params: &mut ParameterSets, frame: &EncodedVideoFrame) {
|
||||
let nal_units = split_annexb_nal_units(frame.data.as_ref());
|
||||
|
||||
match frame.codec {
|
||||
VideoEncoderType::H264 => {
|
||||
for nal in nal_units {
|
||||
match h264_nal_type(nal) {
|
||||
Some(7) => params.h264_sps = Some(Bytes::copy_from_slice(nal)),
|
||||
Some(8) => params.h264_pps = Some(Bytes::copy_from_slice(nal)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
VideoEncoderType::H265 => {
|
||||
for nal in nal_units {
|
||||
match h265_nal_type(nal) {
|
||||
Some(32) => params.h265_vps = Some(Bytes::copy_from_slice(nal)),
|
||||
Some(33) => params.h265_sps = Some(Bytes::copy_from_slice(nal)),
|
||||
Some(34) => params.h265_pps = Some(Bytes::copy_from_slice(nal)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn split_annexb_nal_units(data: &[u8]) -> Vec<&[u8]> {
|
||||
let mut nal_units = Vec::new();
|
||||
let mut cursor = 0usize;
|
||||
|
||||
while let Some((start, start_code_len)) = find_annexb_start_code(data, cursor) {
|
||||
let nal_start = start + start_code_len;
|
||||
if nal_start >= data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let next_start = find_annexb_start_code(data, nal_start)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(data.len());
|
||||
|
||||
let mut nal_end = next_start;
|
||||
while nal_end > nal_start && data[nal_end - 1] == 0 {
|
||||
nal_end -= 1;
|
||||
}
|
||||
|
||||
if nal_end > nal_start {
|
||||
nal_units.push(&data[nal_start..nal_end]);
|
||||
}
|
||||
|
||||
cursor = next_start;
|
||||
}
|
||||
|
||||
nal_units
|
||||
}
|
||||
|
||||
fn find_annexb_start_code(data: &[u8], from: usize) -> Option<(usize, usize)> {
|
||||
if from >= data.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut i = from;
|
||||
while i + 3 <= data.len() {
|
||||
if i + 4 <= data.len()
|
||||
&& data[i] == 0
|
||||
&& data[i + 1] == 0
|
||||
&& data[i + 2] == 0
|
||||
&& data[i + 3] == 1
|
||||
{
|
||||
return Some((i, 4));
|
||||
}
|
||||
|
||||
if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
|
||||
return Some((i, 3));
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn h264_nal_type(nal: &[u8]) -> Option<u8> {
|
||||
nal.first().map(|value| value & 0x1f)
|
||||
}
|
||||
|
||||
fn h265_nal_type(nal: &[u8]) -> Option<u8> {
|
||||
nal.first().map(|value| (value >> 1) & 0x3f)
|
||||
}
|
||||
9
src/rtsp/codec.rs
Normal file
9
src/rtsp/codec.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use crate::config::RtspCodec;
|
||||
use crate::video::encoder::VideoCodecType;
|
||||
|
||||
pub(crate) fn rtsp_codec_to_video(codec: RtspCodec) -> VideoCodecType {
|
||||
match codec {
|
||||
RtspCodec::H264 => VideoCodecType::H264,
|
||||
RtspCodec::H265 => VideoCodecType::H265,
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,14 @@
|
||||
pub mod service;
|
||||
//! RTSP TCP server exposing H.264/H.265 video from [`VideoStreamManager`](crate::video::VideoStreamManager).
|
||||
|
||||
mod auth;
|
||||
mod bitstream;
|
||||
mod codec;
|
||||
mod protocol;
|
||||
mod response;
|
||||
mod sdp;
|
||||
mod service;
|
||||
mod state;
|
||||
mod streaming;
|
||||
mod types;
|
||||
|
||||
pub use service::{RtspService, RtspServiceStatus};
|
||||
|
||||
193
src/rtsp/protocol.rs
Normal file
193
src/rtsp/protocol.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use rtsp_types as rtsp;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::types::RtspRequest;
|
||||
|
||||
pub(crate) const OPTIONS_PUBLIC_CAPABILITIES: &str =
|
||||
"OPTIONS, DESCRIBE, SETUP, PLAY, GET_PARAMETER, SET_PARAMETER, TEARDOWN";
|
||||
|
||||
pub(crate) fn strip_interleaved_frames_prefix(buffer: &mut Vec<u8>) -> bool {
|
||||
if buffer.len() < 4 || buffer[0] != b'$' {
|
||||
return false;
|
||||
}
|
||||
|
||||
let payload_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize;
|
||||
let frame_len = 4 + payload_len;
|
||||
if buffer.len() < frame_len {
|
||||
return false;
|
||||
}
|
||||
|
||||
buffer.drain(0..frame_len);
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn take_rtsp_request_from_buffer(buffer: &mut Vec<u8>) -> Option<String> {
|
||||
let delimiter = b"\r\n\r\n";
|
||||
let pos = find_bytes(buffer, delimiter)?;
|
||||
let req_end = pos + delimiter.len();
|
||||
let req_bytes: Vec<u8> = buffer.drain(0..req_end).collect();
|
||||
Some(String::from_utf8_lossy(&req_bytes).to_string())
|
||||
}
|
||||
|
||||
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
haystack
|
||||
.windows(needle.len())
|
||||
.position(|window| window == needle)
|
||||
}
|
||||
|
||||
pub(crate) fn parse_rtsp_request(raw: &str) -> Option<RtspRequest> {
|
||||
let (message, consumed): (rtsp::Message<Vec<u8>>, usize) =
|
||||
rtsp::Message::parse(raw.as_bytes()).ok()?;
|
||||
if consumed != raw.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let request = match message {
|
||||
rtsp::Message::Request(req) => req,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let uri = request
|
||||
.request_uri()
|
||||
.map(|value| value.as_str().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut headers = HashMap::new();
|
||||
for (name, value) in request.headers() {
|
||||
headers.insert(name.to_string().to_ascii_lowercase(), value.to_string());
|
||||
}
|
||||
|
||||
Some(RtspRequest {
|
||||
method: request.method().clone(),
|
||||
uri,
|
||||
version: request.version(),
|
||||
headers,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn parse_interleaved_channel(transport: &str) -> Option<u8> {
|
||||
let lower = transport.to_ascii_lowercase();
|
||||
if let Some((_, v)) = lower.split_once("interleaved=") {
|
||||
let head = v.split(';').next().unwrap_or(v);
|
||||
let first = head.split('-').next().unwrap_or(head).trim();
|
||||
return first.parse::<u8>().ok();
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn is_tcp_transport_request(transport: &str) -> bool {
|
||||
transport
|
||||
.split(',')
|
||||
.map(str::trim)
|
||||
.map(str::to_ascii_lowercase)
|
||||
.any(|item| item.contains("rtp/avp/tcp") || item.contains("interleaved="))
|
||||
}
|
||||
|
||||
pub(crate) fn is_valid_rtsp_path(method: &rtsp::Method, uri: &str, configured_path: &str) -> bool {
|
||||
if matches!(method, rtsp::Method::Options) && uri.trim() == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
let normalized_cfg = configured_path.trim_matches('/');
|
||||
if normalized_cfg.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let request_path = extract_rtsp_path(uri);
|
||||
|
||||
if request_path == normalized_cfg {
|
||||
return true;
|
||||
}
|
||||
|
||||
if !matches!(method, rtsp::Method::Setup | rtsp::Method::Teardown) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let control_track_path = format!("{}/trackID=0", normalized_cfg);
|
||||
request_path == "trackID=0" || request_path == control_track_path
|
||||
}
|
||||
|
||||
fn extract_rtsp_path(uri: &str) -> String {
|
||||
let raw_path = if let Some((_, remainder)) = uri.split_once("://") {
|
||||
match remainder.find('/') {
|
||||
Some(idx) => &remainder[idx..],
|
||||
None => "/",
|
||||
}
|
||||
} else {
|
||||
uri
|
||||
};
|
||||
|
||||
raw_path
|
||||
.split('?')
|
||||
.next()
|
||||
.unwrap_or(raw_path)
|
||||
.split('#')
|
||||
.next()
|
||||
.unwrap_or(raw_path)
|
||||
.trim_matches('/')
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rtsp_path_matching_follows_sdp_control_rules() {
|
||||
assert!(is_valid_rtsp_path(
|
||||
&rtsp::Method::Describe,
|
||||
"rtsp://127.0.0.1/live",
|
||||
"live"
|
||||
));
|
||||
assert!(is_valid_rtsp_path(
|
||||
&rtsp::Method::Describe,
|
||||
"rtsp://127.0.0.1/live/?token=1",
|
||||
"/live/"
|
||||
));
|
||||
assert!(!is_valid_rtsp_path(
|
||||
&rtsp::Method::Describe,
|
||||
"rtsp://127.0.0.1/live2",
|
||||
"live"
|
||||
));
|
||||
assert!(!is_valid_rtsp_path(
|
||||
&rtsp::Method::Describe,
|
||||
"rtsp://127.0.0.1/",
|
||||
"/"
|
||||
));
|
||||
|
||||
assert!(is_valid_rtsp_path(
|
||||
&rtsp::Method::Setup,
|
||||
"rtsp://127.0.0.1/live/trackID=0",
|
||||
"live"
|
||||
));
|
||||
assert!(is_valid_rtsp_path(
|
||||
&rtsp::Method::Setup,
|
||||
"rtsp://127.0.0.1/trackID=0",
|
||||
"live"
|
||||
));
|
||||
assert!(!is_valid_rtsp_path(
|
||||
&rtsp::Method::Describe,
|
||||
"rtsp://127.0.0.1/live/trackID=0",
|
||||
"live"
|
||||
));
|
||||
|
||||
assert!(is_valid_rtsp_path(&rtsp::Method::Options, "*", "live"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_parsing_detects_tcp_interleaved_requests() {
|
||||
assert!(is_tcp_transport_request(
|
||||
"RTP/AVP/TCP;unicast;interleaved=0-1"
|
||||
));
|
||||
assert!(is_tcp_transport_request("RTP/AVP;unicast;interleaved=2-3"));
|
||||
assert!(!is_tcp_transport_request(
|
||||
"RTP/AVP;unicast;client_port=8000-8001"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn options_public_includes_standard_methods() {
|
||||
assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("GET_PARAMETER"));
|
||||
assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("TEARDOWN"));
|
||||
}
|
||||
}
|
||||
81
src/rtsp/response.rs
Normal file
81
src/rtsp/response.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use rtsp_types as rtsp;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
|
||||
use super::types::RtspRequest;
|
||||
|
||||
async fn serialize_and_write<W: AsyncWrite + Unpin>(
|
||||
stream: &mut W,
|
||||
response: rtsp::Response<Vec<u8>>,
|
||||
) -> Result<()> {
|
||||
let mut data = Vec::new();
|
||||
response
|
||||
.write(&mut data)
|
||||
.map_err(|e| AppError::BadRequest(format!("failed to serialize RTSP response: {}", e)))?;
|
||||
stream.write_all(&data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn send_simple_response<W: AsyncWrite + Unpin>(
|
||||
stream: &mut W,
|
||||
code: u16,
|
||||
_reason: &str,
|
||||
cseq: Option<&str>,
|
||||
body: &str,
|
||||
) -> Result<()> {
|
||||
let mut builder = rtsp::Response::builder(rtsp::Version::V1_0, status_code_from_u16(code));
|
||||
if let Some(cseq) = cseq {
|
||||
builder = builder.header(rtsp::headers::CSEQ, cseq);
|
||||
}
|
||||
|
||||
let response = builder.build(body.as_bytes().to_vec());
|
||||
serialize_and_write(stream, response).await
|
||||
}
|
||||
|
||||
pub(crate) async fn send_response<W: AsyncWrite + Unpin>(
|
||||
stream: &mut W,
|
||||
req: &RtspRequest,
|
||||
code: u16,
|
||||
_reason: &str,
|
||||
extra_headers: Vec<(String, String)>,
|
||||
body: &str,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
let cseq = req
|
||||
.headers
|
||||
.get("cseq")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "1".to_string());
|
||||
|
||||
let mut builder = rtsp::Response::builder(req.version, status_code_from_u16(code))
|
||||
.header(rtsp::headers::CSEQ, cseq.as_str());
|
||||
|
||||
if !session_id.is_empty() {
|
||||
builder = builder.header(rtsp::headers::SESSION, session_id);
|
||||
}
|
||||
|
||||
for (name, value) in extra_headers {
|
||||
let header_name = rtsp::HeaderName::try_from(name.as_str()).map_err(|e| {
|
||||
AppError::BadRequest(format!("invalid RTSP header name {}: {}", name, e))
|
||||
})?;
|
||||
builder = builder.header(header_name, value);
|
||||
}
|
||||
|
||||
let response = builder.build(body.as_bytes().to_vec());
|
||||
serialize_and_write(stream, response).await
|
||||
}
|
||||
|
||||
pub(crate) fn status_code_from_u16(code: u16) -> rtsp::StatusCode {
|
||||
match code {
|
||||
200 => rtsp::StatusCode::Ok,
|
||||
400 => rtsp::StatusCode::BadRequest,
|
||||
401 => rtsp::StatusCode::Unauthorized,
|
||||
404 => rtsp::StatusCode::NotFound,
|
||||
405 => rtsp::StatusCode::MethodNotAllowed,
|
||||
453 => rtsp::StatusCode::NotEnoughBandwidth,
|
||||
455 => rtsp::StatusCode::MethodNotValidInThisState,
|
||||
461 => rtsp::StatusCode::UnsupportedTransport,
|
||||
_ => rtsp::StatusCode::InternalServerError,
|
||||
}
|
||||
}
|
||||
224
src/rtsp/sdp.rs
Normal file
224
src/rtsp/sdp.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use base64::Engine;
|
||||
use sdp_types as sdp;
|
||||
|
||||
use crate::config::RtspConfig;
|
||||
use crate::video::encoder::VideoCodecType;
|
||||
use crate::webrtc::rtp::parse_profile_level_id_from_sps;
|
||||
|
||||
use super::state::ParameterSets;
|
||||
|
||||
pub(crate) fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> String {
|
||||
let mut attrs = vec!["packetization-mode=1".to_string()];
|
||||
|
||||
if let Some(sps) = params.h264_sps.as_ref() {
|
||||
if let Some(profile_level_id) = parse_profile_level_id_from_sps(sps) {
|
||||
attrs.push(format!("profile-level-id={}", profile_level_id));
|
||||
}
|
||||
} else {
|
||||
attrs.push("profile-level-id=42e01f".to_string());
|
||||
}
|
||||
|
||||
if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) {
|
||||
let sps_b64 = base64::engine::general_purpose::STANDARD.encode(sps.as_ref());
|
||||
let pps_b64 = base64::engine::general_purpose::STANDARD.encode(pps.as_ref());
|
||||
attrs.push(format!("sprop-parameter-sets={},{}", sps_b64, pps_b64));
|
||||
}
|
||||
|
||||
format!("{} {}", payload_type, attrs.join(";"))
|
||||
}
|
||||
|
||||
pub(crate) fn build_h265_fmtp(payload_type: u8, params: &ParameterSets) -> String {
|
||||
let mut attrs = Vec::new();
|
||||
|
||||
if let Some(vps) = params.h265_vps.as_ref() {
|
||||
attrs.push(format!(
|
||||
"sprop-vps={}",
|
||||
base64::engine::general_purpose::STANDARD.encode(vps.as_ref())
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(sps) = params.h265_sps.as_ref() {
|
||||
attrs.push(format!(
|
||||
"sprop-sps={}",
|
||||
base64::engine::general_purpose::STANDARD.encode(sps.as_ref())
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(pps) = params.h265_pps.as_ref() {
|
||||
attrs.push(format!(
|
||||
"sprop-pps={}",
|
||||
base64::engine::general_purpose::STANDARD.encode(pps.as_ref())
|
||||
));
|
||||
}
|
||||
|
||||
if attrs.is_empty() {
|
||||
format!("{} profile-id=1", payload_type)
|
||||
} else {
|
||||
format!("{} {}", payload_type, attrs.join(";"))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_sdp(
|
||||
config: &RtspConfig,
|
||||
codec: VideoCodecType,
|
||||
params: &ParameterSets,
|
||||
) -> String {
|
||||
let (payload_type, codec_name, fmtp_value) = match codec {
|
||||
VideoCodecType::H264 => (96u8, "H264", build_h264_fmtp(96, params)),
|
||||
VideoCodecType::H265 => (99u8, "H265", build_h265_fmtp(99, params)),
|
||||
_ => {
|
||||
tracing::warn!("RTSP SDP: unexpected VideoCodecType, falling back to H264");
|
||||
(96u8, "H264", build_h264_fmtp(96, params))
|
||||
}
|
||||
};
|
||||
|
||||
let session = sdp::Session {
|
||||
origin: sdp::Origin {
|
||||
username: Some("-".to_string()),
|
||||
sess_id: "0".to_string(),
|
||||
sess_version: 0,
|
||||
nettype: "IN".to_string(),
|
||||
addrtype: "IP4".to_string(),
|
||||
unicast_address: config.bind.clone(),
|
||||
},
|
||||
session_name: "One-KVM RTSP Stream".to_string(),
|
||||
session_description: None,
|
||||
uri: None,
|
||||
emails: Vec::new(),
|
||||
phones: Vec::new(),
|
||||
connection: Some(sdp::Connection {
|
||||
nettype: "IN".to_string(),
|
||||
addrtype: "IP4".to_string(),
|
||||
connection_address: "0.0.0.0".to_string(),
|
||||
}),
|
||||
bandwidths: Vec::new(),
|
||||
times: vec![sdp::Time {
|
||||
start_time: 0,
|
||||
stop_time: 0,
|
||||
repeats: Vec::new(),
|
||||
}],
|
||||
time_zones: Vec::new(),
|
||||
key: None,
|
||||
attributes: vec![sdp::Attribute {
|
||||
attribute: "control".to_string(),
|
||||
value: Some("*".to_string()),
|
||||
}],
|
||||
medias: vec![sdp::Media {
|
||||
media: "video".to_string(),
|
||||
port: 0,
|
||||
num_ports: None,
|
||||
proto: "RTP/AVP".to_string(),
|
||||
fmt: payload_type.to_string(),
|
||||
media_title: None,
|
||||
connections: Vec::new(),
|
||||
bandwidths: Vec::new(),
|
||||
key: None,
|
||||
attributes: vec![
|
||||
sdp::Attribute {
|
||||
attribute: "rtpmap".to_string(),
|
||||
value: Some(format!("{} {}/90000", payload_type, codec_name)),
|
||||
},
|
||||
sdp::Attribute {
|
||||
attribute: "fmtp".to_string(),
|
||||
value: Some(fmtp_value),
|
||||
},
|
||||
sdp::Attribute {
|
||||
attribute: "control".to_string(),
|
||||
value: Some("trackID=0".to_string()),
|
||||
},
|
||||
],
|
||||
}],
|
||||
};
|
||||
|
||||
let mut output = Vec::new();
|
||||
if let Err(err) = session.write(&mut output) {
|
||||
tracing::warn!("Failed to serialize SDP with sdp-types: {}", err);
|
||||
return String::new();
|
||||
}
|
||||
|
||||
match String::from_utf8(output) {
|
||||
Ok(sdp_text) => sdp_text,
|
||||
Err(err) => {
|
||||
tracing::warn!("Failed to convert SDP bytes to UTF-8: {}", err);
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::RtspConfig;
|
||||
use bytes::Bytes;
|
||||
|
||||
#[test]
|
||||
fn build_sdp_h264_is_parseable_with_expected_video_attributes() {
|
||||
let config = RtspConfig::default();
|
||||
let mut params = ParameterSets::default();
|
||||
params.h264_sps = Some(Bytes::from_static(&[0x67, 0x42, 0xe0, 0x1f, 0x96, 0x54]));
|
||||
params.h264_pps = Some(Bytes::from_static(&[0x68, 0xce, 0x06, 0xe2]));
|
||||
|
||||
let sdp_text = build_sdp(&config, VideoCodecType::H264, ¶ms);
|
||||
assert!(!sdp_text.is_empty());
|
||||
|
||||
let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed");
|
||||
assert_eq!(session.session_name, "One-KVM RTSP Stream");
|
||||
assert_eq!(session.medias.len(), 1);
|
||||
|
||||
let media = &session.medias[0];
|
||||
assert_eq!(media.media, "video");
|
||||
assert_eq!(media.proto, "RTP/AVP");
|
||||
assert_eq!(media.fmt, "96");
|
||||
|
||||
let has_rtpmap = media.attributes.iter().any(|attr| {
|
||||
attr.attribute == "rtpmap" && attr.value.as_deref() == Some("96 H264/90000")
|
||||
});
|
||||
assert!(has_rtpmap);
|
||||
|
||||
let fmtp_value = media
|
||||
.attributes
|
||||
.iter()
|
||||
.find(|attr| attr.attribute == "fmtp")
|
||||
.and_then(|attr| attr.value.as_deref())
|
||||
.expect("missing fmtp value");
|
||||
assert!(fmtp_value.starts_with("96 "));
|
||||
assert!(fmtp_value.contains("packetization-mode=1"));
|
||||
assert!(fmtp_value.contains("sprop-parameter-sets="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_sdp_h265_is_parseable_with_expected_video_attributes() {
|
||||
let config = RtspConfig::default();
|
||||
let mut params = ParameterSets::default();
|
||||
params.h265_vps = Some(Bytes::from_static(&[0x40, 0x01, 0x0c, 0x01]));
|
||||
params.h265_sps = Some(Bytes::from_static(&[0x42, 0x01, 0x01, 0x60]));
|
||||
params.h265_pps = Some(Bytes::from_static(&[0x44, 0x01, 0xc0, 0x73]));
|
||||
|
||||
let sdp_text = build_sdp(&config, VideoCodecType::H265, ¶ms);
|
||||
assert!(!sdp_text.is_empty());
|
||||
|
||||
let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed");
|
||||
assert_eq!(session.medias.len(), 1);
|
||||
|
||||
let media = &session.medias[0];
|
||||
assert_eq!(media.media, "video");
|
||||
assert_eq!(media.proto, "RTP/AVP");
|
||||
assert_eq!(media.fmt, "99");
|
||||
|
||||
let has_rtpmap = media.attributes.iter().any(|attr| {
|
||||
attr.attribute == "rtpmap" && attr.value.as_deref() == Some("99 H265/90000")
|
||||
});
|
||||
assert!(has_rtpmap);
|
||||
|
||||
let fmtp_value = media
|
||||
.attributes
|
||||
.iter()
|
||||
.find(|attr| attr.attribute == "fmtp")
|
||||
.and_then(|attr| attr.value.as_deref())
|
||||
.expect("missing fmtp value");
|
||||
assert!(fmtp_value.starts_with("99 "));
|
||||
assert!(fmtp_value.contains("sprop-vps="));
|
||||
assert!(fmtp_value.contains("sprop-sps="));
|
||||
assert!(fmtp_value.contains("sprop-pps="));
|
||||
}
|
||||
}
|
||||
1034
src/rtsp/service.rs
1034
src/rtsp/service.rs
File diff suppressed because it is too large
Load Diff
28
src/rtsp/state.rs
Normal file
28
src/rtsp/state.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use bytes::Bytes;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub(crate) struct ParameterSets {
|
||||
pub h264_sps: Option<Bytes>,
|
||||
pub h264_pps: Option<Bytes>,
|
||||
pub h265_vps: Option<Bytes>,
|
||||
pub h265_sps: Option<Bytes>,
|
||||
pub h265_pps: Option<Bytes>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SharedRtspState {
|
||||
pub active_client: Arc<Mutex<Option<SocketAddr>>>,
|
||||
pub parameter_sets: Arc<RwLock<ParameterSets>>,
|
||||
}
|
||||
|
||||
impl SharedRtspState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_client: Arc::new(Mutex::new(None)),
|
||||
parameter_sets: Arc::new(RwLock::new(ParameterSets::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
367
src/rtsp/streaming.rs
Normal file
367
src/rtsp/streaming.rs
Normal file
@@ -0,0 +1,367 @@
|
||||
use bytes::Bytes;
|
||||
use rand::Rng;
|
||||
use rtp::packet::Packet;
|
||||
use rtp::packetizer::Payloader;
|
||||
use rtsp_types as rtsp;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use webrtc::util::{Marshal, MarshalSize};
|
||||
|
||||
use crate::config::RtspCodec;
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::video::encoder::registry::VideoEncoderType;
|
||||
use crate::video::shared_video_pipeline::EncodedVideoFrame;
|
||||
use crate::video::VideoStreamManager;
|
||||
use crate::webrtc::h265_payloader::H265Payloader;
|
||||
|
||||
use super::bitstream::update_parameter_sets;
|
||||
use super::protocol::{
|
||||
parse_rtsp_request, strip_interleaved_frames_prefix, take_rtsp_request_from_buffer,
|
||||
};
|
||||
use super::response::send_response;
|
||||
use super::state::SharedRtspState;
|
||||
use super::types::RtspRequest;
|
||||
|
||||
pub(crate) const RTP_CLOCK_RATE: u32 = 90_000;
|
||||
pub(crate) const RTP_MTU: usize = 1200;
|
||||
pub(crate) const RTSP_BUF_SIZE: usize = 8192;
|
||||
const RTSP_RESUBSCRIBE_DELAY_MS: u64 = 300;
|
||||
|
||||
pub(crate) async fn stream_video_interleaved(
|
||||
stream: TcpStream,
|
||||
video_manager: &Arc<VideoStreamManager>,
|
||||
rtsp_codec: RtspCodec,
|
||||
channel: u8,
|
||||
shared: SharedRtspState,
|
||||
session_id: String,
|
||||
) -> Result<()> {
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
|
||||
let mut rx = video_manager
|
||||
.subscribe_encoded_frames()
|
||||
.await
|
||||
.ok_or_else(|| {
|
||||
AppError::VideoError("RTSP failed to subscribe encoded frames".to_string())
|
||||
})?;
|
||||
|
||||
video_manager.request_keyframe().await.ok();
|
||||
|
||||
let payload_type = match rtsp_codec {
|
||||
RtspCodec::H264 => 96,
|
||||
RtspCodec::H265 => 99,
|
||||
};
|
||||
let mut sequence_number: u16 = rand::rng().random();
|
||||
let ssrc: u32 = rand::rng().random();
|
||||
|
||||
let mut h264_payloader = rtp::codecs::h264::H264Payloader::default();
|
||||
let mut h265_payloader = H265Payloader::new();
|
||||
let mut ctrl_read_buf = [0u8; RTSP_BUF_SIZE];
|
||||
let mut ctrl_buffer = Vec::with_capacity(RTSP_BUF_SIZE);
|
||||
// 4-byte interleaved prefix + RTP header + payload shard (≤ RTP_MTU)
|
||||
let mut interleaved_rtp_buf = Vec::with_capacity(4 + RTP_MTU + 96);
|
||||
let mut last_rtp_timestamp: u32 = 0;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_frame = rx.recv() => {
|
||||
let Some(frame) = maybe_frame else {
|
||||
tracing::warn!("RTSP encoded frame subscription ended, attempting to restart pipeline");
|
||||
|
||||
if let Some(new_rx) = video_manager.subscribe_encoded_frames().await {
|
||||
rx = new_rx;
|
||||
let _ = video_manager.request_keyframe().await;
|
||||
tracing::info!("RTSP frame subscription recovered");
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"RTSP failed to resubscribe encoded frames, retrying in {}ms",
|
||||
RTSP_RESUBSCRIBE_DELAY_MS
|
||||
);
|
||||
sleep(Duration::from_millis(RTSP_RESUBSCRIBE_DELAY_MS)).await;
|
||||
}
|
||||
|
||||
continue;
|
||||
};
|
||||
|
||||
if !is_frame_codec_match(&frame, &rtsp_codec) {
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
let mut params = shared.parameter_sets.write().await;
|
||||
update_parameter_sets(&mut params, &frame);
|
||||
}
|
||||
|
||||
let rtp_timestamp = monotonic_rtp_timestamp(
|
||||
frame.pts_ms,
|
||||
&mut last_rtp_timestamp,
|
||||
frame.duration,
|
||||
);
|
||||
|
||||
let payloads: Vec<Bytes> = match rtsp_codec {
|
||||
RtspCodec::H264 => h264_payloader
|
||||
.payload(RTP_MTU, &frame.data)
|
||||
.map_err(|e| AppError::VideoError(format!("H264 payload failed: {}", e)))?,
|
||||
RtspCodec::H265 => h265_payloader.payload(RTP_MTU, &frame.data),
|
||||
};
|
||||
|
||||
if payloads.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let total_payloads = payloads.len();
|
||||
for (idx, payload) in payloads.into_iter().enumerate() {
|
||||
let marker = idx == total_payloads.saturating_sub(1);
|
||||
let packet = Packet {
|
||||
header: rtp::header::Header {
|
||||
version: 2,
|
||||
padding: false,
|
||||
extension: false,
|
||||
marker,
|
||||
payload_type,
|
||||
sequence_number,
|
||||
timestamp: rtp_timestamp,
|
||||
ssrc,
|
||||
..Default::default()
|
||||
},
|
||||
payload,
|
||||
};
|
||||
|
||||
sequence_number = sequence_number.wrapping_add(1);
|
||||
send_interleaved_rtp(&mut writer, channel, &packet, &mut interleaved_rtp_buf)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if frame.is_keyframe {
|
||||
tracing::debug!("RTSP keyframe sent");
|
||||
}
|
||||
}
|
||||
read_res = reader.read(&mut ctrl_read_buf) => {
|
||||
let n = read_res?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
ctrl_buffer.extend_from_slice(&ctrl_read_buf[..n]);
|
||||
|
||||
while strip_interleaved_frames_prefix(&mut ctrl_buffer) {}
|
||||
|
||||
while let Some(raw_req) = take_rtsp_request_from_buffer(&mut ctrl_buffer) {
|
||||
let Some(req) = parse_rtsp_request(&raw_req) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if handle_play_control_request(&mut writer, &req, &session_id).await? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
while strip_interleaved_frames_prefix(&mut ctrl_buffer) {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn send_interleaved_rtp<W: AsyncWrite + Unpin>(
|
||||
stream: &mut W,
|
||||
channel: u8,
|
||||
packet: &Packet,
|
||||
marshal_buf: &mut Vec<u8>,
|
||||
) -> Result<()> {
|
||||
let rtp_len = packet.marshal_size();
|
||||
let rtp_len_u16 = u16::try_from(rtp_len).map_err(|_| {
|
||||
AppError::VideoError(format!(
|
||||
"RTP packet too large for interleaved framing: {} bytes",
|
||||
rtp_len
|
||||
))
|
||||
})?;
|
||||
|
||||
marshal_buf.clear();
|
||||
marshal_buf.reserve(4 + rtp_len);
|
||||
marshal_buf.extend_from_slice(&[b'$', channel, (rtp_len_u16 >> 8) as u8, rtp_len_u16 as u8]);
|
||||
let body_off = marshal_buf.len();
|
||||
marshal_buf.resize(body_off + rtp_len, 0);
|
||||
|
||||
let written = packet
|
||||
.marshal_to(&mut marshal_buf[body_off..])
|
||||
.map_err(|e| AppError::VideoError(format!("RTP marshal failed: {}", e)))?;
|
||||
if written != rtp_len {
|
||||
return Err(AppError::VideoError(format!(
|
||||
"RTP marshal size mismatch: wrote {written}, expected {rtp_len}"
|
||||
)));
|
||||
}
|
||||
|
||||
stream.write_all(marshal_buf).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_play_control_request<W: AsyncWrite + Unpin>(
|
||||
stream: &mut W,
|
||||
req: &RtspRequest,
|
||||
session_id: &str,
|
||||
) -> Result<bool> {
|
||||
use super::protocol::OPTIONS_PUBLIC_CAPABILITIES;
|
||||
|
||||
match &req.method {
|
||||
rtsp::Method::Teardown => {
|
||||
send_response(stream, req, 200, "OK", vec![], "", session_id).await?;
|
||||
Ok(true)
|
||||
}
|
||||
rtsp::Method::Options => {
|
||||
send_response(
|
||||
stream,
|
||||
req,
|
||||
200,
|
||||
"OK",
|
||||
vec![(
|
||||
"Public".to_string(),
|
||||
OPTIONS_PUBLIC_CAPABILITIES.to_string(),
|
||||
)],
|
||||
"",
|
||||
session_id,
|
||||
)
|
||||
.await?;
|
||||
Ok(false)
|
||||
}
|
||||
rtsp::Method::GetParameter | rtsp::Method::SetParameter => {
|
||||
send_response(stream, req, 200, "OK", vec![], "", session_id).await?;
|
||||
Ok(false)
|
||||
}
|
||||
_ => {
|
||||
send_response(
|
||||
stream,
|
||||
req,
|
||||
405,
|
||||
"Method Not Allowed",
|
||||
vec![],
|
||||
"",
|
||||
session_id,
|
||||
)
|
||||
.await?;
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pts_to_rtp_timestamp(pts_ms: i64) -> u32 {
|
||||
if pts_ms <= 0 {
|
||||
return 0;
|
||||
}
|
||||
((pts_ms as u64 * RTP_CLOCK_RATE as u64) / 1000) as u32
|
||||
}
|
||||
|
||||
fn rtp_timestamp_increment(frame_duration: Duration) -> u32 {
|
||||
let inc = (frame_duration.as_secs_f64() * f64::from(RTP_CLOCK_RATE)).round() as u32;
|
||||
inc.max(1)
|
||||
}
|
||||
|
||||
fn monotonic_rtp_timestamp(pts_ms: i64, last: &mut u32, frame_duration: Duration) -> u32 {
|
||||
let from_pts = pts_to_rtp_timestamp(pts_ms);
|
||||
let inc = rtp_timestamp_increment(frame_duration);
|
||||
let ts = if from_pts > *last {
|
||||
from_pts
|
||||
} else {
|
||||
last.wrapping_add(inc)
|
||||
};
|
||||
*last = ts;
|
||||
ts
|
||||
}
|
||||
|
||||
fn is_frame_codec_match(frame: &EncodedVideoFrame, codec: &RtspCodec) -> bool {
|
||||
matches!(
|
||||
(frame.codec, codec),
|
||||
(VideoEncoderType::H264, RtspCodec::H264) | (VideoEncoderType::H265, RtspCodec::H265)
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use tokio::io::{duplex, AsyncReadExt};
|
||||
|
||||
fn make_test_request(method: rtsp::Method) -> RtspRequest {
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("cseq".to_string(), "7".to_string());
|
||||
RtspRequest {
|
||||
method,
|
||||
uri: "rtsp://127.0.0.1/live".to_string(),
|
||||
version: rtsp::Version::V1_0,
|
||||
headers,
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_response_from_duplex(
|
||||
mut client: tokio::io::DuplexStream,
|
||||
) -> rtsp::Response<Vec<u8>> {
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = client
|
||||
.read(&mut buf)
|
||||
.await
|
||||
.expect("failed to read rtsp response");
|
||||
assert!(n > 0);
|
||||
let (message, consumed): (rtsp::Message<Vec<u8>>, usize) =
|
||||
rtsp::Message::parse(&buf[..n]).expect("failed to parse rtsp response");
|
||||
assert_eq!(consumed, n);
|
||||
|
||||
match message {
|
||||
rtsp::Message::Response(response) => response,
|
||||
_ => panic!("expected RTSP response"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn play_control_teardown_returns_ok_and_stop() {
|
||||
let req = make_test_request(rtsp::Method::Teardown);
|
||||
let (client, mut server) = duplex(4096);
|
||||
|
||||
let should_stop = handle_play_control_request(&mut server, &req, "session-1")
|
||||
.await
|
||||
.expect("control handling failed");
|
||||
assert!(should_stop);
|
||||
|
||||
drop(server);
|
||||
let response = read_response_from_duplex(client).await;
|
||||
assert_eq!(response.status(), rtsp::StatusCode::Ok);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn play_control_pause_returns_method_not_allowed() {
|
||||
let req = make_test_request(rtsp::Method::Pause);
|
||||
let (client, mut server) = duplex(4096);
|
||||
|
||||
let should_stop = handle_play_control_request(&mut server, &req, "session-1")
|
||||
.await
|
||||
.expect("control handling failed");
|
||||
assert!(!should_stop);
|
||||
|
||||
drop(server);
|
||||
let response = read_response_from_duplex(client).await;
|
||||
assert_eq!(response.status(), rtsp::StatusCode::MethodNotAllowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn monotonic_rtp_timestamp_steps_when_pts_stays_zero() {
|
||||
let d = Duration::from_millis(33);
|
||||
let mut last = 0u32;
|
||||
let a = monotonic_rtp_timestamp(0, &mut last, d);
|
||||
let b = monotonic_rtp_timestamp(0, &mut last, d);
|
||||
let c = monotonic_rtp_timestamp(0, &mut last, d);
|
||||
assert!(a > 0);
|
||||
assert!(b > a);
|
||||
assert!(c > b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn monotonic_rtp_timestamp_uses_pts_when_it_advances() {
|
||||
let d = Duration::from_millis(33);
|
||||
let mut last = 0u32;
|
||||
let a = monotonic_rtp_timestamp(1000, &mut last, d);
|
||||
assert_eq!(a, 90_000);
|
||||
let b = monotonic_rtp_timestamp(2000, &mut last, d);
|
||||
assert_eq!(b, 180_000);
|
||||
}
|
||||
}
|
||||
53
src/rtsp/types.rs
Normal file
53
src/rtsp/types.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use rand::Rng;
|
||||
use rtsp_types as rtsp;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum RtspServiceStatus {
|
||||
Stopped,
|
||||
Starting,
|
||||
Running,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for RtspServiceStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Stopped => write!(f, "stopped"),
|
||||
Self::Starting => write!(f, "starting"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Error(err) => write!(f, "error: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RtspRequest {
|
||||
pub method: rtsp::Method,
|
||||
pub uri: String,
|
||||
pub version: rtsp::Version,
|
||||
pub headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub(crate) struct RtspConnectionState {
|
||||
pub session_id: String,
|
||||
pub setup_done: bool,
|
||||
pub interleaved_channel: u8,
|
||||
}
|
||||
|
||||
impl RtspConnectionState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
session_id: generate_session_id(),
|
||||
setup_done: false,
|
||||
interleaved_channel: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn generate_session_id() -> String {
|
||||
let mut rng = rand::rng();
|
||||
let value: u64 = rng.random();
|
||||
format!("{:016x}", value)
|
||||
}
|
||||
@@ -1,21 +1,11 @@
|
||||
//! RustDesk BytesCodec - Variable-length framing for TCP messages
|
||||
//!
|
||||
//! RustDesk uses a custom variable-length encoding for message framing:
|
||||
//! - Length <= 0x3F (63): 1-byte header, format `(len << 2)`
|
||||
//! - Length <= 0x3FFF (16383): 2-byte LE header, format `(len << 2) | 0x1`
|
||||
//! - Length <= 0x3FFFFF (4194303): 3-byte LE header, format `(len << 2) | 0x2`
|
||||
//! - Length <= 0x3FFFFFFF (1073741823): 4-byte LE header, format `(len << 2) | 0x3`
|
||||
//!
|
||||
//! The low 2 bits of the first byte indicate the header length (+1).
|
||||
//! Variable-length TCP framing (RustDesk wire format).
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
/// Maximum packet length (1GB)
|
||||
const MAX_PACKET_LENGTH: usize = 0x3FFFFFFF;
|
||||
|
||||
/// Encode a message with RustDesk's variable-length framing
|
||||
pub fn encode_frame(data: &[u8]) -> io::Result<Vec<u8>> {
|
||||
let len = data.len();
|
||||
let mut buf = Vec::with_capacity(len + 4);
|
||||
@@ -44,8 +34,6 @@ pub fn encode_frame(data: &[u8]) -> io::Result<Vec<u8>> {
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Decode the header to get message length
|
||||
/// Returns (header_length, message_length)
|
||||
fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) {
|
||||
let head_len = ((first_byte & 0x3) + 1) as usize;
|
||||
|
||||
@@ -64,21 +52,17 @@ fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) {
|
||||
(head_len, msg_len)
|
||||
}
|
||||
|
||||
/// Read a single framed message from an async reader
|
||||
pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<BytesMut> {
|
||||
// Read first byte to determine header length
|
||||
let mut first_byte = [0u8; 1];
|
||||
reader.read_exact(&mut first_byte).await?;
|
||||
|
||||
let head_len = ((first_byte[0] & 0x3) + 1) as usize;
|
||||
|
||||
// Read remaining header bytes if needed
|
||||
let mut header_rest = [0u8; 3];
|
||||
if head_len > 1 {
|
||||
reader.read_exact(&mut header_rest[..head_len - 1]).await?;
|
||||
}
|
||||
|
||||
// Calculate message length
|
||||
let (_, msg_len) = decode_header(first_byte[0], &header_rest);
|
||||
|
||||
if msg_len > MAX_PACKET_LENGTH {
|
||||
@@ -88,7 +72,6 @@ pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Byte
|
||||
));
|
||||
}
|
||||
|
||||
// Read message body
|
||||
let mut buf = BytesMut::with_capacity(msg_len);
|
||||
buf.resize(msg_len, 0);
|
||||
reader.read_exact(&mut buf).await?;
|
||||
@@ -96,7 +79,6 @@ pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Byte
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Write a framed message to an async writer
|
||||
pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, data: &[u8]) -> io::Result<()> {
|
||||
let frame = encode_frame(data)?;
|
||||
writer.write_all(&frame).await?;
|
||||
@@ -104,10 +86,6 @@ pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, data: &[u8]) ->
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write a framed message using a reusable buffer (reduces allocations)
|
||||
///
|
||||
/// This version reuses the provided BytesMut buffer to avoid allocation on each call.
|
||||
/// The buffer is cleared before use and will grow as needed.
|
||||
pub async fn write_frame_buffered<W: AsyncWrite + Unpin>(
|
||||
writer: &mut W,
|
||||
data: &[u8],
|
||||
@@ -120,11 +98,9 @@ pub async fn write_frame_buffered<W: AsyncWrite + Unpin>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encode a message with RustDesk's variable-length framing into an existing buffer
|
||||
pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
let len = data.len();
|
||||
|
||||
// Reserve space for header (max 4 bytes) + data
|
||||
buf.reserve(4 + len);
|
||||
|
||||
if len <= 0x3F {
|
||||
@@ -149,7 +125,7 @@ pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// BytesCodec for stateful decoding (compatible with tokio-util codec)
|
||||
/// Stateful decoder for `Framed`.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BytesCodec {
|
||||
state: DecodeState,
|
||||
@@ -180,7 +156,6 @@ impl BytesCodec {
|
||||
self.max_packet_length = n;
|
||||
}
|
||||
|
||||
/// Decode from a BytesMut buffer (for use with Framed)
|
||||
pub fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
|
||||
let n = match self.state {
|
||||
DecodeState::Head => match self.decode_head(src)? {
|
||||
@@ -242,7 +217,6 @@ impl BytesCodec {
|
||||
Ok(Some(src.split_to(n)))
|
||||
}
|
||||
|
||||
/// Encode a message into a BytesMut buffer
|
||||
pub fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> io::Result<()> {
|
||||
let len = data.len();
|
||||
|
||||
@@ -276,7 +250,7 @@ mod tests {
|
||||
fn test_encode_decode_small() {
|
||||
let data = vec![1u8; 63];
|
||||
let encoded = encode_frame(&data).unwrap();
|
||||
assert_eq!(encoded.len(), 63 + 1); // 1 byte header
|
||||
assert_eq!(encoded.len(), 63 + 1);
|
||||
|
||||
let mut codec = BytesCodec::new();
|
||||
let mut buf = BytesMut::from(&encoded[..]);
|
||||
@@ -288,7 +262,7 @@ mod tests {
|
||||
fn test_encode_decode_medium() {
|
||||
let data = vec![2u8; 1000];
|
||||
let encoded = encode_frame(&data).unwrap();
|
||||
assert_eq!(encoded.len(), 1000 + 2); // 2 byte header
|
||||
assert_eq!(encoded.len(), 1000 + 2);
|
||||
|
||||
let mut codec = BytesCodec::new();
|
||||
let mut buf = BytesMut::from(&encoded[..]);
|
||||
@@ -300,7 +274,7 @@ mod tests {
|
||||
fn test_encode_decode_large() {
|
||||
let data = vec![3u8; 100000];
|
||||
let encoded = encode_frame(&data).unwrap();
|
||||
assert_eq!(encoded.len(), 100000 + 3); // 3 byte header
|
||||
assert_eq!(encoded.len(), 100000 + 3);
|
||||
|
||||
let mut codec = BytesCodec::new();
|
||||
let mut buf = BytesMut::from(&encoded[..]);
|
||||
|
||||
@@ -1,57 +1,26 @@
|
||||
//! RustDesk Configuration
|
||||
//!
|
||||
//! Configuration types for the RustDesk protocol integration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use typeshare::typeshare;
|
||||
|
||||
/// RustDesk configuration
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct RustDeskConfig {
|
||||
/// Enable RustDesk protocol
|
||||
pub enabled: bool,
|
||||
|
||||
/// Rendezvous server address (hbbs), e.g., "rs.example.com" or "192.168.1.100:21116"
|
||||
/// Required for RustDesk to function
|
||||
pub rendezvous_server: String,
|
||||
|
||||
/// Relay server address (hbbr), if different from rendezvous server
|
||||
/// Usually the same host as rendezvous server but different port (21117)
|
||||
pub relay_server: Option<String>,
|
||||
|
||||
/// Relay server authentication key (licence_key)
|
||||
/// Required if the relay server is configured with -k option
|
||||
#[typeshare(skip)]
|
||||
pub relay_key: Option<String>,
|
||||
|
||||
/// Device ID (9-digit number), auto-generated if empty
|
||||
pub device_id: String,
|
||||
|
||||
/// Device password for client authentication
|
||||
#[typeshare(skip)]
|
||||
pub device_password: String,
|
||||
|
||||
/// Public key for encryption (Curve25519, base64 encoded), auto-generated
|
||||
#[typeshare(skip)]
|
||||
pub public_key: Option<String>,
|
||||
|
||||
/// Private key for encryption (Curve25519, base64 encoded), auto-generated
|
||||
#[typeshare(skip)]
|
||||
pub private_key: Option<String>,
|
||||
|
||||
/// Signing public key (Ed25519, base64 encoded), auto-generated
|
||||
/// Used for SignedId verification by clients
|
||||
#[typeshare(skip)]
|
||||
pub signing_public_key: Option<String>,
|
||||
|
||||
/// Signing private key (Ed25519, base64 encoded), auto-generated
|
||||
/// Used for signing SignedId messages
|
||||
#[typeshare(skip)]
|
||||
pub signing_private_key: Option<String>,
|
||||
|
||||
/// UUID for rendezvous server registration (persisted to avoid UUID_MISMATCH)
|
||||
#[typeshare(skip)]
|
||||
pub uuid: Option<String>,
|
||||
}
|
||||
@@ -75,8 +44,6 @@ impl Default for RustDeskConfig {
|
||||
}
|
||||
|
||||
impl RustDeskConfig {
|
||||
/// Check if the configuration is valid for starting the service
|
||||
/// Returns true if enabled and has a valid server
|
||||
pub fn is_valid(&self) -> bool {
|
||||
self.enabled
|
||||
&& !self.rendezvous_server.is_empty()
|
||||
@@ -84,44 +51,35 @@ impl RustDeskConfig {
|
||||
&& !self.device_password.is_empty()
|
||||
}
|
||||
|
||||
/// Get the rendezvous server (user-configured)
|
||||
pub fn effective_rendezvous_server(&self) -> &str {
|
||||
&self.rendezvous_server
|
||||
}
|
||||
|
||||
/// Generate a new random device ID
|
||||
pub fn generate_device_id() -> String {
|
||||
generate_device_id()
|
||||
}
|
||||
|
||||
/// Generate a new random password
|
||||
pub fn generate_password() -> String {
|
||||
generate_random_password()
|
||||
}
|
||||
|
||||
/// Get or generate the UUID for rendezvous registration
|
||||
/// Returns (uuid_bytes, is_new) where is_new indicates if a new UUID was generated
|
||||
pub fn ensure_uuid(&mut self) -> ([u8; 16], bool) {
|
||||
if let Some(ref uuid_str) = self.uuid {
|
||||
// Try to parse existing UUID
|
||||
if let Ok(uuid) = uuid::Uuid::parse_str(uuid_str) {
|
||||
return (*uuid.as_bytes(), false);
|
||||
}
|
||||
}
|
||||
// Generate new UUID
|
||||
let new_uuid = uuid::Uuid::new_v4();
|
||||
self.uuid = Some(new_uuid.to_string());
|
||||
(*new_uuid.as_bytes(), true)
|
||||
}
|
||||
|
||||
/// Get the UUID bytes (returns None if not set)
|
||||
pub fn get_uuid_bytes(&self) -> Option<[u8; 16]> {
|
||||
self.uuid
|
||||
.as_ref()
|
||||
.and_then(|s| uuid::Uuid::parse_str(s).ok().map(|u| *u.as_bytes()))
|
||||
}
|
||||
|
||||
/// Get the rendezvous server address with default port
|
||||
pub fn rendezvous_addr(&self) -> String {
|
||||
let server = &self.rendezvous_server;
|
||||
if server.is_empty() {
|
||||
@@ -133,7 +91,6 @@ impl RustDeskConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the relay server address with default port
|
||||
pub fn relay_addr(&self) -> Option<String> {
|
||||
self.relay_server
|
||||
.as_ref()
|
||||
@@ -145,7 +102,6 @@ impl RustDeskConfig {
|
||||
}
|
||||
})
|
||||
.or_else(|| {
|
||||
// Default: same host as rendezvous server
|
||||
let server = &self.rendezvous_server;
|
||||
if !server.is_empty() {
|
||||
let host = server.split(':').next().unwrap_or("");
|
||||
@@ -161,7 +117,6 @@ impl RustDeskConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random 9-digit device ID
|
||||
pub fn generate_device_id() -> String {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::rng();
|
||||
@@ -169,7 +124,6 @@ pub fn generate_device_id() -> String {
|
||||
id.to_string()
|
||||
}
|
||||
|
||||
/// Generate a random 8-character password
|
||||
pub fn generate_random_password() -> String {
|
||||
use rand::Rng;
|
||||
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
|
||||
@@ -212,7 +166,6 @@ mod tests {
|
||||
config.rendezvous_server = "example.com:21116".to_string();
|
||||
assert_eq!(config.rendezvous_addr(), "example.com:21116");
|
||||
|
||||
// Empty server returns empty string
|
||||
config.rendezvous_server = String::new();
|
||||
assert_eq!(config.rendezvous_addr(), "");
|
||||
}
|
||||
@@ -224,17 +177,14 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Rendezvous server configured, relay defaults to same host
|
||||
assert_eq!(config.relay_addr(), Some("example.com:21117".to_string()));
|
||||
|
||||
// Explicit relay server
|
||||
config.relay_server = Some("relay.example.com".to_string());
|
||||
assert_eq!(
|
||||
config.relay_addr(),
|
||||
Some("relay.example.com:21117".to_string())
|
||||
);
|
||||
|
||||
// No rendezvous server, relay is None
|
||||
config.rendezvous_server = String::new();
|
||||
config.relay_server = None;
|
||||
assert_eq!(config.relay_addr(), None);
|
||||
@@ -247,10 +197,8 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// When user sets a server, use it
|
||||
assert_eq!(config.effective_rendezvous_server(), "custom.example.com");
|
||||
|
||||
// When empty, returns empty
|
||||
config.rendezvous_server = String::new();
|
||||
assert_eq!(config.effective_rendezvous_server(), "");
|
||||
}
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
//! RustDesk Connection Handler
|
||||
//!
|
||||
//! This module handles incoming connections from RustDesk clients.
|
||||
//! It manages the connection lifecycle including:
|
||||
//! - Connection establishment (P2P or via relay)
|
||||
//! - Encrypted handshake
|
||||
//! - Authentication
|
||||
//! - Message routing (video, audio, input)
|
||||
//! - Video frame streaming (shared with WebRTC)
|
||||
//! Incoming RustDesk TCP sessions (handshake, AV, input).
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -23,6 +15,7 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::audio::AudioController;
|
||||
use crate::hid::{CanonicalKey, HidController, KeyEventType, KeyboardEvent, KeyboardModifiers};
|
||||
use crate::utils::hostname_from_etc;
|
||||
use crate::video::codec_constraints::{
|
||||
encoder_codec_to_id, encoder_codec_to_video_codec, video_codec_to_encoder_codec,
|
||||
};
|
||||
@@ -94,13 +87,6 @@ impl InputThrottler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get system hostname
|
||||
fn get_hostname() -> String {
|
||||
std::fs::read_to_string("/etc/hostname")
|
||||
.map(|s| s.trim().to_string())
|
||||
.unwrap_or_else(|_| "One-KVM".to_string())
|
||||
}
|
||||
|
||||
/// Connection state
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ConnectionState {
|
||||
@@ -1165,7 +1151,7 @@ impl Connection {
|
||||
|
||||
let mut peer_info = PeerInfo::new();
|
||||
peer_info.username = "one-kvm".to_string();
|
||||
peer_info.hostname = get_hostname();
|
||||
peer_info.hostname = hostname_from_etc();
|
||||
peer_info.platform = RUSTDESK_COMPAT_PLATFORM.to_string();
|
||||
peer_info.displays.push(display_info);
|
||||
peer_info.current_display = 0;
|
||||
@@ -1786,7 +1772,7 @@ async fn run_audio_streaming(
|
||||
}
|
||||
|
||||
// Subscribe to the audio Opus stream
|
||||
let mut opus_rx = match audio_controller.subscribe_opus_async().await {
|
||||
let mut opus_rx = match audio_controller.subscribe_opus().await {
|
||||
Some(rx) => rx,
|
||||
None => {
|
||||
// Audio not available, wait and retry
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
//! RustDesk Cryptography
|
||||
//!
|
||||
//! This module implements the NaCl-based cryptography used by RustDesk:
|
||||
//! - Curve25519 for key exchange
|
||||
//! - XSalsa20-Poly1305 for authenticated encryption
|
||||
//! - Ed25519 for signatures
|
||||
//! - Ed25519 to Curve25519 key conversion for unified keypair usage
|
||||
//! NaCl crypto (RustDesk-compatible).
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
|
||||
use sodiumoxide::crypto::box_::{self, Nonce, PublicKey, SecretKey};
|
||||
@@ -12,7 +6,6 @@ use sodiumoxide::crypto::secretbox;
|
||||
use sodiumoxide::crypto::sign::{self, ed25519};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Cryptography errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CryptoError {
|
||||
#[error("Failed to initialize sodiumoxide")]
|
||||
@@ -31,13 +24,10 @@ pub enum CryptoError {
|
||||
KeyConversionFailed,
|
||||
}
|
||||
|
||||
/// Initialize the cryptography library
|
||||
/// Must be called before using any crypto functions
|
||||
pub fn init() -> Result<(), CryptoError> {
|
||||
sodiumoxide::init().map_err(|_| CryptoError::InitError)
|
||||
}
|
||||
|
||||
/// A keypair for asymmetric encryption
|
||||
#[derive(Clone)]
|
||||
pub struct KeyPair {
|
||||
pub public_key: PublicKey,
|
||||
@@ -45,7 +35,6 @@ pub struct KeyPair {
|
||||
}
|
||||
|
||||
impl KeyPair {
|
||||
/// Generate a new random keypair
|
||||
pub fn generate() -> Self {
|
||||
let (public_key, secret_key) = box_::gen_keypair();
|
||||
Self {
|
||||
@@ -54,7 +43,6 @@ impl KeyPair {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from existing keys
|
||||
pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result<Self, CryptoError> {
|
||||
let pk = PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?;
|
||||
let sk = SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?;
|
||||
@@ -64,27 +52,22 @@ impl KeyPair {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get public key as bytes
|
||||
pub fn public_key_bytes(&self) -> &[u8] {
|
||||
self.public_key.as_ref()
|
||||
}
|
||||
|
||||
/// Get secret key as bytes
|
||||
pub fn secret_key_bytes(&self) -> &[u8] {
|
||||
self.secret_key.as_ref()
|
||||
}
|
||||
|
||||
/// Encode public key as base64
|
||||
pub fn public_key_base64(&self) -> String {
|
||||
BASE64.encode(self.public_key_bytes())
|
||||
}
|
||||
|
||||
/// Encode secret key as base64
|
||||
pub fn secret_key_base64(&self) -> String {
|
||||
BASE64.encode(self.secret_key_bytes())
|
||||
}
|
||||
|
||||
/// Create from base64-encoded keys
|
||||
pub fn from_base64(public_key: &str, secret_key: &str) -> Result<Self, CryptoError> {
|
||||
let pk_bytes = BASE64
|
||||
.decode(public_key)
|
||||
@@ -96,15 +79,10 @@ impl KeyPair {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random nonce for box encryption
|
||||
pub fn generate_nonce() -> Nonce {
|
||||
box_::gen_nonce()
|
||||
}
|
||||
|
||||
/// Encrypt data using public-key cryptography (NaCl box)
|
||||
///
|
||||
/// Uses the sender's secret key and receiver's public key for encryption.
|
||||
/// Returns (nonce, ciphertext).
|
||||
pub fn encrypt_box(
|
||||
data: &[u8],
|
||||
their_public_key: &PublicKey,
|
||||
@@ -115,7 +93,6 @@ pub fn encrypt_box(
|
||||
(nonce, ciphertext)
|
||||
}
|
||||
|
||||
/// Decrypt data using public-key cryptography (NaCl box)
|
||||
pub fn decrypt_box(
|
||||
ciphertext: &[u8],
|
||||
nonce: &Nonce,
|
||||
@@ -126,14 +103,12 @@ pub fn decrypt_box(
|
||||
.map_err(|_| CryptoError::DecryptionFailed)
|
||||
}
|
||||
|
||||
/// Encrypt data with a precomputed shared key
|
||||
pub fn encrypt_with_key(data: &[u8], key: &secretbox::Key) -> (secretbox::Nonce, Vec<u8>) {
|
||||
let nonce = secretbox::gen_nonce();
|
||||
let ciphertext = secretbox::seal(data, &nonce, key);
|
||||
(nonce, ciphertext)
|
||||
}
|
||||
|
||||
/// Decrypt data with a precomputed shared key
|
||||
pub fn decrypt_with_key(
|
||||
ciphertext: &[u8],
|
||||
nonce: &secretbox::Nonce,
|
||||
@@ -142,8 +117,6 @@ pub fn decrypt_with_key(
|
||||
secretbox::open(ciphertext, nonce, key).map_err(|_| CryptoError::DecryptionFailed)
|
||||
}
|
||||
|
||||
/// Compute a shared symmetric key from public/private keypair
|
||||
/// This is the precomputed key for the NaCl box
|
||||
pub fn precompute_key(
|
||||
their_public_key: &PublicKey,
|
||||
our_secret_key: &SecretKey,
|
||||
@@ -151,23 +124,18 @@ pub fn precompute_key(
|
||||
box_::precompute(their_public_key, our_secret_key)
|
||||
}
|
||||
|
||||
/// Create a symmetric key from raw bytes
|
||||
pub fn symmetric_key_from_slice(key: &[u8]) -> Result<secretbox::Key, CryptoError> {
|
||||
secretbox::Key::from_slice(key).ok_or(CryptoError::InvalidKeyLength)
|
||||
}
|
||||
|
||||
/// Parse a nonce from bytes
|
||||
pub fn nonce_from_slice(bytes: &[u8]) -> Result<Nonce, CryptoError> {
|
||||
Nonce::from_slice(bytes).ok_or(CryptoError::InvalidNonce)
|
||||
}
|
||||
|
||||
/// Parse a public key from bytes
|
||||
pub fn public_key_from_slice(bytes: &[u8]) -> Result<PublicKey, CryptoError> {
|
||||
PublicKey::from_slice(bytes).ok_or(CryptoError::InvalidKeyLength)
|
||||
}
|
||||
|
||||
/// Hash a password for storage/comparison
|
||||
/// RustDesk uses simple SHA256 for password hashing
|
||||
pub fn hash_password(password: &str, salt: &str) -> Vec<u8> {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
@@ -176,35 +144,24 @@ pub fn hash_password(password: &str, salt: &str) -> Vec<u8> {
|
||||
hasher.finalize().to_vec()
|
||||
}
|
||||
|
||||
/// RustDesk double hash for password verification
|
||||
/// Client calculates: SHA256(SHA256(password + salt) + challenge)
|
||||
/// This matches what the client sends in LoginRequest
|
||||
pub fn hash_password_double(password: &str, salt: &str, challenge: &str) -> Vec<u8> {
|
||||
use sha2::{Digest, Sha256};
|
||||
// First hash: SHA256(password + salt)
|
||||
let mut hasher1 = Sha256::new();
|
||||
hasher1.update(password.as_bytes());
|
||||
hasher1.update(salt.as_bytes());
|
||||
let first_hash = hasher1.finalize();
|
||||
|
||||
// Second hash: SHA256(first_hash + challenge)
|
||||
let mut hasher2 = Sha256::new();
|
||||
hasher2.update(first_hash);
|
||||
hasher2.update(challenge.as_bytes());
|
||||
hasher2.finalize().to_vec()
|
||||
}
|
||||
|
||||
/// Verify a password hash
|
||||
pub fn verify_password(password: &str, salt: &str, expected_hash: &[u8]) -> bool {
|
||||
let computed = hash_password(password, salt);
|
||||
// Constant-time comparison would be better, but for our use case this is acceptable
|
||||
computed == expected_hash
|
||||
}
|
||||
|
||||
/// Decrypt symmetric key using Curve25519 secret key directly
|
||||
///
|
||||
/// This is used when we have a fresh Curve25519 keypair for the connection
|
||||
/// (as per RustDesk protocol - each connection generates a new keypair)
|
||||
pub fn decrypt_symmetric_key(
|
||||
their_temp_public_key: &[u8],
|
||||
sealed_symmetric_key: &[u8],
|
||||
@@ -217,7 +174,6 @@ pub fn decrypt_symmetric_key(
|
||||
let their_pk =
|
||||
PublicKey::from_slice(their_temp_public_key).ok_or(CryptoError::InvalidKeyLength)?;
|
||||
|
||||
// Use zero nonce as per RustDesk protocol
|
||||
let nonce = box_::Nonce([0u8; box_::NONCEBYTES]);
|
||||
|
||||
let key_bytes = box_::open(sealed_symmetric_key, &nonce, &their_pk, our_secret_key)
|
||||
@@ -226,11 +182,7 @@ pub fn decrypt_symmetric_key(
|
||||
secretbox::Key::from_slice(&key_bytes).ok_or(CryptoError::InvalidKeyLength)
|
||||
}
|
||||
|
||||
/// Encrypt a message using the negotiated symmetric key
|
||||
///
|
||||
/// RustDesk uses a specific nonce format for session encryption
|
||||
pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) -> Vec<u8> {
|
||||
// Create nonce from counter (little-endian, padded to 24 bytes)
|
||||
let mut nonce_bytes = [0u8; secretbox::NONCEBYTES];
|
||||
nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes());
|
||||
let nonce = secretbox::Nonce(nonce_bytes);
|
||||
@@ -238,13 +190,11 @@ pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) ->
|
||||
secretbox::seal(data, &nonce, key)
|
||||
}
|
||||
|
||||
/// Decrypt a message using the negotiated symmetric key
|
||||
pub fn decrypt_message(
|
||||
ciphertext: &[u8],
|
||||
key: &secretbox::Key,
|
||||
nonce_counter: u64,
|
||||
) -> Result<Vec<u8>, CryptoError> {
|
||||
// Create nonce from counter (little-endian, padded to 24 bytes)
|
||||
let mut nonce_bytes = [0u8; secretbox::NONCEBYTES];
|
||||
nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes());
|
||||
let nonce = secretbox::Nonce(nonce_bytes);
|
||||
@@ -252,7 +202,6 @@ pub fn decrypt_message(
|
||||
secretbox::open(ciphertext, &nonce, key).map_err(|_| CryptoError::DecryptionFailed)
|
||||
}
|
||||
|
||||
/// Ed25519 signing keypair for RustDesk SignedId messages
|
||||
#[derive(Clone)]
|
||||
pub struct SigningKeyPair {
|
||||
pub public_key: sign::PublicKey,
|
||||
@@ -260,7 +209,6 @@ pub struct SigningKeyPair {
|
||||
}
|
||||
|
||||
impl SigningKeyPair {
|
||||
/// Generate a new random signing keypair
|
||||
pub fn generate() -> Self {
|
||||
let (public_key, secret_key) = sign::gen_keypair();
|
||||
Self {
|
||||
@@ -269,7 +217,6 @@ impl SigningKeyPair {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from existing keys
|
||||
pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result<Self, CryptoError> {
|
||||
let pk = sign::PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?;
|
||||
let sk = sign::SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?;
|
||||
@@ -279,27 +226,22 @@ impl SigningKeyPair {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get public key as bytes
|
||||
pub fn public_key_bytes(&self) -> &[u8] {
|
||||
self.public_key.as_ref()
|
||||
}
|
||||
|
||||
/// Get secret key as bytes
|
||||
pub fn secret_key_bytes(&self) -> &[u8] {
|
||||
self.secret_key.as_ref()
|
||||
}
|
||||
|
||||
/// Encode public key as base64
|
||||
pub fn public_key_base64(&self) -> String {
|
||||
BASE64.encode(self.public_key_bytes())
|
||||
}
|
||||
|
||||
/// Encode secret key as base64
|
||||
pub fn secret_key_base64(&self) -> String {
|
||||
BASE64.encode(self.secret_key_bytes())
|
||||
}
|
||||
|
||||
/// Create from base64-encoded keys
|
||||
pub fn from_base64(public_key: &str, secret_key: &str) -> Result<Self, CryptoError> {
|
||||
let pk_bytes = BASE64
|
||||
.decode(public_key)
|
||||
@@ -310,42 +252,27 @@ impl SigningKeyPair {
|
||||
Self::from_keys(&pk_bytes, &sk_bytes)
|
||||
}
|
||||
|
||||
/// Sign a message
|
||||
/// Returns the signature prepended to the message (as per RustDesk protocol)
|
||||
pub fn sign(&self, message: &[u8]) -> Vec<u8> {
|
||||
sign::sign(message, &self.secret_key)
|
||||
}
|
||||
|
||||
/// Sign a message and return just the signature (64 bytes)
|
||||
pub fn sign_detached(&self, message: &[u8]) -> [u8; 64] {
|
||||
let sig = sign::sign_detached(message, &self.secret_key);
|
||||
// Use as_ref() to access the signature bytes since the inner field is private
|
||||
let sig_bytes: &[u8] = sig.as_ref();
|
||||
let mut result = [0u8; 64];
|
||||
result.copy_from_slice(sig_bytes);
|
||||
result
|
||||
}
|
||||
|
||||
/// Convert Ed25519 public key to Curve25519 public key for encryption
|
||||
///
|
||||
/// This allows using the same keypair for both signing and encryption,
|
||||
/// which is required by RustDesk's protocol where clients encrypt the
|
||||
/// symmetric key using the public key from IdPk.
|
||||
pub fn to_curve25519_pk(&self) -> Result<PublicKey, CryptoError> {
|
||||
ed25519::to_curve25519_pk(&self.public_key).map_err(|_| CryptoError::KeyConversionFailed)
|
||||
}
|
||||
|
||||
/// Convert Ed25519 secret key to Curve25519 secret key for decryption
|
||||
///
|
||||
/// This allows decrypting messages that were encrypted using the
|
||||
/// converted public key.
|
||||
pub fn to_curve25519_sk(&self) -> Result<SecretKey, CryptoError> {
|
||||
ed25519::to_curve25519_sk(&self.secret_key).map_err(|_| CryptoError::KeyConversionFailed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify a signed message
|
||||
/// Returns the original message if signature is valid
|
||||
pub fn verify_signed(
|
||||
signed_message: &[u8],
|
||||
public_key: &sign::PublicKey,
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
//! RustDesk Frame Adapters
|
||||
//!
|
||||
//! Converts One-KVM video/audio frames to RustDesk protocol format.
|
||||
//! Optimized for zero-copy where possible and buffer reuse.
|
||||
|
||||
use bytes::Bytes;
|
||||
use protobuf::Message as ProtobufMessage;
|
||||
|
||||
@@ -11,7 +6,6 @@ use super::protocol::hbb::message::{
|
||||
CursorData, CursorPosition, EncodedVideoFrame, EncodedVideoFrames, Message, Misc, VideoFrame,
|
||||
};
|
||||
|
||||
/// Video codec type for RustDesk
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum VideoCodec {
|
||||
H264,
|
||||
@@ -22,7 +16,6 @@ pub enum VideoCodec {
|
||||
}
|
||||
|
||||
impl VideoCodec {
|
||||
/// Get the codec ID for the RustDesk protocol
|
||||
pub fn to_codec_id(self) -> i32 {
|
||||
match self {
|
||||
VideoCodec::H264 => 0,
|
||||
@@ -34,21 +27,15 @@ impl VideoCodec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Video frame adapter for converting to RustDesk format
|
||||
pub struct VideoFrameAdapter {
|
||||
/// Current codec
|
||||
codec: VideoCodec,
|
||||
/// Frame sequence number
|
||||
seq: u32,
|
||||
/// Timestamp offset
|
||||
timestamp_base: u64,
|
||||
/// Cached H264 SPS/PPS (Annex B NAL without start code)
|
||||
h264_sps: Option<Bytes>,
|
||||
h264_pps: Option<Bytes>,
|
||||
}
|
||||
|
||||
impl VideoFrameAdapter {
|
||||
/// Create a new video frame adapter
|
||||
pub fn new(codec: VideoCodec) -> Self {
|
||||
Self {
|
||||
codec,
|
||||
@@ -59,14 +46,10 @@ impl VideoFrameAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set codec type
|
||||
pub fn set_codec(&mut self, codec: VideoCodec) {
|
||||
self.codec = codec;
|
||||
}
|
||||
|
||||
/// Convert encoded video data to RustDesk Message (zero-copy version)
|
||||
///
|
||||
/// This version takes Bytes directly to avoid copying the frame data.
|
||||
pub fn encode_frame_from_bytes(
|
||||
&mut self,
|
||||
data: Bytes,
|
||||
@@ -74,7 +57,6 @@ impl VideoFrameAdapter {
|
||||
timestamp_ms: u64,
|
||||
) -> Message {
|
||||
let data = self.prepare_h264_frame(data, is_keyframe);
|
||||
// Calculate relative timestamp
|
||||
if self.seq == 0 {
|
||||
self.timestamp_base = timestamp_ms;
|
||||
}
|
||||
@@ -87,11 +69,9 @@ impl VideoFrameAdapter {
|
||||
|
||||
self.seq = self.seq.wrapping_add(1);
|
||||
|
||||
// Wrap in EncodedVideoFrames container
|
||||
let mut frames = EncodedVideoFrames::new();
|
||||
frames.frames.push(frame);
|
||||
|
||||
// Create the appropriate VideoFrame variant based on codec
|
||||
let mut video_frame = VideoFrame::new();
|
||||
match self.codec {
|
||||
VideoCodec::H264 => video_frame.union = Some(vf_union::Union::H264s(frames)),
|
||||
@@ -111,7 +91,6 @@ impl VideoFrameAdapter {
|
||||
return data;
|
||||
}
|
||||
|
||||
// Parse SPS/PPS from Annex B data (without start codes)
|
||||
let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data);
|
||||
let mut has_sps = false;
|
||||
let mut has_pps = false;
|
||||
@@ -125,7 +104,6 @@ impl VideoFrameAdapter {
|
||||
has_pps = true;
|
||||
}
|
||||
|
||||
// Inject cached SPS/PPS before IDR when missing
|
||||
if is_keyframe && (!has_sps || !has_pps) {
|
||||
if let (Some(sps), Some(pps)) = (self.h264_sps.as_ref(), self.h264_pps.as_ref()) {
|
||||
let mut out = Vec::with_capacity(8 + sps.len() + pps.len() + data.len());
|
||||
@@ -141,14 +119,10 @@ impl VideoFrameAdapter {
|
||||
data
|
||||
}
|
||||
|
||||
/// Convert encoded video data to RustDesk Message
|
||||
pub fn encode_frame(&mut self, data: &[u8], is_keyframe: bool, timestamp_ms: u64) -> Message {
|
||||
self.encode_frame_from_bytes(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms)
|
||||
}
|
||||
|
||||
/// Encode frame to bytes for sending (zero-copy version)
|
||||
///
|
||||
/// Takes Bytes directly to avoid copying the frame data.
|
||||
pub fn encode_frame_bytes_zero_copy(
|
||||
&mut self,
|
||||
data: Bytes,
|
||||
@@ -159,7 +133,6 @@ impl VideoFrameAdapter {
|
||||
Bytes::from(msg.write_to_bytes().unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Encode frame to bytes for sending
|
||||
pub fn encode_frame_bytes(
|
||||
&mut self,
|
||||
data: &[u8],
|
||||
@@ -169,24 +142,18 @@ impl VideoFrameAdapter {
|
||||
self.encode_frame_bytes_zero_copy(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms)
|
||||
}
|
||||
|
||||
/// Get current sequence number
|
||||
pub fn seq(&self) -> u32 {
|
||||
self.seq
|
||||
}
|
||||
}
|
||||
|
||||
/// Audio frame adapter for converting to RustDesk format
|
||||
pub struct AudioFrameAdapter {
|
||||
/// Sample rate
|
||||
sample_rate: u32,
|
||||
/// Channels
|
||||
channels: u8,
|
||||
/// Format sent flag
|
||||
format_sent: bool,
|
||||
}
|
||||
|
||||
impl AudioFrameAdapter {
|
||||
/// Create a new audio frame adapter
|
||||
pub fn new(sample_rate: u32, channels: u8) -> Self {
|
||||
Self {
|
||||
sample_rate,
|
||||
@@ -195,7 +162,6 @@ impl AudioFrameAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create audio format message (should be sent once before audio frames)
|
||||
pub fn create_format_message(&mut self) -> Message {
|
||||
self.format_sent = true;
|
||||
|
||||
@@ -211,12 +177,10 @@ impl AudioFrameAdapter {
|
||||
msg
|
||||
}
|
||||
|
||||
/// Check if format message has been sent
|
||||
pub fn format_sent(&self) -> bool {
|
||||
self.format_sent
|
||||
}
|
||||
|
||||
/// Convert Opus audio data to RustDesk Message
|
||||
pub fn encode_opus_frame(&self, data: &[u8]) -> Message {
|
||||
let mut frame = AudioFrame::new();
|
||||
frame.data = Bytes::copy_from_slice(data);
|
||||
@@ -226,23 +190,19 @@ impl AudioFrameAdapter {
|
||||
msg
|
||||
}
|
||||
|
||||
/// Encode Opus frame to bytes for sending
|
||||
pub fn encode_opus_bytes(&self, data: &[u8]) -> Bytes {
|
||||
let msg = self.encode_opus_frame(data);
|
||||
Bytes::from(msg.write_to_bytes().unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Reset state (call when restarting audio stream)
|
||||
pub fn reset(&mut self) {
|
||||
self.format_sent = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Cursor data adapter
|
||||
pub struct CursorAdapter;
|
||||
|
||||
impl CursorAdapter {
|
||||
/// Create cursor data message
|
||||
pub fn encode_cursor(
|
||||
id: u64,
|
||||
hotx: i32,
|
||||
@@ -264,7 +224,6 @@ impl CursorAdapter {
|
||||
msg
|
||||
}
|
||||
|
||||
/// Create cursor position message
|
||||
pub fn encode_position(x: i32, y: i32) -> Message {
|
||||
let mut pos = CursorPosition::new();
|
||||
pos.x = x;
|
||||
@@ -284,7 +243,6 @@ mod tests {
|
||||
fn test_video_frame_encoding() {
|
||||
let mut adapter = VideoFrameAdapter::new(VideoCodec::H264);
|
||||
|
||||
// Encode a keyframe
|
||||
let data = vec![0x00, 0x00, 0x00, 0x01, 0x67]; // H264 SPS NAL
|
||||
let msg = adapter.encode_frame(&data, true, 0);
|
||||
|
||||
@@ -324,7 +282,6 @@ mod tests {
|
||||
fn test_audio_frame_encoding() {
|
||||
let adapter = AudioFrameAdapter::new(48000, 2);
|
||||
|
||||
// Encode an Opus frame
|
||||
let opus_data = vec![0xFC, 0x01, 0x02]; // Fake Opus data
|
||||
let msg = adapter.encode_opus_frame(&opus_data);
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
//! RustDesk HID Adapter
|
||||
//!
|
||||
//! Converts RustDesk HID events (KeyEvent, MouseEvent) to One-KVM HID events.
|
||||
|
||||
use super::protocol::hbb::message::key_event as ke_union;
|
||||
use super::protocol::{ControlKey, KeyEvent, MouseEvent};
|
||||
use crate::hid::{
|
||||
@@ -10,8 +6,6 @@ use crate::hid::{
|
||||
};
|
||||
use protobuf::Enum;
|
||||
|
||||
/// Mouse event types from RustDesk protocol
|
||||
/// mask = (button << 3) | event_type
|
||||
pub mod mouse_type {
|
||||
pub const MOVE: i32 = 0;
|
||||
pub const DOWN: i32 = 1;
|
||||
@@ -21,7 +15,6 @@ pub mod mouse_type {
|
||||
pub const MOVE_RELATIVE: i32 = 5;
|
||||
}
|
||||
|
||||
/// Mouse button IDs from RustDesk protocol (before left shift by 3)
|
||||
pub mod mouse_button {
|
||||
pub const LEFT: i32 = 0x01;
|
||||
pub const RIGHT: i32 = 0x02;
|
||||
@@ -30,9 +23,6 @@ pub mod mouse_button {
|
||||
pub const FORWARD: i32 = 0x10;
|
||||
}
|
||||
|
||||
/// Convert RustDesk MouseEvent to One-KVM MouseEvent(s)
|
||||
/// Returns a Vec because a single RustDesk event may need multiple One-KVM events
|
||||
/// (e.g., move + button + scroll)
|
||||
pub fn convert_mouse_event(
|
||||
event: &MouseEvent,
|
||||
screen_width: u32,
|
||||
@@ -41,23 +31,18 @@ pub fn convert_mouse_event(
|
||||
) -> Vec<OneKvmMouseEvent> {
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Parse RustDesk mask format: (button << 3) | event_type
|
||||
let event_type = event.mask & 0x07;
|
||||
let button_id = event.mask >> 3;
|
||||
let include_abs_move = !relative_mode;
|
||||
|
||||
match event_type {
|
||||
mouse_type::MOVE => {
|
||||
// RustDesk uses absolute coordinates
|
||||
let x = event.x.max(0) as u32;
|
||||
let y = event.y.max(0) as u32;
|
||||
|
||||
// Normalize to 0-32767 range for absolute mouse (USB HID standard)
|
||||
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
|
||||
let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32;
|
||||
|
||||
// Move event - may have button held down (button_id > 0 means dragging)
|
||||
// Just send move, button state is tracked separately by HID backend
|
||||
events.push(OneKvmMouseEvent {
|
||||
event_type: MouseEventType::MoveAbs,
|
||||
x: abs_x,
|
||||
@@ -67,7 +52,6 @@ pub fn convert_mouse_event(
|
||||
});
|
||||
}
|
||||
mouse_type::MOVE_RELATIVE => {
|
||||
// Relative movement uses delta values directly (dx, dy).
|
||||
events.push(OneKvmMouseEvent {
|
||||
event_type: MouseEventType::Move,
|
||||
x: event.x,
|
||||
@@ -78,7 +62,6 @@ pub fn convert_mouse_event(
|
||||
}
|
||||
mouse_type::DOWN => {
|
||||
if include_abs_move {
|
||||
// Button down - first move, then press
|
||||
let x = event.x.max(0) as u32;
|
||||
let y = event.y.max(0) as u32;
|
||||
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
|
||||
@@ -104,7 +87,6 @@ pub fn convert_mouse_event(
|
||||
}
|
||||
mouse_type::UP => {
|
||||
if include_abs_move {
|
||||
// Button up - first move, then release
|
||||
let x = event.x.max(0) as u32;
|
||||
let y = event.y.max(0) as u32;
|
||||
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
|
||||
@@ -130,7 +112,6 @@ pub fn convert_mouse_event(
|
||||
}
|
||||
mouse_type::WHEEL => {
|
||||
if include_abs_move {
|
||||
// Scroll event - move first, then scroll
|
||||
let x = event.x.max(0) as u32;
|
||||
let y = event.y.max(0) as u32;
|
||||
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
|
||||
@@ -144,9 +125,6 @@ pub fn convert_mouse_event(
|
||||
});
|
||||
}
|
||||
|
||||
// RustDesk encodes scroll direction in the y coordinate
|
||||
// Positive y = scroll up, Negative y = scroll down
|
||||
// The button_id field is not used for direction
|
||||
let scroll = if event.y > 0 { 1i8 } else { -1i8 };
|
||||
events.push(OneKvmMouseEvent {
|
||||
event_type: MouseEventType::Scroll,
|
||||
@@ -158,7 +136,6 @@ pub fn convert_mouse_event(
|
||||
}
|
||||
_ => {
|
||||
if include_abs_move {
|
||||
// Unknown event type, just move
|
||||
let x = event.x.max(0) as u32;
|
||||
let y = event.y.max(0) as u32;
|
||||
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
|
||||
@@ -177,7 +154,6 @@ pub fn convert_mouse_event(
|
||||
events
|
||||
}
|
||||
|
||||
/// Convert RustDesk button ID to One-KVM MouseButton
|
||||
fn button_id_to_button(button_id: i32) -> Option<MouseButton> {
|
||||
match button_id {
|
||||
mouse_button::LEFT => Some(MouseButton::Left),
|
||||
@@ -187,34 +163,19 @@ fn button_id_to_button(button_id: i32) -> Option<MouseButton> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert RustDesk KeyEvent to One-KVM KeyboardEvent
|
||||
///
|
||||
/// RustDesk KeyEvent has two modes:
|
||||
/// - down=true/false: Key state (pressed/released)
|
||||
/// - press=true: Complete key press (down + up), used for typing
|
||||
///
|
||||
/// For press=true events, we only send Down and let the caller handle
|
||||
/// the timing for Up event if needed. Most systems handle this correctly.
|
||||
pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
|
||||
// Determine if this is a key down or key up event
|
||||
// press=true means "key was pressed" (down event)
|
||||
// down=true means key is currently held down
|
||||
// down=false with press=false means key was released
|
||||
let event_type = if event.down || event.press {
|
||||
KeyEventType::Down
|
||||
} else {
|
||||
KeyEventType::Up
|
||||
};
|
||||
|
||||
// For modifier keys sent as ControlKey, don't include them in modifiers
|
||||
// to avoid double-pressing. The modifier will be tracked by HID state.
|
||||
let modifiers = if is_modifier_control_key(event) {
|
||||
KeyboardModifiers::default()
|
||||
} else {
|
||||
parse_modifiers(event)
|
||||
};
|
||||
|
||||
// Handle control keys
|
||||
if let Some(ke_union::Union::ControlKey(ck)) = &event.union {
|
||||
if let Some(key) = control_key_to_hid(ck.value()) {
|
||||
let key = CanonicalKey::from_hid_usage(key)?;
|
||||
@@ -226,9 +187,7 @@ pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle character keys (chr field contains platform-specific keycode)
|
||||
if let Some(ke_union::Union::Chr(chr)) = &event.union {
|
||||
// chr contains USB HID scancode on Windows, X11 keycode on Linux
|
||||
if let Some(key) = keycode_to_hid(*chr) {
|
||||
let key = CanonicalKey::from_hid_usage(key)?;
|
||||
return Some(KeyboardEvent {
|
||||
@@ -239,13 +198,9 @@ pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle unicode (for text input, we'd need to convert to scancodes)
|
||||
// Unicode input requires more complex handling, skip for now
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if the event is a modifier key sent as ControlKey
|
||||
fn is_modifier_control_key(event: &KeyEvent) -> bool {
|
||||
if let Some(ke_union::Union::ControlKey(ck)) = &event.union {
|
||||
let val = ck.value();
|
||||
@@ -260,7 +215,6 @@ fn is_modifier_control_key(event: &KeyEvent) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Parse modifier keys from RustDesk KeyEvent into KeyboardModifiers
|
||||
fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers {
|
||||
let mut modifiers = KeyboardModifiers::default();
|
||||
|
||||
@@ -281,7 +235,6 @@ fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers {
|
||||
modifiers
|
||||
}
|
||||
|
||||
/// Convert RustDesk ControlKey to USB HID usage code
|
||||
fn control_key_to_hid(key: i32) -> Option<u8> {
|
||||
match key {
|
||||
x if x == ControlKey::Alt as i32 => Some(0xE2), // Left Alt
|
||||
@@ -342,67 +295,47 @@ fn control_key_to_hid(key: i32) -> Option<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert platform keycode to USB HID usage code
|
||||
/// Handles Windows Virtual Key Codes, X11 keycodes, and ASCII codes
|
||||
fn keycode_to_hid(keycode: u32) -> Option<u8> {
|
||||
// First try ASCII code mapping (RustDesk often sends ASCII codes)
|
||||
if let Some(hid) = ascii_to_hid(keycode) {
|
||||
return Some(hid);
|
||||
}
|
||||
// Then try Windows Virtual Key Code mapping
|
||||
if let Some(hid) = windows_vk_to_hid(keycode) {
|
||||
return Some(hid);
|
||||
}
|
||||
// Fall back to X11 keycode mapping for Linux clients
|
||||
x11_keycode_to_hid(keycode)
|
||||
}
|
||||
|
||||
/// Convert ASCII code to USB HID usage code
|
||||
fn ascii_to_hid(ascii: u32) -> Option<u8> {
|
||||
match ascii {
|
||||
// Lowercase letters a-z (ASCII 97-122)
|
||||
97..=122 => {
|
||||
// USB HID: a=0x04, b=0x05, ..., z=0x1D
|
||||
Some((ascii - 97 + 0x04) as u8)
|
||||
}
|
||||
// Uppercase letters A-Z (ASCII 65-90)
|
||||
65..=90 => {
|
||||
// USB HID: A=0x04, B=0x05, ..., Z=0x1D (same as lowercase)
|
||||
Some((ascii - 65 + 0x04) as u8)
|
||||
}
|
||||
// Numbers 0-9 (ASCII 48-57)
|
||||
97..=122 => Some((ascii - 97 + 0x04) as u8),
|
||||
65..=90 => Some((ascii - 65 + 0x04) as u8),
|
||||
48 => Some(0x27), // 0
|
||||
49..=57 => Some((ascii - 49 + 0x1E) as u8), // 1-9
|
||||
// Common punctuation
|
||||
32 => Some(0x2C), // Space
|
||||
13 => Some(0x28), // Enter (CR)
|
||||
10 => Some(0x28), // Enter (LF)
|
||||
9 => Some(0x2B), // Tab
|
||||
27 => Some(0x29), // Escape
|
||||
8 => Some(0x2A), // Backspace
|
||||
127 => Some(0x4C), // Delete
|
||||
// Symbols (US keyboard layout)
|
||||
45 => Some(0x2D), // -
|
||||
61 => Some(0x2E), // =
|
||||
91 => Some(0x2F), // [
|
||||
93 => Some(0x30), // ]
|
||||
92 => Some(0x31), // \
|
||||
59 => Some(0x33), // ;
|
||||
39 => Some(0x34), // '
|
||||
96 => Some(0x35), // `
|
||||
44 => Some(0x36), // ,
|
||||
46 => Some(0x37), // .
|
||||
47 => Some(0x38), // /
|
||||
32 => Some(0x2C), // Space
|
||||
13 => Some(0x28), // Enter (CR)
|
||||
10 => Some(0x28), // Enter (LF)
|
||||
9 => Some(0x2B), // Tab
|
||||
27 => Some(0x29), // Escape
|
||||
8 => Some(0x2A), // Backspace
|
||||
127 => Some(0x4C), // Delete
|
||||
45 => Some(0x2D), // -
|
||||
61 => Some(0x2E), // =
|
||||
91 => Some(0x2F), // [
|
||||
93 => Some(0x30), // ]
|
||||
92 => Some(0x31), // \
|
||||
59 => Some(0x33), // ;
|
||||
39 => Some(0x34), // '
|
||||
96 => Some(0x35), // `
|
||||
44 => Some(0x36), // ,
|
||||
46 => Some(0x37), // .
|
||||
47 => Some(0x38), // /
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Windows Virtual Key Code to USB HID usage code
|
||||
fn windows_vk_to_hid(vk: u32) -> Option<u8> {
|
||||
match vk {
|
||||
// Letters A-Z (VK_A=0x41 to VK_Z=0x5A)
|
||||
0x41..=0x5A => {
|
||||
// USB HID: A=0x04, B=0x05, ..., Z=0x1D
|
||||
let letter = (vk - 0x41) as u8;
|
||||
Some(match letter {
|
||||
0 => 0x04, // A
|
||||
@@ -434,21 +367,16 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
// Numbers 0-9 (VK_0=0x30 to VK_9=0x39)
|
||||
0x30 => Some(0x27), // 0
|
||||
0x31..=0x39 => Some((vk - 0x31 + 0x1E) as u8), // 1-9
|
||||
// Numpad 0-9 (VK_NUMPAD0=0x60 to VK_NUMPAD9=0x69)
|
||||
0x60 => Some(0x62), // Numpad 0
|
||||
0x61..=0x69 => Some((vk - 0x61 + 0x59) as u8), // Numpad 1-9
|
||||
// Numpad operators
|
||||
0x6A => Some(0x55), // Numpad *
|
||||
0x6B => Some(0x57), // Numpad +
|
||||
0x6D => Some(0x56), // Numpad -
|
||||
0x6E => Some(0x63), // Numpad .
|
||||
0x6F => Some(0x54), // Numpad /
|
||||
// Function keys F1-F12 (VK_F1=0x70 to VK_F12=0x7B)
|
||||
0x6A => Some(0x55), // Numpad *
|
||||
0x6B => Some(0x57), // Numpad +
|
||||
0x6D => Some(0x56), // Numpad -
|
||||
0x6E => Some(0x63), // Numpad .
|
||||
0x6F => Some(0x54), // Numpad /
|
||||
0x70..=0x7B => Some((vk - 0x70 + 0x3A) as u8),
|
||||
// Special keys
|
||||
0x08 => Some(0x2A), // Backspace
|
||||
0x09 => Some(0x2B), // Tab
|
||||
0x0D => Some(0x28), // Enter
|
||||
@@ -464,7 +392,6 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
|
||||
0x28 => Some(0x51), // Down Arrow
|
||||
0x2D => Some(0x49), // Insert
|
||||
0x2E => Some(0x4C), // Delete
|
||||
// OEM keys (US keyboard layout)
|
||||
0xBA => Some(0x33), // ; :
|
||||
0xBB => Some(0x2E), // = +
|
||||
0xBC => Some(0x36), // , <
|
||||
@@ -476,66 +403,56 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
|
||||
0xDC => Some(0x31), // \ |
|
||||
0xDD => Some(0x30), // ] }
|
||||
0xDE => Some(0x34), // ' "
|
||||
// Lock keys
|
||||
0x14 => Some(0x39), // Caps Lock
|
||||
0x90 => Some(0x53), // Num Lock
|
||||
0x91 => Some(0x47), // Scroll Lock
|
||||
// Print Screen, Pause
|
||||
0x2C => Some(0x46), // Print Screen
|
||||
0x13 => Some(0x48), // Pause
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert X11 keycode to USB HID usage code (for Linux clients)
|
||||
fn x11_keycode_to_hid(keycode: u32) -> Option<u8> {
|
||||
match keycode {
|
||||
// Numbers: X11 keycode 10="1", 11="2", ..., 18="9", 19="0"
|
||||
10..=18 => Some((keycode - 10 + 0x1E) as u8), // 1-9
|
||||
19 => Some(0x27), // 0
|
||||
// Punctuation
|
||||
20 => Some(0x2D), // -
|
||||
21 => Some(0x2E), // =
|
||||
34 => Some(0x2F), // [
|
||||
35 => Some(0x30), // ]
|
||||
// Letters (X11 keycodes are row-based)
|
||||
// Row 1: q(24) w(25) e(26) r(27) t(28) y(29) u(30) i(31) o(32) p(33)
|
||||
24 => Some(0x14), // q
|
||||
25 => Some(0x1A), // w
|
||||
26 => Some(0x08), // e
|
||||
27 => Some(0x15), // r
|
||||
28 => Some(0x17), // t
|
||||
29 => Some(0x1C), // y
|
||||
30 => Some(0x18), // u
|
||||
31 => Some(0x0C), // i
|
||||
32 => Some(0x12), // o
|
||||
33 => Some(0x13), // p
|
||||
// Row 2: a(38) s(39) d(40) f(41) g(42) h(43) j(44) k(45) l(46)
|
||||
38 => Some(0x04), // a
|
||||
39 => Some(0x16), // s
|
||||
40 => Some(0x07), // d
|
||||
41 => Some(0x09), // f
|
||||
42 => Some(0x0A), // g
|
||||
43 => Some(0x0B), // h
|
||||
44 => Some(0x0D), // j
|
||||
45 => Some(0x0E), // k
|
||||
46 => Some(0x0F), // l
|
||||
47 => Some(0x33), // ;
|
||||
48 => Some(0x34), // '
|
||||
49 => Some(0x35), // `
|
||||
51 => Some(0x31), // \
|
||||
// Row 3: z(52) x(53) c(54) v(55) b(56) n(57) m(58)
|
||||
52 => Some(0x1D), // z
|
||||
53 => Some(0x1B), // x
|
||||
54 => Some(0x06), // c
|
||||
55 => Some(0x19), // v
|
||||
56 => Some(0x05), // b
|
||||
57 => Some(0x11), // n
|
||||
58 => Some(0x10), // m
|
||||
59 => Some(0x36), // ,
|
||||
60 => Some(0x37), // .
|
||||
61 => Some(0x38), // /
|
||||
// Space
|
||||
20 => Some(0x2D), // -
|
||||
21 => Some(0x2E), // =
|
||||
34 => Some(0x2F), // [
|
||||
35 => Some(0x30), // ]
|
||||
24 => Some(0x14), // q
|
||||
25 => Some(0x1A), // w
|
||||
26 => Some(0x08), // e
|
||||
27 => Some(0x15), // r
|
||||
28 => Some(0x17), // t
|
||||
29 => Some(0x1C), // y
|
||||
30 => Some(0x18), // u
|
||||
31 => Some(0x0C), // i
|
||||
32 => Some(0x12), // o
|
||||
33 => Some(0x13), // p
|
||||
38 => Some(0x04), // a
|
||||
39 => Some(0x16), // s
|
||||
40 => Some(0x07), // d
|
||||
41 => Some(0x09), // f
|
||||
42 => Some(0x0A), // g
|
||||
43 => Some(0x0B), // h
|
||||
44 => Some(0x0D), // j
|
||||
45 => Some(0x0E), // k
|
||||
46 => Some(0x0F), // l
|
||||
47 => Some(0x33), // ;
|
||||
48 => Some(0x34), // '
|
||||
49 => Some(0x35), // `
|
||||
51 => Some(0x31), // \
|
||||
52 => Some(0x1D), // z
|
||||
53 => Some(0x1B), // x
|
||||
54 => Some(0x06), // c
|
||||
55 => Some(0x19), // v
|
||||
56 => Some(0x05), // b
|
||||
57 => Some(0x11), // n
|
||||
58 => Some(0x10), // m
|
||||
59 => Some(0x36), // ,
|
||||
60 => Some(0x37), // .
|
||||
61 => Some(0x38), // /
|
||||
65 => Some(0x2C),
|
||||
_ => None,
|
||||
}
|
||||
@@ -573,7 +490,6 @@ mod tests {
|
||||
|
||||
let events = convert_mouse_event(&event, 1920, 1080, false);
|
||||
assert!(events.len() >= 2);
|
||||
// Should have a button down event
|
||||
assert!(events
|
||||
.iter()
|
||||
.any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left)));
|
||||
|
||||
@@ -1,17 +1,4 @@
|
||||
//! RustDesk Protocol Integration Module
|
||||
//!
|
||||
//! This module implements the RustDesk client protocol, enabling One-KVM devices
|
||||
//! to be accessed via standard RustDesk clients through existing hbbs/hbbr servers.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! - `config`: Configuration types for RustDesk settings
|
||||
//! - `protocol`: Protobuf message wrappers and serialization
|
||||
//! - `crypto`: NaCl cryptography (key generation, encryption, signatures)
|
||||
//! - `rendezvous`: Communication with hbbs rendezvous server
|
||||
//! - `connection`: Client session handling
|
||||
//! - `frame_adapters`: Video/audio frame conversion to RustDesk format
|
||||
//! - `hid_adapter`: RustDesk HID events to One-KVM conversion
|
||||
//! RustDesk peer protocol (hbbs / hbbr).
|
||||
|
||||
pub mod bytes_codec;
|
||||
pub mod config;
|
||||
@@ -44,19 +31,13 @@ use self::connection::ConnectionManager;
|
||||
use self::protocol::{make_local_addr, make_relay_response, make_request_relay};
|
||||
use self::rendezvous::{AddrMangle, RendezvousMediator, RendezvousStatus};
|
||||
|
||||
/// Relay connection timeout
|
||||
const RELAY_CONNECT_TIMEOUT_MS: u64 = 10_000;
|
||||
|
||||
/// RustDesk service status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ServiceStatus {
|
||||
/// Service is stopped
|
||||
Stopped,
|
||||
/// Service is starting
|
||||
Starting,
|
||||
/// Service is running and registered with rendezvous server
|
||||
Running,
|
||||
/// Service encountered an error
|
||||
Error(String),
|
||||
}
|
||||
|
||||
@@ -71,15 +52,8 @@ impl std::fmt::Display for ServiceStatus {
|
||||
}
|
||||
}
|
||||
|
||||
/// Default port for direct TCP connections (same as RustDesk)
|
||||
const DIRECT_LISTEN_PORT: u16 = 21118;
|
||||
|
||||
/// RustDesk Service
|
||||
///
|
||||
/// Manages the RustDesk protocol integration, including:
|
||||
/// - Registration with hbbs rendezvous server
|
||||
/// - Accepting connections from RustDesk clients
|
||||
/// - Streaming video/audio and receiving HID input
|
||||
pub struct RustDeskService {
|
||||
config: Arc<RwLock<RustDeskConfig>>,
|
||||
status: Arc<RwLock<ServiceStatus>>,
|
||||
@@ -95,7 +69,6 @@ pub struct RustDeskService {
|
||||
}
|
||||
|
||||
impl RustDeskService {
|
||||
/// Create a new RustDesk service instance
|
||||
pub fn new(
|
||||
config: RustDeskConfig,
|
||||
video_manager: Arc<VideoStreamManager>,
|
||||
@@ -120,42 +93,34 @@ impl RustDeskService {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the port for direct TCP connections
|
||||
pub fn listen_port(&self) -> u16 {
|
||||
*self.listen_port.read()
|
||||
}
|
||||
|
||||
/// Get current service status
|
||||
pub fn status(&self) -> ServiceStatus {
|
||||
self.status.read().clone()
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub fn config(&self) -> RustDeskConfig {
|
||||
self.config.read().clone()
|
||||
}
|
||||
|
||||
/// Update configuration
|
||||
pub fn update_config(&self, config: RustDeskConfig) {
|
||||
*self.config.write() = config;
|
||||
}
|
||||
|
||||
/// Get rendezvous status
|
||||
pub fn rendezvous_status(&self) -> Option<RendezvousStatus> {
|
||||
self.rendezvous.read().as_ref().map(|r| r.status())
|
||||
}
|
||||
|
||||
/// Get device ID
|
||||
pub fn device_id(&self) -> String {
|
||||
self.config.read().device_id.clone()
|
||||
}
|
||||
|
||||
/// Get connection count
|
||||
pub fn connection_count(&self) -> usize {
|
||||
self.connection_manager.connection_count()
|
||||
}
|
||||
|
||||
/// Start the RustDesk service
|
||||
pub async fn start(&self) -> anyhow::Result<()> {
|
||||
let config = self.config.read().clone();
|
||||
|
||||
@@ -181,74 +146,44 @@ impl RustDeskService {
|
||||
config.rendezvous_addr()
|
||||
);
|
||||
|
||||
// Initialize crypto
|
||||
if let Err(e) = crypto::init() {
|
||||
error!("Failed to initialize crypto: {}", e);
|
||||
*self.status.write() = ServiceStatus::Error(e.to_string());
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
// Create and start rendezvous mediator with relay callback
|
||||
let mediator = Arc::new(RendezvousMediator::new(config.clone()));
|
||||
|
||||
// Set the keypair on connection manager (Curve25519 for encryption)
|
||||
let keypair = mediator.ensure_keypair();
|
||||
self.connection_manager.set_keypair(keypair);
|
||||
|
||||
// Set the signing keypair on connection manager (Ed25519 for SignedId)
|
||||
let signing_keypair = mediator.ensure_signing_keypair();
|
||||
self.connection_manager.set_signing_keypair(signing_keypair);
|
||||
|
||||
// Set the HID controller on connection manager
|
||||
self.connection_manager.set_hid(self.hid.clone());
|
||||
|
||||
// Set the audio controller on connection manager for audio streaming
|
||||
self.connection_manager.set_audio(self.audio.clone());
|
||||
|
||||
// Set the video manager on connection manager for video streaming
|
||||
self.connection_manager
|
||||
.set_video_manager(self.video_manager.clone());
|
||||
|
||||
*self.rendezvous.write() = Some(mediator.clone());
|
||||
|
||||
// Start TCP listener BEFORE the rendezvous mediator to ensure port is set correctly
|
||||
// This prevents race condition where mediator starts registration with wrong port
|
||||
let (tcp_handles, listen_port) = self.start_tcp_listener_with_port().await?;
|
||||
*self.tcp_listener_handle.write() = Some(tcp_handles);
|
||||
|
||||
// Set the listen port on mediator before starting the registration loop
|
||||
mediator.set_listen_port(listen_port);
|
||||
|
||||
// Create relay request handler
|
||||
let connection_manager = self.connection_manager.clone();
|
||||
let video_manager = self.video_manager.clone();
|
||||
let hid = self.hid.clone();
|
||||
let audio = self.audio.clone();
|
||||
let service_config = self.config.clone();
|
||||
|
||||
// Set the punch callback on the mediator (tries P2P first, then relay)
|
||||
let connection_manager_punch = self.connection_manager.clone();
|
||||
let video_manager_punch = self.video_manager.clone();
|
||||
let hid_punch = self.hid.clone();
|
||||
let audio_punch = self.audio.clone();
|
||||
let service_config_punch = self.config.clone();
|
||||
|
||||
mediator.set_punch_callback(Arc::new(
|
||||
mediator.set_punch_callback(Arc::new({
|
||||
let connection_manager = connection_manager.clone();
|
||||
let service_config = service_config.clone();
|
||||
move |peer_addr, rendezvous_addr, relay_server, uuid, socket_addr, device_id| {
|
||||
let conn_mgr = connection_manager_punch.clone();
|
||||
let video = video_manager_punch.clone();
|
||||
let hid = hid_punch.clone();
|
||||
let audio = audio_punch.clone();
|
||||
let config = service_config_punch.clone();
|
||||
|
||||
let conn_mgr = connection_manager.clone();
|
||||
let config = service_config.clone();
|
||||
tokio::spawn(async move {
|
||||
// Get relay_key from config (no public server fallback)
|
||||
let relay_key = {
|
||||
let cfg = config.read();
|
||||
cfg.relay_key.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Try P2P direct connection first
|
||||
if let Some(addr) = peer_addr {
|
||||
info!("Attempting P2P direct connection to {}", addr);
|
||||
match punch::try_direct_connection(addr).await {
|
||||
@@ -265,7 +200,7 @@ impl RustDeskService {
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to relay
|
||||
let relay_key = rustdesk_relay_key(&config);
|
||||
if let Err(e) = handle_relay_request(
|
||||
&rendezvous_addr,
|
||||
&relay_server,
|
||||
@@ -274,34 +209,23 @@ impl RustDeskService {
|
||||
&device_id,
|
||||
&relay_key,
|
||||
conn_mgr,
|
||||
video,
|
||||
hid,
|
||||
audio,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Failed to handle relay request: {}", e);
|
||||
}
|
||||
});
|
||||
},
|
||||
));
|
||||
}
|
||||
}));
|
||||
|
||||
// Set the relay callback on the mediator
|
||||
mediator.set_relay_callback(Arc::new(
|
||||
mediator.set_relay_callback(Arc::new({
|
||||
let connection_manager = connection_manager.clone();
|
||||
let service_config = service_config.clone();
|
||||
move |rendezvous_addr, relay_server, uuid, socket_addr, device_id| {
|
||||
let conn_mgr = connection_manager.clone();
|
||||
let video = video_manager.clone();
|
||||
let hid = hid.clone();
|
||||
let audio = audio.clone();
|
||||
let config = service_config.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Get relay_key from config (no public server fallback)
|
||||
let relay_key = {
|
||||
let cfg = config.read();
|
||||
cfg.relay_key.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
let relay_key = rustdesk_relay_key(&config);
|
||||
if let Err(e) = handle_relay_request(
|
||||
&rendezvous_addr,
|
||||
&relay_server,
|
||||
@@ -310,19 +234,15 @@ impl RustDeskService {
|
||||
&device_id,
|
||||
&relay_key,
|
||||
conn_mgr,
|
||||
video,
|
||||
hid,
|
||||
audio,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Failed to handle relay request: {}", e);
|
||||
}
|
||||
});
|
||||
},
|
||||
));
|
||||
}
|
||||
}));
|
||||
|
||||
// Set the intranet callback on the mediator for same-LAN connections
|
||||
let connection_manager2 = self.connection_manager.clone();
|
||||
mediator.set_intranet_callback(Arc::new(
|
||||
move |rendezvous_addr, peer_socket_addr, local_addr, relay_server, device_id| {
|
||||
@@ -345,7 +265,6 @@ impl RustDeskService {
|
||||
},
|
||||
));
|
||||
|
||||
// Spawn rendezvous task
|
||||
let status = self.status.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
@@ -357,7 +276,6 @@ impl RustDeskService {
|
||||
Err(e) => {
|
||||
error!("Rendezvous mediator error: {}", e);
|
||||
*status.write() = ServiceStatus::Error(e.to_string());
|
||||
// Wait before retry
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
*status.write() = ServiceStatus::Starting;
|
||||
}
|
||||
@@ -372,10 +290,7 @@ impl RustDeskService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start TCP listener for direct peer connections
|
||||
/// Returns the join handle and the port that was bound
|
||||
async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(Vec<JoinHandle<()>>, u16)> {
|
||||
// Try to bind to the default port, or find an available port
|
||||
let (listeners, listen_port) = match self.bind_direct_listeners(DIRECT_LISTEN_PORT) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
@@ -453,7 +368,6 @@ impl RustDeskService {
|
||||
Ok((listeners, listen_port))
|
||||
}
|
||||
|
||||
/// Stop the RustDesk service
|
||||
pub async fn stop(&self) -> anyhow::Result<()> {
|
||||
if self.status() == ServiceStatus::Stopped {
|
||||
return Ok(());
|
||||
@@ -461,23 +375,18 @@ impl RustDeskService {
|
||||
|
||||
info!("Stopping RustDesk service");
|
||||
|
||||
// Send shutdown signal (this will stop the TCP listener)
|
||||
let _ = self.shutdown_tx.send(());
|
||||
|
||||
// Close all connections
|
||||
self.connection_manager.close_all();
|
||||
|
||||
// Stop rendezvous mediator
|
||||
if let Some(mediator) = self.rendezvous.read().as_ref() {
|
||||
mediator.stop();
|
||||
}
|
||||
|
||||
// Wait for rendezvous task to finish
|
||||
if let Some(handle) = self.rendezvous_handle.write().take() {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
// Wait for TCP listener task to finish
|
||||
if let Some(handles) = self.tcp_listener_handle.write().take() {
|
||||
for handle in handles {
|
||||
handle.abort();
|
||||
@@ -490,15 +399,12 @@ impl RustDeskService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Restart the service with new configuration
|
||||
pub async fn restart(&self, config: RustDeskConfig) -> anyhow::Result<()> {
|
||||
self.stop().await?;
|
||||
self.update_config(config);
|
||||
self.start().await
|
||||
}
|
||||
|
||||
/// Save keypair and UUID to config
|
||||
/// Returns the updated config if changes were made
|
||||
pub fn save_credentials(&self) -> Option<RustDeskConfig> {
|
||||
if let Some(mediator) = self.rendezvous.read().as_ref() {
|
||||
let kp = mediator.ensure_keypair();
|
||||
@@ -506,7 +412,6 @@ impl RustDeskService {
|
||||
let mut config = self.config.write();
|
||||
let mut changed = false;
|
||||
|
||||
// Save encryption keypair (Curve25519)
|
||||
let pk = kp.public_key_base64();
|
||||
let sk = kp.secret_key_base64();
|
||||
if config.public_key.as_ref() != Some(&pk) || config.private_key.as_ref() != Some(&sk) {
|
||||
@@ -515,7 +420,6 @@ impl RustDeskService {
|
||||
changed = true;
|
||||
}
|
||||
|
||||
// Save signing keypair (Ed25519)
|
||||
let signing_pk = skp.public_key_base64();
|
||||
let signing_sk = skp.secret_key_base64();
|
||||
if config.signing_public_key.as_ref() != Some(&signing_pk)
|
||||
@@ -526,7 +430,6 @@ impl RustDeskService {
|
||||
changed = true;
|
||||
}
|
||||
|
||||
// Save UUID if it was newly generated
|
||||
if mediator.uuid_needs_save() {
|
||||
let mediator_config = mediator.config();
|
||||
if let Some(uuid) = mediator_config.uuid {
|
||||
@@ -545,21 +448,16 @@ impl RustDeskService {
|
||||
None
|
||||
}
|
||||
|
||||
/// Save keypair to config (deprecated, use save_credentials instead)
|
||||
#[deprecated(note = "Use save_credentials instead")]
|
||||
pub fn save_keypair(&self) {
|
||||
let _ = self.save_credentials();
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle relay request from rendezvous server
|
||||
///
|
||||
/// Correct flow based on RustDesk protocol:
|
||||
/// 1. Connect to RENDEZVOUS server (not relay!)
|
||||
/// 2. Send RelayResponse with client's socket_addr
|
||||
/// 3. Connect to RELAY server
|
||||
/// 4. Accept connection without waiting for response
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn rustdesk_relay_key(config: &Arc<RwLock<RustDeskConfig>>) -> String {
|
||||
config.read().relay_key.clone().unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn handle_relay_request(
|
||||
rendezvous_addr: &str,
|
||||
relay_server: &str,
|
||||
@@ -568,16 +466,12 @@ async fn handle_relay_request(
|
||||
device_id: &str,
|
||||
relay_key: &str,
|
||||
connection_manager: Arc<ConnectionManager>,
|
||||
_video_manager: Arc<VideoStreamManager>,
|
||||
_hid: Arc<HidController>,
|
||||
_audio: Arc<AudioController>,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"Handling relay request: rendezvous={}, relay={}, uuid={}",
|
||||
rendezvous_addr, relay_server, uuid
|
||||
);
|
||||
|
||||
// Step 1: Connect to RENDEZVOUS server and send RelayResponse
|
||||
let rendezvous_socket_addr: SocketAddr = tokio::net::lookup_host(rendezvous_addr)
|
||||
.await?
|
||||
.next()
|
||||
@@ -597,8 +491,7 @@ async fn handle_relay_request(
|
||||
rendezvous_socket_addr
|
||||
);
|
||||
|
||||
// Send RelayResponse to rendezvous server with client's socket_addr
|
||||
// IMPORTANT: Include our device ID so rendezvous server can look up and sign our public key
|
||||
// Rendezvous looks up our pk by device id (must set `id`, not raw pk on wire).
|
||||
let relay_response = make_relay_response(uuid, socket_addr, relay_server, device_id);
|
||||
let bytes = relay_response
|
||||
.write_to_bytes()
|
||||
@@ -606,10 +499,8 @@ async fn handle_relay_request(
|
||||
bytes_codec::write_frame(&mut rendezvous_stream, &bytes).await?;
|
||||
debug!("Sent RelayResponse to rendezvous server for uuid={}", uuid);
|
||||
|
||||
// Close rendezvous connection - we don't need to wait for response
|
||||
drop(rendezvous_stream);
|
||||
|
||||
// Step 2: Connect to RELAY server and send RequestRelay to identify ourselves
|
||||
let relay_addr: SocketAddr = tokio::net::lookup_host(relay_server)
|
||||
.await?
|
||||
.next()
|
||||
@@ -624,9 +515,7 @@ async fn handle_relay_request(
|
||||
|
||||
info!("Connected to relay server at {}", relay_addr);
|
||||
|
||||
// Send RequestRelay to relay server with our uuid, licence_key, and peer's socket_addr
|
||||
// The licence_key is required if the relay server is configured with -k option
|
||||
// The socket_addr is CRITICAL - the relay server uses it to match us with the peer
|
||||
// Relay pairs peers by uuid + mangled peer socket_addr (required when hbbr uses -k).
|
||||
let request_relay = make_request_relay(uuid, relay_key, socket_addr);
|
||||
let bytes = request_relay
|
||||
.write_to_bytes()
|
||||
@@ -634,10 +523,8 @@ async fn handle_relay_request(
|
||||
bytes_codec::write_frame(&mut stream, &bytes).await?;
|
||||
debug!("Sent RequestRelay to relay server for uuid={}", uuid);
|
||||
|
||||
// Decode peer address for logging
|
||||
let peer_addr = rendezvous::AddrMangle::decode(socket_addr).unwrap_or(relay_addr);
|
||||
|
||||
// Step 3: Accept connection - relay server bridges the connection
|
||||
connection_manager
|
||||
.accept_connection(stream, peer_addr)
|
||||
.await?;
|
||||
@@ -649,14 +536,6 @@ async fn handle_relay_request(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle intranet/same-LAN connection request
|
||||
///
|
||||
/// When the server determines that the client and peer are on the same intranet
|
||||
/// (same public IP or both on LAN), it sends FetchLocalAddr to the peer.
|
||||
/// The peer must:
|
||||
/// 1. Open a TCP connection to the rendezvous server
|
||||
/// 2. Send LocalAddr with our local address
|
||||
/// 3. Accept the peer connection over that same TCP stream
|
||||
async fn handle_intranet_request(
|
||||
rendezvous_addr: &str,
|
||||
peer_socket_addr: &[u8],
|
||||
@@ -670,11 +549,9 @@ async fn handle_intranet_request(
|
||||
rendezvous_addr, local_addr, device_id
|
||||
);
|
||||
|
||||
// Decode peer address for logging
|
||||
let peer_addr = AddrMangle::decode(peer_socket_addr);
|
||||
debug!("Peer address from FetchLocalAddr: {:?}", peer_addr);
|
||||
|
||||
// Connect to rendezvous server via TCP with timeout
|
||||
let mut stream =
|
||||
tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(rendezvous_addr))
|
||||
.await
|
||||
@@ -685,7 +562,6 @@ async fn handle_intranet_request(
|
||||
rendezvous_addr
|
||||
);
|
||||
|
||||
// Build LocalAddr message with our local address (mangled)
|
||||
let local_addr_bytes = AddrMangle::encode(local_addr);
|
||||
let msg = make_local_addr(
|
||||
peer_socket_addr,
|
||||
@@ -698,24 +574,16 @@ async fn handle_intranet_request(
|
||||
.write_to_bytes()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?;
|
||||
|
||||
// Send LocalAddr using RustDesk's variable-length framing
|
||||
bytes_codec::write_frame(&mut stream, &bytes).await?;
|
||||
|
||||
info!("Sent LocalAddr to rendezvous server, waiting for peer connection");
|
||||
|
||||
// Now the rendezvous server will forward this to the client,
|
||||
// and the client will connect to us through this same TCP stream.
|
||||
// The server proxies the connection between client and peer.
|
||||
|
||||
// Get peer address for logging/connection tracking
|
||||
let effective_peer_addr = peer_addr.unwrap_or_else(|| {
|
||||
// If we can't decode the peer address, use the rendezvous server address
|
||||
rendezvous_addr
|
||||
.parse()
|
||||
.unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap())
|
||||
});
|
||||
|
||||
// Accept the connection - the stream is now a proxied connection to the client
|
||||
connection_manager
|
||||
.accept_connection(stream, effective_peer_addr)
|
||||
.await?;
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
//! RustDesk Protocol Messages
|
||||
//!
|
||||
//! This module provides the compiled protobuf messages for the RustDesk protocol.
|
||||
//! Messages are generated from rendezvous.proto and message.proto at build time.
|
||||
//! Uses protobuf-rust (same as RustDesk server) for full compatibility.
|
||||
//! Protobuf wrappers (`protos/` → `OUT_DIR`).
|
||||
|
||||
use protobuf::Message;
|
||||
|
||||
// Include the generated protobuf code
|
||||
#[path = ""]
|
||||
pub mod hbb {
|
||||
include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));
|
||||
}
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use hbb::rendezvous::{
|
||||
punch_hole_response, relay_response, rendezvous_message, ConfigUpdate, ConnType,
|
||||
FetchLocalAddr, HealthCheck, KeyExchange, LocalAddr, NatType, OnlineRequest, OnlineResponse,
|
||||
@@ -21,7 +15,6 @@ pub use hbb::rendezvous::{
|
||||
RequestRelay, SoftwareUpdate, TestNatRequest, TestNatResponse,
|
||||
};
|
||||
|
||||
// Re-export message.proto types
|
||||
pub use hbb::message::{
|
||||
key_event, login_response, message, misc, AudioFormat, AudioFrame, Auth2FA, Clipboard,
|
||||
ControlKey, CursorData, CursorPosition, DisplayInfo, EncodedVideoFrame, EncodedVideoFrames,
|
||||
@@ -30,7 +23,6 @@ pub use hbb::message::{
|
||||
SupportedResolutions, TestDelay, VideoFrame, WindowsSessions,
|
||||
};
|
||||
|
||||
/// Helper to create a RendezvousMessage with RegisterPeer
|
||||
pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage {
|
||||
let mut rp = RegisterPeer::new();
|
||||
rp.id = id.to_string();
|
||||
@@ -41,7 +33,6 @@ pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage {
|
||||
msg
|
||||
}
|
||||
|
||||
/// Helper to create a RendezvousMessage with RegisterPk
|
||||
pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> RendezvousMessage {
|
||||
let mut rpk = RegisterPk::new();
|
||||
rpk.id = id.to_string();
|
||||
@@ -54,7 +45,6 @@ pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> Rende
|
||||
msg
|
||||
}
|
||||
|
||||
/// Helper to create a PunchHoleSent message
|
||||
pub fn make_punch_hole_sent(
|
||||
socket_addr: &[u8],
|
||||
id: &str,
|
||||
@@ -74,10 +64,7 @@ pub fn make_punch_hole_sent(
|
||||
msg
|
||||
}
|
||||
|
||||
/// Helper to create a RelayResponse message (sent to rendezvous server)
|
||||
/// IMPORTANT: The union field should be `Id` (our device ID), NOT `Pk`.
|
||||
/// The rendezvous server will look up our registered public key using this ID,
|
||||
/// sign it with the server's private key, and set the `pk` field before forwarding to client.
|
||||
/// Use `id` (device id), not raw `pk`; hbbs fills `pk` when forwarding.
|
||||
pub fn make_relay_response(
|
||||
uuid: &str,
|
||||
socket_addr: &[u8],
|
||||
@@ -96,13 +83,7 @@ pub fn make_relay_response(
|
||||
msg
|
||||
}
|
||||
|
||||
/// Helper to create a RequestRelay message (sent to relay server to identify ourselves)
|
||||
///
|
||||
/// The `licence_key` is required if the relay server is configured with a key.
|
||||
/// If the key doesn't match, the relay server will silently reject the connection.
|
||||
///
|
||||
/// IMPORTANT: `socket_addr` is the peer's encoded socket address (from FetchLocalAddr/RelayResponse).
|
||||
/// The relay server uses this to match the two peers connecting to the same relay session.
|
||||
/// `socket_addr` must be the peer's mangled addr; `licence_key` required if hbbr uses `-k`.
|
||||
pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) -> RendezvousMessage {
|
||||
let mut rr = RequestRelay::new();
|
||||
rr.uuid = uuid.to_string();
|
||||
@@ -114,8 +95,6 @@ pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) ->
|
||||
msg
|
||||
}
|
||||
|
||||
/// Helper to create a LocalAddr response message
|
||||
/// This is sent in response to FetchLocalAddr when a peer on the same LAN wants to connect
|
||||
pub fn make_local_addr(
|
||||
socket_addr: &[u8],
|
||||
local_addr: &[u8],
|
||||
@@ -135,12 +114,10 @@ pub fn make_local_addr(
|
||||
msg
|
||||
}
|
||||
|
||||
/// Decode a RendezvousMessage from bytes
|
||||
pub fn decode_rendezvous_message(buf: &[u8]) -> Result<RendezvousMessage, protobuf::Error> {
|
||||
RendezvousMessage::parse_from_bytes(buf)
|
||||
}
|
||||
|
||||
/// Decode a Message (session message) from bytes
|
||||
pub fn decode_message(buf: &[u8]) -> Result<hbb::message::Message, protobuf::Error> {
|
||||
hbb::message::Message::parse_from_bytes(buf)
|
||||
}
|
||||
|
||||
@@ -1,36 +1,19 @@
|
||||
//! P2P Punch Hole Implementation
|
||||
//!
|
||||
//! This module implements TCP direct connection attempts with relay fallback.
|
||||
//! When a PunchHole request is received, we try to connect directly to the peer.
|
||||
//! If the direct connection fails (timeout), we fall back to relay.
|
||||
//! Direct TCP attempt before relay fallback.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, info, warn};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::connection::ConnectionManager;
|
||||
|
||||
/// Timeout for direct TCP connection attempt
|
||||
const DIRECT_CONNECT_TIMEOUT_MS: u64 = 3000;
|
||||
|
||||
/// Result of a punch hole attempt
|
||||
#[derive(Debug)]
|
||||
pub enum PunchResult {
|
||||
/// Direct connection succeeded
|
||||
DirectConnection(TcpStream),
|
||||
/// Direct connection failed, should use relay
|
||||
NeedRelay,
|
||||
}
|
||||
|
||||
/// Attempt direct TCP connection to peer
|
||||
///
|
||||
/// This is a simplified P2P approach:
|
||||
/// 1. Try to connect directly to the peer's address
|
||||
/// 2. If successful within timeout, use direct connection
|
||||
/// 3. If failed or timeout, fall back to relay
|
||||
pub async fn try_direct_connection(peer_addr: SocketAddr) -> PunchResult {
|
||||
info!("Attempting direct TCP connection to {}", peer_addr);
|
||||
|
||||
@@ -54,76 +37,3 @@ pub async fn try_direct_connection(peer_addr: SocketAddr) -> PunchResult {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Punch hole handler that tries direct connection first, then falls back to relay
|
||||
pub struct PunchHoleHandler {
|
||||
connection_manager: Arc<ConnectionManager>,
|
||||
}
|
||||
|
||||
impl PunchHoleHandler {
|
||||
pub fn new(connection_manager: Arc<ConnectionManager>) -> Self {
|
||||
Self { connection_manager }
|
||||
}
|
||||
|
||||
/// Handle punch hole request
|
||||
///
|
||||
/// Tries direct connection first, falls back to relay if needed.
|
||||
/// Returns true if direct connection succeeded, false if relay is needed.
|
||||
pub async fn handle_punch_hole(&self, peer_addr: Option<SocketAddr>) -> bool {
|
||||
let peer_addr = match peer_addr {
|
||||
Some(addr) => addr,
|
||||
None => {
|
||||
warn!("No peer address available for punch hole");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
match try_direct_connection(peer_addr).await {
|
||||
PunchResult::DirectConnection(stream) => {
|
||||
// Direct connection succeeded, accept it
|
||||
match self
|
||||
.connection_manager
|
||||
.accept_connection(stream, peer_addr)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
info!("P2P direct connection established with {}", peer_addr);
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to accept direct connection: {}", e);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
PunchResult::NeedRelay => {
|
||||
debug!("Direct connection failed, need relay for {}", peer_addr);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a punch hole attempt with relay fallback
|
||||
///
|
||||
/// This function spawns an async task that:
|
||||
/// 1. Tries direct TCP connection to peer
|
||||
/// 2. If successful, accepts the connection
|
||||
/// 3. If failed, calls the relay callback
|
||||
pub fn spawn_punch_with_fallback<F>(
|
||||
connection_manager: Arc<ConnectionManager>,
|
||||
peer_addr: Option<SocketAddr>,
|
||||
relay_callback: F,
|
||||
) where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
tokio::spawn(async move {
|
||||
let handler = PunchHoleHandler::new(connection_manager);
|
||||
|
||||
if !handler.handle_punch_hole(peer_addr).await {
|
||||
// Direct connection failed, use relay
|
||||
info!("Falling back to relay connection");
|
||||
relay_callback();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
//! RustDesk Rendezvous Mediator
|
||||
//!
|
||||
//! This module handles communication with the hbbs rendezvous server.
|
||||
//! It registers the device ID and public key, handles punch hole requests,
|
||||
//! and relay requests.
|
||||
//! HBBS UDP registration; punch / relay / intranet callbacks.
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
@@ -24,19 +20,14 @@ use super::protocol::{
|
||||
rendezvous_message, NatType, RendezvousMessage,
|
||||
};
|
||||
|
||||
/// Registration interval in milliseconds
|
||||
const REG_INTERVAL_MS: u64 = 12_000;
|
||||
|
||||
/// Minimum registration timeout
|
||||
const MIN_REG_TIMEOUT_MS: u64 = 3_000;
|
||||
|
||||
/// Maximum registration timeout
|
||||
const MAX_REG_TIMEOUT_MS: u64 = 30_000;
|
||||
|
||||
/// Timer interval for checking registration status
|
||||
const TIMER_INTERVAL_MS: u64 = 300;
|
||||
|
||||
/// Rendezvous mediator status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum RendezvousStatus {
|
||||
Disconnected,
|
||||
@@ -58,44 +49,13 @@ impl std::fmt::Display for RendezvousStatus {
|
||||
}
|
||||
}
|
||||
|
||||
/// Callback for handling incoming connection requests
|
||||
pub type ConnectionCallback = Arc<dyn Fn(ConnectionRequest) + Send + Sync>;
|
||||
|
||||
/// Incoming connection request from a RustDesk client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectionRequest {
|
||||
/// Peer socket address (encoded)
|
||||
pub socket_addr: Vec<u8>,
|
||||
/// Relay server to use
|
||||
pub relay_server: String,
|
||||
/// NAT type
|
||||
pub nat_type: NatType,
|
||||
/// Connection UUID
|
||||
pub uuid: String,
|
||||
/// Whether to use secure connection
|
||||
pub secure: bool,
|
||||
}
|
||||
|
||||
/// Callback type for relay requests
|
||||
/// Parameters: rendezvous_addr, relay_server, uuid, socket_addr (client's mangled address), device_id
|
||||
pub type RelayCallback = Arc<dyn Fn(String, String, String, Vec<u8>, String) + Send + Sync>;
|
||||
|
||||
/// Callback type for P2P punch hole requests
|
||||
/// Parameters: peer_addr (decoded), relay_callback_params (rendezvous_addr, relay_server, uuid, socket_addr, device_id)
|
||||
/// Returns: should call relay callback if P2P fails
|
||||
pub type PunchCallback =
|
||||
Arc<dyn Fn(Option<SocketAddr>, String, String, String, Vec<u8>, String) + Send + Sync>;
|
||||
|
||||
/// Callback type for intranet/local address connections
|
||||
/// Parameters: rendezvous_addr, peer_socket_addr (mangled), local_addr, relay_server, device_id
|
||||
pub type IntranetCallback = Arc<dyn Fn(String, Vec<u8>, SocketAddr, String, String) + Send + Sync>;
|
||||
|
||||
/// Rendezvous Mediator
|
||||
///
|
||||
/// Handles communication with hbbs rendezvous server:
|
||||
/// - Registers device ID and public key
|
||||
/// - Maintains keep-alive with server
|
||||
/// - Handles punch hole and relay requests
|
||||
pub struct RendezvousMediator {
|
||||
config: Arc<RwLock<RustDeskConfig>>,
|
||||
keypair: Arc<RwLock<Option<KeyPair>>>,
|
||||
@@ -114,11 +74,9 @@ pub struct RendezvousMediator {
|
||||
}
|
||||
|
||||
impl RendezvousMediator {
|
||||
/// Create a new rendezvous mediator
|
||||
pub fn new(mut config: RustDeskConfig) -> Self {
|
||||
let (shutdown_tx, _) = broadcast::channel(1);
|
||||
|
||||
// Get or generate UUID from config (persisted)
|
||||
let (uuid, uuid_needs_save) = config.ensure_uuid();
|
||||
|
||||
Self {
|
||||
@@ -139,88 +97,71 @@ impl RendezvousMediator {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the TCP listen port for direct connections
|
||||
pub fn set_listen_port(&self, port: u16) {
|
||||
let old_port = *self.listen_port.read();
|
||||
if old_port != port {
|
||||
*self.listen_port.write() = port;
|
||||
// Port changed, increment serial to notify server
|
||||
self.increment_serial();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the TCP listen port
|
||||
pub fn listen_port(&self) -> u16 {
|
||||
*self.listen_port.read()
|
||||
}
|
||||
|
||||
/// Increment the serial number to indicate local state change
|
||||
pub fn increment_serial(&self) {
|
||||
let mut serial = self.serial.write();
|
||||
*serial = serial.wrapping_add(1);
|
||||
debug!("Serial incremented to {}", *serial);
|
||||
}
|
||||
|
||||
/// Get current serial number
|
||||
pub fn serial(&self) -> i32 {
|
||||
*self.serial.read()
|
||||
}
|
||||
|
||||
/// Check if UUID needs to be saved to persistent storage
|
||||
pub fn uuid_needs_save(&self) -> bool {
|
||||
*self.uuid_needs_save.read()
|
||||
}
|
||||
|
||||
/// Get the current config (with UUID set)
|
||||
pub fn config(&self) -> RustDeskConfig {
|
||||
self.config.read().clone()
|
||||
}
|
||||
|
||||
/// Mark UUID as saved
|
||||
pub fn mark_uuid_saved(&self) {
|
||||
*self.uuid_needs_save.write() = false;
|
||||
}
|
||||
|
||||
/// Set the callback for relay requests
|
||||
pub fn set_relay_callback(&self, callback: RelayCallback) {
|
||||
*self.relay_callback.write() = Some(callback);
|
||||
}
|
||||
|
||||
/// Set the callback for P2P punch hole requests
|
||||
pub fn set_punch_callback(&self, callback: PunchCallback) {
|
||||
*self.punch_callback.write() = Some(callback);
|
||||
}
|
||||
|
||||
/// Set the callback for intranet/local address connections
|
||||
pub fn set_intranet_callback(&self, callback: IntranetCallback) {
|
||||
*self.intranet_callback.write() = Some(callback);
|
||||
}
|
||||
|
||||
/// Get current status
|
||||
pub fn status(&self) -> RendezvousStatus {
|
||||
self.status.read().clone()
|
||||
}
|
||||
|
||||
/// Update configuration
|
||||
pub fn update_config(&self, config: RustDeskConfig) {
|
||||
*self.config.write() = config;
|
||||
// Config changed, increment serial to notify server
|
||||
self.increment_serial();
|
||||
}
|
||||
|
||||
/// Initialize or get keypair (Curve25519 for encryption)
|
||||
pub fn ensure_keypair(&self) -> KeyPair {
|
||||
let mut keypair_guard = self.keypair.write();
|
||||
if keypair_guard.is_none() {
|
||||
let config = self.config.read();
|
||||
// Try to load from config first
|
||||
if let (Some(pk), Some(sk)) = (&config.public_key, &config.private_key) {
|
||||
if let Ok(kp) = KeyPair::from_base64(pk, sk) {
|
||||
*keypair_guard = Some(kp.clone());
|
||||
return kp;
|
||||
}
|
||||
}
|
||||
// Generate new keypair
|
||||
let kp = KeyPair::generate();
|
||||
*keypair_guard = Some(kp.clone());
|
||||
kp
|
||||
@@ -229,12 +170,10 @@ impl RendezvousMediator {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize or get signing keypair (Ed25519 for SignedId)
|
||||
pub fn ensure_signing_keypair(&self) -> SigningKeyPair {
|
||||
let mut signing_guard = self.signing_keypair.write();
|
||||
if signing_guard.is_none() {
|
||||
let config = self.config.read();
|
||||
// Try to load from config first
|
||||
if let (Some(pk), Some(sk)) = (&config.signing_public_key, &config.signing_private_key)
|
||||
{
|
||||
if let Ok(skp) = SigningKeyPair::from_base64(pk, sk) {
|
||||
@@ -245,7 +184,6 @@ impl RendezvousMediator {
|
||||
warn!("Failed to decode signing keypair from config, generating new one");
|
||||
}
|
||||
}
|
||||
// Generate new signing keypair
|
||||
let skp = SigningKeyPair::generate();
|
||||
debug!("Generated new signing keypair");
|
||||
*signing_guard = Some(skp.clone());
|
||||
@@ -255,12 +193,10 @@ impl RendezvousMediator {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the device ID
|
||||
pub fn device_id(&self) -> String {
|
||||
self.config.read().device_id.clone()
|
||||
}
|
||||
|
||||
/// Start the rendezvous mediator
|
||||
pub async fn start(&self) -> anyhow::Result<()> {
|
||||
let config = self.config.read().clone();
|
||||
let effective_server = config.effective_rendezvous_server();
|
||||
@@ -284,13 +220,11 @@ impl RendezvousMediator {
|
||||
config.device_id, addr
|
||||
);
|
||||
|
||||
// Resolve server address
|
||||
let server_addr: SocketAddr = tokio::net::lookup_host(&addr)
|
||||
.await?
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to resolve {}", addr))?;
|
||||
|
||||
// Create UDP socket (match address family, enforce IPV6_V6ONLY)
|
||||
let bind_addr = match server_addr {
|
||||
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
|
||||
@@ -302,11 +236,9 @@ impl RendezvousMediator {
|
||||
info!("Connected to rendezvous server at {}", server_addr);
|
||||
*self.status.write() = RendezvousStatus::Connected;
|
||||
|
||||
// Start registration loop
|
||||
self.registration_loop(socket).await
|
||||
}
|
||||
|
||||
/// Main registration loop
|
||||
async fn registration_loop(&self, socket: UdpSocket) -> anyhow::Result<()> {
|
||||
let mut timer = interval(Duration::from_millis(TIMER_INTERVAL_MS));
|
||||
let mut recv_buf = vec![0u8; 65535];
|
||||
@@ -318,7 +250,6 @@ impl RendezvousMediator {
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Handle incoming messages
|
||||
result = socket.recv(&mut recv_buf) => {
|
||||
match result {
|
||||
Ok(len) => {
|
||||
@@ -336,7 +267,6 @@ impl RendezvousMediator {
|
||||
}
|
||||
}
|
||||
|
||||
// Periodic registration
|
||||
_ = timer.tick() => {
|
||||
let now = Instant::now();
|
||||
let expired = last_register_resp
|
||||
@@ -360,7 +290,6 @@ impl RendezvousMediator {
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown signal
|
||||
_ = shutdown_rx.recv() => {
|
||||
info!("Rendezvous mediator shutting down");
|
||||
break;
|
||||
@@ -372,20 +301,16 @@ impl RendezvousMediator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send registration message
|
||||
async fn send_register(&self, socket: &UdpSocket) -> anyhow::Result<()> {
|
||||
let key_confirmed = *self.key_confirmed.read();
|
||||
|
||||
if !key_confirmed {
|
||||
// Send RegisterPk with public key
|
||||
self.send_register_pk(socket).await
|
||||
} else {
|
||||
// Send RegisterPeer heartbeat
|
||||
self.send_register_peer(socket).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Send RegisterPeer message
|
||||
async fn send_register_peer(&self, socket: &UdpSocket) -> anyhow::Result<()> {
|
||||
let id = self.device_id();
|
||||
let serial = *self.serial.read();
|
||||
@@ -398,12 +323,8 @@ impl RendezvousMediator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send RegisterPk message
|
||||
/// Uses the Ed25519 signing public key for registration
|
||||
async fn send_register_pk(&self, socket: &UdpSocket) -> anyhow::Result<()> {
|
||||
let id = self.device_id();
|
||||
// Use signing public key (Ed25519) for RegisterPk
|
||||
// This is what clients will use to verify our SignedId signature
|
||||
let signing_keypair = self.ensure_signing_keypair();
|
||||
let pk = signing_keypair.public_key_bytes();
|
||||
let uuid = *self.uuid.read();
|
||||
@@ -417,12 +338,6 @@ impl RendezvousMediator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle FetchLocalAddr - send to callback for proper TCP handling
|
||||
///
|
||||
/// The intranet callback will:
|
||||
/// 1. Open a TCP connection to the rendezvous server
|
||||
/// 2. Send LocalAddr message
|
||||
/// 3. Accept the peer connection over that same TCP stream
|
||||
async fn send_local_addr(
|
||||
&self,
|
||||
_udp_socket: &UdpSocket,
|
||||
@@ -431,21 +346,17 @@ impl RendezvousMediator {
|
||||
) -> anyhow::Result<()> {
|
||||
let id = self.device_id();
|
||||
|
||||
// Get our actual local IP addresses for same-LAN connection
|
||||
let local_addrs = get_local_addresses();
|
||||
if local_addrs.is_empty() {
|
||||
debug!("No local addresses available for LocalAddr response");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get the rendezvous server address for TCP connection
|
||||
let config = self.config.read().clone();
|
||||
let rendezvous_addr = config.rendezvous_addr();
|
||||
|
||||
// Use TCP listen port for direct connections
|
||||
let listen_port = self.listen_port();
|
||||
|
||||
// Use the first local IP
|
||||
let local_ip = local_addrs[0];
|
||||
let local_sock_addr = SocketAddr::new(local_ip, listen_port);
|
||||
|
||||
@@ -454,7 +365,6 @@ impl RendezvousMediator {
|
||||
local_sock_addr, rendezvous_addr
|
||||
);
|
||||
|
||||
// Call the intranet callback if set
|
||||
if let Some(callback) = self.intranet_callback.read().as_ref() {
|
||||
callback(
|
||||
rendezvous_addr,
|
||||
@@ -470,7 +380,6 @@ impl RendezvousMediator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle response from rendezvous server
|
||||
async fn handle_response(
|
||||
&self,
|
||||
socket: &UdpSocket,
|
||||
@@ -486,7 +395,6 @@ impl RendezvousMediator {
|
||||
match msg.union {
|
||||
Some(rendezvous_message::Union::RegisterPeerResponse(rpr)) => {
|
||||
if rpr.request_pk {
|
||||
// Server wants us to register our public key
|
||||
info!("Server requested public key registration");
|
||||
*self.key_confirmed.write() = false;
|
||||
self.send_register_pk(socket).await?;
|
||||
@@ -497,30 +405,24 @@ impl RendezvousMediator {
|
||||
info!("Received RegisterPkResponse: result={:?}", rpr.result);
|
||||
match rpr.result.value() {
|
||||
0 => {
|
||||
// OK
|
||||
info!("✓ Public key registered successfully with server");
|
||||
*self.key_confirmed.write() = true;
|
||||
// Increment serial after successful registration
|
||||
self.increment_serial();
|
||||
*self.status.write() = RendezvousStatus::Registered;
|
||||
}
|
||||
2 => {
|
||||
// UUID_MISMATCH
|
||||
warn!("UUID mismatch, need to re-register");
|
||||
*self.key_confirmed.write() = false;
|
||||
}
|
||||
3 => {
|
||||
// ID_EXISTS
|
||||
error!("Device ID already exists on server");
|
||||
*self.status.write() =
|
||||
RendezvousStatus::Error("Device ID already exists".to_string());
|
||||
}
|
||||
4 => {
|
||||
// TOO_FREQUENT
|
||||
warn!("Registration too frequent");
|
||||
}
|
||||
5 => {
|
||||
// INVALID_ID_FORMAT
|
||||
error!("Invalid device ID format");
|
||||
*self.status.write() =
|
||||
RendezvousStatus::Error("Invalid ID format".to_string());
|
||||
@@ -540,7 +442,6 @@ impl RendezvousMediator {
|
||||
let effective_relay_server =
|
||||
select_relay_server(config.relay_server.as_deref(), &ph.relay_server);
|
||||
|
||||
// Decode the peer's socket address
|
||||
let peer_addr = if !ph.socket_addr.is_empty() {
|
||||
AddrMangle::decode(&ph.socket_addr)
|
||||
} else {
|
||||
@@ -556,9 +457,7 @@ impl RendezvousMediator {
|
||||
ph.nat_type
|
||||
);
|
||||
|
||||
// Send PunchHoleSent to acknowledge
|
||||
// IMPORTANT: socket_addr in PunchHoleSent should be the PEER's address (from PunchHole),
|
||||
// not our own address. This is how RustDesk protocol works.
|
||||
let id = self.device_id();
|
||||
|
||||
info!(
|
||||
@@ -586,16 +485,11 @@ impl RendezvousMediator {
|
||||
info!("Sent PunchHoleSent response successfully");
|
||||
}
|
||||
|
||||
// Try P2P direct connection first, fall back to relay if needed
|
||||
if let Some(relay_server) = effective_relay_server {
|
||||
// Generate a standard UUID v4 for relay pairing
|
||||
// This must match the format used by RustDesk client
|
||||
let uuid = uuid::Uuid::new_v4().to_string();
|
||||
let rendezvous_addr = config.rendezvous_addr();
|
||||
let device_id = config.device_id.clone();
|
||||
|
||||
// Use punch callback if set (tries P2P first, then relay)
|
||||
// Otherwise fall back to relay callback directly
|
||||
if let Some(callback) = self.punch_callback.read().as_ref() {
|
||||
callback(
|
||||
peer_addr,
|
||||
@@ -630,7 +524,6 @@ impl RendezvousMediator {
|
||||
rr.uuid,
|
||||
rr.secure
|
||||
);
|
||||
// Call the relay callback to handle the connection
|
||||
if let Some(callback) = self.relay_callback.read().as_ref() {
|
||||
if let Some(relay_server) = effective_relay_server {
|
||||
let rendezvous_addr = config.rendezvous_addr();
|
||||
@@ -653,7 +546,6 @@ impl RendezvousMediator {
|
||||
select_relay_server(config.relay_server.as_deref(), &fla.relay_server)
|
||||
.unwrap_or_default();
|
||||
|
||||
// Decode the peer address for logging
|
||||
let peer_addr = AddrMangle::decode(&fla.socket_addr);
|
||||
info!(
|
||||
"Received FetchLocalAddr request: peer_addr={:?}, socket_addr_len={}, relay_server={}, effective_relay_server={}",
|
||||
@@ -662,7 +554,6 @@ impl RendezvousMediator {
|
||||
fla.relay_server,
|
||||
effective_relay_server
|
||||
);
|
||||
// Respond with our local address for same-LAN direct connection
|
||||
self.send_local_addr(socket, &fla.socket_addr, &effective_relay_server)
|
||||
.await?;
|
||||
}
|
||||
@@ -671,7 +562,6 @@ impl RendezvousMediator {
|
||||
*self.serial.write() = cu.serial;
|
||||
}
|
||||
Some(other) => {
|
||||
// Log the actual message type for debugging
|
||||
let type_name = match other {
|
||||
rendezvous_message::Union::PunchHoleRequest(_) => "PunchHoleRequest",
|
||||
rendezvous_message::Union::PunchHoleResponse(_) => "PunchHoleResponse",
|
||||
@@ -696,23 +586,18 @@ impl RendezvousMediator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the rendezvous mediator
|
||||
pub fn stop(&self) {
|
||||
info!("Stopping rendezvous mediator");
|
||||
let _ = self.shutdown_tx.send(());
|
||||
*self.status.write() = RendezvousStatus::Disconnected;
|
||||
}
|
||||
|
||||
/// Get a shutdown receiver
|
||||
pub fn shutdown_rx(&self) -> broadcast::Receiver<()> {
|
||||
self.shutdown_tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
/// AddrMangle - RustDesk's address encoding scheme
|
||||
///
|
||||
/// Certain routers and firewalls scan packets and modify IP addresses.
|
||||
/// This encoding mangles the address to avoid detection.
|
||||
/// RustDesk mangled socket encoding.
|
||||
pub struct AddrMangle;
|
||||
|
||||
fn normalize_relay_server(server: &str) -> Option<String> {
|
||||
@@ -735,9 +620,7 @@ fn select_relay_server(local_relay: Option<&str>, server_relay: &str) -> Option<
|
||||
}
|
||||
|
||||
impl AddrMangle {
|
||||
/// Encode a SocketAddr to bytes using RustDesk's mangle algorithm
|
||||
pub fn encode(addr: SocketAddr) -> Vec<u8> {
|
||||
// Try to convert IPv6-mapped IPv4 to plain IPv4
|
||||
let addr = try_into_v4(addr);
|
||||
|
||||
match addr {
|
||||
@@ -753,7 +636,6 @@ impl AddrMangle {
|
||||
let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF));
|
||||
let bytes = v.to_le_bytes();
|
||||
|
||||
// Remove trailing zeros
|
||||
let mut n_padding = 0;
|
||||
for i in bytes.iter().rev() {
|
||||
if *i == 0u8 {
|
||||
@@ -774,13 +656,11 @@ impl AddrMangle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode bytes to SocketAddr using RustDesk's mangle algorithm
|
||||
pub fn decode(bytes: &[u8]) -> Option<SocketAddr> {
|
||||
use std::convert::TryInto;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4};
|
||||
|
||||
if bytes.len() > 16 {
|
||||
// IPv6 format: 16 bytes IP + 2 bytes port
|
||||
if bytes.len() != 18 {
|
||||
return None;
|
||||
}
|
||||
@@ -791,7 +671,6 @@ impl AddrMangle {
|
||||
return Some(SocketAddr::new(std::net::IpAddr::V6(ip), port));
|
||||
}
|
||||
|
||||
// IPv4 mangled format
|
||||
let mut padded = [0u8; 16];
|
||||
padded[..bytes.len()].copy_from_slice(bytes);
|
||||
let number = u128::from_le_bytes(padded);
|
||||
@@ -805,7 +684,6 @@ impl AddrMangle {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to convert IPv6-mapped IPv4 address to plain IPv4
|
||||
fn try_into_v4(addr: SocketAddr) -> SocketAddr {
|
||||
match addr {
|
||||
SocketAddr::V6(v6) if !addr.ip().is_loopback() => {
|
||||
@@ -818,41 +696,30 @@ fn try_into_v4(addr: SocketAddr) -> SocketAddr {
|
||||
addr
|
||||
}
|
||||
|
||||
/// Check if an interface name belongs to Docker or other virtual networks
|
||||
fn is_virtual_interface(name: &str) -> bool {
|
||||
// Docker interfaces
|
||||
name.starts_with("docker")
|
||||
|| name.starts_with("br-")
|
||||
|| name.starts_with("veth")
|
||||
// Kubernetes/container interfaces
|
||||
|| name.starts_with("cni")
|
||||
|| name.starts_with("flannel")
|
||||
|| name.starts_with("calico")
|
||||
|| name.starts_with("weave")
|
||||
// Virtual bridge interfaces
|
||||
|| name.starts_with("virbr")
|
||||
|| name.starts_with("lxcbr")
|
||||
|| name.starts_with("lxdbr")
|
||||
// VPN interfaces (usually not useful for LAN discovery)
|
||||
|| name.starts_with("tun")
|
||||
|| name.starts_with("tap")
|
||||
}
|
||||
|
||||
/// Check if an IP address is in a Docker/container private range
|
||||
fn is_docker_ip(ip: &std::net::IpAddr) -> bool {
|
||||
if let std::net::IpAddr::V4(ipv4) = ip {
|
||||
let octets = ipv4.octets();
|
||||
// Docker default bridge: 172.17.0.0/16
|
||||
if octets[0] == 172 && octets[1] == 17 {
|
||||
return true;
|
||||
}
|
||||
// Docker user-defined networks: 172.18-31.0.0/16
|
||||
if octets[0] == 172 && (18..=31).contains(&octets[1]) {
|
||||
return true;
|
||||
}
|
||||
// Docker overlay networks: 10.0.0.0/8 (common range)
|
||||
// Note: 10.x.x.x is also used for corporate LANs, so we only filter
|
||||
// specific Docker-like patterns (10.0.x.x with small third octet)
|
||||
if octets[0] == 10 && octets[1] == 0 && octets[2] < 10 {
|
||||
return true;
|
||||
}
|
||||
@@ -860,22 +727,18 @@ fn is_docker_ip(ip: &std::net::IpAddr) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Get local IP addresses (non-loopback, non-Docker)
|
||||
fn get_local_addresses() -> Vec<std::net::IpAddr> {
|
||||
let mut addrs = Vec::new();
|
||||
|
||||
// Use pnet or network-interface crate if available, otherwise use simple method
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
if let Ok(interfaces) = std::fs::read_dir("/sys/class/net") {
|
||||
for entry in interfaces.flatten() {
|
||||
let iface_name = entry.file_name().to_string_lossy().to_string();
|
||||
// Skip loopback and virtual interfaces
|
||||
if iface_name == "lo" || is_virtual_interface(&iface_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to get IP via command (simple approach)
|
||||
if let Ok(output) = std::process::Command::new("ip")
|
||||
.args(["-4", "addr", "show", &iface_name])
|
||||
.output()
|
||||
@@ -886,7 +749,6 @@ fn get_local_addresses() -> Vec<std::net::IpAddr> {
|
||||
let ip_part = &line[inet_pos + 5..];
|
||||
if let Some(slash_pos) = ip_part.find('/') {
|
||||
if let Ok(ip) = ip_part[..slash_pos].parse::<std::net::IpAddr>() {
|
||||
// Skip loopback and Docker IPs
|
||||
if !ip.is_loopback() && !is_docker_ip(&ip) {
|
||||
addrs.push(ip);
|
||||
}
|
||||
@@ -899,15 +761,11 @@ fn get_local_addresses() -> Vec<std::net::IpAddr> {
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try to get default route interface IP
|
||||
if addrs.is_empty() {
|
||||
// Try using DNS lookup to get local IP (connects to external server)
|
||||
if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") {
|
||||
// Connect to a public DNS server (doesn't actually send data)
|
||||
if socket.connect("8.8.8.8:53").is_ok() {
|
||||
if let Ok(local_addr) = socket.local_addr() {
|
||||
let ip = local_addr.ip();
|
||||
// Skip loopback and Docker IPs
|
||||
if !ip.is_loopback() && !is_docker_ip(&ip) {
|
||||
addrs.push(ip);
|
||||
}
|
||||
|
||||
71
src/state.rs
71
src/state.rs
@@ -5,8 +5,9 @@ use crate::atx::AtxController;
|
||||
use crate::audio::AudioController;
|
||||
use crate::auth::{SessionStore, UserStore};
|
||||
use crate::config::ConfigStore;
|
||||
use crate::db::DatabasePool;
|
||||
use crate::events::{
|
||||
AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent,
|
||||
AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, LedState, MsdDeviceInfo, SystemEvent,
|
||||
TtydDeviceInfo, VideoDeviceInfo,
|
||||
};
|
||||
use crate::extensions::{ExtensionId, ExtensionManager};
|
||||
@@ -17,68 +18,42 @@ use crate::rtsp::RtspService;
|
||||
use crate::rustdesk::RustDeskService;
|
||||
use crate::update::UpdateService;
|
||||
use crate::video::VideoStreamManager;
|
||||
use crate::webrtc::WebRtcStreamer;
|
||||
|
||||
/// Application-wide state shared across handlers
|
||||
///
|
||||
/// # Video Streaming
|
||||
///
|
||||
/// All video operations should go through `stream_manager`:
|
||||
/// - `stream_manager.webrtc_streamer()` - WebRTC streaming (H264, extensible to VP8/VP9/H265)
|
||||
/// - `stream_manager.mjpeg_handler()` - MJPEG stream handler
|
||||
/// - `stream_manager.streamer()` - Low-level video capture
|
||||
/// - `stream_manager.start()` / `stream_manager.stop()` - Stream control
|
||||
/// - `stream_manager.stats()` - Stream statistics
|
||||
/// - `stream_manager.list_devices()` - List video devices
|
||||
/// Shared Axum/App state: video flows through [`VideoStreamManager`]; WebRTC SDP/ICE/sessions on [`WebRtcStreamer`].
|
||||
pub struct AppState {
|
||||
/// Configuration store
|
||||
pub db: DatabasePool,
|
||||
pub config: ConfigStore,
|
||||
/// Session store
|
||||
pub sessions: SessionStore,
|
||||
/// User store
|
||||
pub users: UserStore,
|
||||
/// OTG Service - centralized USB gadget lifecycle management
|
||||
/// This is the single owner of OtgGadgetManager, coordinating HID and MSD functions
|
||||
pub otg_service: Arc<OtgService>,
|
||||
/// Video stream manager (unified MJPEG/WebRTC management)
|
||||
/// This is the single entry point for all video operations.
|
||||
pub stream_manager: Arc<VideoStreamManager>,
|
||||
/// HID controller
|
||||
pub webrtc: Arc<WebRtcStreamer>,
|
||||
pub hid: Arc<HidController>,
|
||||
/// MSD controller (optional, may not be initialized)
|
||||
pub msd: Arc<RwLock<Option<MsdController>>>,
|
||||
/// ATX controller (optional, may not be initialized)
|
||||
pub atx: Arc<RwLock<Option<AtxController>>>,
|
||||
/// Audio controller
|
||||
pub audio: Arc<AudioController>,
|
||||
/// RustDesk remote access service (optional)
|
||||
pub rustdesk: Arc<RwLock<Option<Arc<RustDeskService>>>>,
|
||||
/// RTSP streaming service (optional)
|
||||
pub rtsp: Arc<RwLock<Option<Arc<RtspService>>>>,
|
||||
/// Extension manager (ttyd, gostc, easytier)
|
||||
pub extensions: Arc<ExtensionManager>,
|
||||
/// Event bus for real-time notifications
|
||||
pub events: Arc<EventBus>,
|
||||
/// Latest device info snapshot for WebSocket clients
|
||||
device_info_tx: watch::Sender<Option<SystemEvent>>,
|
||||
/// Online update service
|
||||
pub update: Arc<UpdateService>,
|
||||
/// Shutdown signal sender
|
||||
pub shutdown_tx: broadcast::Sender<()>,
|
||||
/// Recently revoked session IDs (for client kick detection)
|
||||
pub revoked_sessions: Arc<RwLock<VecDeque<String>>>,
|
||||
/// Data directory path
|
||||
data_dir: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
/// Create new application state
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
db: DatabasePool,
|
||||
config: ConfigStore,
|
||||
sessions: SessionStore,
|
||||
users: UserStore,
|
||||
otg_service: Arc<OtgService>,
|
||||
stream_manager: Arc<VideoStreamManager>,
|
||||
webrtc: Arc<WebRtcStreamer>,
|
||||
hid: Arc<HidController>,
|
||||
msd: Option<MsdController>,
|
||||
atx: Option<AtxController>,
|
||||
@@ -94,11 +69,13 @@ impl AppState {
|
||||
let (device_info_tx, _device_info_rx) = watch::channel(None);
|
||||
|
||||
Arc::new(Self {
|
||||
db,
|
||||
config,
|
||||
sessions,
|
||||
users,
|
||||
otg_service,
|
||||
stream_manager,
|
||||
webrtc,
|
||||
hid,
|
||||
msd: Arc::new(RwLock::new(msd)),
|
||||
atx: Arc::new(RwLock::new(atx)),
|
||||
@@ -115,22 +92,18 @@ impl AppState {
|
||||
})
|
||||
}
|
||||
|
||||
/// Get data directory path
|
||||
pub fn data_dir(&self) -> &std::path::PathBuf {
|
||||
&self.data_dir
|
||||
}
|
||||
|
||||
/// Subscribe to shutdown signal
|
||||
pub fn shutdown_signal(&self) -> broadcast::Receiver<()> {
|
||||
self.shutdown_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Subscribe to the latest device info snapshot.
|
||||
pub fn subscribe_device_info(&self) -> watch::Receiver<Option<SystemEvent>> {
|
||||
self.device_info_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Record revoked session IDs (bounded queue)
|
||||
pub async fn remember_revoked_sessions(&self, session_ids: Vec<String>) {
|
||||
if session_ids.is_empty() {
|
||||
return;
|
||||
@@ -144,19 +117,12 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a session ID was revoked (kicked)
|
||||
pub async fn is_session_revoked(&self, session_id: &str) -> bool {
|
||||
let guard = self.revoked_sessions.read().await;
|
||||
guard.iter().any(|id| id == session_id)
|
||||
}
|
||||
|
||||
/// Get complete device information for WebSocket clients
|
||||
///
|
||||
/// This method collects the current state of all devices (video, HID, MSD, ATX, Audio)
|
||||
/// and returns a DeviceInfo event that can be sent to clients.
|
||||
/// Uses tokio::join! to collect all device info in parallel for better performance.
|
||||
pub async fn get_device_info(&self) -> SystemEvent {
|
||||
// Collect all device info in parallel
|
||||
let (video, hid, msd, atx, audio, ttyd) = tokio::join!(
|
||||
self.collect_video_info(),
|
||||
self.collect_hid_info(),
|
||||
@@ -176,19 +142,15 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Publish DeviceInfo event to all connected WebSocket clients
|
||||
pub async fn publish_device_info(&self) {
|
||||
let device_info = self.get_device_info().await;
|
||||
let _ = self.device_info_tx.send(Some(device_info));
|
||||
}
|
||||
|
||||
/// Collect video device information
|
||||
async fn collect_video_info(&self) -> VideoDeviceInfo {
|
||||
// Use stream_manager to get video info (includes stream_mode)
|
||||
self.stream_manager.get_video_info().await
|
||||
}
|
||||
|
||||
/// Collect HID device information
|
||||
async fn collect_hid_info(&self) -> HidDeviceInfo {
|
||||
let state = self.hid.snapshot().await;
|
||||
|
||||
@@ -199,14 +161,19 @@ impl AppState {
|
||||
online: state.online,
|
||||
supports_absolute_mouse: state.supports_absolute_mouse,
|
||||
keyboard_leds_enabled: state.keyboard_leds_enabled,
|
||||
led_state: state.led_state,
|
||||
led_state: LedState {
|
||||
num_lock: state.led_state.num_lock,
|
||||
caps_lock: state.led_state.caps_lock,
|
||||
scroll_lock: state.led_state.scroll_lock,
|
||||
compose: state.led_state.compose,
|
||||
kana: state.led_state.kana,
|
||||
},
|
||||
device: state.device,
|
||||
error: state.error,
|
||||
error_code: state.error_code,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect MSD device information (optional)
|
||||
async fn collect_msd_info(&self) -> Option<MsdDeviceInfo> {
|
||||
let msd_guard = self.msd.read().await;
|
||||
let msd = msd_guard.as_ref()?;
|
||||
@@ -227,9 +194,7 @@ impl AppState {
|
||||
})
|
||||
}
|
||||
|
||||
/// Collect ATX device information (optional)
|
||||
async fn collect_atx_info(&self) -> Option<AtxDeviceInfo> {
|
||||
// Predefined backend strings to avoid repeated allocations
|
||||
const BACKEND_POWER_ONLY: &str = "power: configured, reset: none";
|
||||
const BACKEND_RESET_ONLY: &str = "power: none, reset: configured";
|
||||
const BACKEND_BOTH: &str = "power: configured, reset: configured";
|
||||
@@ -254,7 +219,6 @@ impl AppState {
|
||||
})
|
||||
}
|
||||
|
||||
/// Collect Audio device information (optional)
|
||||
async fn collect_audio_info(&self) -> Option<AudioDeviceInfo> {
|
||||
let status = self.audio.status().await;
|
||||
|
||||
@@ -267,7 +231,6 @@ impl AppState {
|
||||
})
|
||||
}
|
||||
|
||||
/// Collect ttyd status information
|
||||
async fn collect_ttyd_info(&self) -> TtydDeviceInfo {
|
||||
let status = self.extensions.status(ExtensionId::Ttyd).await;
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
//! MJPEG stream handler
|
||||
//!
|
||||
//! Manages video frame distribution and per-client statistics.
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use parking_lot::Mutex as ParkingMutex;
|
||||
use parking_lot::RwLock as ParkingRwLock;
|
||||
@@ -17,28 +13,18 @@ use crate::video::encoder::JpegEncoder;
|
||||
use crate::video::format::PixelFormat;
|
||||
use crate::video::VideoFrame;
|
||||
|
||||
// No placeholder JPEGs: capture calls `set_offline()`; UI uses `stream.state_changed`.
|
||||
|
||||
/// Client ID type (UUID string)
|
||||
pub type ClientId = String;
|
||||
|
||||
/// Per-client session information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientSession {
|
||||
/// Unique client ID
|
||||
pub id: ClientId,
|
||||
/// Connection timestamp
|
||||
pub connected_at: Instant,
|
||||
/// Last activity timestamp (frame sent)
|
||||
pub last_activity: Instant,
|
||||
/// Frames sent to this client
|
||||
pub frames_sent: u64,
|
||||
/// FPS calculator (1-second rolling window)
|
||||
pub fps_calculator: FpsCalculator,
|
||||
}
|
||||
|
||||
impl ClientSession {
|
||||
/// Create a new client session
|
||||
pub fn new(id: ClientId) -> Self {
|
||||
let now = Instant::now();
|
||||
Self {
|
||||
@@ -50,44 +36,31 @@ impl ClientSession {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get connection duration
|
||||
pub fn connected_duration(&self) -> Duration {
|
||||
self.last_activity.duration_since(self.connected_at)
|
||||
}
|
||||
|
||||
/// Get idle duration
|
||||
pub fn idle_duration(&self) -> Duration {
|
||||
Instant::now().duration_since(self.last_activity)
|
||||
pub fn connected_elapsed(&self) -> Duration {
|
||||
self.connected_at.elapsed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Rolling window FPS calculator
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FpsCalculator {
|
||||
/// Frame timestamps in last window
|
||||
frame_times: VecDeque<Instant>,
|
||||
/// Window duration (default 1 second)
|
||||
window: Duration,
|
||||
/// Cached count of frames in current window (optimization to avoid O(n) filtering)
|
||||
count_in_window: usize,
|
||||
}
|
||||
|
||||
impl FpsCalculator {
|
||||
/// Create a new FPS calculator with 1-second window
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
frame_times: VecDeque::with_capacity(120), // Max 120fps tracking
|
||||
frame_times: VecDeque::with_capacity(120),
|
||||
window: Duration::from_secs(1),
|
||||
count_in_window: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a frame sent
|
||||
pub fn record_frame(&mut self) {
|
||||
let now = Instant::now();
|
||||
self.frame_times.push_back(now);
|
||||
|
||||
// Remove frames outside window and maintain count
|
||||
let cutoff = now - self.window;
|
||||
while let Some(&oldest) = self.frame_times.front() {
|
||||
if oldest < cutoff {
|
||||
@@ -97,31 +70,18 @@ impl FpsCalculator {
|
||||
}
|
||||
}
|
||||
|
||||
// Update cached count
|
||||
self.count_in_window = self.frame_times.len();
|
||||
}
|
||||
|
||||
/// Calculate current FPS (frames in last 1 second window)
|
||||
pub fn current_fps(&self) -> u32 {
|
||||
// Return cached count instead of filtering entire deque (O(1) instead of O(n))
|
||||
self.count_in_window as u32
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FpsCalculator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Auto-pause configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoPauseConfig {
|
||||
/// Enable auto-pause when no clients
|
||||
pub enabled: bool,
|
||||
/// Delay before pausing (default 10s)
|
||||
pub shutdown_delay_secs: u64,
|
||||
/// Client timeout for cleanup (default 30s)
|
||||
pub client_timeout_secs: u64,
|
||||
}
|
||||
|
||||
@@ -135,43 +95,27 @@ impl Default for AutoPauseConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// MJPEG stream handler
|
||||
/// Manages video frame distribution to HTTP clients
|
||||
pub struct MjpegStreamHandler {
|
||||
/// Current frame (latest) - using ArcSwap for lock-free reads
|
||||
current_frame: ArcSwap<Option<VideoFrame>>,
|
||||
/// Frame update notification
|
||||
frame_notify: broadcast::Sender<()>,
|
||||
/// Whether stream is online
|
||||
online: AtomicBool,
|
||||
/// Frame sequence counter
|
||||
sequence: AtomicU64,
|
||||
/// Per-client sessions (ClientId -> ClientSession)
|
||||
/// Use parking_lot::RwLock for better performance
|
||||
clients: ParkingRwLock<HashMap<ClientId, ClientSession>>,
|
||||
/// Auto-pause configuration
|
||||
auto_pause_config: ParkingRwLock<AutoPauseConfig>,
|
||||
/// Last frame timestamp
|
||||
last_frame_ts: ParkingRwLock<Option<Instant>>,
|
||||
/// Dropped same frames count
|
||||
dropped_same_frames: AtomicU64,
|
||||
/// Maximum consecutive same frames to drop (0 = disabled)
|
||||
max_drop_same_frames: AtomicU64,
|
||||
/// JPEG encoder for non-JPEG input formats
|
||||
jpeg_encoder: ParkingMutex<Option<JpegEncoder>>,
|
||||
/// JPEG quality for software encoding (1-100)
|
||||
jpeg_quality: AtomicU64,
|
||||
}
|
||||
|
||||
impl MjpegStreamHandler {
|
||||
/// Create a new MJPEG stream handler
|
||||
pub fn new() -> Self {
|
||||
Self::with_drop_limit(100) // Default: drop up to 100 same frames
|
||||
Self::with_drop_limit(100)
|
||||
}
|
||||
|
||||
/// Create handler with custom drop limit
|
||||
pub fn with_drop_limit(max_drop: u64) -> Self {
|
||||
let (frame_notify, _) = broadcast::channel(16); // Buffer size 16 for low latency
|
||||
let (frame_notify, _) = broadcast::channel(16);
|
||||
Self {
|
||||
current_frame: ArcSwap::from_pointee(None),
|
||||
frame_notify,
|
||||
@@ -187,16 +131,12 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set JPEG quality for software encoding (1-100)
|
||||
pub fn set_jpeg_quality(&self, quality: u8) {
|
||||
let clamped = quality.clamp(1, 100) as u64;
|
||||
self.jpeg_quality.store(clamped, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Update current frame
|
||||
pub fn update_frame(&self, frame: VideoFrame) {
|
||||
// Fast path: if no MJPEG clients are connected, do minimal bookkeeping and avoid
|
||||
// expensive work (JPEG encoding and per-frame dedup hashing).
|
||||
let has_clients = !self.clients.read().is_empty();
|
||||
if !has_clients {
|
||||
self.dropped_same_frames.store(0, Ordering::Relaxed);
|
||||
@@ -204,8 +144,6 @@ impl MjpegStreamHandler {
|
||||
self.online.store(frame.online, Ordering::SeqCst);
|
||||
*self.last_frame_ts.write() = Some(Instant::now());
|
||||
|
||||
// Keep the latest compressed frame for "instant first frame" when a client connects.
|
||||
// Avoid retaining large raw buffers when there are no MJPEG clients.
|
||||
if frame.format.is_compressed() {
|
||||
self.current_frame.store(Arc::new(Some(frame)));
|
||||
} else {
|
||||
@@ -214,7 +152,6 @@ impl MjpegStreamHandler {
|
||||
return;
|
||||
}
|
||||
|
||||
// If frame is not JPEG, encode it
|
||||
let frame = if !frame.format.is_compressed() {
|
||||
match self.encode_to_jpeg(&frame) {
|
||||
Ok(jpeg_frame) => jpeg_frame,
|
||||
@@ -227,17 +164,13 @@ impl MjpegStreamHandler {
|
||||
frame
|
||||
};
|
||||
|
||||
// Frame deduplication (ustreamer-style)
|
||||
// Check if this frame is identical to the previous one
|
||||
let max_drop = self.max_drop_same_frames.load(Ordering::Relaxed);
|
||||
if max_drop > 0 && frame.online {
|
||||
let current = self.current_frame.load();
|
||||
if let Some(ref prev_frame) = **current {
|
||||
let dropped_count = self.dropped_same_frames.load(Ordering::Relaxed);
|
||||
|
||||
// Check if we should drop this frame
|
||||
if dropped_count < max_drop && frames_are_identical(prev_frame, &frame) {
|
||||
// Check last frame timestamp to ensure minimum 1fps
|
||||
let last_ts = *self.last_frame_ts.read();
|
||||
let should_force_send = if let Some(ts) = last_ts {
|
||||
ts.elapsed() >= Duration::from_secs(1)
|
||||
@@ -246,16 +179,13 @@ impl MjpegStreamHandler {
|
||||
};
|
||||
|
||||
if !should_force_send {
|
||||
// Drop this duplicate frame
|
||||
self.dropped_same_frames.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
// If more than 1 second since last frame, force send even if identical
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Frame is different or limit reached or forced by 1fps guarantee, update
|
||||
self.dropped_same_frames.store(0, Ordering::Relaxed);
|
||||
|
||||
self.sequence.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -263,17 +193,14 @@ impl MjpegStreamHandler {
|
||||
*self.last_frame_ts.write() = Some(Instant::now());
|
||||
self.current_frame.store(Arc::new(Some(frame)));
|
||||
|
||||
// Notify waiting clients
|
||||
let _ = self.frame_notify.send(());
|
||||
}
|
||||
|
||||
/// Encode a non-JPEG frame to JPEG
|
||||
fn encode_to_jpeg(&self, frame: &VideoFrame) -> Result<VideoFrame, String> {
|
||||
let resolution = frame.resolution;
|
||||
let sequence = self.sequence.load(Ordering::Relaxed);
|
||||
let desired_quality = self.jpeg_quality.load(Ordering::Relaxed) as u32;
|
||||
|
||||
// Get or create encoder
|
||||
let mut encoder_guard = self.jpeg_encoder.lock();
|
||||
let encoder = encoder_guard.get_or_insert_with(|| {
|
||||
let config = EncoderConfig::jpeg(resolution, desired_quality);
|
||||
@@ -286,15 +213,12 @@ impl MjpegStreamHandler {
|
||||
enc
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create JPEG encoder: {}, using default", e);
|
||||
// Try with default config
|
||||
JpegEncoder::new(EncoderConfig::jpeg(resolution, desired_quality))
|
||||
.expect("Failed to create default JPEG encoder")
|
||||
warn!("Failed to create JPEG encoder: {}", e);
|
||||
panic!("Failed to create JPEG encoder");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Check if resolution changed
|
||||
if encoder.config().resolution != resolution {
|
||||
debug!(
|
||||
"Resolution changed, recreating JPEG encoder: {}x{}",
|
||||
@@ -312,7 +236,6 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// Encode based on input format
|
||||
let encoded = match frame.format {
|
||||
PixelFormat::Yuyv => encoder
|
||||
.encode_yuyv(frame.data(), sequence)
|
||||
@@ -343,38 +266,32 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
};
|
||||
|
||||
// Create new VideoFrame with JPEG data (zero-copy: Bytes -> Arc<Bytes>)
|
||||
Ok(VideoFrame::new(
|
||||
encoded.data,
|
||||
resolution,
|
||||
PixelFormat::Mjpeg,
|
||||
0, // stride not relevant for JPEG
|
||||
0,
|
||||
sequence,
|
||||
))
|
||||
}
|
||||
|
||||
/// Marks offline; clients exit their read loop. UI overlay comes from `stream.state_changed`.
|
||||
pub fn set_offline(&self) {
|
||||
self.online.store(false, Ordering::SeqCst);
|
||||
let _ = self.frame_notify.send(());
|
||||
}
|
||||
|
||||
/// Set stream online (called when streaming starts)
|
||||
pub fn set_online(&self) {
|
||||
self.online.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Check if stream is online
|
||||
pub fn is_online(&self) -> bool {
|
||||
self.online.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Get current client count
|
||||
pub fn client_count(&self) -> u64 {
|
||||
self.clients.read().len() as u64
|
||||
}
|
||||
|
||||
/// Register a new client
|
||||
pub fn register_client(&self, client_id: ClientId) {
|
||||
let session = ClientSession::new(client_id.clone());
|
||||
self.clients.write().insert(client_id.clone(), session);
|
||||
@@ -385,10 +302,9 @@ impl MjpegStreamHandler {
|
||||
);
|
||||
}
|
||||
|
||||
/// Unregister a client
|
||||
pub fn unregister_client(&self, client_id: &str) {
|
||||
if let Some(session) = self.clients.write().remove(client_id) {
|
||||
let duration = session.connected_duration();
|
||||
let duration = session.connected_elapsed();
|
||||
let duration_secs = duration.as_secs_f32();
|
||||
let avg_fps = if duration_secs > 0.1 {
|
||||
session.frames_sent as f32 / duration_secs
|
||||
@@ -402,7 +318,6 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Record frame sent to a specific client
|
||||
pub fn record_frame_sent(&self, client_id: &str) {
|
||||
if let Some(session) = self.clients.write().get_mut(client_id) {
|
||||
session.last_activity = Instant::now();
|
||||
@@ -411,7 +326,6 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get per-client statistics
|
||||
pub fn get_clients_stat(&self) -> HashMap<String, crate::events::types::ClientStats> {
|
||||
self.clients
|
||||
.read()
|
||||
@@ -422,43 +336,33 @@ impl MjpegStreamHandler {
|
||||
crate::events::types::ClientStats {
|
||||
id: id.clone(),
|
||||
fps: session.fps_calculator.current_fps(),
|
||||
connected_secs: session.connected_duration().as_secs(),
|
||||
connected_secs: session.connected_elapsed().as_secs(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get auto-pause configuration
|
||||
pub fn auto_pause_config(&self) -> AutoPauseConfig {
|
||||
self.auto_pause_config.read().clone()
|
||||
}
|
||||
|
||||
/// Update auto-pause configuration
|
||||
pub fn set_auto_pause_config(&self, config: AutoPauseConfig) {
|
||||
let config_clone = config.clone();
|
||||
*self.auto_pause_config.write() = config;
|
||||
info!(
|
||||
"Auto-pause config updated: enabled={}, delay={}s, timeout={}s",
|
||||
config_clone.enabled,
|
||||
config_clone.shutdown_delay_secs,
|
||||
config_clone.client_timeout_secs
|
||||
config.enabled, config.shutdown_delay_secs, config.client_timeout_secs
|
||||
);
|
||||
*self.auto_pause_config.write() = config;
|
||||
}
|
||||
|
||||
/// Get current frame (if any)
|
||||
pub fn current_frame(&self) -> Option<VideoFrame> {
|
||||
(**self.current_frame.load()).clone()
|
||||
}
|
||||
|
||||
/// Subscribe to frame updates
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<()> {
|
||||
self.frame_notify.subscribe()
|
||||
}
|
||||
|
||||
/// Disconnect all clients (used during config changes)
|
||||
/// This clears the client list and sets the stream offline,
|
||||
/// which will cause all active MJPEG streams to terminate.
|
||||
pub fn disconnect_all_clients(&self) {
|
||||
let count = {
|
||||
let mut clients = self.clients.write();
|
||||
@@ -469,32 +373,21 @@ impl MjpegStreamHandler {
|
||||
if count > 0 {
|
||||
info!("Disconnected all {} MJPEG clients for config change", count);
|
||||
}
|
||||
// Set offline to signal all streaming tasks to stop
|
||||
self.set_offline();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MjpegStreamHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard for client lifecycle management
|
||||
/// Ensures cleanup even on panic or abrupt disconnection
|
||||
pub struct ClientGuard {
|
||||
client_id: ClientId,
|
||||
handler: Arc<MjpegStreamHandler>,
|
||||
}
|
||||
|
||||
impl ClientGuard {
|
||||
/// Create a new client guard
|
||||
pub fn new(client_id: ClientId, handler: Arc<MjpegStreamHandler>) -> Self {
|
||||
handler.register_client(client_id.clone());
|
||||
Self { client_id, handler }
|
||||
}
|
||||
|
||||
/// Get client ID
|
||||
pub fn id(&self) -> &ClientId {
|
||||
&self.client_id
|
||||
}
|
||||
@@ -507,8 +400,6 @@ impl Drop for ClientGuard {
|
||||
}
|
||||
|
||||
impl MjpegStreamHandler {
|
||||
/// Start stale client cleanup task
|
||||
/// Should be called once when handler is created
|
||||
pub fn start_cleanup_task(self: Arc<Self>) {
|
||||
let handler = self.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -522,7 +413,6 @@ impl MjpegStreamHandler {
|
||||
let now = Instant::now();
|
||||
let mut stale = Vec::new();
|
||||
|
||||
// Find stale clients
|
||||
{
|
||||
let clients = handler.clients.read();
|
||||
for (id, session) in clients.iter() {
|
||||
@@ -532,7 +422,6 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove stale clients
|
||||
if !stale.is_empty() {
|
||||
let mut clients = handler.clients.write();
|
||||
for id in stale {
|
||||
@@ -550,10 +439,7 @@ impl MjpegStreamHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare two frames for equality (hash-based, ustreamer-style)
|
||||
/// Returns true if frames are identical in geometry and content
|
||||
fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
|
||||
// Quick checks first (geometry)
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
@@ -574,13 +460,10 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Avoid hashing the whole frame for obviously different frames by sampling a few
|
||||
// fixed-size windows first. If all samples match, fall back to the cached hash.
|
||||
let a_data = a.data();
|
||||
let b_data = b.data();
|
||||
let len = a_data.len();
|
||||
|
||||
// Small frames: direct compare is cheap.
|
||||
if len <= 256 {
|
||||
return a_data == b_data;
|
||||
}
|
||||
@@ -588,7 +471,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
|
||||
const SAMPLE: usize = 16;
|
||||
debug_assert!(len == b_data.len());
|
||||
|
||||
// Head + tail.
|
||||
if a_data[..SAMPLE] != b_data[..SAMPLE] {
|
||||
return false;
|
||||
}
|
||||
@@ -596,7 +478,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Two interior samples (quarter + middle) to catch common "same header/footer" cases.
|
||||
let quarter = len / 4;
|
||||
let quarter_start = quarter.saturating_sub(SAMPLE / 2);
|
||||
if a_data[quarter_start..quarter_start + SAMPLE]
|
||||
@@ -610,8 +491,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare hashes instead of full binary data.
|
||||
// Hash is computed once and cached in OnceLock for efficiency.
|
||||
a.get_hash() == b.get_hash()
|
||||
}
|
||||
|
||||
@@ -627,7 +506,6 @@ mod tests {
|
||||
assert!(!handler.is_online());
|
||||
assert_eq!(handler.client_count(), 0);
|
||||
|
||||
// Create a frame
|
||||
let _frame = VideoFrame::new(
|
||||
Bytes::from(vec![0xFF, 0xD8, 0x00, 0x00, 0xFF, 0xD9]),
|
||||
Resolution::VGA,
|
||||
@@ -641,15 +519,12 @@ mod tests {
|
||||
fn test_fps_calculator() {
|
||||
let mut calc = FpsCalculator::new();
|
||||
|
||||
// Initially empty
|
||||
assert_eq!(calc.current_fps(), 0);
|
||||
|
||||
// Record some frames
|
||||
calc.record_frame();
|
||||
calc.record_frame();
|
||||
calc.record_frame();
|
||||
|
||||
// Should have 3 frames in window
|
||||
assert!(calc.frame_times.len() == 3);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
//! Video streaming module
|
||||
//!
|
||||
//! Provides MJPEG streaming and WebSocket handlers for MJPEG mode.
|
||||
//!
|
||||
//! # Components
|
||||
//!
|
||||
//! - `MjpegStreamHandler` - HTTP multipart MJPEG video streaming
|
||||
//! - `WsHidHandler` - WebSocket HID input handler
|
||||
//! MJPEG multipart streaming and WebSocket HID (for MJPEG mode).
|
||||
|
||||
pub mod mjpeg;
|
||||
pub mod ws_hid;
|
||||
|
||||
@@ -1,25 +1,4 @@
|
||||
//! WebSocket HID Handler for MJPEG mode
|
||||
//!
|
||||
//! This module provides a standalone WebSocket HID handler that can be used
|
||||
//! independently of the application state. It manages multiple WebSocket
|
||||
//! connections and forwards HID events to the HID controller.
|
||||
//!
|
||||
//! # Protocol
|
||||
//!
|
||||
//! Only binary protocol is supported for optimal performance.
|
||||
//! See `crate::hid::datachannel` for message format details.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! WsHidHandler
|
||||
//! |
|
||||
//! +-- clients: HashMap<ClientId, WsHidClient>
|
||||
//! +-- hid_controller: Arc<HidController>
|
||||
//! |
|
||||
//! +-- add_client() -> spawns client handler task
|
||||
//! +-- remove_client()
|
||||
//! ```
|
||||
//! WebSocket HID for MJPEG mode; binary messages per `crate::hid::datachannel`.
|
||||
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
@@ -34,51 +13,34 @@ use tracing::{debug, error, info, warn};
|
||||
use crate::hid::datachannel::{parse_hid_message, HidChannelEvent};
|
||||
use crate::hid::HidController;
|
||||
|
||||
/// Client ID type
|
||||
pub type ClientId = String;
|
||||
|
||||
/// WebSocket HID client information
|
||||
#[derive(Debug)]
|
||||
pub struct WsHidClient {
|
||||
/// Client ID
|
||||
pub id: ClientId,
|
||||
/// Connection timestamp
|
||||
pub connected_at: Instant,
|
||||
/// Events processed
|
||||
pub events_processed: AtomicU64,
|
||||
/// Shutdown signal sender
|
||||
shutdown_tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
impl WsHidClient {
|
||||
/// Get events processed count
|
||||
pub fn events_count(&self) -> u64 {
|
||||
self.events_processed.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get connection duration in seconds
|
||||
pub fn connected_secs(&self) -> u64 {
|
||||
self.connected_at.elapsed().as_secs()
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket HID Handler
|
||||
///
|
||||
/// Manages WebSocket connections for HID input in MJPEG mode.
|
||||
/// Only binary protocol is supported for optimal performance.
|
||||
pub struct WsHidHandler {
|
||||
/// HID controller reference
|
||||
hid_controller: RwLock<Option<Arc<HidController>>>,
|
||||
/// Active clients
|
||||
clients: RwLock<HashMap<ClientId, Arc<WsHidClient>>>,
|
||||
/// Running state
|
||||
running: AtomicBool,
|
||||
/// Total events processed
|
||||
total_events: AtomicU64,
|
||||
}
|
||||
|
||||
impl WsHidHandler {
|
||||
/// Create a new WebSocket HID handler
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
hid_controller: RwLock::new(None),
|
||||
@@ -88,50 +50,39 @@ impl WsHidHandler {
|
||||
})
|
||||
}
|
||||
|
||||
/// Set HID controller
|
||||
pub fn set_hid_controller(&self, hid: Arc<HidController>) {
|
||||
*self.hid_controller.write() = Some(hid);
|
||||
info!("WsHidHandler: HID controller set");
|
||||
}
|
||||
|
||||
/// Get HID controller
|
||||
pub fn hid_controller(&self) -> Option<Arc<HidController>> {
|
||||
self.hid_controller.read().clone()
|
||||
}
|
||||
|
||||
/// Check if HID controller is available
|
||||
pub fn is_hid_available(&self) -> bool {
|
||||
self.hid_controller.read().is_some()
|
||||
}
|
||||
|
||||
/// Get client count
|
||||
pub fn client_count(&self) -> usize {
|
||||
self.clients.read().len()
|
||||
}
|
||||
|
||||
/// Check if running
|
||||
pub fn is_running(&self) -> bool {
|
||||
self.running.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Stop the handler
|
||||
pub fn stop(&self) {
|
||||
self.running.store(false, Ordering::SeqCst);
|
||||
// Signal all clients to disconnect
|
||||
let clients = self.clients.read();
|
||||
for client in clients.values() {
|
||||
let _ = client.shutdown_tx.try_send(());
|
||||
}
|
||||
}
|
||||
|
||||
/// Get total events processed
|
||||
pub fn total_events(&self) -> u64 {
|
||||
self.total_events.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Add a new WebSocket client
|
||||
///
|
||||
/// This spawns a background task to handle the WebSocket connection.
|
||||
pub async fn add_client(self: &Arc<Self>, client_id: ClientId, socket: WebSocket) {
|
||||
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
|
||||
|
||||
@@ -151,7 +102,6 @@ impl WsHidHandler {
|
||||
self.client_count()
|
||||
);
|
||||
|
||||
// Spawn handler task
|
||||
let handler = self.clone();
|
||||
tokio::spawn(async move {
|
||||
handler
|
||||
@@ -161,7 +111,6 @@ impl WsHidHandler {
|
||||
});
|
||||
}
|
||||
|
||||
/// Remove a client
|
||||
pub fn remove_client(&self, client_id: &str) {
|
||||
if let Some(client) = self.clients.write().remove(client_id) {
|
||||
info!(
|
||||
@@ -173,7 +122,6 @@ impl WsHidHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a WebSocket client connection
|
||||
async fn handle_client(
|
||||
&self,
|
||||
client_id: ClientId,
|
||||
@@ -183,7 +131,6 @@ impl WsHidHandler {
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Send initial status as binary: 0x00 = ok, 0x01 = error
|
||||
let status_byte = if self.is_hid_available() {
|
||||
0x00u8
|
||||
} else {
|
||||
@@ -222,7 +169,6 @@ impl WsHidHandler {
|
||||
debug!("WsHidHandler: Client {} stream ended", client_id);
|
||||
break;
|
||||
}
|
||||
// Ignore text messages - binary protocol only
|
||||
Some(Ok(Message::Text(_))) => {
|
||||
warn!("WsHidHandler: Ignoring text message from client {} (binary protocol only)", client_id);
|
||||
}
|
||||
@@ -232,7 +178,6 @@ impl WsHidHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset HID state when client disconnects to release any held keys/buttons
|
||||
let hid = self.hid_controller.read().clone();
|
||||
if let Some(hid) = hid {
|
||||
if let Err(e) = hid.reset().await {
|
||||
@@ -246,7 +191,6 @@ impl WsHidHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle binary HID message
|
||||
async fn handle_binary_message(&self, data: &[u8], client: &WsHidClient) -> Result<(), String> {
|
||||
let hid = self
|
||||
.hid_controller
|
||||
@@ -279,17 +223,6 @@ impl WsHidHandler {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WsHidHandler {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hid_controller: RwLock::new(None),
|
||||
clients: RwLock::new(HashMap::new()),
|
||||
running: AtomicBool::new(true),
|
||||
total_events: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
18
src/stream_encoder.rs
Normal file
18
src/stream_encoder.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
//! `EncoderType` → `EncoderBackend` (breaks config ↔ video import cycles).
|
||||
|
||||
use crate::config::EncoderType;
|
||||
use crate::video::encoder::EncoderBackend;
|
||||
|
||||
/// `None` means “auto” in WebRTC / pipeline (same as `EncoderType::Auto`).
|
||||
pub fn encoder_type_to_backend(encoder: EncoderType) -> Option<EncoderBackend> {
|
||||
match encoder {
|
||||
EncoderType::Auto => None,
|
||||
EncoderType::Software => Some(EncoderBackend::Software),
|
||||
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
|
||||
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
|
||||
EncoderType::Qsv => Some(EncoderBackend::Qsv),
|
||||
EncoderType::Amf => Some(EncoderBackend::Amf),
|
||||
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
|
||||
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
|
||||
}
|
||||
}
|
||||
@@ -142,15 +142,10 @@ impl UpdateService {
|
||||
}
|
||||
|
||||
pub async fn overview(&self, channel: UpdateChannel) -> Result<UpdateOverviewResponse> {
|
||||
let channels: ChannelsManifest = self.fetch_json("/v1/channels.json").await?;
|
||||
let releases: ReleasesManifest = self.fetch_json("/v1/releases.json").await?;
|
||||
let (channels, releases) = self.fetch_manifests().await?;
|
||||
|
||||
let current_version = parse_version(env!("CARGO_PKG_VERSION"))?;
|
||||
let latest_version_str = match channel {
|
||||
UpdateChannel::Stable => channels.stable,
|
||||
UpdateChannel::Beta => channels.beta,
|
||||
};
|
||||
let latest_version = parse_version(&latest_version_str)?;
|
||||
let latest_version = parse_version(&channel_head_version(&channels, channel))?;
|
||||
let current_parts = parse_version_parts(¤t_version)?;
|
||||
let latest_parts = parse_version_parts(&latest_version)?;
|
||||
|
||||
@@ -159,11 +154,7 @@ impl UpdateService {
|
||||
if release.channel != channel {
|
||||
continue;
|
||||
}
|
||||
let version = match parse_version(&release.version) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let version_parts = match parse_version_parts(&version) {
|
||||
let version_parts = match parse_version_parts(&release.version) {
|
||||
Ok(parts) => parts,
|
||||
Err(_) => continue,
|
||||
};
|
||||
@@ -253,16 +244,11 @@ impl UpdateService {
|
||||
)
|
||||
.await;
|
||||
|
||||
let channels: ChannelsManifest = self.fetch_json("/v1/channels.json").await?;
|
||||
let releases: ReleasesManifest = self.fetch_json("/v1/releases.json").await?;
|
||||
let (channels, releases) = self.fetch_manifests().await?;
|
||||
|
||||
let current_version = parse_version(env!("CARGO_PKG_VERSION"))?;
|
||||
let target_version = if let Some(channel) = req.channel {
|
||||
let version_str = match channel {
|
||||
UpdateChannel::Stable => channels.stable,
|
||||
UpdateChannel::Beta => channels.beta,
|
||||
};
|
||||
parse_version(&version_str)?
|
||||
parse_version(&channel_head_version(&channels, channel))?
|
||||
} else {
|
||||
parse_version(req.target_version.as_deref().unwrap_or_default())?
|
||||
};
|
||||
@@ -443,6 +429,12 @@ impl UpdateService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_manifests(&self) -> Result<(ChannelsManifest, ReleasesManifest)> {
|
||||
let channels = self.fetch_json("/v1/channels.json").await?;
|
||||
let releases = self.fetch_json("/v1/releases.json").await?;
|
||||
Ok((channels, releases))
|
||||
}
|
||||
|
||||
async fn fetch_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
|
||||
let url = format!("{}{}", self.base_url.trim_end_matches('/'), path);
|
||||
let response = self
|
||||
@@ -494,22 +486,7 @@ impl UpdateService {
|
||||
}
|
||||
|
||||
fn parse_version(input: &str) -> Result<String> {
|
||||
let parts: Vec<&str> = input.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Invalid version {}, expected x.x.x",
|
||||
input
|
||||
)));
|
||||
}
|
||||
if parts
|
||||
.iter()
|
||||
.any(|p| p.is_empty() || !p.chars().all(|c| c.is_ascii_digit()))
|
||||
{
|
||||
return Err(AppError::Internal(format!(
|
||||
"Invalid version {}, expected numeric x.x.x",
|
||||
input
|
||||
)));
|
||||
}
|
||||
parse_version_parts(input)?;
|
||||
Ok(input.to_string())
|
||||
}
|
||||
|
||||
@@ -527,16 +504,26 @@ fn parse_version_parts(input: &str) -> Result<[u64; 3]> {
|
||||
input
|
||||
)));
|
||||
}
|
||||
let major = parts[0]
|
||||
.parse::<u64>()
|
||||
.map_err(|e| AppError::Internal(format!("Invalid major version {}: {}", parts[0], e)))?;
|
||||
let minor = parts[1]
|
||||
.parse::<u64>()
|
||||
.map_err(|e| AppError::Internal(format!("Invalid minor version {}: {}", parts[1], e)))?;
|
||||
let patch = parts[2]
|
||||
.parse::<u64>()
|
||||
.map_err(|e| AppError::Internal(format!("Invalid patch version {}: {}", parts[2], e)))?;
|
||||
Ok([major, minor, patch])
|
||||
let mut out = [0u64; 3];
|
||||
for (i, p) in parts.iter().enumerate() {
|
||||
if p.is_empty() || !p.chars().all(|c| c.is_ascii_digit()) {
|
||||
return Err(AppError::Internal(format!(
|
||||
"Invalid version {}, expected numeric x.x.x",
|
||||
input
|
||||
)));
|
||||
}
|
||||
out[i] = p
|
||||
.parse::<u64>()
|
||||
.map_err(|e| AppError::Internal(format!("Invalid version component {}: {}", p, e)))?;
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn channel_head_version(channels: &ChannelsManifest, channel: UpdateChannel) -> String {
|
||||
match channel {
|
||||
UpdateChannel::Stable => channels.stable.clone(),
|
||||
UpdateChannel::Beta => channels.beta.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn compare_version_parts(a: &[u64; 3], b: &[u64; 3]) -> std::cmp::Ordering {
|
||||
|
||||
23
src/utils/fs.rs
Normal file
23
src/utils/fs.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! Small filesystem helpers.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// Read a UTF-8 file and trim surrounding whitespace.
|
||||
pub fn read_trimmed(path: &Path) -> Option<String> {
|
||||
std::fs::read_to_string(path)
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
}
|
||||
|
||||
/// Sorted list of directory entry names (lossy exclusion on non-UTF8).
|
||||
pub fn list_dir_names(path: &Path) -> Vec<String> {
|
||||
let mut names = std::fs::read_dir(path)
|
||||
.ok()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.flatten()
|
||||
.filter_map(|entry| entry.file_name().into_string().ok())
|
||||
.collect::<Vec<_>>();
|
||||
names.sort();
|
||||
names
|
||||
}
|
||||
15
src/utils/host.rs
Normal file
15
src/utils/host.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Host identity helpers.
|
||||
|
||||
/// Truncated content of `/etc/hostname`. Used where RustDesk peers expect the configured static name.
|
||||
pub fn hostname_from_etc() -> String {
|
||||
std::fs::read_to_string("/etc/hostname")
|
||||
.map(|s| s.trim().to_string())
|
||||
.unwrap_or_else(|_| "One-KVM".to_string())
|
||||
}
|
||||
|
||||
/// Current kernel hostname (`gethostname`). Used for live device info in the UI.
|
||||
pub fn hostname_uname() -> String {
|
||||
nix::unistd::gethostname()
|
||||
.map(|s| s.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|_| "unknown".to_string())
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
//! Utility modules for One-KVM
|
||||
//!
|
||||
//! This module contains common utilities used across the codebase.
|
||||
//! Shared utilities.
|
||||
|
||||
pub mod fs;
|
||||
pub mod host;
|
||||
pub mod net;
|
||||
pub mod throttle;
|
||||
|
||||
pub use fs::{list_dir_names, read_trimmed};
|
||||
pub use host::{hostname_from_etc, hostname_uname};
|
||||
pub use net::{bind_tcp_listener, bind_udp_socket};
|
||||
pub use throttle::LogThrottler;
|
||||
|
||||
@@ -1,44 +1,15 @@
|
||||
//! Log throttling utility
|
||||
//!
|
||||
//! Provides a mechanism to limit how often the same log message is recorded,
|
||||
//! preventing log flooding when errors occur repeatedly.
|
||||
//! Limits repeated identical log lines (e.g. reconnect failures).
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Log throttler that limits how often the same message is logged
|
||||
///
|
||||
/// This is useful for preventing log flooding when errors occur repeatedly,
|
||||
/// such as when a device is disconnected and reconnection attempts fail.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use one_kvm::utils::LogThrottler;
|
||||
/// use std::time::Duration;
|
||||
///
|
||||
/// let throttler = LogThrottler::new(Duration::from_secs(5));
|
||||
///
|
||||
/// // First call returns true
|
||||
/// assert!(throttler.should_log("device_error"));
|
||||
///
|
||||
/// // Subsequent calls within 5 seconds return false
|
||||
/// assert!(!throttler.should_log("device_error"));
|
||||
/// ```
|
||||
pub struct LogThrottler {
|
||||
/// Map of message key to last log time
|
||||
last_logged: RwLock<HashMap<String, Instant>>,
|
||||
/// Throttle interval
|
||||
interval: Duration,
|
||||
}
|
||||
|
||||
impl LogThrottler {
|
||||
/// Create a new log throttler with the specified interval
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `interval` - The minimum time between log messages for the same key
|
||||
pub fn new(interval: Duration) -> Self {
|
||||
Self {
|
||||
last_logged: RwLock::new(HashMap::new()),
|
||||
@@ -46,23 +17,14 @@ impl LogThrottler {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new log throttler with interval specified in seconds
|
||||
pub fn with_secs(secs: u64) -> Self {
|
||||
Self::new(Duration::from_secs(secs))
|
||||
}
|
||||
|
||||
/// Check if a message should be logged (not throttled)
|
||||
///
|
||||
/// Returns `true` if the message should be logged, `false` if it should be throttled.
|
||||
/// If `true` is returned, the internal timestamp is updated.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - A unique identifier for the message type
|
||||
/// Returns whether to emit the log line; updates the timestamp when `true`.
|
||||
pub fn should_log(&self, key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
// First check with read lock (fast path)
|
||||
{
|
||||
let map = self.last_logged.read().unwrap();
|
||||
if let Some(last) = map.get(key) {
|
||||
@@ -72,9 +34,7 @@ impl LogThrottler {
|
||||
}
|
||||
}
|
||||
|
||||
// Update with write lock
|
||||
let mut map = self.last_logged.write().unwrap();
|
||||
// Double-check after acquiring write lock
|
||||
if let Some(last) = map.get(key) {
|
||||
if now.duration_since(*last) < self.interval {
|
||||
return false;
|
||||
@@ -84,32 +44,14 @@ impl LogThrottler {
|
||||
true
|
||||
}
|
||||
|
||||
/// Clear throttle state for a specific key
|
||||
///
|
||||
/// This should be called when an error condition recovers,
|
||||
/// so the next error will be logged immediately.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - The key to clear
|
||||
/// Call when a condition recovers so the next failure logs immediately.
|
||||
pub fn clear(&self, key: &str) {
|
||||
self.last_logged.write().unwrap().remove(key);
|
||||
}
|
||||
|
||||
/// Clear all throttle state
|
||||
pub fn clear_all(&self) {
|
||||
self.last_logged.write().unwrap().clear();
|
||||
}
|
||||
|
||||
/// Get the number of tracked keys
|
||||
pub fn len(&self) -> usize {
|
||||
self.last_logged.read().unwrap().len()
|
||||
}
|
||||
|
||||
/// Check if the throttler is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.last_logged.read().unwrap().is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for LogThrottler {
|
||||
@@ -122,23 +64,11 @@ impl Clone for LogThrottler {
|
||||
}
|
||||
|
||||
impl Default for LogThrottler {
|
||||
/// Create a default log throttler with 5 second interval
|
||||
fn default() -> Self {
|
||||
Self::with_secs(5)
|
||||
}
|
||||
}
|
||||
|
||||
/// Macro for throttled warning logging
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use one_kvm::utils::LogThrottler;
|
||||
/// use one_kvm::warn_throttled;
|
||||
///
|
||||
/// let throttler = LogThrottler::default();
|
||||
/// warn_throttled!(throttler, "my_error", "Error occurred: {}", "details");
|
||||
/// ```
|
||||
#[macro_export]
|
||||
macro_rules! warn_throttled {
|
||||
($throttler:expr, $key:expr, $($arg:tt)*) => {
|
||||
@@ -148,7 +78,6 @@ macro_rules! warn_throttled {
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for throttled error logging
|
||||
#[macro_export]
|
||||
macro_rules! error_throttled {
|
||||
($throttler:expr, $key:expr, $($arg:tt)*) => {
|
||||
@@ -158,16 +87,6 @@ macro_rules! error_throttled {
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for throttled info logging
|
||||
#[macro_export]
|
||||
macro_rules! info_throttled {
|
||||
($throttler:expr, $key:expr, $($arg:tt)*) => {
|
||||
if $throttler.should_log($key) {
|
||||
tracing::info!($($arg)*);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -183,16 +102,11 @@ mod tests {
|
||||
fn test_throttling() {
|
||||
let throttler = LogThrottler::new(Duration::from_millis(100));
|
||||
|
||||
// First call should succeed
|
||||
assert!(throttler.should_log("test_key"));
|
||||
|
||||
// Immediate second call should be throttled
|
||||
assert!(!throttler.should_log("test_key"));
|
||||
|
||||
// Wait for throttle to expire
|
||||
thread::sleep(Duration::from_millis(150));
|
||||
|
||||
// Should succeed again
|
||||
assert!(throttler.should_log("test_key"));
|
||||
}
|
||||
|
||||
@@ -200,7 +114,6 @@ mod tests {
|
||||
fn test_different_keys() {
|
||||
let throttler = LogThrottler::with_secs(10);
|
||||
|
||||
// Different keys should be independent
|
||||
assert!(throttler.should_log("key1"));
|
||||
assert!(throttler.should_log("key2"));
|
||||
assert!(!throttler.should_log("key1"));
|
||||
@@ -214,10 +127,8 @@ mod tests {
|
||||
assert!(throttler.should_log("test_key"));
|
||||
assert!(!throttler.should_log("test_key"));
|
||||
|
||||
// Clear the key
|
||||
throttler.clear("test_key");
|
||||
|
||||
// Should be able to log again
|
||||
assert!(throttler.should_log("test_key"));
|
||||
}
|
||||
|
||||
@@ -239,19 +150,4 @@ mod tests {
|
||||
let throttler = LogThrottler::default();
|
||||
assert!(throttler.should_log("test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_len_and_is_empty() {
|
||||
let throttler = LogThrottler::with_secs(10);
|
||||
|
||||
assert!(throttler.is_empty());
|
||||
assert_eq!(throttler.len(), 0);
|
||||
|
||||
throttler.should_log("key1");
|
||||
assert!(!throttler.is_empty());
|
||||
assert_eq!(throttler.len(), 1);
|
||||
|
||||
throttler.should_log("key2");
|
||||
assert_eq!(throttler.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
30
src/video/capture_limits.rs
Normal file
30
src/video/capture_limits.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
//! Shared tuning for V4L2 MJPEG capture paths (`Streamer` + `SharedVideoPipeline`).
|
||||
|
||||
/// Frames smaller than this are treated as incomplete / noise.
|
||||
pub(crate) const MIN_CAPTURE_FRAME_SIZE: usize = 128;
|
||||
|
||||
/// After startup, validate JPEG header every N frames to limit CPU use.
|
||||
pub(crate) const JPEG_VALIDATE_INTERVAL: u64 = 30;
|
||||
|
||||
/// Validate every MJPEG frame for the first N frames (UVC warm-up / bad headers).
|
||||
pub(crate) const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3;
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn should_validate_jpeg_frame(validate_counter: u64) -> bool {
|
||||
validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES
|
||||
|| validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn jpeg_validation_policy_startup_then_interval() {
|
||||
assert!(should_validate_jpeg_frame(1));
|
||||
assert!(should_validate_jpeg_frame(2));
|
||||
assert!(should_validate_jpeg_frame(3));
|
||||
assert!(!should_validate_jpeg_frame(4));
|
||||
assert!(should_validate_jpeg_frame(30));
|
||||
}
|
||||
}
|
||||
@@ -135,7 +135,7 @@ pub async fn enforce_constraints_with_stream_manager(
|
||||
}
|
||||
|
||||
if current_mode == StreamMode::WebRTC {
|
||||
let current_codec = stream_manager.webrtc_streamer().current_video_codec().await;
|
||||
let current_codec = stream_manager.current_video_codec().await;
|
||||
if !constraints.is_webrtc_codec_allowed(current_codec) {
|
||||
let target_codec = constraints.preferred_webrtc_codec();
|
||||
stream_manager.set_video_codec(target_codec).await?;
|
||||
|
||||
@@ -14,9 +14,7 @@ use tracing::{debug, info, warn};
|
||||
use v4l2r::bindings::{
|
||||
v4l2_bt_timings, v4l2_dv_timings, V4L2_DV_BT_656_1120, V4L2_DV_FL_HAS_CEA861_VIC,
|
||||
};
|
||||
use v4l2r::ioctl::{
|
||||
self, Event as V4l2Event, EventType, QueryDvTimingsError, SubscribeEventFlags,
|
||||
};
|
||||
use v4l2r::ioctl::{self, Event as V4l2Event, EventType, QueryDvTimingsError, SubscribeEventFlags};
|
||||
use v4l2r::nix::errno::Errno;
|
||||
|
||||
use crate::video::SignalStatus;
|
||||
@@ -143,9 +141,9 @@ pub fn probe_signal(subdev_fd: &impl AsRawFd, kind: CsiBridgeKind) -> ProbeResul
|
||||
Err(QueryDvTimingsError::NoLink) => ProbeResult::NoCable,
|
||||
Err(QueryDvTimingsError::UnstableSignal) => ProbeResult::NoSync,
|
||||
Err(QueryDvTimingsError::IoctlError(Errno::ERANGE)) => ProbeResult::OutOfRange,
|
||||
Err(QueryDvTimingsError::IoctlError(
|
||||
Errno::EIO | Errno::EREMOTEIO | Errno::ETIMEDOUT,
|
||||
)) => ProbeResult::NoSync,
|
||||
Err(QueryDvTimingsError::IoctlError(Errno::EIO | Errno::EREMOTEIO | Errno::ETIMEDOUT)) => {
|
||||
ProbeResult::NoSync
|
||||
}
|
||||
Err(QueryDvTimingsError::Unsupported) | Err(QueryDvTimingsError::IoctlError(_)) => {
|
||||
ProbeResult::NoSignal
|
||||
}
|
||||
@@ -222,14 +220,8 @@ fn classify_timings(timings: v4l2_dv_timings, kind: CsiBridgeKind) -> ProbeResul
|
||||
return ProbeResult::NoSignal;
|
||||
}
|
||||
|
||||
let total_h: u64 = (width
|
||||
+ bt.hfrontporch
|
||||
+ bt.hsync
|
||||
+ bt.hbackporch) as u64;
|
||||
let total_v: u64 = (height
|
||||
+ bt.vfrontporch
|
||||
+ bt.vsync
|
||||
+ bt.vbackporch) as u64;
|
||||
let total_h: u64 = (width + bt.hfrontporch + bt.hsync + bt.hbackporch) as u64;
|
||||
let total_v: u64 = (height + bt.vfrontporch + bt.vsync + bt.vbackporch) as u64;
|
||||
let fps = if total_h > 0 && total_v > 0 && pixelclock > 0 {
|
||||
Some(pixelclock as f64 / (total_h as f64 * total_v as f64))
|
||||
} else {
|
||||
|
||||
@@ -168,16 +168,15 @@ impl VideoDevice {
|
||||
// subdev (the video node returns ENOTTY). Tc358743 and rk_hdmirx
|
||||
// typically expose DV ioctls on the video node itself, but having
|
||||
// the subdev handle for EDID/event subscription doesn't hurt.
|
||||
let (subdev_path, bridge_kind) = if is_rkcif_driver(&caps.driver)
|
||||
|| is_rk_hdmirx_driver(&caps.driver, &caps.card)
|
||||
{
|
||||
match csi_bridge::discover_subdev_for_video(&self.path) {
|
||||
Some((path, kind)) => (Some(path), Some(format!("{:?}", kind).to_lowercase())),
|
||||
None => (None, None),
|
||||
}
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let (subdev_path, bridge_kind) =
|
||||
if is_rkcif_driver(&caps.driver) || is_rk_hdmirx_driver(&caps.driver, &caps.card) {
|
||||
match csi_bridge::discover_subdev_for_video(&self.path) {
|
||||
Some((path, kind)) => (Some(path), Some(format!("{:?}", kind).to_lowercase())),
|
||||
None => (None, None),
|
||||
}
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Probe the HDMI source for both signal presence *and* the live
|
||||
// frame-rate. rkcif's `VIDIOC_ENUM_FRAMEINTERVALS` returns a
|
||||
@@ -225,9 +224,7 @@ impl VideoDevice {
|
||||
(false, None)
|
||||
}
|
||||
}
|
||||
} else if is_rk_hdmirx_driver(&caps.driver, &caps.card)
|
||||
|| is_rkcif_driver(&caps.driver)
|
||||
{
|
||||
} else if is_rk_hdmirx_driver(&caps.driver, &caps.card) || is_rkcif_driver(&caps.driver) {
|
||||
let dv = self.current_dv_timings_mode();
|
||||
debug!(
|
||||
"has_signal via video node {:?} (driver={}): dv_timings={:?}",
|
||||
@@ -247,21 +244,20 @@ impl VideoDevice {
|
||||
(true, None)
|
||||
};
|
||||
|
||||
let mut formats = if is_rk_hdmirx_driver(&caps.driver, &caps.card)
|
||||
|| is_rkcif_driver(&caps.driver)
|
||||
{
|
||||
// CSI/HDMI bridge drivers (rk_hdmirx, rkcif) expose multiple pixel
|
||||
// formats via ENUM_FMT (e.g. rk_hdmirx: BGR3/NV24/NV16/NV12) but
|
||||
// `ENUM_FRAMESIZES` is fiction for these drivers (rkcif reports a
|
||||
// degenerate `64x64 StepWise 8/8` that only describes its DMA
|
||||
// engine, rk_hdmirx returns ENOTTY). The only authoritative
|
||||
// resolution is whatever the bridge subdev's DV timings report,
|
||||
// so we treat the HDMI source mode as the single allowed
|
||||
// resolution for every pixel format.
|
||||
self.enumerate_bridge_formats(subdev_hdmi_mode)?
|
||||
} else {
|
||||
self.enumerate_formats()?
|
||||
};
|
||||
let mut formats =
|
||||
if is_rk_hdmirx_driver(&caps.driver, &caps.card) || is_rkcif_driver(&caps.driver) {
|
||||
// CSI/HDMI bridge drivers (rk_hdmirx, rkcif) expose multiple pixel
|
||||
// formats via ENUM_FMT (e.g. rk_hdmirx: BGR3/NV24/NV16/NV12) but
|
||||
// `ENUM_FRAMESIZES` is fiction for these drivers (rkcif reports a
|
||||
// degenerate `64x64 StepWise 8/8` that only describes its DMA
|
||||
// engine, rk_hdmirx returns ENOTTY). The only authoritative
|
||||
// resolution is whatever the bridge subdev's DV timings report,
|
||||
// so we treat the HDMI source mode as the single allowed
|
||||
// resolution for every pixel format.
|
||||
self.enumerate_bridge_formats(subdev_hdmi_mode)?
|
||||
} else {
|
||||
self.enumerate_formats()?
|
||||
};
|
||||
|
||||
// For CSI/HDMI bridges, the driver-enumerated fps list is fiction
|
||||
// (rkcif: always `1..30`; rk_hdmirx: typically `ENOTTY`). Replace
|
||||
@@ -923,7 +919,11 @@ pub fn enumerate_devices() -> Result<Vec<VideoDeviceInfo>> {
|
||||
// The path tiebreaker ensures deterministic ordering when multiple sub-devices
|
||||
// share the same priority (e.g. rkcif nodes), so that /dev/video0 is preferred
|
||||
// over /dev/video10 after deduplication.
|
||||
devices.sort_by(|a, b| b.priority.cmp(&a.priority).then_with(|| a.path.cmp(&b.path)));
|
||||
devices.sort_by(|a, b| {
|
||||
b.priority
|
||||
.cmp(&a.priority)
|
||||
.then_with(|| a.path.cmp(&b.path))
|
||||
});
|
||||
|
||||
// Deduplicate rkcif sub-devices: the driver exposes many /dev/video* nodes
|
||||
// for a single MIPI CSI pipeline. Keep only the highest-priority node per
|
||||
@@ -976,8 +976,11 @@ fn collapse_rkcif_probe_candidates(candidates: &mut Vec<PathBuf>) {
|
||||
|
||||
fn sysfs_uevent_driver(path: &Path) -> Option<String> {
|
||||
let name = path.file_name()?.to_str()?;
|
||||
let uevent =
|
||||
read_sysfs_string(&Path::new("/sys/class/video4linux").join(name).join("device/uevent"))?;
|
||||
let uevent = read_sysfs_string(
|
||||
&Path::new("/sys/class/video4linux")
|
||||
.join(name)
|
||||
.join("device/uevent"),
|
||||
)?;
|
||||
extract_uevent_value(&uevent, "driver")
|
||||
}
|
||||
|
||||
@@ -1037,7 +1040,10 @@ fn sysfs_maybe_capture(path: &Path) -> bool {
|
||||
// kernel driver that created them has been unloaded but the device nodes
|
||||
// were never cleaned up. Opening them returns ENODEV; skip the probe.
|
||||
if !sysfs_base.exists() {
|
||||
debug!("Skipping {:?}: no matching /sys/class/video4linux entry", path);
|
||||
debug!(
|
||||
"Skipping {:?}: no matching /sys/class/video4linux entry",
|
||||
path
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1081,7 +1087,13 @@ fn sysfs_maybe_capture(path: &Path) -> bool {
|
||||
// succeed QUERYCAP but expose only VIDEO_M2M / STATS / PARAMS and get
|
||||
// filtered later — skipping here saves an open() + ioctl() per node.
|
||||
let driver_skip = [
|
||||
"rkvenc", "rkvdec", "vepu", "vdpu", "hantro", "mpp_", "rockchip-vpu",
|
||||
"rkvenc",
|
||||
"rkvdec",
|
||||
"vepu",
|
||||
"vdpu",
|
||||
"hantro",
|
||||
"mpp_",
|
||||
"rockchip-vpu",
|
||||
];
|
||||
if let Some(driver) = &driver {
|
||||
if driver_skip.iter().any(|hint| driver.contains(hint)) {
|
||||
|
||||
@@ -1,91 +1,13 @@
|
||||
//! Encoder traits and common types
|
||||
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Instant;
|
||||
use typeshare::typeshare;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
|
||||
/// Bitrate preset for video encoding
|
||||
///
|
||||
/// Simplifies bitrate configuration by providing three intuitive presets
|
||||
/// plus a custom option for advanced users.
|
||||
#[typeshare]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
#[derive(Default)]
|
||||
pub enum BitratePreset {
|
||||
/// Speed priority: 1 Mbps, lowest latency, smaller GOP
|
||||
/// Best for: slow networks, remote management, low-bandwidth scenarios
|
||||
Speed,
|
||||
/// Balanced: 4 Mbps, good quality/latency tradeoff
|
||||
/// Best for: typical usage, recommended default
|
||||
#[default]
|
||||
Balanced,
|
||||
/// Quality priority: 8 Mbps, best visual quality
|
||||
/// Best for: local network, high-bandwidth scenarios, detailed work
|
||||
Quality,
|
||||
/// Custom bitrate in kbps (for advanced users)
|
||||
Custom(u32),
|
||||
}
|
||||
|
||||
impl BitratePreset {
|
||||
/// Get bitrate value in kbps
|
||||
pub fn bitrate_kbps(&self) -> u32 {
|
||||
match self {
|
||||
Self::Speed => 1000,
|
||||
Self::Balanced => 4000,
|
||||
Self::Quality => 8000,
|
||||
Self::Custom(kbps) => *kbps,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get recommended GOP size based on preset
|
||||
///
|
||||
/// Speed preset uses shorter GOP for faster recovery from packet loss.
|
||||
/// Quality preset uses longer GOP for better compression efficiency.
|
||||
pub fn gop_size(&self, fps: u32) -> u32 {
|
||||
match self {
|
||||
Self::Speed => (fps / 2).max(15), // 0.5 second, minimum 15 frames
|
||||
Self::Balanced => fps, // 1 second
|
||||
Self::Quality => fps * 2, // 2 seconds
|
||||
Self::Custom(_) => fps, // Default 1 second for custom
|
||||
}
|
||||
}
|
||||
|
||||
/// Get quality preset name for encoder configuration
|
||||
pub fn quality_level(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Speed => "low", // ultrafast/veryfast preset
|
||||
Self::Balanced => "medium", // medium preset
|
||||
Self::Quality => "high", // slower preset, better quality
|
||||
Self::Custom(_) => "medium",
|
||||
}
|
||||
}
|
||||
|
||||
/// Create from kbps value, mapping to nearest preset or Custom
|
||||
pub fn from_kbps(kbps: u32) -> Self {
|
||||
match kbps {
|
||||
0..=1500 => Self::Speed,
|
||||
1501..=6000 => Self::Balanced,
|
||||
6001..=10000 => Self::Quality,
|
||||
_ => Self::Custom(kbps),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BitratePreset {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Speed => write!(f, "Speed (1 Mbps)"),
|
||||
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
|
||||
Self::Quality => write!(f, "Quality (8 Mbps)"),
|
||||
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Defined in `config::schema` (typeshare + serde). Re-export for encoder users.
|
||||
pub use crate::config::BitratePreset;
|
||||
|
||||
/// Encoder configuration
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//!
|
||||
//! This module provides V4L2 video capture, encoding, and streaming functionality.
|
||||
|
||||
pub(crate) mod capture_limits;
|
||||
pub mod codec_constraints;
|
||||
pub mod convert;
|
||||
pub mod csi_bridge;
|
||||
@@ -13,6 +14,8 @@ pub mod frame;
|
||||
pub mod shared_video_pipeline;
|
||||
pub mod stream_manager;
|
||||
pub mod streamer;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
pub mod usb_reset;
|
||||
pub mod v4l2r_capture;
|
||||
|
||||
|
||||
@@ -38,25 +38,19 @@ const CAPTURE_TIMEOUT_STOP_THRESHOLD: u32 = 60;
|
||||
const CAPTURE_TIMEOUT_SOFT_RESTART_THRESHOLD: u32 = 3;
|
||||
const CSI_BRIDGE_NOSIGNAL_INTERVAL_MS: u64 = 500;
|
||||
const NOSIGNAL_POLL_MAX: Duration = Duration::from_secs(20);
|
||||
/// Minimum valid frame size for capture
|
||||
const MIN_CAPTURE_FRAME_SIZE: usize = 128;
|
||||
/// Validate every JPEG frame during startup to avoid poisoning HW decoders
|
||||
/// with incomplete UVC warm-up frames.
|
||||
const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3;
|
||||
/// Validate JPEG header every N frames to reduce overhead
|
||||
const JPEG_VALIDATE_INTERVAL: u64 = 30;
|
||||
/// Throttle repeated encoding errors to avoid log flooding
|
||||
const ENCODE_ERROR_THROTTLE_SECS: u64 = 5;
|
||||
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::utils::LogThrottler;
|
||||
use crate::video::capture_limits::{should_validate_jpeg_frame, MIN_CAPTURE_FRAME_SIZE};
|
||||
use crate::video::csi_bridge::{self, ProbeResult};
|
||||
use crate::video::device::parse_bridge_kind;
|
||||
use crate::video::encoder::registry::{EncoderBackend, VideoEncoderType};
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::video::frame::{FrameBuffer, FrameBufferPool, VideoFrame};
|
||||
use crate::video::device::parse_bridge_kind;
|
||||
use crate::video::SignalStatus;
|
||||
use crate::video::v4l2r_capture::{is_source_changed_error, BridgeContext, V4l2rCaptureStream};
|
||||
use crate::video::SignalStatus;
|
||||
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
|
||||
use hwcodec::ffmpeg_hw::last_error_message as ffmpeg_hw_last_error;
|
||||
|
||||
@@ -250,11 +244,6 @@ fn log_encoding_error(
|
||||
}
|
||||
}
|
||||
|
||||
fn should_validate_jpeg_frame(validate_counter: u64) -> bool {
|
||||
validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES
|
||||
|| validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL)
|
||||
}
|
||||
|
||||
/// Pipeline statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SharedVideoPipelineStats {
|
||||
@@ -283,10 +272,7 @@ pub struct SharedVideoPipeline {
|
||||
last_state_notification: ParkingMutex<Option<PipelineStateNotification>>,
|
||||
}
|
||||
|
||||
fn poll_bridge_subdev_after_no_signal(
|
||||
bridge_ctx: &BridgeContext,
|
||||
pipeline: &SharedVideoPipeline,
|
||||
) {
|
||||
fn poll_bridge_subdev_after_no_signal(bridge_ctx: &BridgeContext, pipeline: &SharedVideoPipeline) {
|
||||
let Some(subdev_path) = bridge_ctx.subdev_path.as_ref() else {
|
||||
return;
|
||||
};
|
||||
@@ -313,7 +299,10 @@ fn poll_bridge_subdev_after_no_signal(
|
||||
let fd = match csi_bridge::open_subdev(subdev_path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
debug!("No-signal poll: open subdev {:?} failed: {}", subdev_path, e);
|
||||
debug!(
|
||||
"No-signal poll: open subdev {:?} failed: {}",
|
||||
subdev_path, e
|
||||
);
|
||||
std::thread::sleep(Duration::from_millis(CSI_BRIDGE_NOSIGNAL_INTERVAL_MS));
|
||||
continue;
|
||||
}
|
||||
@@ -565,21 +554,20 @@ impl SharedVideoPipeline {
|
||||
subdev_path.clone(),
|
||||
parse_bridge_kind(bridge_kind.as_deref()),
|
||||
);
|
||||
let preopened: Option<V4l2rCaptureStream> =
|
||||
match V4l2rCaptureStream::open_with_bridge(
|
||||
&device_path,
|
||||
config.resolution,
|
||||
config.input_format,
|
||||
config.fps,
|
||||
buffer_count.max(1),
|
||||
Duration::from_secs(2),
|
||||
bridge_ctx_probe,
|
||||
) {
|
||||
Ok(s) => {
|
||||
let negotiated_res = s.resolution();
|
||||
let negotiated_fmt = s.format();
|
||||
if negotiated_res != config.resolution || negotiated_fmt != config.input_format {
|
||||
info!(
|
||||
let preopened: Option<V4l2rCaptureStream> = match V4l2rCaptureStream::open_with_bridge(
|
||||
&device_path,
|
||||
config.resolution,
|
||||
config.input_format,
|
||||
config.fps,
|
||||
buffer_count.max(1),
|
||||
Duration::from_secs(2),
|
||||
bridge_ctx_probe,
|
||||
) {
|
||||
Ok(s) => {
|
||||
let negotiated_res = s.resolution();
|
||||
let negotiated_fmt = s.format();
|
||||
if negotiated_res != config.resolution || negotiated_fmt != config.input_format {
|
||||
info!(
|
||||
"Negotiated capture {}x{} {:?} (configured {}x{} {:?}) — aligning encoder to source",
|
||||
negotiated_res.width,
|
||||
negotiated_res.height,
|
||||
@@ -588,25 +576,25 @@ impl SharedVideoPipeline {
|
||||
config.resolution.height,
|
||||
config.input_format
|
||||
);
|
||||
config.resolution = negotiated_res;
|
||||
config.input_format = negotiated_fmt;
|
||||
*self.config.write().await = config.clone();
|
||||
}
|
||||
Some(s)
|
||||
config.resolution = negotiated_res;
|
||||
config.input_format = negotiated_fmt;
|
||||
*self.config.write().await = config.clone();
|
||||
}
|
||||
Err(AppError::CaptureNoSignal { kind }) => {
|
||||
debug!(
|
||||
"Pre-probe: no signal — encoder uses configured geometry until capture opens"
|
||||
);
|
||||
let status = SignalStatus::from_str(&kind).unwrap_or(SignalStatus::NoSignal);
|
||||
self.notify_state(PipelineStateNotification::no_signal(
|
||||
status,
|
||||
Some(Duration::from_secs(2).as_millis() as u64),
|
||||
));
|
||||
None
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
Some(s)
|
||||
}
|
||||
Err(AppError::CaptureNoSignal { kind }) => {
|
||||
debug!(
|
||||
"Pre-probe: no signal — encoder uses configured geometry until capture opens"
|
||||
);
|
||||
let status = SignalStatus::from_str(&kind).unwrap_or(SignalStatus::NoSignal);
|
||||
self.notify_state(PipelineStateNotification::no_signal(
|
||||
status,
|
||||
Some(Duration::from_secs(2).as_millis() as u64),
|
||||
));
|
||||
None
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let mut encoder_state = build_encoder_state(&config)?;
|
||||
let _ = self.running.send(true);
|
||||
@@ -707,10 +695,8 @@ impl SharedVideoPipeline {
|
||||
let latest_frame = latest_frame.clone();
|
||||
let frame_seq_tx = frame_seq_tx.clone();
|
||||
let buffer_pool = buffer_pool.clone();
|
||||
let bridge_ctx = BridgeContext::from_parts(
|
||||
subdev_path,
|
||||
parse_bridge_kind(bridge_kind.as_deref()),
|
||||
);
|
||||
let bridge_ctx =
|
||||
BridgeContext::from_parts(subdev_path, parse_bridge_kind(bridge_kind.as_deref()));
|
||||
std::thread::spawn(move || {
|
||||
let mut stream: Option<V4l2rCaptureStream> = None;
|
||||
let mut initial_geometry: Option<(Resolution, PixelFormat)> = None;
|
||||
@@ -874,7 +860,8 @@ impl SharedVideoPipeline {
|
||||
|
||||
// ── No usable stream? Try to (re)open, back off on failure. ──
|
||||
if stream.is_none() {
|
||||
match open_or_retry(&device_path, &config, buffer_count, bridge_ctx.clone()) {
|
||||
match open_or_retry(&device_path, &config, buffer_count, bridge_ctx.clone())
|
||||
{
|
||||
OpenResult::Opened(new_stream) => {
|
||||
let new_res = new_stream.resolution();
|
||||
let new_fmt = new_stream.format();
|
||||
@@ -884,7 +871,8 @@ impl SharedVideoPipeline {
|
||||
// encoder was sized to saved settings — if DV timings now
|
||||
// disagree, we cannot encode until WebRTC resyncs dimensions.
|
||||
if initial_geometry.is_none()
|
||||
&& (new_res != config.resolution || new_fmt != config.input_format)
|
||||
&& (new_res != config.resolution
|
||||
|| new_fmt != config.input_format)
|
||||
{
|
||||
info!(
|
||||
"Deferred capture open is {}x{} {:?} but encoder expects {}x{} {:?} — stopping for dimension resync",
|
||||
@@ -898,7 +886,8 @@ impl SharedVideoPipeline {
|
||||
pipeline.notify_state(PipelineStateNotification::device_busy(
|
||||
"config_changing",
|
||||
));
|
||||
*pipeline.pending_sync_geometry.lock() = Some((new_res, new_fmt));
|
||||
*pipeline.pending_sync_geometry.lock() =
|
||||
Some((new_res, new_fmt));
|
||||
let _ = pipeline.running.send(false);
|
||||
pipeline.running_flag.store(false, Ordering::Release);
|
||||
let _ = frame_seq_tx.send(sequence.wrapping_add(1));
|
||||
@@ -950,8 +939,7 @@ impl SharedVideoPipeline {
|
||||
);
|
||||
}
|
||||
OpenResult::NoSignal(status) => {
|
||||
consecutive_timeouts =
|
||||
consecutive_timeouts.saturating_add(1);
|
||||
consecutive_timeouts = consecutive_timeouts.saturating_add(1);
|
||||
if consecutive_timeouts >= CAPTURE_TIMEOUT_STOP_THRESHOLD {
|
||||
warn!(
|
||||
"Capture soft-restart gave up after {} attempts, \
|
||||
@@ -1092,9 +1080,7 @@ impl SharedVideoPipeline {
|
||||
}
|
||||
}
|
||||
|
||||
if consecutive_timeouts
|
||||
>= CAPTURE_TIMEOUT_SOFT_RESTART_THRESHOLD
|
||||
{
|
||||
if consecutive_timeouts >= CAPTURE_TIMEOUT_SOFT_RESTART_THRESHOLD {
|
||||
// Drop the stream so the next loop
|
||||
// iteration re-opens via the DV-timings
|
||||
// probe. This catches source-side
|
||||
@@ -1105,12 +1091,10 @@ impl SharedVideoPipeline {
|
||||
closing stream for soft-restart",
|
||||
consecutive_timeouts
|
||||
);
|
||||
pipeline.notify_state(
|
||||
PipelineStateNotification::no_signal(
|
||||
SignalStatus::UvcCaptureStall,
|
||||
Some(Duration::from_secs(2).as_millis() as u64),
|
||||
),
|
||||
);
|
||||
pipeline.notify_state(PipelineStateNotification::no_signal(
|
||||
SignalStatus::UvcCaptureStall,
|
||||
Some(Duration::from_secs(2).as_millis() as u64),
|
||||
));
|
||||
stream = None;
|
||||
continue;
|
||||
}
|
||||
@@ -1551,13 +1535,4 @@ mod tests {
|
||||
let h265 = SharedVideoPipelineConfig::h265(Resolution::HD720, BitratePreset::Speed);
|
||||
assert_eq!(h265.output_codec, VideoEncoderType::H265);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_startup_jpeg_validation_policy() {
|
||||
assert!(should_validate_jpeg_frame(1));
|
||||
assert!(should_validate_jpeg_frame(2));
|
||||
assert!(should_validate_jpeg_frame(3));
|
||||
assert!(!should_validate_jpeg_frame(4));
|
||||
assert!(should_validate_jpeg_frame(30));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ use crate::stream::MjpegStreamHandler;
|
||||
use crate::video::codec_constraints::StreamCodecConstraints;
|
||||
use crate::video::format::{PixelFormat, Resolution};
|
||||
use crate::video::is_csi_hdmi_bridge;
|
||||
use crate::video::streamer::{Streamer, StreamerStats, StreamerState};
|
||||
use crate::webrtc::WebRtcStreamer;
|
||||
use crate::video::streamer::{Streamer, StreamerState, StreamerStats};
|
||||
use crate::video::traits::VideoOutput;
|
||||
|
||||
/// Video stream manager configuration
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -95,8 +95,8 @@ pub struct VideoStreamManager {
|
||||
mode: RwLock<StreamMode>,
|
||||
/// MJPEG streamer (handles video capture and MJPEG distribution)
|
||||
streamer: Arc<Streamer>,
|
||||
/// WebRTC streamer (unified WebRTC manager with multi-codec support)
|
||||
webrtc_streamer: Arc<WebRtcStreamer>,
|
||||
/// WebRTC output (unified WebRTC manager with multi-codec support)
|
||||
webrtc_streamer: Arc<dyn VideoOutput>,
|
||||
/// Event bus for notifications
|
||||
events: RwLock<Option<Arc<EventBus>>>,
|
||||
/// Configuration store
|
||||
@@ -111,7 +111,7 @@ impl VideoStreamManager {
|
||||
/// Create a new video stream manager with WebRtcStreamer
|
||||
pub fn with_webrtc_streamer(
|
||||
streamer: Arc<Streamer>,
|
||||
webrtc_streamer: Arc<WebRtcStreamer>,
|
||||
webrtc_streamer: Arc<dyn VideoOutput>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
mode: RwLock::new(StreamMode::Mjpeg),
|
||||
@@ -175,11 +175,6 @@ impl VideoStreamManager {
|
||||
self.streamer.clone()
|
||||
}
|
||||
|
||||
/// Get the WebRTC streamer (unified interface with multi-codec support)
|
||||
pub fn webrtc_streamer(&self) -> Arc<WebRtcStreamer> {
|
||||
self.webrtc_streamer.clone()
|
||||
}
|
||||
|
||||
/// Get the MJPEG stream handler
|
||||
pub fn mjpeg_handler(&self) -> Arc<MjpegStreamHandler> {
|
||||
self.streamer.mjpeg_handler()
|
||||
@@ -812,6 +807,16 @@ impl VideoStreamManager {
|
||||
self.webrtc_streamer.get_pipeline_config().await
|
||||
}
|
||||
|
||||
/// Get current video codec type
|
||||
pub async fn current_video_codec(&self) -> crate::video::encoder::VideoCodecType {
|
||||
self.webrtc_streamer.current_video_codec().await
|
||||
}
|
||||
|
||||
/// Check if hardware encoding is in use
|
||||
pub async fn is_hardware_encoding(&self) -> bool {
|
||||
self.webrtc_streamer.is_hardware_encoding().await
|
||||
}
|
||||
|
||||
/// Set video codec for the shared video pipeline
|
||||
///
|
||||
/// This allows external consumers (like RustDesk) to set the video codec
|
||||
|
||||
@@ -12,7 +12,9 @@ use tokio::sync::RwLock;
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use super::csi_bridge;
|
||||
use super::device::{enumerate_devices, find_best_device, parse_bridge_kind, VideoDevice, VideoDeviceInfo};
|
||||
use super::device::{
|
||||
enumerate_devices, find_best_device, parse_bridge_kind, VideoDevice, VideoDeviceInfo,
|
||||
};
|
||||
use super::format::{PixelFormat, Resolution};
|
||||
use super::frame::{FrameBuffer, FrameBufferPool, VideoFrame};
|
||||
use super::is_csi_hdmi_bridge;
|
||||
@@ -20,13 +22,9 @@ use crate::error::{AppError, Result};
|
||||
use crate::events::{EventBus, SystemEvent};
|
||||
use crate::stream::MjpegStreamHandler;
|
||||
use crate::utils::LogThrottler;
|
||||
use crate::video::capture_limits::{should_validate_jpeg_frame, MIN_CAPTURE_FRAME_SIZE};
|
||||
use crate::video::v4l2r_capture::{is_source_changed_error, BridgeContext, V4l2rCaptureStream};
|
||||
|
||||
/// Minimum valid frame size for capture
|
||||
const MIN_CAPTURE_FRAME_SIZE: usize = 128;
|
||||
/// Validate JPEG header every N frames to reduce overhead
|
||||
const JPEG_VALIDATE_INTERVAL: u64 = 30;
|
||||
|
||||
/// Streamer configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamerConfig {
|
||||
@@ -477,11 +475,10 @@ impl Streamer {
|
||||
);
|
||||
return Ok(preferred);
|
||||
}
|
||||
let fmt = device
|
||||
.formats
|
||||
.first()
|
||||
.map(|f| f.format)
|
||||
.ok_or_else(|| AppError::VideoError("No supported formats found".to_string()))?;
|
||||
let fmt =
|
||||
device.formats.first().map(|f| f.format).ok_or_else(|| {
|
||||
AppError::VideoError("No supported formats found".to_string())
|
||||
})?;
|
||||
info!(
|
||||
"select_format: CSI bridge with signal, preferred {:?} unavailable, selected {:?} from {:?}",
|
||||
preferred,
|
||||
@@ -916,9 +913,7 @@ impl Streamer {
|
||||
"CSI open probe reports no signal ({:?}), will soft-restart",
|
||||
status
|
||||
);
|
||||
set_retry(
|
||||
backoff_secs(no_signal_restart_count).saturating_mul(1000),
|
||||
);
|
||||
set_retry(backoff_secs(no_signal_restart_count).saturating_mul(1000));
|
||||
go_offline();
|
||||
set_state(status.into());
|
||||
last_error = Some(format!("CaptureNoSignal({})", kind));
|
||||
@@ -952,8 +947,9 @@ impl Streamer {
|
||||
// restart path. This lets CSI bridges recover on their
|
||||
// own when the source comes back (resolution change,
|
||||
// host reboot, HDMI cable re-plug).
|
||||
let was_no_signal =
|
||||
handle.block_on(async { self.state().await }).is_no_signal_like();
|
||||
let was_no_signal = handle
|
||||
.block_on(async { self.state().await })
|
||||
.is_no_signal_like();
|
||||
if !was_no_signal {
|
||||
error!(
|
||||
"Failed to open device {:?}: {}",
|
||||
@@ -965,9 +961,7 @@ impl Streamer {
|
||||
break 'session;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Open failed in NoSignal-like state, backing off before soft-restart"
|
||||
);
|
||||
debug!("Open failed in NoSignal-like state, backing off before soft-restart");
|
||||
let wait = backoff_secs(no_signal_restart_count);
|
||||
set_retry(wait.saturating_mul(1000));
|
||||
std::thread::sleep(Duration::from_secs(wait));
|
||||
@@ -1040,9 +1034,7 @@ impl Streamer {
|
||||
Err(e) => {
|
||||
if is_source_changed_error(&e) {
|
||||
info!("Capture SOURCE_CHANGE — soft-restart for DV re-probe");
|
||||
set_retry(
|
||||
backoff_secs(no_signal_restart_count).saturating_mul(1000),
|
||||
);
|
||||
set_retry(backoff_secs(no_signal_restart_count).saturating_mul(1000));
|
||||
go_offline();
|
||||
set_state(StreamerState::NoSignal);
|
||||
need_soft_restart = true;
|
||||
@@ -1112,15 +1104,13 @@ impl Streamer {
|
||||
|
||||
if is_transient_signal_error {
|
||||
if os_err == Some(71) {
|
||||
warn!(
|
||||
"Capture transient error (EPROTO/-71, often UVC USB): {}",
|
||||
e
|
||||
);
|
||||
let is_uvc = handle.block_on(async {
|
||||
self.current_device.read().await.as_ref().is_some_and(
|
||||
|d| d.driver.eq_ignore_ascii_case("uvcvideo"),
|
||||
)
|
||||
});
|
||||
warn!("Capture transient error (EPROTO/-71, often UVC USB): {}", e);
|
||||
let is_uvc =
|
||||
handle.block_on(async {
|
||||
self.current_device.read().await.as_ref().is_some_and(|d| {
|
||||
d.driver.eq_ignore_ascii_case("uvcvideo")
|
||||
})
|
||||
});
|
||||
if is_uvc {
|
||||
go_offline();
|
||||
set_state(StreamerState::UvcUsbError);
|
||||
@@ -1133,9 +1123,7 @@ impl Streamer {
|
||||
e
|
||||
);
|
||||
}
|
||||
set_retry(
|
||||
backoff_secs(no_signal_restart_count).saturating_mul(1000),
|
||||
);
|
||||
set_retry(backoff_secs(no_signal_restart_count).saturating_mul(1000));
|
||||
go_offline();
|
||||
set_state(StreamerState::NoSignal);
|
||||
need_soft_restart = true;
|
||||
@@ -1165,7 +1153,7 @@ impl Streamer {
|
||||
|
||||
validate_counter = validate_counter.wrapping_add(1);
|
||||
if pixel_format.is_compressed()
|
||||
&& validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL)
|
||||
&& should_validate_jpeg_frame(validate_counter)
|
||||
&& !VideoFrame::is_valid_jpeg_bytes(&owned[..frame_size])
|
||||
{
|
||||
continue 'capture;
|
||||
@@ -1567,7 +1555,10 @@ fn probe_subdev_signal(
|
||||
let fd = match csi_bridge::open_subdev(subdev_path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
debug!("probe_subdev_signal: failed to open {:?}: {}", subdev_path, e);
|
||||
debug!(
|
||||
"probe_subdev_signal: failed to open {:?}: {}",
|
||||
subdev_path, e
|
||||
);
|
||||
return Some(crate::video::SignalStatus::NoSignal);
|
||||
}
|
||||
};
|
||||
@@ -1608,9 +1599,7 @@ fn wait_subdev_for_source_change(
|
||||
let wait = remaining.min(slice);
|
||||
match csi_bridge::wait_source_change(&fd, wait) {
|
||||
Ok(true) => {
|
||||
info!(
|
||||
"Subdev SOURCE_CHANGE during no-signal wait, retrying open immediately"
|
||||
);
|
||||
info!("Subdev SOURCE_CHANGE during no-signal wait, retrying open immediately");
|
||||
return;
|
||||
}
|
||||
Ok(false) => continue,
|
||||
|
||||
47
src/video/traits.rs
Normal file
47
src/video/traits.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Traits for video output consumers (WebRTC, RTSP, RustDesk, etc.)
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::types::{
|
||||
BitratePreset, PixelFormat, Resolution, SharedVideoPipeline, SharedVideoPipelineConfig,
|
||||
SharedVideoPipelineStats, VideoCodecType,
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::events::EventBus;
|
||||
use crate::hid::HidController;
|
||||
|
||||
/// Trait for video output consumers that receive encoded video frames.
|
||||
///
|
||||
/// Implemented by `WebRtcStreamer`. `VideoStreamManager` depends on this
|
||||
/// trait instead of the concrete type, breaking the video <-> webrtc
|
||||
/// circular import.
|
||||
#[async_trait::async_trait]
|
||||
pub trait VideoOutput: Send + Sync {
|
||||
async fn set_event_bus(&self, events: Arc<EventBus>);
|
||||
async fn update_video_config(&self, resolution: Resolution, format: PixelFormat, fps: u32);
|
||||
async fn set_capture_device(
|
||||
&self,
|
||||
device_path: PathBuf,
|
||||
jpeg_quality: u8,
|
||||
subdev_path: Option<PathBuf>,
|
||||
bridge_kind: Option<String>,
|
||||
v4l2_driver: Option<String>,
|
||||
);
|
||||
async fn current_video_codec(&self) -> VideoCodecType;
|
||||
async fn is_hardware_encoding(&self) -> bool;
|
||||
async fn close_all_sessions(&self);
|
||||
async fn close_all_sessions_and_release_device(&self) -> usize;
|
||||
async fn session_count(&self) -> usize;
|
||||
async fn set_hid_controller(&self, hid: Arc<HidController>);
|
||||
async fn set_audio_enabled(&self, enabled: bool) -> Result<()>;
|
||||
async fn is_audio_enabled(&self) -> bool;
|
||||
async fn reconnect_audio_sources(&self);
|
||||
async fn ensure_video_pipeline_for_external(&self) -> Result<Arc<SharedVideoPipeline>>;
|
||||
async fn get_pipeline_config(&self) -> Option<SharedVideoPipelineConfig>;
|
||||
async fn set_video_codec(&self, codec: VideoCodecType) -> Result<()>;
|
||||
async fn set_bitrate_preset(&self, preset: BitratePreset) -> Result<()>;
|
||||
async fn request_keyframe(&self) -> Result<()>;
|
||||
async fn current_video_geometry(&self) -> (Resolution, PixelFormat, u32);
|
||||
async fn pipeline_stats(&self) -> Option<SharedVideoPipelineStats>;
|
||||
}
|
||||
22
src/video/types.rs
Normal file
22
src/video/types.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! Re-exports of shared video types used by other modules (e.g., webrtc)
|
||||
//!
|
||||
//! External modules should import from `crate::video::types` instead of
|
||||
//! reaching into internal submodules directly.
|
||||
|
||||
// From video::format
|
||||
pub use super::format::{PixelFormat, Resolution};
|
||||
|
||||
// From video::frame
|
||||
pub use super::frame::VideoFrame;
|
||||
|
||||
// From video::encoder (codec-level types)
|
||||
pub use super::encoder::{BitratePreset, VideoCodecType};
|
||||
|
||||
// From video::encoder::registry
|
||||
pub use super::encoder::registry::{EncoderBackend, VideoEncoderType};
|
||||
|
||||
// From video::shared_video_pipeline
|
||||
pub use super::shared_video_pipeline::{
|
||||
EncodedVideoFrame, PipelineStateNotification, SharedVideoPipeline, SharedVideoPipelineConfig,
|
||||
SharedVideoPipelineStats,
|
||||
};
|
||||
@@ -129,9 +129,7 @@ impl V4l2rCaptureStream {
|
||||
subdev_dv_mode = Some(mode);
|
||||
}
|
||||
other => {
|
||||
let status = other
|
||||
.as_status()
|
||||
.unwrap_or(SignalStatus::NoSignal);
|
||||
let status = other.as_status().unwrap_or(SignalStatus::NoSignal);
|
||||
debug!(
|
||||
"Subdev {:?} reports no signal ({:?}) — refusing STREAMON",
|
||||
subdev_path, status
|
||||
@@ -200,7 +198,10 @@ impl V4l2rCaptureStream {
|
||||
}
|
||||
(Ok(f), _, _) => f,
|
||||
(Err(e), _, _) => {
|
||||
return Err(AppError::VideoError(format!("Failed to get device format: {}", e)));
|
||||
return Err(AppError::VideoError(format!(
|
||||
"Failed to get device format: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -446,10 +447,7 @@ impl V4l2rCaptureStream {
|
||||
let mut poll_fds: Vec<PollFd> = Vec::with_capacity(2);
|
||||
poll_fds.push(PollFd::new(
|
||||
self.fd.as_fd(),
|
||||
PollFlags::POLLIN
|
||||
| PollFlags::POLLPRI
|
||||
| PollFlags::POLLERR
|
||||
| PollFlags::POLLHUP,
|
||||
PollFlags::POLLIN | PollFlags::POLLPRI | PollFlags::POLLERR | PollFlags::POLLHUP,
|
||||
));
|
||||
if let Some(subdev_fd) = self.subdev_fd.as_ref() {
|
||||
poll_fds.push(PollFd::new(subdev_fd.as_fd(), PollFlags::POLLPRI));
|
||||
|
||||
@@ -31,12 +31,8 @@ use tracing::{debug, info, warn};
|
||||
use crate::audio::OpusFrame;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Audio packet type identifier
|
||||
const AUDIO_PACKET_TYPE: u8 = 0x02;
|
||||
|
||||
/// Audio WebSocket upgrade handler
|
||||
///
|
||||
/// Upgrades HTTP connection to WebSocket for audio streaming.
|
||||
pub async fn audio_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -44,16 +40,13 @@ pub async fn audio_ws_handler(
|
||||
ws.on_upgrade(move |socket| handle_audio_socket(socket, state))
|
||||
}
|
||||
|
||||
/// Handle audio WebSocket connection
|
||||
async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Try to get Opus frame subscription
|
||||
let opus_rx = match state.audio.subscribe_opus_async().await {
|
||||
let opus_rx = match state.audio.subscribe_opus().await {
|
||||
Some(rx) => rx,
|
||||
None => {
|
||||
warn!("Audio not streaming, rejecting WebSocket connection");
|
||||
// Send error message before closing
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
r#"{"error": "Audio not streaming"}"#.to_string().into(),
|
||||
@@ -68,16 +61,13 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
|
||||
info!("Audio WebSocket client connected");
|
||||
|
||||
// Track connection for cleanup
|
||||
let mut closed = false;
|
||||
|
||||
// Use interval instead of sleep for more efficient keepalive
|
||||
let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive Opus frames and send to client
|
||||
opus_result = opus_rx.recv() => {
|
||||
let frame = match opus_result {
|
||||
Some(f) => f,
|
||||
@@ -94,7 +84,6 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle client messages (ping/close)
|
||||
msg = receiver.next() => {
|
||||
match msg {
|
||||
Some(Ok(Message::Close(_))) => {
|
||||
@@ -107,11 +96,8 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Pong(_))) => {
|
||||
// Pong received, connection is alive
|
||||
}
|
||||
Some(Ok(Message::Pong(_))) => {}
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
// Handle potential control messages
|
||||
debug!("Received text message on audio WS: {}", text);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
@@ -119,14 +105,12 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Connection closed
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Periodic ping to keep connection alive (using interval)
|
||||
_ = ping_interval.tick() => {
|
||||
if sender.send(Message::Ping(vec![].into())).await.is_err() {
|
||||
warn!("Failed to send ping, disconnecting");
|
||||
@@ -137,39 +121,24 @@ async fn handle_audio_socket(socket: WebSocket, state: Arc<AppState>) {
|
||||
}
|
||||
|
||||
if !closed {
|
||||
// Try to send close message
|
||||
let _ = sender.send(Message::Close(None)).await;
|
||||
}
|
||||
|
||||
info!("Audio WebSocket client disconnected");
|
||||
}
|
||||
|
||||
/// Encode Opus frame to binary packet format
|
||||
///
|
||||
/// ## Format
|
||||
///
|
||||
/// | Offset | Size | Description |
|
||||
/// |--------|------|-------------|
|
||||
/// | 0 | 1 | Packet type (0x02 for audio) |
|
||||
/// | 1 | 4 | Timestamp (u32 LE, ms since start) |
|
||||
/// | 5 | 2 | Duration (u16 LE, ms) |
|
||||
/// | 7 | 4 | Sequence number (u32 LE) |
|
||||
/// | 11 | 4 | Data length (u32 LE) |
|
||||
/// | 15 | N | Opus encoded data |
|
||||
fn encode_audio_packet(frame: &OpusFrame, stream_start: Instant) -> Vec<u8> {
|
||||
let timestamp_ms = stream_start.elapsed().as_millis() as u32;
|
||||
let data_len = frame.data.len() as u32;
|
||||
|
||||
let mut buf = Vec::with_capacity(15 + frame.data.len());
|
||||
|
||||
// Header
|
||||
buf.push(AUDIO_PACKET_TYPE);
|
||||
buf.extend_from_slice(×tamp_ms.to_le_bytes());
|
||||
buf.extend_from_slice(&(frame.duration_ms as u16).to_le_bytes());
|
||||
buf.extend_from_slice(&(frame.sequence as u32).to_le_bytes());
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
|
||||
// Opus data
|
||||
buf.extend_from_slice(&frame.data);
|
||||
|
||||
buf
|
||||
@@ -186,8 +155,6 @@ mod tests {
|
||||
data: Bytes::from(vec![1, 2, 3, 4, 5]),
|
||||
duration_ms: 20,
|
||||
sequence: 42,
|
||||
timestamp: Instant::now(),
|
||||
rtp_timestamp: 0,
|
||||
};
|
||||
|
||||
let stream_start = Instant::now();
|
||||
@@ -195,11 +162,5 @@ mod tests {
|
||||
|
||||
assert!(encoded.len() >= 15);
|
||||
assert_eq!(encoded[0], AUDIO_PACKET_TYPE);
|
||||
// decode_audio_packet function was removed, skip decode test
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_invalid_packet() {
|
||||
// decode_audio_packet function was removed, skip this test
|
||||
}
|
||||
}
|
||||
|
||||
31
src/web/error.rs
Normal file
31
src/web/error.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use crate::error::AppError;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let body = ErrorResponse {
|
||||
success: false,
|
||||
message: self.to_string(),
|
||||
};
|
||||
|
||||
tracing::error!(
|
||||
error_type = std::any::type_name_of_val(&self),
|
||||
error_message = %body.message,
|
||||
"Request failed"
|
||||
);
|
||||
|
||||
// Always return 200 OK - success/failure is indicated by the success field
|
||||
(StatusCode::OK, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,10 @@
|
||||
//! 配置热重载逻辑
|
||||
//!
|
||||
//! 从 handlers.rs 中抽取的配置应用函数,负责将配置变更应用到各个子系统。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::*;
|
||||
use crate::error::{AppError, Result};
|
||||
use crate::rtsp::RtspService;
|
||||
use crate::state::AppState;
|
||||
use crate::stream_encoder::encoder_type_to_backend;
|
||||
use crate::video::codec_constraints::{
|
||||
enforce_constraints_with_stream_manager, StreamCodecConstraints,
|
||||
};
|
||||
@@ -32,13 +29,11 @@ async fn reconcile_otg_from_store(state: &Arc<AppState>) -> Result<()> {
|
||||
.map_err(|e| AppError::Config(format!("OTG reconcile failed: {}", e)))
|
||||
}
|
||||
|
||||
/// 应用 Video 配置变更
|
||||
pub async fn apply_video_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &VideoConfig,
|
||||
new_config: &VideoConfig,
|
||||
) -> Result<()> {
|
||||
// 检查配置是否实际变更
|
||||
if old_config == new_config {
|
||||
tracing::info!("Video config unchanged, skipping reload");
|
||||
return Ok(());
|
||||
@@ -74,7 +69,6 @@ pub async fn apply_video_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 Stream 配置变更
|
||||
pub async fn apply_stream_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &StreamConfig,
|
||||
@@ -82,32 +76,24 @@ pub async fn apply_stream_config(
|
||||
) -> Result<()> {
|
||||
tracing::info!("Applying stream config changes...");
|
||||
|
||||
// 更新编码器后端
|
||||
if old_config.encoder != new_config.encoder {
|
||||
let encoder_backend = new_config.encoder.to_backend();
|
||||
let encoder_backend = encoder_type_to_backend(new_config.encoder.clone());
|
||||
tracing::info!(
|
||||
"Updating encoder backend to: {:?} (from config: {:?})",
|
||||
encoder_backend,
|
||||
new_config.encoder
|
||||
);
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.update_encoder_backend(encoder_backend)
|
||||
.await;
|
||||
state.webrtc.update_encoder_backend(encoder_backend).await;
|
||||
}
|
||||
|
||||
// 更新码率
|
||||
if old_config.bitrate_preset != new_config.bitrate_preset {
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.set_bitrate_preset(new_config.bitrate_preset)
|
||||
.await
|
||||
.ok(); // Ignore error if no active stream
|
||||
}
|
||||
|
||||
// 更新 ICE 配置 (STUN/TURN)
|
||||
let ice_changed = old_config.stun_server != new_config.stun_server
|
||||
|| old_config.turn_server != new_config.turn_server
|
||||
|| old_config.turn_username != new_config.turn_username
|
||||
@@ -120,8 +106,7 @@ pub async fn apply_stream_config(
|
||||
new_config.turn_server
|
||||
);
|
||||
state
|
||||
.stream_manager
|
||||
.webrtc_streamer()
|
||||
.webrtc
|
||||
.update_ice_config(
|
||||
new_config.stun_server.clone(),
|
||||
new_config.turn_server.clone(),
|
||||
@@ -139,7 +124,6 @@ pub async fn apply_stream_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 HID 配置变更
|
||||
pub async fn apply_hid_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &HidConfig,
|
||||
@@ -202,7 +186,6 @@ pub async fn apply_hid_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 MSD 配置变更
|
||||
pub async fn apply_msd_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &MsdConfig,
|
||||
@@ -218,7 +201,6 @@ pub async fn apply_msd_config(
|
||||
tracing::debug!("Old MSD config: {:?}", old_config);
|
||||
tracing::debug!("New MSD config: {:?}", new_config);
|
||||
|
||||
// Check if MSD enabled state changed
|
||||
let old_msd_enabled = old_config.enabled;
|
||||
let new_msd_enabled = new_config.enabled;
|
||||
let msd_dir_changed = old_config.msd_dir != new_config.msd_dir;
|
||||
@@ -232,7 +214,6 @@ pub async fn apply_msd_config(
|
||||
tracing::info!("MSD directory changed: {}", new_config.msd_dir);
|
||||
}
|
||||
|
||||
// Ensure MSD directories exist (msd/images, msd/ventoy)
|
||||
let msd_dir = new_config.msd_dir_path();
|
||||
if let Err(e) = std::fs::create_dir_all(msd_dir.join("images")) {
|
||||
tracing::warn!("Failed to create MSD images directory: {}", e);
|
||||
@@ -255,7 +236,6 @@ pub async fn apply_msd_config(
|
||||
|
||||
reconcile_otg_from_store(state).await?;
|
||||
|
||||
// Shutdown existing controller if present
|
||||
let mut msd_guard = state.msd.write().await;
|
||||
if let Some(msd) = msd_guard.as_mut() {
|
||||
if let Err(e) = msd.shutdown().await {
|
||||
@@ -271,15 +251,12 @@ pub async fn apply_msd_config(
|
||||
.await
|
||||
.map_err(|e| AppError::Config(format!("MSD initialization failed: {}", e)))?;
|
||||
|
||||
// Set event bus
|
||||
let events = state.events.clone();
|
||||
msd.set_event_bus(events).await;
|
||||
|
||||
// Store the initialized controller
|
||||
*state.msd.write().await = Some(msd);
|
||||
tracing::info!("MSD initialized successfully");
|
||||
} else {
|
||||
// MSD disabled - shutdown
|
||||
tracing::info!("MSD disabled in config, shutting down...");
|
||||
|
||||
let mut msd_guard = state.msd.write().await;
|
||||
@@ -306,7 +283,6 @@ pub async fn apply_msd_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 ATX 配置变更
|
||||
pub async fn apply_atx_config(
|
||||
state: &Arc<AppState>,
|
||||
_old_config: &AtxConfig,
|
||||
@@ -314,10 +290,8 @@ pub async fn apply_atx_config(
|
||||
) -> Result<()> {
|
||||
tracing::info!("Applying ATX config changes...");
|
||||
|
||||
// Convert AtxConfig to AtxControllerConfig
|
||||
let controller_config = new_config.to_controller_config();
|
||||
|
||||
// Reload the ATX controller with new configuration
|
||||
let atx_guard = state.atx.read().await;
|
||||
if let Some(atx) = atx_guard.as_ref() {
|
||||
if let Err(e) = atx.reload(controller_config).await {
|
||||
@@ -326,7 +300,6 @@ pub async fn apply_atx_config(
|
||||
}
|
||||
tracing::info!("ATX controller reloaded successfully");
|
||||
} else {
|
||||
// ATX controller not initialized, create a new one if enabled
|
||||
drop(atx_guard);
|
||||
|
||||
if new_config.enabled {
|
||||
@@ -345,7 +318,6 @@ pub async fn apply_atx_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 Audio 配置变更
|
||||
pub async fn apply_audio_config(
|
||||
state: &Arc<AppState>,
|
||||
_old_config: &AudioConfig,
|
||||
@@ -353,17 +325,14 @@ pub async fn apply_audio_config(
|
||||
) -> Result<()> {
|
||||
tracing::info!("Applying audio config changes...");
|
||||
|
||||
// Create audio controller config from new config
|
||||
let audio_config = crate::audio::AudioControllerConfig {
|
||||
enabled: new_config.enabled,
|
||||
device: new_config.device.clone(),
|
||||
quality: crate::audio::AudioQuality::from_str(&new_config.quality),
|
||||
quality: new_config.quality.parse::<crate::audio::AudioQuality>()?,
|
||||
};
|
||||
|
||||
// Update audio controller
|
||||
if let Err(e) = state.audio.update_config(audio_config).await {
|
||||
tracing::error!("Audio config update failed: {}", e);
|
||||
// Don't fail - audio errors are not critical
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Audio config applied: enabled={}, device={}",
|
||||
@@ -372,7 +341,6 @@ pub async fn apply_audio_config(
|
||||
);
|
||||
}
|
||||
|
||||
// Also update WebRTC audio enabled state
|
||||
if let Err(e) = state
|
||||
.stream_manager
|
||||
.set_webrtc_audio_enabled(new_config.enabled)
|
||||
@@ -383,7 +351,6 @@ pub async fn apply_audio_config(
|
||||
tracing::info!("WebRTC audio enabled: {}", new_config.enabled);
|
||||
}
|
||||
|
||||
// Reconnect audio sources for existing WebRTC sessions
|
||||
if new_config.enabled {
|
||||
state.stream_manager.reconnect_webrtc_audio_sources().await;
|
||||
}
|
||||
@@ -391,7 +358,6 @@ pub async fn apply_audio_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply stream codec constraints derived from global config.
|
||||
pub async fn enforce_stream_codec_constraints(state: &Arc<AppState>) -> Result<Option<String>> {
|
||||
let config = state.config.get();
|
||||
let constraints = StreamCodecConstraints::from_config(&config);
|
||||
@@ -400,7 +366,6 @@ pub async fn enforce_stream_codec_constraints(state: &Arc<AppState>) -> Result<O
|
||||
Ok(enforcement.message)
|
||||
}
|
||||
|
||||
/// 应用 RustDesk 配置变更
|
||||
pub async fn apply_rustdesk_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &crate::rustdesk::config::RustDeskConfig,
|
||||
@@ -411,9 +376,7 @@ pub async fn apply_rustdesk_config(
|
||||
let mut rustdesk_guard = state.rustdesk.write().await;
|
||||
let mut credentials_to_save = None;
|
||||
|
||||
// Check if service needs to be stopped
|
||||
if old_config.enabled && !new_config.enabled {
|
||||
// Disable service
|
||||
if let Some(ref service) = *rustdesk_guard {
|
||||
if let Err(e) = service.stop().await {
|
||||
tracing::error!("Failed to stop RustDesk service: {}", e);
|
||||
@@ -423,14 +386,12 @@ pub async fn apply_rustdesk_config(
|
||||
*rustdesk_guard = None;
|
||||
}
|
||||
|
||||
// Check if service needs to be started or restarted
|
||||
if new_config.enabled {
|
||||
let need_restart = old_config.rendezvous_server != new_config.rendezvous_server
|
||||
|| old_config.device_id != new_config.device_id
|
||||
|| old_config.device_password != new_config.device_password;
|
||||
|
||||
if rustdesk_guard.is_none() {
|
||||
// Create new service
|
||||
tracing::info!("Initializing RustDesk service...");
|
||||
let service = crate::rustdesk::RustDeskService::new(
|
||||
new_config.clone(),
|
||||
@@ -442,12 +403,10 @@ pub async fn apply_rustdesk_config(
|
||||
tracing::error!("Failed to start RustDesk service: {}", e);
|
||||
} else {
|
||||
tracing::info!("RustDesk service started with ID: {}", new_config.device_id);
|
||||
// Save generated keypair and UUID to config
|
||||
credentials_to_save = service.save_credentials();
|
||||
}
|
||||
*rustdesk_guard = Some(std::sync::Arc::new(service));
|
||||
} else if need_restart {
|
||||
// Restart existing service with new config
|
||||
if let Some(ref service) = *rustdesk_guard {
|
||||
if let Err(e) = service.restart(new_config.clone()).await {
|
||||
tracing::error!("Failed to restart RustDesk service: {}", e);
|
||||
@@ -456,14 +415,12 @@ pub async fn apply_rustdesk_config(
|
||||
"RustDesk service restarted with ID: {}",
|
||||
new_config.device_id
|
||||
);
|
||||
// Save generated keypair and UUID to config
|
||||
credentials_to_save = service.save_credentials();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save credentials to persistent config store (outside the lock)
|
||||
drop(rustdesk_guard);
|
||||
if let Some(updated_config) = credentials_to_save {
|
||||
tracing::info!("Saving RustDesk credentials to config store...");
|
||||
@@ -491,7 +448,6 @@ pub async fn apply_rustdesk_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 应用 RTSP 配置变更
|
||||
pub async fn apply_rtsp_config(
|
||||
state: &Arc<AppState>,
|
||||
old_config: &RtspConfig,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! ATX configuration handlers
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -11,29 +9,23 @@ use crate::state::AppState;
|
||||
use super::apply::apply_atx_config;
|
||||
use super::types::AtxConfigUpdate;
|
||||
|
||||
/// Get ATX configuration
|
||||
pub async fn get_atx_config(State(state): State<Arc<AppState>>) -> Json<AtxConfig> {
|
||||
Json(state.config.get().atx.clone())
|
||||
}
|
||||
|
||||
/// Update ATX configuration
|
||||
pub async fn update_atx_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<AtxConfigUpdate>,
|
||||
) -> Result<Json<AtxConfig>> {
|
||||
// 1. Read current configuration snapshot
|
||||
let current_config = state.config.get();
|
||||
let old_atx_config = current_config.atx.clone();
|
||||
|
||||
// 2. Validate request, including merged effective serial parameter checks
|
||||
req.validate_with_current(&old_atx_config)?;
|
||||
|
||||
// 3. Ensure ATX serial devices do not conflict with HID CH9329 serial device
|
||||
let mut merged_atx_config = old_atx_config.clone();
|
||||
req.apply_to(&mut merged_atx_config);
|
||||
validate_serial_device_conflict(&merged_atx_config, ¤t_config.hid)?;
|
||||
|
||||
// 4. Persist update into config store
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
@@ -41,10 +33,8 @@ pub async fn update_atx_config(
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 5. Load new config
|
||||
let new_atx_config = state.config.get().atx.clone();
|
||||
|
||||
// 6. Apply to subsystem (hot reload)
|
||||
if let Err(e) = apply_atx_config(&state, &old_atx_config, &new_atx_config).await {
|
||||
tracing::error!("Failed to apply ATX config: {}", e);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//! Audio 配置 Handler
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -10,23 +8,18 @@ use crate::state::AppState;
|
||||
use super::apply::apply_audio_config;
|
||||
use super::types::AudioConfigUpdate;
|
||||
|
||||
/// 获取 Audio 配置
|
||||
pub async fn get_audio_config(State(state): State<Arc<AppState>>) -> Json<AudioConfig> {
|
||||
Json(state.config.get().audio.clone())
|
||||
}
|
||||
|
||||
/// 更新 Audio 配置
|
||||
pub async fn update_audio_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<AudioConfigUpdate>,
|
||||
) -> Result<Json<AudioConfig>> {
|
||||
// 1. 验证请求
|
||||
req.validate()?;
|
||||
|
||||
// 2. 获取旧配置
|
||||
let old_audio_config = state.config.get().audio.clone();
|
||||
|
||||
// 3. 应用更新到配置存储
|
||||
state
|
||||
.config
|
||||
.update(|config| {
|
||||
@@ -34,10 +27,8 @@ pub async fn update_audio_config(
|
||||
})
|
||||
.await?;
|
||||
|
||||
// 4. 获取新配置
|
||||
let new_audio_config = state.config.get().audio.clone();
|
||||
|
||||
// 5. 应用到子系统(热重载)
|
||||
if let Err(e) = apply_audio_config(&state, &old_audio_config, &new_audio_config).await {
|
||||
tracing::error!("Failed to apply audio config: {}", e);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user