feat: 初步增加 Windows 支持

This commit is contained in:
mofeng-git
2026-05-18 22:43:28 +08:00
parent 0b9d94f53f
commit 935fa823f2
163 changed files with 11419 additions and 7581 deletions

View File

@@ -63,12 +63,6 @@ clap = { version = "4", features = ["derive"] }
# Time (cookie max_age + RFC3339 timestamps)
time = { version = "0.3", features = ["serde", "formatting", "parsing"] }
# Video capture (V4L2)
v4l2r = "0.0.7"
# JPEG encoding (libjpeg-turbo, SIMD accelerated)
turbojpeg = "1.3"
# Bytes handling
bytes = "1"
bytemuck = { version = "1.24", features = ["derive"] }
@@ -93,11 +87,6 @@ rtp = "0.14"
rtsp-types = "0.1"
sdp-types = "0.1"
# Audio (ALSA capture + Opus encoding)
# Note: audiopus links to libopus.so (unavoidable for audio support)
alsa = "0.11"
audiopus = "0.2"
# HID (serial port for CH9329)
serialport = "4"
async-trait = "0.1"
@@ -106,22 +95,42 @@ libc = "0.2"
# Ventoy bootable image support
ventoy-img = { path = "libs/ventoy-img-rs" }
# ATX (GPIO control)
gpio-cdev = "0.6"
# H264 hardware/software encoding (hwcodec from rustdesk)
hwcodec = { path = "libs/hwcodec" }
# RustDesk protocol support
protobuf = { version = "3.7", features = ["with-bytes"] }
sodiumoxide = "0.2"
sha2 = "0.10"
# High-performance pixel format conversion (libyuv)
libyuv = { path = "res/vcpkg/libyuv" }
# TypeScript type generation
typeshare = "1.0"
[target.'cfg(any(unix, windows))'.dependencies]
# Video encoding/decoding (FFmpeg/libjpeg-turbo/libyuv; available on Windows and Linux)
hwcodec = { path = "libs/hwcodec" }
libyuv = { path = "res/vcpkg/libyuv" }
turbojpeg = "1.3"
# Note: audiopus links to libopus.so (unavoidable for audio support)
audiopus = "0.2"
[target.'cfg(unix)'.dependencies]
# Video capture (V4L2)
v4l2r = "0.0.7"
# Audio (ALSA capture)
alsa = "0.11"
# ATX (GPIO control)
gpio-cdev = "0.6"
[target.'cfg(windows)'.dependencies]
cpal = { version = "0.17", default-features = false }
windows-sys = { version = "0.61", features = [
"Win32_Foundation",
"Win32_NetworkManagement_IpHelper",
"Win32_NetworkManagement_Ndis",
"Win32_Networking_WinSock",
"Win32_System_SystemInformation",
"Win32_System_Threading",
] }
[dev-dependencies]
tempfile = "3"

14
agents.md Normal file
View File

@@ -0,0 +1,14 @@
# Agents Notes
## Windows MSVC Build
Run from the repository root in PowerShell:
```powershell
$env:VCPKG_ROOT='C:\Users\mofen\code\vcpkg'
$env:TURBOJPEG_SOURCE='explicit'
$env:TURBOJPEG_LIB_DIR='C:\Users\mofen\code\vcpkg\installed\x64-windows-static\lib'
$env:TURBOJPEG_INCLUDE_DIR='C:\Users\mofen\code\vcpkg\installed\x64-windows-static\include'
cargo build --target x86_64-pc-windows-msvc
```

View File

@@ -34,7 +34,9 @@ fn build_common(builder: &mut Build) {
// system
#[cfg(windows)]
{
["d3d11", "dxgi"].map(|lib| println!("cargo:rustc-link-lib={}", lib));
for lib in ["d3d11", "dxgi"] {
println!("cargo:rustc-link-lib={}", lib);
}
}
builder.include(&common_dir);
@@ -99,6 +101,7 @@ mod ffmpeg {
link_os();
build_ffmpeg_ram(builder);
build_ffmpeg_hw(builder);
build_ffmpeg_capture(builder);
}
/// Link system FFmpeg using pkg-config or custom path
@@ -282,15 +285,24 @@ mod ffmpeg {
)
);
{
// Only need avcodec and avutil for encoding
// avdevice/avformat are needed by the Windows DirectShow capture bridge.
let mut static_libs = vec!["avcodec", "avutil"];
if target_os == "windows" {
static_libs.push("libmfx");
static_libs.extend([
"avformat",
"avdevice",
"avfilter",
"swresample",
"swscale",
"vpx",
"libx264",
"x265-static",
"libmfx",
]);
}
for lib in static_libs {
println!("cargo:rustc-link-lib=static={}", lib);
}
static_libs
.iter()
.map(|lib| println!("cargo:rustc-link-lib=static={}", lib))
.count();
}
let include = path.join("include");
@@ -304,7 +316,10 @@ mod ffmpeg {
let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let dyn_libs: Vec<&str> = if target_os == "windows" {
["User32", "bcrypt", "ole32", "advapi32"].to_vec()
[
"User32", "bcrypt", "ole32", "advapi32", "mfuuid", "strmiids",
]
.to_vec()
} else if target_os == "linux" {
// Base libraries for all Linux platforms
let mut v = vec!["drm", "stdc++"];
@@ -375,6 +390,34 @@ mod ffmpeg {
}
}
fn build_ffmpeg_capture(builder: &mut Build) {
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
if target_os != "windows" {
return;
}
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let capture_header = manifest_dir
.join("cpp")
.join("ffmpeg_capture_ffi.h")
.to_string_lossy()
.to_string();
bindgen::builder()
.header(capture_header)
.rustified_enum("*")
.generate()
.unwrap()
.write_to_file(
Path::new(&env::var_os("OUT_DIR").unwrap()).join("ffmpeg_capture_ffi.rs"),
)
.unwrap();
builder.file(manifest_dir.join("cpp").join("ffmpeg_capture.cpp"));
println!("cargo:rustc-link-lib=strmiids");
println!("cargo:rustc-link-lib=oleaut32");
println!("cargo:rustc-link-lib=quartz");
}
fn build_ffmpeg_hw(builder: &mut Build) {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let ffmpeg_hw_dir = manifest_dir.join("cpp").join("ffmpeg_hw");

View File

@@ -1,4 +1,5 @@
extern "C" {
#include <libavcodec/avcodec.h>
#include <libavutil/opt.h>
}
@@ -99,13 +100,12 @@ void set_av_codec_ctx(AVCodecContext *c, const std::string &name, int kbs,
c->color_primaries = AVCOL_PRI_SMPTE170M;
c->color_trc = AVCOL_TRC_SMPTE170M;
// Profile selection: use BASELINE for software H264 (faster, simpler)
if (is_software_h264(name)) {
c->profile = FF_PROFILE_H264_BASELINE; // Simpler profile for real-time
} else if (name.find("h264") != std::string::npos) {
c->profile = FF_PROFILE_H264_HIGH;
// WebRTC SDP advertises constrained baseline. Keep hardware and software
// encoders on the same browser-friendly H264 profile.
if (name.find("h264") != std::string::npos) {
c->profile = AV_PROFILE_H264_CONSTRAINED_BASELINE;
} else if (name.find("hevc") != std::string::npos) {
c->profile = FF_PROFILE_HEVC_MAIN;
c->profile = AV_PROFILE_HEVC_MAIN;
}
}
@@ -120,8 +120,7 @@ bool set_lantency_free(void *priv_data, const std::string &name) {
}
if (name.find("amf") != std::string::npos) {
if ((ret = av_opt_set(priv_data, "query_timeout", "1000", 0)) < 0) {
LOG_ERROR(std::string("amf set_lantency_free failed, ret = ") + av_err2str(ret));
return false;
LOG_WARN(std::string("amf query_timeout option is unavailable, ret = ") + av_err2str(ret));
}
}
if (name.find("qsv") != std::string::npos) {

View File

@@ -0,0 +1,879 @@
#define NOMINMAX
#include "ffmpeg_capture_ffi.h"
#include <Windows.h>
#include <dshow.h>
#include <dvdmedia.h>
extern "C" {
#include <libavcodec/codec_id.h>
#include <libavdevice/avdevice.h>
#include <libavformat/avformat.h>
#include <libavutil/avutil.h>
#include <libavutil/error.h>
#include <libavutil/pixfmt.h>
}
#include <atomic>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>
#pragma comment(lib, "strmiids")
thread_local std::string g_last_error;
struct HwcodecDshowCaptureContext {
AVFormatContext* format_ctx = nullptr;
int stream_index = -1;
int width = 0;
int height = 0;
int pixel_format = HWCODEC_CAPTURE_FMT_UNKNOWN;
int stride = 0;
int timeout_ms = 2000;
std::atomic<long long> deadline_ms{0};
std::atomic<int> timed_out{0};
uint64_t sequence = 0;
};
namespace {
struct DshowCapabilityEntry {
std::string format;
int width = 0;
int height = 0;
std::vector<int> fps;
};
const char* requested_pixel_format_name(int requested_format);
void set_last_error(const std::string& message) {
g_last_error = message;
}
std::string ffmpeg_error(int errnum) {
char buffer[AV_ERROR_MAX_STRING_SIZE] = {0};
av_make_error_string(buffer, sizeof(buffer), errnum);
return std::string(buffer);
}
long long now_ms() {
return static_cast<long long>(GetTickCount64());
}
std::string wide_to_utf8(const wchar_t* value) {
if (!value) {
return std::string();
}
int size = WideCharToMultiByte(CP_UTF8, 0, value, -1, nullptr, 0, nullptr, nullptr);
if (size <= 1) {
return std::string();
}
std::string result(static_cast<size_t>(size - 1), '\0');
WideCharToMultiByte(
CP_UTF8,
0,
value,
-1,
result.empty() ? nullptr : &result[0],
size,
nullptr,
nullptr);
return result;
}
void add_fps_candidate(std::vector<int>* fps, LONGLONG interval_100ns) {
if (!fps || interval_100ns <= 0) {
return;
}
double fps_value = 10000000.0 / static_cast<double>(interval_100ns);
int rounded = static_cast<int>(fps_value + 0.5);
if (rounded <= 0) {
return;
}
if (std::find(fps->begin(), fps->end(), rounded) == fps->end()) {
fps->push_back(rounded);
}
}
void normalize_fps(std::vector<int>* fps) {
if (!fps) {
return;
}
std::sort(fps->begin(), fps->end(), std::greater<int>());
fps->erase(std::unique(fps->begin(), fps->end()), fps->end());
}
const char* media_subtype_to_format(const GUID& subtype) {
if (subtype == MEDIASUBTYPE_MJPG) {
return "MJPEG";
}
if (subtype == MEDIASUBTYPE_YUY2) {
return "YUYV";
}
if (subtype == MEDIASUBTYPE_UYVY) {
return "UYVY";
}
if (subtype == MEDIASUBTYPE_YVYU) {
return "YVYU";
}
if (subtype == MEDIASUBTYPE_NV12) {
return "NV12";
}
if (subtype == MEDIASUBTYPE_RGB24) {
return "RGB24";
}
if (subtype == MEDIASUBTYPE_RGB32) {
return "BGR24";
}
if (subtype == MEDIASUBTYPE_IYUV) {
return "YUV420";
}
if (subtype == MEDIASUBTYPE_YV12) {
return "YVU420";
}
return nullptr;
}
void free_media_type(AM_MEDIA_TYPE* media_type) {
if (!media_type) {
return;
}
if (media_type->cbFormat != 0) {
CoTaskMemFree(media_type->pbFormat);
media_type->cbFormat = 0;
media_type->pbFormat = nullptr;
}
if (media_type->pUnk != nullptr) {
media_type->pUnk->Release();
media_type->pUnk = nullptr;
}
CoTaskMemFree(media_type);
}
bool fill_capability_entry(
const AM_MEDIA_TYPE* media_type,
const VIDEO_STREAM_CONFIG_CAPS* caps,
DshowCapabilityEntry* out_entry) {
if (!media_type || !out_entry) {
return false;
}
const char* format = media_subtype_to_format(media_type->subtype);
if (!format) {
return false;
}
LONG width = 0;
LONG height = 0;
REFERENCE_TIME avg_time_per_frame = 0;
if (media_type->formattype == FORMAT_VideoInfo && media_type->pbFormat &&
media_type->cbFormat >= sizeof(VIDEOINFOHEADER)) {
const auto* info = reinterpret_cast<const VIDEOINFOHEADER*>(media_type->pbFormat);
width = info->bmiHeader.biWidth;
height = std::abs(info->bmiHeader.biHeight);
avg_time_per_frame = info->AvgTimePerFrame;
} else if (media_type->formattype == FORMAT_VideoInfo2 && media_type->pbFormat &&
media_type->cbFormat >= sizeof(VIDEOINFOHEADER2)) {
const auto* info = reinterpret_cast<const VIDEOINFOHEADER2*>(media_type->pbFormat);
width = info->bmiHeader.biWidth;
height = std::abs(info->bmiHeader.biHeight);
avg_time_per_frame = info->AvgTimePerFrame;
}
if ((width <= 0 || height <= 0) && caps) {
width = std::max<LONG>(caps->InputSize.cx, caps->MinOutputSize.cx);
height = std::max<LONG>(caps->InputSize.cy, caps->MinOutputSize.cy);
if (width <= 0 || height <= 0) {
width = caps->MaxOutputSize.cx;
height = caps->MaxOutputSize.cy;
}
}
if (width <= 0 || height <= 0) {
return false;
}
out_entry->format = format;
out_entry->width = static_cast<int>(width);
out_entry->height = static_cast<int>(height);
out_entry->fps.clear();
add_fps_candidate(&out_entry->fps, avg_time_per_frame);
if (caps) {
add_fps_candidate(&out_entry->fps, caps->MinFrameInterval);
add_fps_candidate(&out_entry->fps, caps->MaxFrameInterval);
}
normalize_fps(&out_entry->fps);
return true;
}
void append_stream_capabilities(IAMStreamConfig* stream_config, std::vector<DshowCapabilityEntry>* entries) {
if (!stream_config || !entries) {
return;
}
int cap_count = 0;
int cap_size = 0;
HRESULT hr = stream_config->GetNumberOfCapabilities(&cap_count, &cap_size);
if (FAILED(hr) || cap_count <= 0 || cap_size < static_cast<int>(sizeof(VIDEO_STREAM_CONFIG_CAPS))) {
return;
}
std::vector<BYTE> caps_buffer(static_cast<size_t>(cap_size));
for (int index = 0; index < cap_count; ++index) {
AM_MEDIA_TYPE* media_type = nullptr;
hr = stream_config->GetStreamCaps(index, &media_type, caps_buffer.data());
if (FAILED(hr) || !media_type) {
continue;
}
DshowCapabilityEntry entry;
const auto* caps = reinterpret_cast<const VIDEO_STREAM_CONFIG_CAPS*>(caps_buffer.data());
if (fill_capability_entry(media_type, caps, &entry)) {
entries->push_back(std::move(entry));
}
free_media_type(media_type);
}
}
bool find_device_filter(const std::string& device_name, IBaseFilter** out_filter) {
if (!out_filter) {
return false;
}
*out_filter = nullptr;
ICreateDevEnum* dev_enum = nullptr;
IEnumMoniker* enum_moniker = nullptr;
HRESULT hr = CoCreateInstance(
CLSID_SystemDeviceEnum,
nullptr,
CLSCTX_INPROC_SERVER,
IID_ICreateDevEnum,
reinterpret_cast<void**>(&dev_enum));
if (FAILED(hr) || !dev_enum) {
return false;
}
hr = dev_enum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enum_moniker, 0);
dev_enum->Release();
if (hr != S_OK || !enum_moniker) {
return false;
}
bool found = false;
IMoniker* moniker = nullptr;
ULONG fetched = 0;
while (!found && enum_moniker->Next(1, &moniker, &fetched) == S_OK) {
IPropertyBag* bag = nullptr;
hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast<void**>(&bag));
if (SUCCEEDED(hr) && bag) {
VARIANT name;
VariantInit(&name);
if (SUCCEEDED(bag->Read(L"FriendlyName", &name, nullptr)) && name.vt == VT_BSTR) {
auto utf8_name = wide_to_utf8(name.bstrVal);
if (utf8_name == device_name) {
hr = moniker->BindToObject(nullptr, nullptr, IID_IBaseFilter, reinterpret_cast<void**>(out_filter));
found = SUCCEEDED(hr) && *out_filter != nullptr;
}
}
VariantClear(&name);
bag->Release();
}
moniker->Release();
}
enum_moniker->Release();
return found;
}
std::string build_capabilities_payload(const std::vector<DshowCapabilityEntry>& entries) {
std::string payload;
for (size_t i = 0; i < entries.size(); ++i) {
const auto& entry = entries[i];
payload += entry.format;
payload.push_back('|');
payload += std::to_string(entry.width);
payload.push_back('|');
payload += std::to_string(entry.height);
payload.push_back('|');
for (size_t fps_index = 0; fps_index < entry.fps.size(); ++fps_index) {
payload += std::to_string(entry.fps[fps_index]);
if (fps_index + 1 < entry.fps.size()) {
payload.push_back(',');
}
}
if (i + 1 < entries.size()) {
payload.push_back('\n');
}
}
return payload;
}
char* copy_payload(const std::string& payload) {
char* out = reinterpret_cast<char*>(std::malloc(payload.size() + 1));
if (!out) {
set_last_error("Failed to allocate capture payload buffer");
return nullptr;
}
std::memcpy(out, payload.c_str(), payload.size() + 1);
return out;
}
int open_dshow_input_with_options(
AVFormatContext** format_ctx,
const AVInputFormat* input,
const std::string& device_name,
int width,
int height,
int fps,
int requested_format,
bool use_video_size,
bool use_framerate,
bool use_pixel_format,
std::string* attempt_desc) {
if (!format_ctx || !input) {
return AVERROR(EINVAL);
}
AVDictionary* options = nullptr;
std::vector<std::string> parts;
if (use_video_size && width > 0 && height > 0) {
std::string video_size = std::to_string(width) + "x" + std::to_string(height);
av_dict_set(&options, "video_size", video_size.c_str(), 0);
parts.push_back("video_size=" + video_size);
}
if (use_framerate && fps > 0) {
std::string framerate = std::to_string(fps);
av_dict_set(&options, "framerate", framerate.c_str(), 0);
parts.push_back("framerate=" + framerate);
}
av_dict_set(&options, "rtbufsize", "64M", 0);
parts.push_back("rtbufsize=64M");
const char* pixel_format_name = requested_pixel_format_name(requested_format);
if (use_pixel_format && pixel_format_name) {
av_dict_set(&options, "pixel_format", pixel_format_name, 0);
parts.push_back(std::string("pixel_format=") + pixel_format_name);
}
if (attempt_desc) {
*attempt_desc = parts.empty() ? "default options" : "options{";
if (!parts.empty()) {
for (size_t i = 0; i < parts.size(); ++i) {
if (i > 0) {
attempt_desc->append(", ");
}
attempt_desc->append(parts[i]);
}
attempt_desc->append("}");
}
}
std::string input_name = "video=" + device_name;
int ret = avformat_open_input(format_ctx, input_name.c_str(), input, &options);
av_dict_free(&options);
return ret;
}
class ScopedComInit {
public:
ScopedComInit() {
HRESULT hr = CoInitializeEx(nullptr, COINIT_MULTITHREADED);
initialized_ = hr == S_OK || hr == S_FALSE;
}
~ScopedComInit() {
if (initialized_) {
CoUninitialize();
}
}
private:
bool initialized_ = false;
};
int capture_stride(int pixel_format, int width) {
switch (pixel_format) {
case HWCODEC_CAPTURE_FMT_YUYV:
case HWCODEC_CAPTURE_FMT_YVYU:
case HWCODEC_CAPTURE_FMT_UYVY:
return width * 2;
case HWCODEC_CAPTURE_FMT_RGB24:
case HWCODEC_CAPTURE_FMT_BGR24:
return width * 3;
case HWCODEC_CAPTURE_FMT_NV24:
return width * 2;
case HWCODEC_CAPTURE_FMT_NV12:
case HWCODEC_CAPTURE_FMT_NV21:
case HWCODEC_CAPTURE_FMT_NV16:
case HWCODEC_CAPTURE_FMT_YUV420:
case HWCODEC_CAPTURE_FMT_YVU420:
case HWCODEC_CAPTURE_FMT_GREY:
case HWCODEC_CAPTURE_FMT_MJPEG:
case HWCODEC_CAPTURE_FMT_JPEG:
default:
return width;
}
}
int map_raw_pixfmt(int format) {
switch (format) {
case AV_PIX_FMT_YUYV422:
return HWCODEC_CAPTURE_FMT_YUYV;
case AV_PIX_FMT_UYVY422:
return HWCODEC_CAPTURE_FMT_UYVY;
#ifdef AV_PIX_FMT_YVYU422
case AV_PIX_FMT_YVYU422:
return HWCODEC_CAPTURE_FMT_YVYU;
#endif
case AV_PIX_FMT_NV12:
return HWCODEC_CAPTURE_FMT_NV12;
case AV_PIX_FMT_NV21:
return HWCODEC_CAPTURE_FMT_NV21;
#ifdef AV_PIX_FMT_NV16
case AV_PIX_FMT_NV16:
return HWCODEC_CAPTURE_FMT_NV16;
#endif
#ifdef AV_PIX_FMT_NV24
case AV_PIX_FMT_NV24:
return HWCODEC_CAPTURE_FMT_NV24;
#endif
case AV_PIX_FMT_YUV420P:
return HWCODEC_CAPTURE_FMT_YUV420;
#ifdef AV_PIX_FMT_YVU420P
case AV_PIX_FMT_YVU420P:
return HWCODEC_CAPTURE_FMT_YVU420;
#endif
case AV_PIX_FMT_RGB24:
return HWCODEC_CAPTURE_FMT_RGB24;
case AV_PIX_FMT_BGR24:
return HWCODEC_CAPTURE_FMT_BGR24;
case AV_PIX_FMT_GRAY8:
return HWCODEC_CAPTURE_FMT_GREY;
default:
return HWCODEC_CAPTURE_FMT_UNKNOWN;
}
}
int map_codec_to_capture_format(const AVCodecParameters* codecpar) {
if (!codecpar) {
return HWCODEC_CAPTURE_FMT_UNKNOWN;
}
switch (codecpar->codec_id) {
case AV_CODEC_ID_MJPEG:
return HWCODEC_CAPTURE_FMT_MJPEG;
case AV_CODEC_ID_JPEG2000:
return HWCODEC_CAPTURE_FMT_JPEG;
case AV_CODEC_ID_RAWVIDEO:
return map_raw_pixfmt(codecpar->format);
default:
return HWCODEC_CAPTURE_FMT_UNKNOWN;
}
}
int interrupt_callback(void* opaque) {
auto* ctx = reinterpret_cast<HwcodecDshowCaptureContext*>(opaque);
if (!ctx) {
return 0;
}
auto deadline = ctx->deadline_ms.load();
if (deadline <= 0) {
return 0;
}
if (now_ms() > deadline) {
ctx->timed_out.store(1);
return 1;
}
return 0;
}
const char* requested_pixel_format_name(int requested_format) {
switch (requested_format) {
case HWCODEC_CAPTURE_FMT_YUYV:
return "yuyv422";
case HWCODEC_CAPTURE_FMT_UYVY:
return "uyvy422";
case HWCODEC_CAPTURE_FMT_NV12:
return "nv12";
case HWCODEC_CAPTURE_FMT_NV21:
return "nv21";
case HWCODEC_CAPTURE_FMT_RGB24:
return "rgb24";
case HWCODEC_CAPTURE_FMT_BGR24:
return "bgr24";
case HWCODEC_CAPTURE_FMT_GREY:
return "gray";
default:
return nullptr;
}
}
} // namespace
extern "C" const char* hwcodec_capture_last_error(void) {
return g_last_error.c_str();
}
extern "C" char* hwcodec_dshow_list_video_devices(void) {
ScopedComInit com;
ICreateDevEnum* dev_enum = nullptr;
IEnumMoniker* enum_moniker = nullptr;
HRESULT hr = CoCreateInstance(
CLSID_SystemDeviceEnum,
nullptr,
CLSCTX_INPROC_SERVER,
IID_ICreateDevEnum,
reinterpret_cast<void**>(&dev_enum));
if (FAILED(hr)) {
set_last_error("Failed to create DirectShow device enumerator");
return nullptr;
}
hr = dev_enum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enum_moniker, 0);
dev_enum->Release();
if (hr != S_OK || !enum_moniker) {
char* out = reinterpret_cast<char*>(std::malloc(1));
if (out) {
out[0] = '\0';
}
return out;
}
std::vector<std::string> devices;
IMoniker* moniker = nullptr;
ULONG fetched = 0;
while (enum_moniker->Next(1, &moniker, &fetched) == S_OK) {
IPropertyBag* bag = nullptr;
hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast<void**>(&bag));
if (SUCCEEDED(hr) && bag) {
VARIANT name;
VariantInit(&name);
if (SUCCEEDED(bag->Read(L"FriendlyName", &name, nullptr)) && name.vt == VT_BSTR) {
auto utf8_name = wide_to_utf8(name.bstrVal);
if (!utf8_name.empty()) {
devices.push_back(utf8_name);
}
}
VariantClear(&name);
bag->Release();
}
moniker->Release();
}
enum_moniker->Release();
std::string payload;
for (size_t i = 0; i < devices.size(); ++i) {
payload += devices[i];
if (i + 1 < devices.size()) {
payload.push_back('\n');
}
}
return copy_payload(payload);
}
extern "C" char* hwcodec_dshow_list_device_capabilities(const char* device_name) {
if (!device_name || device_name[0] == '\0') {
set_last_error("DirectShow device name is empty");
return nullptr;
}
ScopedComInit com;
IBaseFilter* filter = nullptr;
if (!find_device_filter(device_name, &filter) || !filter) {
set_last_error("Failed to find DirectShow device filter");
return nullptr;
}
std::vector<DshowCapabilityEntry> entries;
IEnumPins* enum_pins = nullptr;
HRESULT hr = filter->EnumPins(&enum_pins);
if (SUCCEEDED(hr) && enum_pins) {
IPin* pin = nullptr;
ULONG fetched = 0;
while (enum_pins->Next(1, &pin, &fetched) == S_OK) {
PIN_DIRECTION direction = PINDIR_INPUT;
if (SUCCEEDED(pin->QueryDirection(&direction)) && direction == PINDIR_OUTPUT) {
IAMStreamConfig* stream_config = nullptr;
if (SUCCEEDED(pin->QueryInterface(IID_IAMStreamConfig, reinterpret_cast<void**>(&stream_config))) &&
stream_config) {
append_stream_capabilities(stream_config, &entries);
stream_config->Release();
}
}
pin->Release();
}
enum_pins->Release();
}
filter->Release();
std::sort(entries.begin(), entries.end(), [](const DshowCapabilityEntry& left, const DshowCapabilityEntry& right) {
if (left.format != right.format) {
return left.format < right.format;
}
if (left.width != right.width) {
return left.width < right.width;
}
if (left.height != right.height) {
return left.height < right.height;
}
return left.fps > right.fps;
});
entries.erase(
std::unique(entries.begin(), entries.end(), [](const DshowCapabilityEntry& left, const DshowCapabilityEntry& right) {
return left.format == right.format && left.width == right.width && left.height == right.height && left.fps == right.fps;
}),
entries.end());
return copy_payload(build_capabilities_payload(entries));
}
extern "C" void hwcodec_capture_string_free(char* ptr) {
if (ptr) {
std::free(ptr);
}
}
extern "C" HwcodecDshowCaptureContext* hwcodec_dshow_capture_open(
const char* device_name,
int width,
int height,
int fps,
int requested_format,
int timeout_ms) {
if (!device_name || device_name[0] == '\0') {
set_last_error("Device name is empty");
return nullptr;
}
avdevice_register_all();
const AVInputFormat* input = av_find_input_format("dshow");
if (!input) {
set_last_error("FFmpeg dshow input format is unavailable");
return nullptr;
}
auto* ctx = new HwcodecDshowCaptureContext();
ctx->timeout_ms = timeout_ms > 0 ? timeout_ms : 2000;
ctx->format_ctx = avformat_alloc_context();
if (!ctx->format_ctx) {
delete ctx;
set_last_error("Failed to allocate FFmpeg format context");
return nullptr;
}
ctx->format_ctx->interrupt_callback.callback = interrupt_callback;
ctx->format_ctx->interrupt_callback.opaque = ctx;
std::string open_attempt;
int ret = open_dshow_input_with_options(
&ctx->format_ctx,
input,
device_name,
width,
height,
fps,
requested_format,
true,
true,
true,
&open_attempt);
if (ret < 0) {
avformat_free_context(ctx->format_ctx);
ctx->format_ctx = avformat_alloc_context();
if (!ctx->format_ctx) {
delete ctx;
set_last_error("Failed to allocate FFmpeg format context for fallback open");
return nullptr;
}
ctx->format_ctx->interrupt_callback.callback = interrupt_callback;
ctx->format_ctx->interrupt_callback.opaque = ctx;
std::string fallback_attempt;
ret = open_dshow_input_with_options(
&ctx->format_ctx,
input,
device_name,
width,
height,
fps,
requested_format,
true,
false,
true,
&fallback_attempt);
if (ret >= 0) {
open_attempt = fallback_attempt;
}
}
if (ret < 0) {
avformat_free_context(ctx->format_ctx);
ctx->format_ctx = avformat_alloc_context();
if (!ctx->format_ctx) {
delete ctx;
set_last_error("Failed to allocate FFmpeg format context for final fallback open");
return nullptr;
}
ctx->format_ctx->interrupt_callback.callback = interrupt_callback;
ctx->format_ctx->interrupt_callback.opaque = ctx;
std::string fallback_attempt;
ret = open_dshow_input_with_options(
&ctx->format_ctx,
input,
device_name,
width,
height,
fps,
requested_format,
false,
false,
false,
&fallback_attempt);
if (ret >= 0) {
open_attempt = fallback_attempt;
}
}
if (ret < 0) {
set_last_error("Failed to open dshow input (" + open_attempt + "): " + ffmpeg_error(ret));
avformat_free_context(ctx->format_ctx);
delete ctx;
return nullptr;
}
ret = avformat_find_stream_info(ctx->format_ctx, nullptr);
if (ret < 0) {
set_last_error("Failed to read stream info: " + ffmpeg_error(ret));
avformat_close_input(&ctx->format_ctx);
delete ctx;
return nullptr;
}
for (unsigned int i = 0; i < ctx->format_ctx->nb_streams; ++i) {
AVStream* stream = ctx->format_ctx->streams[i];
if (stream && stream->codecpar && stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
ctx->stream_index = static_cast<int>(i);
ctx->width = stream->codecpar->width > 0 ? stream->codecpar->width : width;
ctx->height = stream->codecpar->height > 0 ? stream->codecpar->height : height;
ctx->pixel_format = map_codec_to_capture_format(stream->codecpar);
ctx->stride = capture_stride(ctx->pixel_format, ctx->width);
break;
}
}
if (ctx->stream_index < 0) {
set_last_error("No video stream found on DirectShow device");
avformat_close_input(&ctx->format_ctx);
delete ctx;
return nullptr;
}
if (ctx->pixel_format == HWCODEC_CAPTURE_FMT_UNKNOWN) {
set_last_error("DirectShow stream format is unsupported in current Windows backend");
avformat_close_input(&ctx->format_ctx);
delete ctx;
return nullptr;
}
return ctx;
}
extern "C" int hwcodec_dshow_capture_info(
HwcodecDshowCaptureContext* ctx,
HwcodecCaptureStreamInfo* out_info) {
if (!ctx || !out_info) {
set_last_error("Invalid capture context");
return -1;
}
out_info->width = ctx->width;
out_info->height = ctx->height;
out_info->pixel_format = ctx->pixel_format;
out_info->stride = ctx->stride;
return 0;
}
extern "C" int hwcodec_dshow_capture_read(
HwcodecDshowCaptureContext* ctx,
uint8_t** out_data,
int* out_len,
uint64_t* out_sequence) {
if (!ctx || !out_data || !out_len || !out_sequence) {
set_last_error("Invalid capture read arguments");
return -1;
}
*out_data = nullptr;
*out_len = 0;
*out_sequence = 0;
AVPacket packet;
av_init_packet(&packet);
packet.data = nullptr;
packet.size = 0;
while (true) {
ctx->timed_out.store(0);
ctx->deadline_ms.store(now_ms() + ctx->timeout_ms);
int ret = av_read_frame(ctx->format_ctx, &packet);
ctx->deadline_ms.store(0);
if (ret < 0) {
if (ctx->timed_out.load() != 0) {
set_last_error("Timed out waiting for frame");
return -110;
}
set_last_error("Failed to read frame: " + ffmpeg_error(ret));
return ret;
}
if (packet.stream_index != ctx->stream_index) {
av_packet_unref(&packet);
continue;
}
if (packet.size <= 0 || !packet.data) {
av_packet_unref(&packet);
continue;
}
auto* buffer = reinterpret_cast<uint8_t*>(std::malloc(static_cast<size_t>(packet.size)));
if (!buffer) {
av_packet_unref(&packet);
set_last_error("Failed to allocate packet buffer");
return -12;
}
std::memcpy(buffer, packet.data, static_cast<size_t>(packet.size));
*out_data = buffer;
*out_len = packet.size;
*out_sequence = ctx->sequence++;
av_packet_unref(&packet);
return 0;
}
}
extern "C" void hwcodec_dshow_capture_packet_free(uint8_t* data) {
if (data) {
std::free(data);
}
}
extern "C" void hwcodec_dshow_capture_close(HwcodecDshowCaptureContext* ctx) {
if (!ctx) {
return;
}
if (ctx->format_ctx) {
avformat_close_input(&ctx->format_ctx);
}
delete ctx;
}

View File

@@ -0,0 +1,64 @@
#ifndef HWCODEC_FFMPEG_CAPTURE_FFI_H
#define HWCODEC_FFMPEG_CAPTURE_FFI_H
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef struct HwcodecDshowCaptureContext HwcodecDshowCaptureContext;
enum HwcodecCapturePixelFormat {
HWCODEC_CAPTURE_FMT_UNKNOWN = 0,
HWCODEC_CAPTURE_FMT_MJPEG = 1,
HWCODEC_CAPTURE_FMT_JPEG = 2,
HWCODEC_CAPTURE_FMT_YUYV = 3,
HWCODEC_CAPTURE_FMT_YVYU = 4,
HWCODEC_CAPTURE_FMT_UYVY = 5,
HWCODEC_CAPTURE_FMT_NV12 = 6,
HWCODEC_CAPTURE_FMT_NV21 = 7,
HWCODEC_CAPTURE_FMT_NV16 = 8,
HWCODEC_CAPTURE_FMT_NV24 = 9,
HWCODEC_CAPTURE_FMT_YUV420 = 10,
HWCODEC_CAPTURE_FMT_YVU420 = 11,
HWCODEC_CAPTURE_FMT_RGB24 = 12,
HWCODEC_CAPTURE_FMT_BGR24 = 13,
HWCODEC_CAPTURE_FMT_GREY = 14,
};
typedef struct HwcodecCaptureStreamInfo {
int width;
int height;
int pixel_format;
int stride;
} HwcodecCaptureStreamInfo;
const char* hwcodec_capture_last_error(void);
char* hwcodec_dshow_list_video_devices(void);
char* hwcodec_dshow_list_device_capabilities(const char* device_name);
void hwcodec_capture_string_free(char* ptr);
HwcodecDshowCaptureContext* hwcodec_dshow_capture_open(
const char* device_name,
int width,
int height,
int fps,
int requested_format,
int timeout_ms);
int hwcodec_dshow_capture_info(
HwcodecDshowCaptureContext* ctx,
HwcodecCaptureStreamInfo* out_info);
int hwcodec_dshow_capture_read(
HwcodecDshowCaptureContext* ctx,
uint8_t** out_data,
int* out_len,
uint64_t* out_sequence);
void hwcodec_dshow_capture_packet_free(uint8_t* data);
void hwcodec_dshow_capture_close(HwcodecDshowCaptureContext* ctx);
#ifdef __cplusplus
}
#endif
#endif

297
libs/hwcodec/src/capture.rs Normal file
View File

@@ -0,0 +1,297 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use std::ffi::{CStr, CString};
use std::os::raw::c_int;
include!(concat!(env!("OUT_DIR"), "/ffmpeg_capture_ffi.rs"));
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CapturePixelFormat {
Unknown,
Mjpeg,
Jpeg,
Yuyv,
Yvyu,
Uyvy,
Nv12,
Nv21,
Nv16,
Nv24,
Yuv420,
Yvu420,
Rgb24,
Bgr24,
Grey,
}
impl CapturePixelFormat {
pub fn to_ffi(self) -> c_int {
match self {
Self::Unknown => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UNKNOWN as c_int,
Self::Mjpeg => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_MJPEG as c_int,
Self::Jpeg => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_JPEG as c_int,
Self::Yuyv => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUYV as c_int,
Self::Yvyu => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVYU as c_int,
Self::Uyvy => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UYVY as c_int,
Self::Nv12 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV12 as c_int,
Self::Nv21 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV21 as c_int,
Self::Nv16 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV16 as c_int,
Self::Nv24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV24 as c_int,
Self::Yuv420 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUV420 as c_int,
Self::Yvu420 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVU420 as c_int,
Self::Rgb24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_RGB24 as c_int,
Self::Bgr24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_BGR24 as c_int,
Self::Grey => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_GREY as c_int,
}
}
pub fn from_ffi(value: c_int) -> Self {
match value {
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_MJPEG as c_int => Self::Mjpeg,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_JPEG as c_int => Self::Jpeg,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUYV as c_int => Self::Yuyv,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVYU as c_int => Self::Yvyu,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UYVY as c_int => Self::Uyvy,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV12 as c_int => Self::Nv12,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV21 as c_int => Self::Nv21,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV16 as c_int => Self::Nv16,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV24 as c_int => Self::Nv24,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUV420 as c_int => {
Self::Yuv420
}
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVU420 as c_int => {
Self::Yvu420
}
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_RGB24 as c_int => Self::Rgb24,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_BGR24 as c_int => Self::Bgr24,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_GREY as c_int => Self::Grey,
_ => Self::Unknown,
}
}
pub fn from_name(name: &str) -> Option<Self> {
match name.trim().to_ascii_uppercase().as_str() {
"MJPEG" | "MJPG" => Some(Self::Mjpeg),
"JPEG" => Some(Self::Jpeg),
"YUYV" => Some(Self::Yuyv),
"YVYU" => Some(Self::Yvyu),
"UYVY" => Some(Self::Uyvy),
"NV12" => Some(Self::Nv12),
"NV21" => Some(Self::Nv21),
"NV16" => Some(Self::Nv16),
"NV24" => Some(Self::Nv24),
"YUV420" | "I420" | "IYUV" => Some(Self::Yuv420),
"YVU420" | "YV12" => Some(Self::Yvu420),
"RGB24" => Some(Self::Rgb24),
"BGR24" => Some(Self::Bgr24),
"GREY" | "GRAY" | "Y800" => Some(Self::Grey),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct DshowCapability {
pub format: CapturePixelFormat,
pub width: u32,
pub height: u32,
pub fps: Vec<u32>,
}
#[derive(Debug, Clone, Copy)]
pub struct CaptureStreamInfo {
pub width: i32,
pub height: i32,
pub pixel_format: CapturePixelFormat,
pub stride: i32,
}
#[derive(Debug)]
pub struct CaptureError {
pub code: i32,
pub message: String,
}
impl std::fmt::Display for CaptureError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for CaptureError {}
fn last_error_message() -> String {
unsafe {
let ptr = hwcodec_capture_last_error();
if ptr.is_null() {
return String::new();
}
CStr::from_ptr(ptr).to_string_lossy().to_string()
}
}
pub fn list_dshow_video_devices() -> Result<Vec<String>, CaptureError> {
unsafe {
let ptr = hwcodec_dshow_list_video_devices();
if ptr.is_null() {
return Err(CaptureError {
code: -1,
message: last_error_message(),
});
}
let payload = CStr::from_ptr(ptr).to_string_lossy().to_string();
hwcodec_capture_string_free(ptr as *mut _);
Ok(payload
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.map(ToOwned::to_owned)
.collect())
}
}
pub fn list_dshow_device_capabilities(device_name: &str) -> Result<Vec<DshowCapability>, CaptureError> {
let device_name = CString::new(device_name).map_err(|_| CaptureError {
code: -1,
message: "device name contains NUL byte".to_string(),
})?;
unsafe {
let ptr = hwcodec_dshow_list_device_capabilities(device_name.as_ptr());
if ptr.is_null() {
return Err(CaptureError {
code: -1,
message: last_error_message(),
});
}
let payload = CStr::from_ptr(ptr).to_string_lossy().to_string();
hwcodec_capture_string_free(ptr as *mut _);
let capabilities = payload
.lines()
.filter_map(parse_dshow_capability_line)
.collect();
Ok(capabilities)
}
}
fn parse_dshow_capability_line(line: &str) -> Option<DshowCapability> {
let mut parts = line.split('|');
let format = CapturePixelFormat::from_name(parts.next()?.trim())?;
let width = parts.next()?.trim().parse::<u32>().ok()?;
let height = parts.next()?.trim().parse::<u32>().ok()?;
let fps = parts
.next()
.unwrap_or_default()
.split(',')
.filter_map(|value| value.trim().parse::<u32>().ok())
.filter(|value| *value > 0)
.collect::<Vec<_>>();
Some(DshowCapability {
format,
width,
height,
fps,
})
}
pub struct DshowCapture {
ctx: *mut HwcodecDshowCaptureContext,
}
unsafe impl Send for DshowCapture {}
impl DshowCapture {
pub fn open(
device_name: &str,
width: i32,
height: i32,
fps: i32,
requested_format: CapturePixelFormat,
timeout_ms: i32,
) -> Result<Self, CaptureError> {
let device_name = CString::new(device_name).map_err(|_| CaptureError {
code: -1,
message: "device name contains NUL byte".to_string(),
})?;
unsafe {
let ctx = hwcodec_dshow_capture_open(
device_name.as_ptr(),
width,
height,
fps,
requested_format.to_ffi(),
timeout_ms,
);
if ctx.is_null() {
return Err(CaptureError {
code: -1,
message: last_error_message(),
});
}
Ok(Self { ctx })
}
}
pub fn info(&self) -> Result<CaptureStreamInfo, CaptureError> {
unsafe {
let mut info = HwcodecCaptureStreamInfo {
width: 0,
height: 0,
pixel_format: HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UNKNOWN as c_int,
stride: 0,
};
let ret = hwcodec_dshow_capture_info(self.ctx, &mut info);
if ret != 0 {
return Err(CaptureError {
code: ret,
message: last_error_message(),
});
}
Ok(CaptureStreamInfo {
width: info.width,
height: info.height,
pixel_format: CapturePixelFormat::from_ffi(info.pixel_format),
stride: info.stride,
})
}
}
pub fn read_packet(&mut self) -> Result<(Vec<u8>, u64), CaptureError> {
unsafe {
let mut data = std::ptr::null_mut();
let mut len = 0;
let mut sequence = 0u64;
let ret = hwcodec_dshow_capture_read(self.ctx, &mut data, &mut len, &mut sequence);
if ret != 0 {
return Err(CaptureError {
code: ret,
message: last_error_message(),
});
}
if data.is_null() || len <= 0 {
return Err(CaptureError {
code: -1,
message: "empty packet returned by capture backend".to_string(),
});
}
let slice = std::slice::from_raw_parts(data, len as usize);
let vec = slice.to_vec();
hwcodec_dshow_capture_packet_free(data);
Ok((vec, sequence))
}
}
}
impl Drop for DshowCapture {
fn drop(&mut self) {
unsafe {
hwcodec_dshow_capture_close(self.ctx);
}
self.ctx = std::ptr::null_mut();
}
}

View File

@@ -257,7 +257,13 @@ struct ProbePolicy {
impl ProbePolicy {
fn for_codec(codec_name: &str) -> Self {
if codec_name.contains("v4l2m2m") {
if codec_name.contains("amf") {
Self {
max_attempts: 5,
request_keyframe: true,
accept_any_output: true,
}
} else if codec_name.contains("v4l2m2m") {
Self {
max_attempts: 5,
request_keyframe: true,

View File

@@ -1,3 +1,5 @@
#[cfg(windows)]
pub mod capture;
pub mod common;
pub mod ffmpeg;
#[cfg(any(target_arch = "aarch64", target_arch = "arm", feature = "rkmpp"))]

View File

@@ -154,11 +154,13 @@ fn link_vcpkg(mut path: PathBuf) -> bool {
if use_static && static_lib.exists() {
// Static linking (for deb packaging)
println!("cargo:rustc-link-lib=static=yuv");
#[cfg(target_os = "linux")]
println!("cargo:rustc-link-lib=stdc++");
println!("cargo:info=Using libyuv from vcpkg (static linking)");
} else {
// Dynamic linking (default for development)
println!("cargo:rustc-link-lib=yuv");
#[cfg(target_os = "linux")]
println!("cargo:rustc-link-lib=stdc++");
println!("cargo:info=Using libyuv from vcpkg (dynamic linking)");
}

View File

@@ -11,20 +11,14 @@ use super::led::LedSensor;
use super::types::{AtxAction, AtxKeyConfig, AtxLedConfig, AtxState, PowerStatus};
use crate::error::{AppError, Result};
/// ATX power control configuration
#[derive(Debug, Clone, Default)]
pub struct AtxControllerConfig {
/// Whether ATX is enabled
pub enabled: bool,
/// Power button configuration (used for both short and long press)
pub power: AtxKeyConfig,
/// Reset button configuration
pub reset: AtxKeyConfig,
/// LED sensing configuration
pub led: AtxLedConfig,
}
/// Internal state holding all ATX components
/// Grouped together to reduce lock acquisitions
struct AtxInner {
config: AtxControllerConfig,
@@ -33,12 +27,9 @@ struct AtxInner {
led_sensor: Option<LedSensor>,
}
/// ATX Controller
///
/// Manages ATX power control through independent executors for each action.
/// Supports hot-reload of configuration.
pub struct AtxController {
/// Single lock for all internal state to reduce lock contention
inner: RwLock<AtxInner>,
}
@@ -53,6 +44,24 @@ impl AtxController {
&& power.baud_rate == reset.baud_rate
}
async fn init_key_executor(
warn_label: &str,
info_label: &str,
config: AtxKeyConfig,
mut executor: AtxKeyExecutor,
) -> Option<AtxKeyExecutor> {
if let Err(e) = executor.init().await {
warn!("Failed to initialize {} executor: {}", warn_label, e);
return None;
}
info!(
"{} executor initialized: {:?} on {} pin {}",
info_label, config.driver, config.device, config.pin
);
Some(executor)
}
async fn init_components(inner: &mut AtxInner) {
if Self::should_share_serial_device(&inner.config.power, &inner.config.reset) {
match AtxKeyExecutor::open_shared_serial(
@@ -60,36 +69,28 @@ impl AtxController {
inner.config.power.baud_rate,
) {
Ok(shared_serial) => {
let mut power_executor = AtxKeyExecutor::new_with_shared_serial(
inner.config.power.clone(),
shared_serial.clone(),
);
if let Err(e) = power_executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver,
inner.config.power.device,
inner.config.power.pin
for (slot, warn_label, info_label, config, serial) in [
(
&mut inner.power_executor,
"power",
"Power",
inner.config.power.clone(),
shared_serial.clone(),
),
(
&mut inner.reset_executor,
"reset",
"Reset",
inner.config.reset.clone(),
shared_serial,
),
] {
let executor = AtxKeyExecutor::new_with_shared_serial(
config.clone(),
serial,
);
inner.power_executor = Some(power_executor);
}
let mut reset_executor = AtxKeyExecutor::new_with_shared_serial(
inner.config.reset.clone(),
shared_serial,
);
if let Err(e) = reset_executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver,
inner.config.reset.device,
inner.config.reset.pin
);
inner.reset_executor = Some(reset_executor);
*slot = Self::init_key_executor(warn_label, info_label, config, executor)
.await;
}
}
Err(e) => {
@@ -100,40 +101,18 @@ impl AtxController {
}
}
} else {
// Initialize power executor
if inner.config.power.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.power.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver,
inner.config.power.device,
inner.config.power.pin
);
inner.power_executor = Some(executor);
}
}
// Initialize reset executor
if inner.config.reset.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.reset.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver,
inner.config.reset.device,
inner.config.reset.pin
);
inner.reset_executor = Some(executor);
for (slot, warn_label, info_label, config) in [
(&mut inner.power_executor, "power", "Power", inner.config.power.clone()),
(&mut inner.reset_executor, "reset", "Reset", inner.config.reset.clone()),
] {
if config.is_configured() {
let executor = AtxKeyExecutor::new(config.clone());
*slot = Self::init_key_executor(warn_label, info_label, config, executor)
.await;
}
}
}
// Initialize LED sensor
if inner.config.led.is_configured() {
let mut sensor = LedSensor::new(inner.config.led.clone());
if let Err(e) = sensor.init().await {
@@ -149,19 +128,17 @@ impl AtxController {
}
async fn shutdown_components(inner: &mut AtxInner) {
if let Some(executor) = inner.power_executor.as_mut() {
if let Err(e) = executor.shutdown().await {
warn!("Failed to shutdown power executor: {}", e);
for (slot, label) in [
(&mut inner.power_executor, "power"),
(&mut inner.reset_executor, "reset"),
] {
if let Some(executor) = slot.as_mut() {
if let Err(e) = executor.shutdown().await {
warn!("Failed to shutdown {} executor: {}", label, e);
}
}
*slot = None;
}
inner.power_executor = None;
if let Some(executor) = inner.reset_executor.as_mut() {
if let Err(e) = executor.shutdown().await {
warn!("Failed to shutdown reset executor: {}", e);
}
}
inner.reset_executor = None;
if let Some(sensor) = inner.led_sensor.as_mut() {
if let Err(e) = sensor.shutdown().await {
@@ -171,7 +148,20 @@ impl AtxController {
inner.led_sensor = None;
}
/// Create a new ATX controller with the specified configuration
async fn read_power_status(sensor: Option<&LedSensor>) -> PowerStatus {
let Some(sensor) = sensor else {
return PowerStatus::Unknown;
};
match sensor.read().await {
Ok(status) => status,
Err(e) => {
debug!("Failed to read ATX LED sensor: {}", e);
PowerStatus::Unknown
}
}
}
pub fn new(config: AtxControllerConfig) -> Self {
Self {
inner: RwLock::new(AtxInner {
@@ -183,12 +173,10 @@ impl AtxController {
}
}
/// Create a disabled ATX controller
pub fn disabled() -> Self {
Self::new(AtxControllerConfig::default())
}
/// Initialize the ATX controller and its executors
pub async fn init(&self) -> Result<()> {
let mut inner = self.inner.write().await;
@@ -204,7 +192,6 @@ impl AtxController {
Ok(())
}
/// Reload ATX controller configuration
pub async fn reload(&self, config: AtxControllerConfig) -> Result<()> {
let mut inner = self.inner.write().await;
@@ -225,7 +212,6 @@ impl AtxController {
Ok(())
}
/// Shutdown ATX controller and release all resources
pub async fn shutdown(&self) -> Result<()> {
let mut inner = self.inner.write().await;
Self::shutdown_components(&mut inner).await;
@@ -233,86 +219,48 @@ impl AtxController {
Ok(())
}
/// Trigger a power action (short/long/reset)
pub async fn trigger_power_action(&self, action: AtxAction) -> Result<()> {
let inner = self.inner.read().await;
match action {
AtxAction::Short | AtxAction::Long => {
if let Some(executor) = &inner.power_executor {
let duration = match action {
AtxAction::Short => timing::SHORT_PRESS,
AtxAction::Long => timing::LONG_PRESS,
_ => unreachable!(),
};
executor.pulse(duration).await?;
} else {
return Err(AppError::Config(
"Power button not configured for ATX controller".to_string(),
));
}
}
AtxAction::Reset => {
if let Some(executor) = &inner.reset_executor {
executor.pulse(timing::RESET_PRESS).await?;
} else {
return Err(AppError::Config(
"Reset button not configured for ATX controller".to_string(),
));
}
}
}
let (executor, duration) = match action {
AtxAction::Short => (inner.power_executor.as_ref(), timing::SHORT_PRESS),
AtxAction::Long => (inner.power_executor.as_ref(), timing::LONG_PRESS),
AtxAction::Reset => (inner.reset_executor.as_ref(), timing::RESET_PRESS),
};
let Some(executor) = executor else {
return Err(AppError::Config(match action {
AtxAction::Reset => "Reset button not configured for ATX controller",
_ => "Power button not configured for ATX controller",
}
.to_string()));
};
executor.pulse(duration).await?;
Ok(())
}
/// Trigger a short power button press
pub async fn power_short(&self) -> Result<()> {
self.trigger_power_action(AtxAction::Short).await
}
/// Trigger a long power button press
pub async fn power_long(&self) -> Result<()> {
self.trigger_power_action(AtxAction::Long).await
}
/// Trigger a reset button press
pub async fn reset(&self) -> Result<()> {
self.trigger_power_action(AtxAction::Reset).await
}
/// Get the current power status using the LED sensor (if configured)
pub async fn power_status(&self) -> PowerStatus {
let inner = self.inner.read().await;
if let Some(sensor) = &inner.led_sensor {
match sensor.read().await {
Ok(status) => status,
Err(e) => {
debug!("Failed to read ATX LED sensor: {}", e);
PowerStatus::Unknown
}
}
} else {
PowerStatus::Unknown
}
Self::read_power_status(inner.led_sensor.as_ref()).await
}
/// Get a snapshot of the ATX state for API responses
pub async fn state(&self) -> AtxState {
let inner = self.inner.read().await;
let power_status = if let Some(sensor) = &inner.led_sensor {
match sensor.read().await {
Ok(status) => status,
Err(e) => {
debug!("Failed to read ATX LED sensor: {}", e);
PowerStatus::Unknown
}
}
} else {
PowerStatus::Unknown
};
let power_status = Self::read_power_status(inner.led_sensor.as_ref()).await;
AtxState {
available: inner.config.enabled,

34
src/atx/disabled_key.rs Normal file
View File

@@ -0,0 +1,34 @@
use async_trait::async_trait;
use std::time::Duration;
use super::traits::AtxKeyBackend;
use crate::error::{AppError, Result};
pub struct DisabledAtxKeyBackend {
reason: &'static str,
}
impl DisabledAtxKeyBackend {
pub fn new(reason: &'static str) -> Self {
Self { reason }
}
}
#[async_trait]
impl AtxKeyBackend for DisabledAtxKeyBackend {
async fn init(&mut self) -> Result<()> {
Err(AppError::Internal(self.reason.to_string()))
}
async fn pulse(&self, _duration: Duration) -> Result<()> {
Err(AppError::Internal(self.reason.to_string()))
}
async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
fn is_initialized(&self) -> bool {
false
}
}

34
src/atx/disabled_led.rs Normal file
View File

@@ -0,0 +1,34 @@
#![allow(dead_code)]
use super::types::{AtxLedConfig, PowerStatus};
use crate::error::Result;
pub struct LedSensor {
config: AtxLedConfig,
}
impl LedSensor {
pub fn new(config: AtxLedConfig) -> Self {
Self { config }
}
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
pub fn is_initialized(&self) -> bool {
false
}
pub async fn init(&mut self) -> Result<()> {
Ok(())
}
pub async fn read(&self) -> Result<PowerStatus> {
Ok(PowerStatus::Unknown)
}
pub async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
}

View File

@@ -1,497 +1,150 @@
//! ATX Key Executor
//!
//! Lightweight executor for a single ATX key operation.
//! Each executor handles one button (power or reset) with its own hardware binding.
//! ATX key executor backend selector.
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use serialport::SerialPort;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::os::fd::AsRawFd;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use tracing::debug;
use super::types::{ActiveLevel, AtxDriverType, AtxKeyConfig};
use super::serial_relay::SerialRelayBackend;
use super::traits::{AtxKeyBackend, AtxKeyBackendContext, SharedSerialHandle};
use super::types::{AtxDriverType, AtxKeyConfig};
use crate::error::{AppError, Result};
pub type SharedSerialHandle = Arc<Mutex<Box<dyn SerialPort>>>;
const USB_RELAY_MAX_CHANNEL: u8 = 8;
const USB_RELAY_REPORT_LEN: usize = 9;
const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806; // _IOC(_IOC_READ|_IOC_WRITE, 'H', 0x06, 9)
/// Timing constants for ATX operations
pub mod timing {
use std::time::Duration;
/// Short press duration (power on/graceful shutdown)
pub const SHORT_PRESS: Duration = Duration::from_millis(500);
/// Long press duration (force power off)
pub const LONG_PRESS: Duration = Duration::from_millis(5000);
/// Reset press duration
pub const RESET_PRESS: Duration = Duration::from_millis(500);
}
/// Executor for a single ATX key operation
///
/// Each executor manages one hardware button (power or reset).
/// It handles both GPIO and USB relay backends.
pub struct AtxKeyExecutor {
config: AtxKeyConfig,
gpio_handle: Mutex<Option<LineHandle>>,
/// Cached USB relay file handle to avoid repeated open/close syscalls
usb_relay_handle: Mutex<Option<File>>,
/// Cached Serial port handle (can be shared across power/reset executors)
serial_handle: Mutex<Option<SharedSerialHandle>>,
initialized: AtomicBool,
backend: Option<Box<dyn AtxKeyBackend>>,
}
impl AtxKeyExecutor {
/// Create a new executor with the given configuration
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
gpio_handle: Mutex::new(None),
usb_relay_handle: Mutex::new(None),
serial_handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
Self::with_context(config, AtxKeyBackendContext::Standalone)
}
/// Create a new executor with a pre-opened shared serial handle.
pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self {
Self {
config,
gpio_handle: Mutex::new(None),
usb_relay_handle: Mutex::new(None),
serial_handle: Mutex::new(Some(serial_handle)),
initialized: AtomicBool::new(false),
}
Self::with_context(config, AtxKeyBackendContext::SharedSerial(serial_handle))
}
/// Open a serial relay device and wrap it for shared use.
pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result<SharedSerialHandle> {
let port = serialport::new(device, baud_rate)
.timeout(Duration::from_millis(100))
.open()
.map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?;
Ok(Arc::new(Mutex::new(port)))
SerialRelayBackend::open_shared_serial(device, baud_rate)
}
fn with_context(config: AtxKeyConfig, context: AtxKeyBackendContext) -> Self {
let backend = build_backend(&config, context);
Self { config, backend }
}
/// Check if this executor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
/// Check if this executor is initialized
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
/// Initialize the executor
pub async fn init(&mut self) -> Result<()> {
if !self.config.is_configured() {
debug!("ATX key executor not configured, skipping init");
return Ok(());
}
self.validate_runtime_config()?;
match self.config.driver {
AtxDriverType::Gpio => self.init_gpio().await?,
AtxDriverType::UsbRelay => self.init_usb_relay().await?,
AtxDriverType::Serial => self.init_serial().await?,
AtxDriverType::None => {}
}
self.initialized.store(true, Ordering::Relaxed);
Ok(())
}
fn validate_runtime_config(&self) -> Result<()> {
match self.config.driver {
AtxDriverType::Serial => {
if self.config.pin == 0 {
return Err(AppError::Config(
"Serial ATX channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > u8::MAX as u32 {
return Err(AppError::Config(format!(
"Serial ATX channel must be <= {}",
u8::MAX
)));
}
if self.config.baud_rate == 0 {
return Err(AppError::Config(
"Serial ATX baud_rate must be greater than 0".to_string(),
));
}
}
AtxDriverType::UsbRelay => {
if self.config.pin == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > u8::MAX as u32 {
return Err(AppError::Config(format!(
"USB relay channel must be <= {}",
u8::MAX
)));
}
if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
}
AtxDriverType::Gpio | AtxDriverType::None => {}
}
Ok(())
}
/// Initialize GPIO backend
async fn init_gpio(&mut self) -> Result<()> {
info!(
"Initializing GPIO ATX executor on {} pin {}",
self.config.device, self.config.pin
);
let mut chip = Chip::new(&self.config.device)
.map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?;
let line = chip.get_line(self.config.pin).map_err(|e| {
AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e))
let backend = self.backend.as_mut().ok_or_else(|| {
AppError::Internal(format!(
"ATX backend {:?} is unsupported on this platform",
self.config.driver
))
})?;
// Initial value depends on active level (start in inactive state)
let initial_value = match self.config.active_level {
ActiveLevel::High => 0, // Inactive = low
ActiveLevel::Low => 1, // Inactive = high
};
let handle = line
.request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx")
.map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?;
*self.gpio_handle.lock().unwrap() = Some(handle);
debug!("GPIO pin {} configured successfully", self.config.pin);
Ok(())
backend.init().await
}
/// Initialize USB relay backend
async fn init_usb_relay(&self) -> Result<()> {
info!(
"Initializing USB relay ATX executor on {} channel {}",
self.config.device, self.config.pin
);
// Open and cache the device handle
let device = OpenOptions::new()
.read(true)
.write(true)
.open(&self.config.device)
.map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?;
*self.usb_relay_handle.lock().unwrap() = Some(device);
// Ensure relay is off initially
self.send_usb_relay_command(false)?;
debug!(
"USB relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
/// Initialize Serial relay backend
async fn init_serial(&self) -> Result<()> {
info!(
"Initializing Serial relay ATX executor on {} channel {}",
self.config.device, self.config.pin
);
let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned();
if existing_handle.is_none() {
let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?;
*self.serial_handle.lock().unwrap() = Some(shared);
}
// Ensure relay is off initially
self.send_serial_relay_command(false)?;
debug!(
"Serial relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
/// Pulse the button for the specified duration
pub async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_configured() {
return Err(AppError::Internal("ATX key not configured".to_string()));
}
if !self.is_initialized() {
let backend = self.backend.as_ref().ok_or_else(|| {
AppError::Internal(format!(
"ATX backend {:?} is unsupported on this platform",
self.config.driver
))
})?;
if !backend.is_initialized() {
return Err(AppError::Internal("ATX key not initialized".to_string()));
}
match self.config.driver {
AtxDriverType::Gpio => self.pulse_gpio(duration).await,
AtxDriverType::UsbRelay => self.pulse_usb_relay(duration).await,
AtxDriverType::Serial => self.pulse_serial(duration).await,
AtxDriverType::None => Ok(()),
}
backend.pulse(duration).await
}
/// Pulse GPIO pin
async fn pulse_gpio(&self, duration: Duration) -> Result<()> {
let (active, inactive) = match self.config.active_level {
ActiveLevel::High => (1u8, 0u8),
ActiveLevel::Low => (0u8, 1u8),
};
// Set to active state
{
let guard = self.gpio_handle.lock().unwrap();
let handle = guard
.as_ref()
.ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?;
handle
.set_value(active)
.map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?;
}
// Wait for duration (no lock held)
sleep(duration).await;
// Set to inactive state
{
let guard = self.gpio_handle.lock().unwrap();
if let Some(handle) = guard.as_ref() {
handle.set_value(inactive).ok();
}
}
Ok(())
}
/// Pulse USB relay
async fn pulse_usb_relay(&self, duration: Duration) -> Result<()> {
// Turn relay on
self.send_usb_relay_command(true)?;
// Wait for duration
sleep(duration).await;
// Turn relay off
self.send_usb_relay_command(false)?;
Ok(())
}
/// Send USB relay command using cached handle
fn send_usb_relay_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"USB relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if channel > USB_RELAY_MAX_CHANNEL {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
let cmd = Self::build_usb_relay_command(channel, on);
let mut guard = self.usb_relay_handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
if let Err(feature_err) = Self::send_usb_relay_feature_report(device, &cmd) {
debug!(
"USB relay feature report failed ({}), falling back to hidraw write",
feature_err
);
device.write_all(&cmd).map_err(|write_err| {
AppError::Internal(format!(
"USB relay feature report failed: {}; raw write failed: {}",
feature_err, write_err
))
})?;
device
.flush()
.map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?;
}
Ok(())
}
fn build_usb_relay_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] {
let mut cmd = [0x00; USB_RELAY_REPORT_LEN];
cmd[1] = if on { 0xFF } else { 0xFD };
cmd[2] = channel;
cmd
}
fn send_usb_relay_feature_report(
device: &File,
report: &[u8; USB_RELAY_REPORT_LEN],
) -> std::io::Result<()> {
// Linux hidraw feature reports include the report ID as the first byte.
let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) };
if rc < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
/// Pulse Serial relay
async fn pulse_serial(&self, duration: Duration) -> Result<()> {
info!(
"Pulse serial relay on {} pin {}",
self.config.device, self.config.pin
);
// Turn relay on
self.send_serial_relay_command(true)?;
// Wait for duration
sleep(duration).await;
// Turn relay off
self.send_serial_relay_command(false)?;
Ok(())
}
/// Send Serial relay command using cached handle
fn send_serial_relay_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"Serial relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"Serial relay channel must be 1-based (>= 1)".to_string(),
));
}
// LCUS-Type Protocol
// Frame: [StopByte(A0), Channel, State, Checksum]
// Checksum = A0 + channel + state
let state = if on { 1 } else { 0 };
let checksum = 0xA0u8.wrapping_add(channel).wrapping_add(state);
// Example for Channel 1:
// ON: A0 01 01 A2
// OFF: A0 01 00 A1
let cmd = [0xA0, channel, state, checksum];
let serial_handle = self
.serial_handle
.lock()
.unwrap()
.as_ref()
.cloned()
.ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?;
let mut port = serial_handle.lock().unwrap();
port.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?;
port.flush()
.map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?;
Ok(())
}
/// Shutdown the executor
pub async fn shutdown(&mut self) -> Result<()> {
if !self.is_initialized() {
return Ok(());
if let Some(backend) = self.backend.as_mut() {
backend.shutdown().await?;
}
match self.config.driver {
AtxDriverType::Gpio => {
// Release GPIO handle
*self.gpio_handle.lock().unwrap() = None;
}
AtxDriverType::UsbRelay => {
// Ensure relay is off before closing handle
let _ = self.send_usb_relay_command(false);
// Release USB relay handle
*self.usb_relay_handle.lock().unwrap() = None;
}
AtxDriverType::Serial => {
// Ensure relay is off before closing handle
let _ = self.send_serial_relay_command(false);
// Release Serial relay handle
*self.serial_handle.lock().unwrap() = None;
}
AtxDriverType::None => {}
}
self.initialized.store(false, Ordering::Relaxed);
debug!("ATX key executor shutdown complete");
Ok(())
}
}
impl Drop for AtxKeyExecutor {
fn drop(&mut self) {
// Ensure GPIO lines are released
*self.gpio_handle.lock().unwrap() = None;
// Ensure USB relay is off and handle released
if self.config.driver == AtxDriverType::UsbRelay && self.is_initialized() {
let _ = self.send_usb_relay_command(false);
}
*self.usb_relay_handle.lock().unwrap() = None;
// Ensure Serial relay is off and handle released
if self.config.driver == AtxDriverType::Serial && self.is_initialized() {
let _ = self.send_serial_relay_command(false);
}
*self.serial_handle.lock().unwrap() = None;
fn build_backend(
config: &AtxKeyConfig,
context: AtxKeyBackendContext,
) -> Option<Box<dyn AtxKeyBackend>> {
match config.driver {
AtxDriverType::Serial => Some(match context {
AtxKeyBackendContext::Standalone => Box::new(SerialRelayBackend::new(config.clone())),
AtxKeyBackendContext::SharedSerial(handle) => Box::new(
SerialRelayBackend::new_with_shared_serial(config.clone(), handle),
),
}),
AtxDriverType::Gpio => build_gpio_backend(config),
AtxDriverType::UsbRelay => build_hidraw_backend(config),
AtxDriverType::None => None,
}
}
#[cfg(unix)]
fn build_gpio_backend(config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::gpio_linux::GpioLinuxBackend::new(
config.clone(),
)))
}
#[cfg(not(unix))]
fn build_gpio_backend(_config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::disabled_key::DisabledAtxKeyBackend::new(
"GPIO ATX backend is only available on Linux",
)))
}
#[cfg(unix)]
fn build_hidraw_backend(config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::hidraw_linux::HidrawLinuxRelayBackend::new(
config.clone(),
)))
}
#[cfg(not(unix))]
fn build_hidraw_backend(_config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::disabled_key::DisabledAtxKeyBackend::new(
"USB hidraw relay backend is only available on Linux",
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::atx::ActiveLevel;
#[test]
fn test_executor_creation() {
fn executor_creation() {
let config = AtxKeyConfig::default();
let executor = AtxKeyExecutor::new(config);
assert!(!executor.is_configured());
assert!(!executor.is_initialized());
}
#[test]
fn test_executor_with_gpio_config() {
fn executor_with_gpio_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::Gpio,
device: "/dev/gpiochip0".to_string(),
@@ -501,16 +154,15 @@ mod tests {
};
let executor = AtxKeyExecutor::new(config);
assert!(executor.is_configured());
assert!(!executor.is_initialized());
}
#[test]
fn test_executor_with_usb_relay_config() {
fn executor_with_usb_relay_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 1,
active_level: ActiveLevel::High, // Ignored for USB relay
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let executor = AtxKeyExecutor::new(config);
@@ -518,12 +170,12 @@ mod tests {
}
#[test]
fn test_executor_with_serial_config() {
fn executor_with_serial_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 1,
active_level: ActiveLevel::High, // Ignored
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let executor = AtxKeyExecutor::new(config);
@@ -531,91 +183,9 @@ mod tests {
}
#[test]
fn test_timing_constants() {
fn timing_constants() {
assert_eq!(timing::SHORT_PRESS.as_millis(), 500);
assert_eq!(timing::LONG_PRESS.as_millis(), 5000);
assert_eq!(timing::RESET_PRESS.as_millis(), 500);
}
#[test]
fn test_usb_relay_command_format() {
assert_eq!(
AtxKeyExecutor::build_usb_relay_command(1, true),
[0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
AtxKeyExecutor::build_usb_relay_command(1, false),
[0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_zero() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 0,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_usb_relay_channel_zero() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_usb_relay_channel_overflow() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: USB_RELAY_MAX_CHANNEL as u32 + 1,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_overflow() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 256,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_zero_serial_baud_rate() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 1,
active_level: ActiveLevel::High,
baud_rate: 0,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
}

106
src/atx/gpio_linux.rs Normal file
View File

@@ -0,0 +1,106 @@
use async_trait::async_trait;
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::traits::AtxKeyBackend;
use super::types::{ActiveLevel, AtxKeyConfig};
use crate::error::{AppError, Result};
pub struct GpioLinuxBackend {
config: AtxKeyConfig,
handle: Mutex<Option<LineHandle>>,
initialized: AtomicBool,
}
impl GpioLinuxBackend {
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
}
#[async_trait]
impl AtxKeyBackend for GpioLinuxBackend {
async fn init(&mut self) -> Result<()> {
info!(
"Initializing GPIO ATX backend on {} pin {}",
self.config.device, self.config.pin
);
let mut chip = Chip::new(&self.config.device)
.map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?;
let line = chip.get_line(self.config.pin).map_err(|e| {
AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e))
})?;
let initial_value = match self.config.active_level {
ActiveLevel::High => 0,
ActiveLevel::Low => 1,
};
let handle = line
.request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx")
.map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?;
*self.handle.lock().unwrap() = Some(handle);
self.initialized.store(true, Ordering::Relaxed);
debug!("GPIO pin {} configured successfully", self.config.pin);
Ok(())
}
async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_initialized() {
return Err(AppError::Internal("GPIO not initialized".to_string()));
}
let (active, inactive) = match self.config.active_level {
ActiveLevel::High => (1u8, 0u8),
ActiveLevel::Low => (0u8, 1u8),
};
{
let guard = self.handle.lock().unwrap();
let handle = guard
.as_ref()
.ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?;
handle
.set_value(active)
.map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?;
}
sleep(duration).await;
{
let guard = self.handle.lock().unwrap();
if let Some(handle) = guard.as_ref() {
handle.set_value(inactive).ok();
}
}
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
}
impl Drop for GpioLinuxBackend {
fn drop(&mut self) {
*self.handle.lock().unwrap() = None;
}
}

190
src/atx/hidraw_linux.rs Normal file
View File

@@ -0,0 +1,190 @@
use async_trait::async_trait;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::os::fd::AsRawFd;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::traits::AtxKeyBackend;
use super::types::AtxKeyConfig;
use crate::error::{AppError, Result};
const USB_RELAY_MAX_CHANNEL: u8 = 8;
const USB_RELAY_REPORT_LEN: usize = 9;
const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806;
pub struct HidrawLinuxRelayBackend {
config: AtxKeyConfig,
handle: Mutex<Option<File>>,
initialized: AtomicBool,
}
impl HidrawLinuxRelayBackend {
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
fn validate_config(&self) -> Result<()> {
if self.config.pin == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
Ok(())
}
fn send_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"USB relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if channel > USB_RELAY_MAX_CHANNEL {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
let cmd = Self::build_command(channel, on);
let mut guard = self.handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
if let Err(feature_err) = Self::send_feature_report(device, &cmd) {
debug!(
"USB relay feature report failed ({}), falling back to hidraw write",
feature_err
);
device.write_all(&cmd).map_err(|write_err| {
AppError::Internal(format!(
"USB relay feature report failed: {}; raw write failed: {}",
feature_err, write_err
))
})?;
device
.flush()
.map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?;
}
Ok(())
}
pub fn build_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] {
let mut cmd = [0x00; USB_RELAY_REPORT_LEN];
cmd[1] = if on { 0xFF } else { 0xFD };
cmd[2] = channel;
cmd
}
fn send_feature_report(
device: &File,
report: &[u8; USB_RELAY_REPORT_LEN],
) -> std::io::Result<()> {
let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) };
if rc < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
}
#[async_trait]
impl AtxKeyBackend for HidrawLinuxRelayBackend {
async fn init(&mut self) -> Result<()> {
self.validate_config()?;
info!(
"Initializing USB relay ATX backend on {} channel {}",
self.config.device, self.config.pin
);
let device = OpenOptions::new()
.read(true)
.write(true)
.open(&self.config.device)
.map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?;
*self.handle.lock().unwrap() = Some(device);
self.send_command(false)?;
self.initialized.store(true, Ordering::Relaxed);
debug!(
"USB relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_initialized() {
return Err(AppError::Internal("USB relay not initialized".to_string()));
}
self.send_command(true)?;
sleep(duration).await;
self.send_command(false)?;
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
if self.is_initialized() {
let _ = self.send_command(false);
}
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
}
impl Drop for HidrawLinuxRelayBackend {
fn drop(&mut self) {
if self.is_initialized() {
let _ = self.send_command(false);
}
*self.handle.lock().unwrap() = None;
}
}
#[cfg(test)]
mod tests {
use super::HidrawLinuxRelayBackend;
#[test]
fn usb_relay_command_format() {
assert_eq!(
HidrawLinuxRelayBackend::build_command(1, true),
[0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
HidrawLinuxRelayBackend::build_command(1, false),
[0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
}
}

View File

@@ -10,9 +10,6 @@ use tracing::{debug, info};
use super::types::{AtxLedConfig, PowerStatus};
use crate::error::{AppError, Result};
/// LED sensor for reading power status
///
/// Uses GPIO to read the power LED state and determine if the system is on or off.
pub struct LedSensor {
config: AtxLedConfig,
handle: Mutex<Option<LineHandle>>,
@@ -20,7 +17,6 @@ pub struct LedSensor {
}
impl LedSensor {
/// Create a new LED sensor with the given configuration
pub fn new(config: AtxLedConfig) -> Self {
Self {
config,
@@ -29,17 +25,6 @@ impl LedSensor {
}
}
/// Check if the sensor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
/// Check if the sensor is initialized
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
/// Initialize the LED sensor
pub async fn init(&mut self) -> Result<()> {
if !self.config.is_configured() {
debug!("LED sensor not configured, skipping init");
@@ -72,9 +57,8 @@ impl LedSensor {
Ok(())
}
/// Read the current power status
pub async fn read(&self) -> Result<PowerStatus> {
if !self.is_configured() || !self.is_initialized() {
if !self.config.is_configured() || !self.initialized.load(Ordering::Relaxed) {
return Ok(PowerStatus::Unknown);
}
@@ -85,11 +69,10 @@ impl LedSensor {
.get_value()
.map_err(|e| AppError::Internal(format!("LED read failed: {}", e)))?;
// Apply inversion if configured
let is_on = if self.config.inverted {
value == 0 // Active low: 0 means on
value == 0
} else {
value == 1 // Active high: 1 means on
value == 1
};
Ok(if is_on {
@@ -102,7 +85,6 @@ impl LedSensor {
}
}
/// Shutdown the LED sensor
pub async fn shutdown(&mut self) -> Result<()> {
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
@@ -125,8 +107,8 @@ mod tests {
fn test_led_sensor_creation() {
let config = AtxLedConfig::default();
let sensor = LedSensor::new(config);
assert!(!sensor.is_configured());
assert!(!sensor.is_initialized());
assert!(!sensor.config.is_configured());
assert!(!sensor.initialized.load(Ordering::Relaxed));
}
#[test]
@@ -138,8 +120,8 @@ mod tests {
inverted: false,
};
let sensor = LedSensor::new(config);
assert!(sensor.is_configured());
assert!(!sensor.is_initialized());
assert!(sensor.config.is_configured());
assert!(!sensor.initialized.load(Ordering::Relaxed));
}
#[test]
@@ -151,7 +133,6 @@ mod tests {
inverted: true,
};
let sensor = LedSensor::new(config);
assert!(sensor.is_configured());
assert!(sensor.config.inverted);
}
}

View File

@@ -2,53 +2,22 @@
//!
//! Provides ATX power management functionality for IP-KVM.
//! Supports flexible hardware binding with independent configuration for each action.
//!
//! # Features
//!
//! - Power button control (short press for on/graceful shutdown, long press for force off)
//! - Reset button control
//! - Power status monitoring via LED sensing (GPIO only)
//! - Independent hardware binding for each action (GPIO or USB relay)
//! - Hot-reload configuration support
//!
//! # Hardware Support
//!
//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control
//! - **USB Relay**: Uses HID USB relay modules for isolated switching
//! - **Serial Relay**: Uses LCUS-style serial relay modules
//!
//! # Example
//!
//! ```ignore
//! use one_kvm::atx::{AtxController, AtxControllerConfig, AtxKeyConfig, AtxDriverType, ActiveLevel};
//!
//! let config = AtxControllerConfig {
//! enabled: true,
//! power: AtxKeyConfig {
//! driver: AtxDriverType::Gpio,
//! device: "/dev/gpiochip0".to_string(),
//! pin: 5,
//! active_level: ActiveLevel::High,
//! baud_rate: 9600,
//! },
//! reset: AtxKeyConfig {
//! driver: AtxDriverType::UsbRelay,
//! device: "/dev/hidraw0".to_string(),
//! pin: 0,
//! active_level: ActiveLevel::High,
//! baud_rate: 9600,
//! },
//! led: Default::default(),
//! };
//!
//! let controller = AtxController::new(config);
//! controller.init().await?;
//! controller.power_short().await?; // Turn on or graceful shutdown
//! ```
mod controller;
#[cfg(not(unix))]
mod disabled_key;
mod executor;
#[cfg(unix)]
mod gpio_linux;
#[cfg(unix)]
mod hidraw_linux;
#[cfg(unix)]
mod led;
#[cfg(not(unix))]
#[path = "disabled_led.rs"]
mod led;
mod serial_relay;
mod traits;
mod types;
mod wol;
@@ -58,8 +27,9 @@ pub use types::{
ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig, AtxPowerRequest,
AtxState, PowerStatus,
};
pub use wol::send_wol;
pub use wol::{list_wol_history, record_wol_history, send_wol};
#[cfg(any(unix, test))]
fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool {
let upper = uevent.to_ascii_uppercase();
upper.contains("000016C0:000005DF")
@@ -69,6 +39,7 @@ fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool {
|| upper.contains("USB RELAY")
}
#[cfg(unix)]
fn is_usb_relay_hidraw(name: &str) -> bool {
let uevent_path = format!("/sys/class/hidraw/{}/device/uevent", name);
std::fs::read_to_string(uevent_path)
@@ -82,7 +53,9 @@ fn is_usb_relay_hidraw(name: &str) -> bool {
pub fn discover_devices() -> AtxDevices {
let mut devices = AtxDevices::default();
// Single pass through /dev directory
devices.serial_ports = crate::utils::list_serial_ports();
#[cfg(unix)]
if let Ok(entries) = std::fs::read_dir("/dev") {
for entry in entries.flatten() {
let name = entry.file_name();
@@ -100,6 +73,7 @@ pub fn discover_devices() -> AtxDevices {
devices.gpio_chips.sort();
devices.usb_relays.sort();
devices.serial_ports.sort();
devices.serial_ports.dedup();
devices
}
@@ -129,7 +103,6 @@ mod tests {
#[test]
fn test_module_exports() {
// Verify all public exports are accessible
let _: AtxDriverType = AtxDriverType::None;
let _: ActiveLevel = ActiveLevel::High;
let _: AtxKeyConfig = AtxKeyConfig::default();

141
src/atx/serial_relay.rs Normal file
View File

@@ -0,0 +1,141 @@
use async_trait::async_trait;
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::traits::{validate_serial_config, AtxKeyBackend, SharedSerialHandle};
use super::types::AtxKeyConfig;
use crate::error::{AppError, Result};
pub struct SerialRelayBackend {
config: AtxKeyConfig,
serial_handle: Mutex<Option<SharedSerialHandle>>,
initialized: AtomicBool,
}
impl SerialRelayBackend {
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
serial_handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self {
Self {
config,
serial_handle: Mutex::new(Some(serial_handle)),
initialized: AtomicBool::new(false),
}
}
pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result<SharedSerialHandle> {
let port = serialport::new(device, baud_rate)
.timeout(Duration::from_millis(100))
.open()
.map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?;
Ok(Arc::new(Mutex::new(port)))
}
fn send_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"Serial relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
let state = if on { 1 } else { 0 };
let checksum = 0xA0u8.wrapping_add(channel).wrapping_add(state);
let cmd = [0xA0, channel, state, checksum];
let serial_handle = self
.serial_handle
.lock()
.unwrap()
.as_ref()
.cloned()
.ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?;
let mut port = serial_handle.lock().unwrap();
port.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?;
port.flush()
.map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl AtxKeyBackend for SerialRelayBackend {
async fn init(&mut self) -> Result<()> {
validate_serial_config(&self.config)?;
info!(
"Initializing Serial relay ATX backend on {} channel {}",
self.config.device, self.config.pin
);
let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned();
if existing_handle.is_none() {
let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?;
*self.serial_handle.lock().unwrap() = Some(shared);
}
self.send_command(false)?;
self.initialized.store(true, Ordering::Relaxed);
debug!(
"Serial relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_initialized() {
return Err(AppError::Internal(
"Serial relay not initialized".to_string(),
));
}
info!(
"Pulse serial relay on {} pin {}",
self.config.device, self.config.pin
);
self.send_command(true)?;
sleep(duration).await;
self.send_command(false)?;
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
if !self.is_initialized() {
return Ok(());
}
let _ = self.send_command(false);
*self.serial_handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
}
impl Drop for SerialRelayBackend {
fn drop(&mut self) {
if self.is_initialized() {
let _ = self.send_command(false);
}
*self.serial_handle.lock().unwrap() = None;
}
}

51
src/atx/traits.rs Normal file
View File

@@ -0,0 +1,51 @@
use async_trait::async_trait;
use serialport::SerialPort;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use super::types::AtxKeyConfig;
use crate::error::Result;
pub type SharedSerialHandle = Arc<Mutex<Box<dyn SerialPort>>>;
#[async_trait]
pub trait AtxKeyBackend: Send + Sync {
async fn init(&mut self) -> Result<()>;
async fn pulse(&self, duration: Duration) -> Result<()>;
async fn shutdown(&mut self) -> Result<()>;
fn is_initialized(&self) -> bool;
}
#[derive(Debug, Clone)]
pub enum AtxKeyBackendContext {
Standalone,
SharedSerial(SharedSerialHandle),
}
pub fn validate_serial_config(config: &AtxKeyConfig) -> Result<()> {
if config.device.trim().is_empty() {
return Err(crate::error::AppError::Config(
"Serial ATX device cannot be empty".to_string(),
));
}
if config.pin == 0 {
return Err(crate::error::AppError::Config(
"Serial ATX channel must be 1-based (>= 1)".to_string(),
));
}
if config.pin > u8::MAX as u32 {
return Err(crate::error::AppError::Config(format!(
"Serial ATX channel must be <= {}",
u8::MAX
)));
}
if config.baud_rate == 0 {
return Err(crate::error::AppError::Config(
"Serial ATX baud_rate must be greater than 0".to_string(),
));
}
Ok(())
}

View File

@@ -6,67 +6,43 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Power status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PowerStatus {
/// Power is on
On,
/// Power is off
Off,
/// Power status unknown (no LED connected)
#[default]
Unknown,
}
/// Driver type for ATX key operations
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AtxDriverType {
/// GPIO control via Linux character device
Gpio,
/// USB HID relay module
UsbRelay,
/// Serial/COM port relay (taobao LCUS type)
Serial,
/// Disabled / Not configured
#[default]
None,
}
/// Active level for GPIO pins
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ActiveLevel {
/// Active high (default for most cases)
#[default]
High,
/// Active low (inverted)
Low,
}
/// Configuration for a single ATX key (power or reset)
/// This is the "four-tuple" configuration: (driver, device, pin/channel, level)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct AtxKeyConfig {
/// Driver type (GPIO or USB Relay)
pub driver: AtxDriverType,
/// Device path:
/// - For GPIO: /dev/gpiochipX
/// - For USB Relay: /dev/hidrawX
pub device: String,
/// Pin or channel number:
/// - For GPIO: GPIO pin number
/// - For USB Relay: relay channel (1-based)
/// - For Serial Relay (LCUS): relay channel (1-based)
pub pin: u32,
/// Active level (only applicable to GPIO, ignored for USB Relay)
pub active_level: ActiveLevel,
/// Baud rate for serial relay (start with 9600)
pub baud_rate: u32,
}
@@ -83,77 +59,54 @@ impl Default for AtxKeyConfig {
}
impl AtxKeyConfig {
/// Check if this key is configured
pub fn is_configured(&self) -> bool {
self.driver != AtxDriverType::None && !self.device.is_empty()
}
}
/// LED sensing configuration (optional)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct AtxLedConfig {
/// Whether LED sensing is enabled
pub enabled: bool,
/// GPIO chip for LED sensing
pub gpio_chip: String,
/// GPIO pin for LED input
pub gpio_pin: u32,
/// Whether LED is active low (inverted logic)
pub inverted: bool,
}
impl AtxLedConfig {
/// Check if LED sensing is configured
pub fn is_configured(&self) -> bool {
self.enabled && !self.gpio_chip.is_empty()
}
}
/// ATX state information
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AtxState {
/// Whether ATX feature is available/enabled
pub available: bool,
/// Whether power button is configured
pub power_configured: bool,
/// Whether reset button is configured
pub reset_configured: bool,
/// Current power status
pub power_status: PowerStatus,
/// Whether power LED sensing is supported
pub led_supported: bool,
}
/// ATX power action request
#[derive(Debug, Clone, Deserialize)]
pub struct AtxPowerRequest {
/// Action to perform: "short", "long", "reset"
pub action: AtxAction,
}
/// ATX power action
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AtxAction {
/// Short press power button (turn on or graceful shutdown)
Short,
/// Long press power button (force power off)
Long,
/// Press reset button
Reset,
}
/// Available ATX devices for discovery
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDevices {
/// Available GPIO chips (/dev/gpiochip*)
pub gpio_chips: Vec<String>,
/// Available USB HID relay devices (/dev/hidraw*)
pub usb_relays: Vec<String>,
/// Available Serial ports (/dev/ttyUSB*)
pub serial_ports: Vec<String>,
}
@@ -201,13 +154,13 @@ mod tests {
assert!(!config.is_configured());
config.driver = AtxDriverType::Gpio;
assert!(!config.is_configured()); // device still empty
assert!(!config.is_configured());
config.device = "/dev/gpiochip0".to_string();
assert!(config.is_configured());
config.driver = AtxDriverType::None;
assert!(!config.is_configured()); // driver is None
assert!(!config.is_configured());
}
#[test]
@@ -224,7 +177,7 @@ mod tests {
assert!(!config.is_configured());
config.enabled = true;
assert!(!config.is_configured()); // gpio_chip still empty
assert!(!config.is_configured());
config.gpio_chip = "/dev/gpiochip0".to_string();
assert!(config.is_configured());

View File

@@ -3,18 +3,14 @@
//! Sends magic packets to wake up remote machines.
use std::net::{SocketAddr, UdpSocket};
use tracing::{debug, info};
use tracing::info;
use crate::error::{AppError, Result};
/// WOL magic packet structure:
/// - 6 bytes of 0xFF
/// - 16 repetitions of the target MAC address (6 bytes each)
/// Total: 6 + 16 * 6 = 102 bytes
const WOL_HISTORY_MAX_ENTRIES: i64 = 50;
const MAGIC_PACKET_SIZE: usize = 102;
/// Parse MAC address string into bytes
/// Supports formats: "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF"
fn parse_mac_address(mac: &str) -> Result<[u8; 6]> {
let mac = mac.trim().to_uppercase();
let parts: Vec<&str> = if mac.contains(':') {
@@ -44,16 +40,13 @@ fn parse_mac_address(mac: &str) -> Result<[u8; 6]> {
Ok(bytes)
}
/// Build WOL magic packet
fn build_magic_packet(mac: &[u8; 6]) -> [u8; MAGIC_PACKET_SIZE] {
let mut packet = [0u8; MAGIC_PACKET_SIZE];
// First 6 bytes are 0xFF
for byte in packet.iter_mut().take(6) {
*byte = 0xFF;
}
// Next 96 bytes are 16 repetitions of the MAC address
for i in 0..16 {
let offset = 6 + i * 6;
packet[offset..offset + 6].copy_from_slice(mac);
@@ -73,16 +66,13 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
info!("Sending WOL packet to {} via {:?}", mac_address, interface);
// Create UDP socket
let socket = UdpSocket::bind("0.0.0.0:0")
.map_err(|e| AppError::Internal(format!("Failed to create UDP socket: {}", e)))?;
// Enable broadcast
socket
.set_broadcast(true)
.map_err(|e| AppError::Internal(format!("Failed to enable broadcast: {}", e)))?;
// Bind to specific interface if specified
#[cfg(target_os = "linux")]
if let Some(iface) = interface {
if !iface.is_empty() {
@@ -90,8 +80,7 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
let fd = socket.as_raw_fd();
let iface_bytes = iface.as_bytes();
// SO_BINDTODEVICE requires interface name as null-terminated string
let mut iface_buf = [0u8; 16]; // IFNAMSIZ is typically 16
let mut iface_buf = [0u8; 16];
let len = iface_bytes.len().min(15);
iface_buf[..len].copy_from_slice(&iface_bytes[..len]);
@@ -112,18 +101,16 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
iface, err
)));
}
debug!("Bound to interface: {}", iface);
tracing::debug!("Bound to interface: {}", iface);
}
}
// Send to broadcast address on port 9 (discard protocol, commonly used for WOL)
let broadcast_addr: SocketAddr = "255.255.255.255:9".parse().unwrap();
socket
.send_to(&packet, broadcast_addr)
.map_err(|e| AppError::Internal(format!("Failed to send WOL packet: {}", e)))?;
// Also try sending to port 7 (echo protocol, alternative WOL port)
let broadcast_addr_7: SocketAddr = "255.255.255.255:7".parse().unwrap();
let _ = socket.send_to(&packet, broadcast_addr_7);
@@ -131,6 +118,55 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
Ok(())
}
pub async fn record_wol_history(pool: &sqlx::Pool<sqlx::Sqlite>, mac_address: &str) -> Result<()> {
sqlx::query(
r#"
INSERT INTO wol_history (mac_address, updated_at)
VALUES (?1, CAST(strftime('%s', 'now') AS INTEGER))
ON CONFLICT(mac_address) DO UPDATE SET
updated_at = excluded.updated_at
"#,
)
.bind(mac_address)
.execute(pool)
.await?;
sqlx::query(
r#"
DELETE FROM wol_history
WHERE mac_address NOT IN (
SELECT mac_address FROM wol_history
ORDER BY updated_at DESC
LIMIT ?1
)
"#,
)
.bind(WOL_HISTORY_MAX_ENTRIES)
.execute(pool)
.await?;
Ok(())
}
pub async fn list_wol_history(
pool: &sqlx::Pool<sqlx::Sqlite>,
limit: usize,
) -> Result<Vec<(String, i64)>> {
let rows = sqlx::query_as(
r#"
SELECT mac_address, updated_at
FROM wol_history
ORDER BY updated_at DESC
LIMIT ?1
"#,
)
.bind(limit as i64)
.fetch_all(pool)
.await?;
Ok(rows)
}
#[cfg(test)]
mod tests {
use super::*;
@@ -159,12 +195,10 @@ mod tests {
let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF];
let packet = build_magic_packet(&mac);
// Check header (6 bytes of 0xFF)
for byte in packet.iter().take(6) {
assert_eq!(*byte, 0xFF);
}
// Check MAC repetitions
for i in 0..16 {
let offset = 6 + i * 6;
assert_eq!(&packet[offset..offset + 6], &mac);

View File

@@ -1,334 +1,9 @@
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, info};
#[cfg(unix)]
#[path = "capture_linux.rs"]
mod imp;
use super::device::AudioDeviceInfo;
use crate::error::{AppError, Result};
use crate::utils::LogThrottler;
use crate::{error_throttled, warn_throttled};
#[cfg(windows)]
#[path = "capture_windows.rs"]
mod imp;
#[derive(Debug, Clone)]
pub struct AudioConfig {
pub device_name: String,
pub sample_rate: u32,
pub channels: u32,
pub frame_size: u32,
pub buffer_frames: u32,
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
pub fn for_device(device: &AudioDeviceInfo) -> Self {
Self {
device_name: device.name.clone(),
..Default::default()
}
}
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
#[derive(Debug, Clone)]
pub struct AudioFrame {
pub data: Bytes,
pub sample_rate: u32,
pub channels: u32,
pub samples: u32,
pub sequence: u64,
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / bps,
data,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
log_throttler: LogThrottler,
}
impl AudioCapturer {
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
state: Arc::new(state_tx),
state_rx,
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
log_throttler: LogThrottler::with_secs(5),
}
}
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
debug!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
config.device_name, e
))
})?;
{
let hwp = HwParams::any(&pcm)
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
hwp.set_channels(config.channels)
.map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?;
hwp.set_rate(config.sample_rate, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?;
hwp.set_format(Format::s16())
.map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?;
hwp.set_access(Access::RWInterleaved)
.map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?;
hwp.set_buffer_size_near(config.buffer_frames as Frames)
.map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?;
hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?;
pcm.hw_params(&hwp)
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
}
let hw_now = pcm.hw_params_current().map_err(|e| {
AppError::AudioError(format!("Failed to read hw_params after apply: {}", e))
})?;
let actual_rate = hw_now
.get_rate()
.map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?;
let actual_ch = hw_now
.get_channels()
.map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?;
if actual_rate != 48_000 {
return Err(AppError::AudioError(format!(
"Audio capture requires 48000 Hz; device is {} Hz",
actual_rate
)));
}
if actual_ch != 2 {
return Err(AppError::AudioError(format!(
"Audio capture requires 2 channels (stereo); device has {}",
actual_ch
)));
}
debug!("Audio capture: 48000 Hz, 2 ch");
pcm.prepare()
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
let _ = state.send(CaptureState::Running);
let period_frames = pcm
.hw_params_current()
.ok()
.and_then(|h| h.get_period_size().ok())
.map(|f| f as usize)
.unwrap_or(1024)
.max(256);
let buf_frames = period_frames.saturating_mul(4).max(2048);
let bytes_per_frame = (config.channels as usize) * 2;
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
while !stop_flag.load(Ordering::Relaxed) {
match pcm.state() {
State::XRun => {
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
let _ = pcm.prepare();
continue;
}
State::Suspended => {
warn_throttled!(
log_throttler,
"suspended",
"Audio device suspended, recovering"
);
let _ = pcm.resume();
continue;
}
_ => {}
}
// io_bytes: USB capture often lacks mmap (io_checked requires it).
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
Ok(frames_read) => {
if frames_read == 0 {
continue;
}
let byte_count = frames_read * config.channels as usize * 2;
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(&buffer[..byte_count]),
config.channels,
48_000,
seq,
);
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
}
Err(e) => {
let desc = e.to_string();
if is_device_lost_error(&desc) {
return Err(AppError::AudioError(format!(
"Audio device lost while reading {}: {}",
config.device_name, e
)));
} else if desc.contains("EPIPE") || desc.contains("Broken pipe") {
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
let _ = pcm.prepare();
} else {
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
}
}
}
}
info!("Audio capture stopped");
Ok(())
}
fn is_device_lost_error(desc: &str) -> bool {
desc.contains("No such device")
|| desc.contains("ENODEV")
|| desc.contains("ENXIO")
|| desc.contains("ESHUTDOWN")
}
pub use imp::*;

334
src/audio/capture_linux.rs Normal file
View File

@@ -0,0 +1,334 @@
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, info};
use super::device::AudioDeviceInfo;
use crate::error::{AppError, Result};
use crate::utils::LogThrottler;
use crate::{error_throttled, warn_throttled};
#[derive(Debug, Clone)]
pub struct AudioConfig {
pub device_name: String,
pub sample_rate: u32,
pub channels: u32,
pub frame_size: u32,
pub buffer_frames: u32,
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
pub fn for_device(device: &AudioDeviceInfo) -> Self {
Self {
device_name: device.name.clone(),
..Default::default()
}
}
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
#[derive(Debug, Clone)]
pub struct AudioFrame {
pub data: Bytes,
pub sample_rate: u32,
pub channels: u32,
pub samples: u32,
pub sequence: u64,
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / bps,
data,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
log_throttler: LogThrottler,
}
impl AudioCapturer {
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
state: Arc::new(state_tx),
state_rx,
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
log_throttler: LogThrottler::with_secs(5),
}
}
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
debug!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
config.device_name, e
))
})?;
{
let hwp = HwParams::any(&pcm)
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
hwp.set_channels(config.channels)
.map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?;
hwp.set_rate(config.sample_rate, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?;
hwp.set_format(Format::s16())
.map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?;
hwp.set_access(Access::RWInterleaved)
.map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?;
hwp.set_buffer_size_near(config.buffer_frames as Frames)
.map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?;
hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?;
pcm.hw_params(&hwp)
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
}
let hw_now = pcm.hw_params_current().map_err(|e| {
AppError::AudioError(format!("Failed to read hw_params after apply: {}", e))
})?;
let actual_rate = hw_now
.get_rate()
.map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?;
let actual_ch = hw_now
.get_channels()
.map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?;
if actual_rate != 48_000 {
return Err(AppError::AudioError(format!(
"Audio capture requires 48000 Hz; device is {} Hz",
actual_rate
)));
}
if actual_ch != 2 {
return Err(AppError::AudioError(format!(
"Audio capture requires 2 channels (stereo); device has {}",
actual_ch
)));
}
debug!("Audio capture: 48000 Hz, 2 ch");
pcm.prepare()
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
let _ = state.send(CaptureState::Running);
let period_frames = pcm
.hw_params_current()
.ok()
.and_then(|h| h.get_period_size().ok())
.map(|f| f as usize)
.unwrap_or(1024)
.max(256);
let buf_frames = period_frames.saturating_mul(4).max(2048);
let bytes_per_frame = (config.channels as usize) * 2;
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
while !stop_flag.load(Ordering::Relaxed) {
match pcm.state() {
State::XRun => {
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
let _ = pcm.prepare();
continue;
}
State::Suspended => {
warn_throttled!(
log_throttler,
"suspended",
"Audio device suspended, recovering"
);
let _ = pcm.resume();
continue;
}
_ => {}
}
// io_bytes: USB capture often lacks mmap (io_checked requires it).
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
Ok(frames_read) => {
if frames_read == 0 {
continue;
}
let byte_count = frames_read * config.channels as usize * 2;
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(&buffer[..byte_count]),
config.channels,
48_000,
seq,
);
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
}
Err(e) => {
let desc = e.to_string();
if is_device_lost_error(&desc) {
return Err(AppError::AudioError(format!(
"Audio device lost while reading {}: {}",
config.device_name, e
)));
} else if desc.contains("EPIPE") || desc.contains("Broken pipe") {
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
let _ = pcm.prepare();
} else {
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
}
}
}
}
info!("Audio capture stopped");
Ok(())
}
fn is_device_lost_error(desc: &str) -> bool {
desc.contains("No such device")
|| desc.contains("ENODEV")
|| desc.contains("ENXIO")
|| desc.contains("ESHUTDOWN")
}

View File

@@ -0,0 +1,516 @@
use bytes::Bytes;
use cpal::traits::{DeviceTrait, StreamTrait};
use cpal::{BufferSize, SampleFormat, StreamConfig};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, info};
use crate::audio::device::{find_wasapi_device, AudioDeviceInfo};
use crate::error::{AppError, Result};
use crate::error_throttled;
use crate::utils::LogThrottler;
#[derive(Debug, Clone)]
pub struct AudioConfig {
pub device_name: String,
pub sample_rate: u32,
pub channels: u32,
pub frame_size: u32,
pub buffer_frames: u32,
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
pub fn for_device(device: &AudioDeviceInfo) -> Self {
Self {
device_name: device.name.clone(),
..Default::default()
}
}
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
#[derive(Debug, Clone)]
pub struct AudioFrame {
pub data: Bytes,
pub sample_rate: u32,
pub channels: u32,
pub samples: u32,
pub sequence: u64,
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / bps,
data,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
log_throttler: LogThrottler,
}
impl AudioCapturer {
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
state: Arc::new(state_tx),
state_rx,
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
log_throttler: LogThrottler::with_secs(5),
}
}
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
debug!(
"Starting WASAPI audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(
log_throttler,
"capture_error",
"WASAPI audio capture error: {}",
e
);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
info!("Stopping WASAPI audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
let device = find_wasapi_device(&config.device_name)?;
let device_label = device_label(&device);
let supported = select_input_config(&device, config)?;
let sample_format = supported.sample_format();
let input_channels = supported.channels() as u32;
let input_rate = supported.sample_rate();
let stream_config = StreamConfig {
channels: supported.channels(),
sample_rate: supported.sample_rate(),
buffer_size: BufferSize::Fixed(config.period_frames.max(128)),
};
debug!(
"WASAPI capture selected: {} @ {}Hz {}ch {:?}",
device_label, input_rate, input_channels, sample_format
);
let (tx, rx) = mpsc::sync_channel::<Vec<i16>>(8);
let (err_tx, err_rx) = mpsc::sync_channel::<String>(1);
let callback_stop = Arc::new(AtomicBool::new(false));
let stream = match sample_format {
SampleFormat::F32 => build_stream::<f32>(
&device,
&stream_config,
input_channels,
input_rate,
tx.clone(),
err_tx.clone(),
callback_stop.clone(),
),
SampleFormat::I16 => build_stream::<i16>(
&device,
&stream_config,
input_channels,
input_rate,
tx.clone(),
err_tx.clone(),
callback_stop.clone(),
),
SampleFormat::U16 => build_stream::<u16>(
&device,
&stream_config,
input_channels,
input_rate,
tx.clone(),
err_tx.clone(),
callback_stop.clone(),
),
other => {
return Err(AppError::AudioError(format!(
"Unsupported WASAPI sample format: {:?}",
other
)));
}
}?;
stream
.play()
.map_err(|e| AppError::AudioError(format!("Failed to start WASAPI stream: {}", e)))?;
let _ = state.send(CaptureState::Running);
while !stop_flag.load(Ordering::Relaxed) {
if let Ok(err) = err_rx.try_recv() {
return Err(AppError::AudioError(format!(
"WASAPI stream error for {}: {}",
device_label, err
)));
}
match rx.recv_timeout(Duration::from_millis(100)) {
Ok(samples) => {
if samples.is_empty() {
continue;
}
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(bytemuck::cast_slice(&samples)),
2,
48_000,
seq,
);
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {}
Err(mpsc::RecvTimeoutError::Disconnected) => {
return Err(AppError::AudioError(format!(
"WASAPI capture callback stopped for {}",
device_label
)));
}
}
}
callback_stop.store(true, Ordering::SeqCst);
drop(stream);
info!("WASAPI audio capture stopped");
let _ = log_throttler;
Ok(())
}
fn select_input_config(
device: &cpal::Device,
config: &AudioConfig,
) -> Result<cpal::SupportedStreamConfig> {
let requested_rate = config.sample_rate;
let mut fallback = None;
let configs = device.supported_input_configs().map_err(|e| {
AppError::AudioError(format!("Failed to query WASAPI input configs: {}", e))
})?;
for range in configs {
let sample_format = range.sample_format();
if !matches!(
sample_format,
SampleFormat::F32 | SampleFormat::I16 | SampleFormat::U16
) {
continue;
}
if fallback
.as_ref()
.is_none_or(|best: &cpal::SupportedStreamConfigRange| {
range.cmp_default_heuristics(best).is_gt()
})
{
fallback = Some(range);
}
if range.channels() >= 2
&& range.min_sample_rate() <= requested_rate
&& requested_rate <= range.max_sample_rate()
{
return Ok(range.with_sample_rate(requested_rate));
}
}
if let Some(range) = fallback {
let rate = if range.min_sample_rate() <= requested_rate
&& requested_rate <= range.max_sample_rate()
{
requested_rate
} else {
range.with_max_sample_rate().sample_rate()
};
return Ok(range.with_sample_rate(rate));
}
device.default_input_config().map_err(|e| {
AppError::AudioError(format!(
"No supported WASAPI input format found, and default config failed: {}",
e
))
})
}
fn build_stream<T>(
device: &cpal::Device,
config: &StreamConfig,
input_channels: u32,
input_rate: u32,
tx: mpsc::SyncSender<Vec<i16>>,
err_tx: mpsc::SyncSender<String>,
stop_flag: Arc<AtomicBool>,
) -> Result<cpal::Stream>
where
T: cpal::SizedSample + SampleToI16,
{
let mut converter = PcmConverter::new(input_channels, input_rate, 2, 48_000);
let data_tx = tx.clone();
let stream = device
.build_input_stream(
config,
move |data: &[T], _| {
if stop_flag.load(Ordering::Relaxed) {
return;
}
let pcm = converter.convert(data);
if !pcm.is_empty() {
let _ = data_tx.try_send(pcm);
}
},
move |err| {
let _ = err_tx.try_send(err.to_string());
},
Some(Duration::from_secs(2)),
)
.map_err(|e| AppError::AudioError(format!("Failed to build WASAPI input stream: {}", e)))?;
Ok(stream)
}
trait SampleToI16: Copy + Send + 'static {
fn to_i16_sample(self) -> i16;
}
impl SampleToI16 for i16 {
fn to_i16_sample(self) -> i16 {
self
}
}
impl SampleToI16 for u16 {
fn to_i16_sample(self) -> i16 {
(self as i32 - 32768).clamp(i16::MIN as i32, i16::MAX as i32) as i16
}
}
impl SampleToI16 for f32 {
fn to_i16_sample(self) -> i16 {
(self.clamp(-1.0, 1.0) * i16::MAX as f32).round() as i16
}
}
struct PcmConverter {
input_channels: usize,
input_rate: u32,
output_channels: usize,
output_rate: u32,
input_position: u64,
next_output_position: u64,
}
impl PcmConverter {
fn new(input_channels: u32, input_rate: u32, output_channels: u32, output_rate: u32) -> Self {
Self {
input_channels: input_channels.max(1) as usize,
input_rate: input_rate.max(1),
output_channels: output_channels.max(1) as usize,
output_rate: output_rate.max(1),
input_position: 0,
next_output_position: 0,
}
}
fn convert<T: SampleToI16>(&mut self, input: &[T]) -> Vec<i16> {
let frames = input.len() / self.input_channels;
if frames == 0 {
return Vec::new();
}
if self.input_rate == self.output_rate {
self.input_position = self.input_position.saturating_add(frames as u64);
return self.convert_channels(input, frames);
}
let start = self.input_position;
let end = start.saturating_add(frames as u64);
let mut out = Vec::with_capacity(
((frames as u64 * self.output_rate as u64 / self.input_rate as u64 + 2) as usize)
* self.output_channels,
);
while self.source_position_for_output(self.next_output_position) < end {
let src = self.source_position_for_output(self.next_output_position);
if src >= start {
let local = (src - start) as usize;
self.push_frame(input, local.min(frames - 1), &mut out);
}
self.next_output_position = self.next_output_position.saturating_add(1);
}
self.input_position = end;
out
}
fn source_position_for_output(&self, output_position: u64) -> u64 {
output_position.saturating_mul(self.input_rate as u64) / self.output_rate as u64
}
fn convert_channels<T: SampleToI16>(&self, input: &[T], frames: usize) -> Vec<i16> {
let mut out = Vec::with_capacity(frames * self.output_channels);
for frame in 0..frames {
self.push_frame(input, frame, &mut out);
}
out
}
fn push_frame<T: SampleToI16>(&self, input: &[T], frame: usize, out: &mut Vec<i16>) {
let base = frame * self.input_channels;
let left = input
.get(base)
.copied()
.map(SampleToI16::to_i16_sample)
.unwrap_or(0);
let right = if self.input_channels > 1 {
input
.get(base + 1)
.copied()
.map(SampleToI16::to_i16_sample)
.unwrap_or(left)
} else {
left
};
out.push(left);
if self.output_channels > 1 {
out.push(right);
}
}
}
fn device_label(device: &cpal::Device) -> String {
device
.description()
.map(|desc| desc.to_string())
.or_else(|_| {
#[allow(deprecated)]
device.name()
})
.unwrap_or_else(|_| "Unknown WASAPI capture device".to_string())
}

View File

@@ -1,107 +1,21 @@
//! Device selection, quality presets, streaming.
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use tracing::{debug, info};
use super::capture::AudioConfig;
use super::device::{
enumerate_audio_devices, enumerate_audio_devices_with_current, find_best_audio_device,
AudioDeviceInfo,
};
use super::encoder::{OpusConfig, OpusFrame};
use super::device::{enumerate_audio_devices_with_current, find_best_audio_device, AudioDeviceInfo};
use super::encoder::OpusFrame;
use super::monitor::AudioHealthMonitor;
use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
use super::streamer::{AudioStreamer, AudioStreamerConfig};
use super::recovery;
use super::types::{AudioControllerConfig, AudioQuality, AudioStatus};
use crate::error::{AppError, Result};
use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent};
use crate::events::EventBus;
const AUDIO_RECOVERY_RETRY_DELAY: Duration = Duration::from_secs(1);
type AudioRecoveredCallback = Arc<dyn Fn() + Send + Sync>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
Voice,
#[default]
Balanced,
High,
}
impl AudioQuality {
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
AudioQuality::Balanced => 64000,
AudioQuality::High => 128000,
}
}
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
AudioQuality::Balanced => OpusConfig::default(),
AudioQuality::High => OpusConfig::music(),
}
}
}
impl FromStr for AudioQuality {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"voice" => Ok(Self::Voice),
"balanced" => Ok(Self::Balanced),
"high" => Ok(Self::High),
_ => Err(AppError::BadRequest(format!(
"invalid audio quality {:?} (expected voice, balanced, or high)",
s.trim()
))),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AudioQuality::Voice => write!(f, "voice"),
AudioQuality::Balanced => write!(f, "balanced"),
AudioQuality::High => write!(f, "high"),
}
}
}
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
pub enabled: bool,
pub device: String,
pub quality: AudioQuality,
}
impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: String::new(),
quality: AudioQuality::Balanced,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
pub enabled: bool,
pub streaming: bool,
pub device: Option<String>,
pub quality: AudioQuality,
pub subscriber_count: usize,
pub error: Option<String>,
}
pub(super) type AudioRecoveredCallback = Arc<dyn Fn() + Send + Sync>;
pub struct AudioController {
config: Arc<RwLock<AudioControllerConfig>>,
@@ -135,274 +49,12 @@ impl AudioController {
}
async fn mark_device_info_dirty(&self) {
if let Some(ref bus) = *self.event_bus.read().await {
if let Some(bus) = self.event_bus.read().await.as_ref() {
bus.mark_device_info_dirty();
}
}
async fn publish_state(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
state: &str,
device: Option<String>,
reason: Option<&str>,
next_retry_ms: Option<u64>,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamStateChanged {
state: state.to_string(),
device,
reason: reason.map(str::to_string),
next_retry_ms,
});
bus.mark_device_info_dirty();
}
}
async fn publish_device_lost(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
reason: &str,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: device.to_string(),
reason: reason.to_string(),
});
}
}
async fn publish_reconnecting(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
attempt: u32,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamReconnecting {
device: device.to_string(),
attempt,
});
}
}
async fn publish_recovered(event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>, device: &str) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamRecovered {
device: device.to_string(),
});
}
}
fn select_recovery_device(
devices: &[AudioDeviceInfo],
preferred: &str,
) -> Option<AudioDeviceInfo> {
if !preferred.trim().is_empty() {
if let Some(device) = devices.iter().find(|d| d.name == preferred) {
return Some(device.clone());
}
}
devices
.iter()
.find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2))
.or_else(|| {
devices
.iter()
.find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2))
})
.or_else(|| devices.first())
.cloned()
}
fn spawn_stream_monitor_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
let mut state_rx = streamer.state_watch();
tokio::spawn(async move {
loop {
if state_rx.changed().await.is_err() {
return;
}
if *state_rx.borrow() != AudioStreamState::Error {
continue;
}
{
let current = streamer_slot.read().await;
if !current
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &streamer))
{
return;
}
}
let reason = format!("Audio device lost: {}", device);
monitor.report_error(&reason, "device_lost").await;
Self::spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
device,
reason,
);
return;
}
});
}
fn spawn_recovery_task_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
if recovery_in_progress.swap(true, Ordering::SeqCst) {
debug!("Audio recovery already in progress");
return;
}
tokio::spawn(async move {
warn!("Audio recovery started for {}: {}", lost_device, reason);
Self::publish_device_lost(&event_bus, &lost_device, &reason).await;
Self::publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_device_lost"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
let mut attempt = 0u32;
loop {
if !recovery_in_progress.load(Ordering::SeqCst) {
debug!("Audio recovery canceled");
return;
}
if streamer_slot
.read()
.await
.as_ref()
.is_some_and(|s| s.is_running())
{
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
let cfg = config.read().await.clone();
if !cfg.enabled {
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
attempt = attempt.saturating_add(1);
Self::publish_reconnecting(&event_bus, &lost_device, attempt).await;
Self::publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_reconnecting"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await;
let devices = match enumerate_audio_devices() {
Ok(devices) => devices,
Err(e) => {
debug!(
"Audio recovery enumerate failed (attempt {}): {}",
attempt, e
);
continue;
}
};
let Some(device) = Self::select_recovery_device(&devices, &cfg.device) else {
debug!("No audio devices found during recovery attempt {}", attempt);
continue;
};
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device.name.clone(),
..Default::default()
},
opus: cfg.quality.to_opus_config(),
};
let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config));
match new_streamer.start().await {
Ok(()) => {
{
let mut cfg = config.write().await;
cfg.device = device.name.clone();
}
*streamer_slot.write().await = Some(new_streamer.clone());
monitor.report_recovered().await;
Self::publish_recovered(&event_bus, &device.name).await;
if let Some(callback) = recovered_callback.read().await.clone() {
callback();
}
Self::publish_state(
&event_bus,
"streaming",
Some(device.name.clone()),
None,
None,
)
.await;
recovery_in_progress.store(false, Ordering::SeqCst);
info!(
"Audio device recovered with {} after {} attempts",
device.name, attempt
);
Self::spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
new_streamer,
device.name,
);
return;
}
Err(e) => {
debug!(
"Audio recovery start failed with {} (attempt {}): {}",
device.name, attempt, e
);
}
}
}
});
}
fn spawn_recovery_task(&self, lost_device: String, reason: String) {
Self::spawn_recovery_task_from_parts(
recovery::spawn_recovery_task(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
@@ -415,7 +67,7 @@ impl AudioController {
}
fn spawn_stream_monitor(&self, streamer: Arc<AudioStreamer>, device: String) {
Self::spawn_stream_monitor_from_parts(
recovery::spawn_stream_monitor(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
@@ -477,7 +129,7 @@ impl AudioController {
config.quality = quality;
}
if let Some(ref streamer) = *self.streamer.read().await {
if let Some(streamer) = self.streamer.read().await.as_ref() {
streamer.set_bitrate(quality.bitrate()).await?;
}
@@ -578,11 +230,11 @@ impl AudioController {
}
pub async fn is_streaming(&self) -> bool {
if let Some(ref streamer) = *self.streamer.read().await {
streamer.is_running()
} else {
false
}
self.streamer
.read()
.await
.as_ref()
.is_some_and(|streamer| streamer.is_running())
}
pub async fn status(&self) -> AudioStatus {

View File

@@ -1,201 +1,9 @@
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
use tracing::{debug, info, warn};
#[cfg(unix)]
#[path = "device_linux.rs"]
mod imp;
use crate::error::{AppError, Result};
#[cfg(windows)]
#[path = "device_windows.rs"]
mod imp;
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
pub name: String,
pub description: String,
pub card_index: i32,
pub device_index: i32,
pub sample_rates: Vec<u32>,
pub channels: Vec<u32>,
pub is_capture: bool,
pub is_hdmi: bool,
pub usb_bus: Option<String>,
}
fn get_usb_bus_info(card_index: i32) -> Option<String> {
if card_index < 0 {
return None;
}
let device_path = format!("/sys/class/sound/card{}/device", card_index);
let link_target = std::fs::read_link(&device_path).ok()?;
let link_str = link_target.to_string_lossy();
for component in link_str.split('/') {
if component.contains('-') && !component.contains(':') {
if component
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
return Some(component.to_string());
}
}
}
None
}
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
let cards = alsa::card::Iter::new();
for card_result in cards {
let card = match card_result {
Ok(c) => c,
Err(e) => {
debug!("Error iterating card: {}", e);
continue;
}
};
let card_index = card.get_index();
let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string());
let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone());
debug!("Found audio card {}: {}", card_index, card_longname);
let long_lower = card_longname.to_lowercase();
let is_hdmi = long_lower.contains("hdmi")
|| long_lower.contains("capture")
|| long_lower.contains("usb");
let usb_bus = get_usb_bus_info(card_index);
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
let is_current_device = current_device == Some(device_name.as_str());
let mut push_info =
|sample_rates: Vec<u32>, channels: Vec<u32>, description: String| {
devices.push(AudioDeviceInfo {
name: device_name.clone(),
description,
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
};
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
push_info(
sample_rates,
channels,
format!("{} - Device {}", card_longname, device_index),
);
}
}
Err(_) => {
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
push_info(
vec![44100, 48000],
vec![2],
format!("{} - Device {} (in use)", card_longname, device_index),
);
}
}
}
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
for rate in &common_rates {
if hwp.test_rate(*rate).is_ok() {
supported_rates.push(*rate);
}
}
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
supported_channels.push(ch);
}
}
(supported_rates, supported_channels)
}
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError(
"No audio capture devices found".to_string(),
));
}
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
device.description
);
Ok(device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enumerate_devices() {
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
assert!(result.is_ok());
}
}
pub use imp::*;

201
src/audio/device_linux.rs Normal file
View File

@@ -0,0 +1,201 @@
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
pub name: String,
pub description: String,
pub card_index: i32,
pub device_index: i32,
pub sample_rates: Vec<u32>,
pub channels: Vec<u32>,
pub is_capture: bool,
pub is_hdmi: bool,
pub usb_bus: Option<String>,
}
fn get_usb_bus_info(card_index: i32) -> Option<String> {
if card_index < 0 {
return None;
}
let device_path = format!("/sys/class/sound/card{}/device", card_index);
let link_target = std::fs::read_link(&device_path).ok()?;
let link_str = link_target.to_string_lossy();
for component in link_str.split('/') {
if component.contains('-') && !component.contains(':') {
if component
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
return Some(component.to_string());
}
}
}
None
}
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
let cards = alsa::card::Iter::new();
for card_result in cards {
let card = match card_result {
Ok(c) => c,
Err(e) => {
debug!("Error iterating card: {}", e);
continue;
}
};
let card_index = card.get_index();
let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string());
let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone());
debug!("Found audio card {}: {}", card_index, card_longname);
let long_lower = card_longname.to_lowercase();
let is_hdmi = long_lower.contains("hdmi")
|| long_lower.contains("capture")
|| long_lower.contains("usb");
let usb_bus = get_usb_bus_info(card_index);
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
let is_current_device = current_device == Some(device_name.as_str());
let mut push_info =
|sample_rates: Vec<u32>, channels: Vec<u32>, description: String| {
devices.push(AudioDeviceInfo {
name: device_name.clone(),
description,
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
};
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
push_info(
sample_rates,
channels,
format!("{} - Device {}", card_longname, device_index),
);
}
}
Err(_) => {
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
push_info(
vec![44100, 48000],
vec![2],
format!("{} - Device {} (in use)", card_longname, device_index),
);
}
}
}
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
for rate in &common_rates {
if hwp.test_rate(*rate).is_ok() {
supported_rates.push(*rate);
}
}
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
supported_channels.push(ch);
}
}
(supported_rates, supported_channels)
}
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError(
"No audio capture devices found".to_string(),
));
}
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
device.description
);
Ok(device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enumerate_devices() {
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
assert!(result.is_ok());
}
}

232
src/audio/device_windows.rs Normal file
View File

@@ -0,0 +1,232 @@
use cpal::traits::{DeviceTrait, HostTrait};
use cpal::DeviceId;
use serde::Serialize;
use std::str::FromStr;
use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
pub name: String,
pub description: String,
pub card_index: i32,
pub device_index: i32,
pub sample_rates: Vec<u32>,
pub channels: Vec<u32>,
pub is_capture: bool,
pub is_hdmi: bool,
pub usb_bus: Option<String>,
}
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let host = cpal::default_host();
let devices = host
.input_devices()
.map_err(|e| AppError::AudioError(format!("Failed to enumerate WASAPI devices: {}", e)))?;
let mut result = Vec::new();
for (index, device) in devices.enumerate() {
let labels = device_labels(&device);
let id = device
.id()
.map(|id| id.to_string())
.unwrap_or_else(|_| format!("wasapi-index:{}", index));
let (sample_rates, channels) = query_device_caps(&device);
if sample_rates.is_empty() || channels.is_empty() {
debug!(
"Skipping WASAPI endpoint without usable input caps: {}",
labels.search_text
);
continue;
}
let is_current =
current_device == Some(id.as_str()) || current_device == Some(labels.display.as_str());
let description = if is_current {
format!("{} (in use)", labels.display)
} else {
labels.display.clone()
};
let lower = labels.search_text.to_lowercase();
let is_hdmi = lower.contains("hdmi")
|| lower.contains("capture")
|| lower.contains("usb")
|| lower.contains("digital");
result.push(AudioDeviceInfo {
name: id,
description,
card_index: index as i32,
device_index: 0,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: None,
});
}
info!("Found {} WASAPI audio capture devices", result.len());
Ok(result)
}
fn query_device_caps(device: &cpal::Device) -> (Vec<u32>, Vec<u32>) {
let mut sample_rates = Vec::new();
let mut channels = Vec::new();
if let Ok(configs) = device.supported_input_configs() {
for cfg in configs {
for rate in [8000, 16000, 22050, 44100, 48000, 96000] {
if cfg.min_sample_rate() <= rate
&& rate <= cfg.max_sample_rate()
&& !sample_rates.contains(&rate)
{
sample_rates.push(rate);
}
}
let ch = cfg.channels() as u32;
if !channels.contains(&ch) {
channels.push(ch);
}
}
}
if (sample_rates.is_empty() || channels.is_empty()) && device.default_input_config().is_ok() {
if let Ok(default_cfg) = device.default_input_config() {
if !sample_rates.contains(&default_cfg.sample_rate()) {
sample_rates.push(default_cfg.sample_rate());
}
let ch = default_cfg.channels() as u32;
if !channels.contains(&ch) {
channels.push(ch);
}
}
}
sample_rates.sort_unstable();
channels.sort_unstable();
(sample_rates, channels)
}
struct DeviceLabels {
display: String,
search_text: String,
}
fn device_labels(device: &cpal::Device) -> DeviceLabels {
match device.description() {
Ok(desc) => {
let formatted = desc.to_string();
let display = desc
.extended()
.first()
.cloned()
.unwrap_or_else(|| formatted.clone());
let mut parts = vec![formatted, desc.name().to_string(), display.clone()];
parts.extend(desc.extended().iter().cloned());
DeviceLabels {
display,
search_text: parts.join(" "),
}
}
Err(_) => {
#[allow(deprecated)]
let display = device
.name()
.unwrap_or_else(|_| "Unknown WASAPI capture device".to_string());
DeviceLabels {
display: display.clone(),
search_text: display,
}
}
}
}
pub(crate) fn find_wasapi_device(requested_device: &str) -> Result<cpal::Device> {
let host = cpal::default_host();
let trimmed = requested_device.trim();
if trimmed.is_empty()
|| trimmed.eq_ignore_ascii_case("auto")
|| trimmed.eq_ignore_ascii_case("default")
{
return host.default_input_device().ok_or_else(|| {
AppError::AudioError("No default WASAPI input device found".to_string())
});
}
if let Ok(id) = DeviceId::from_str(trimmed) {
if let Some(device) = host.device_by_id(&id) {
return Ok(device);
}
}
let needle = trimmed.to_lowercase();
let devices = host
.input_devices()
.map_err(|e| AppError::AudioError(format!("Failed to enumerate WASAPI devices: {}", e)))?;
for device in devices {
let id_match = device
.id()
.map(|id| id.to_string() == trimmed)
.unwrap_or(false);
let labels = device_labels(&device);
if id_match || labels.search_text.to_lowercase().contains(&needle) {
return Ok(device);
}
}
Err(AppError::AudioError(format!(
"WASAPI audio device not found: {}",
requested_device
)))
}
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError(
"No WASAPI audio capture devices found".to_string(),
));
}
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected WASAPI capture device: {}", device.description);
return Ok(device.clone());
}
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected WASAPI capture device: {}", device.description);
return Ok(device.clone());
}
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback WASAPI audio device: {} (will resample if needed)",
device.description
);
Ok(device)
}

View File

@@ -1,15 +1,21 @@
//! ALSA capture, Opus encode, device enumeration, streaming, controller, health monitor.
//! Platform audio capture, Opus encode, device enumeration, streaming, controller, health monitor.
#[cfg(any(unix, windows))]
pub mod capture;
pub mod controller;
#[cfg(any(unix, windows))]
pub mod device;
#[cfg(any(unix, windows))]
pub mod encoder;
pub mod monitor;
pub mod recovery;
pub mod streamer;
pub mod types;
pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
pub use controller::AudioController;
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus};
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
pub use types::{AudioControllerConfig, AudioQuality, AudioStatus};

320
src/audio/recovery.rs Normal file
View File

@@ -0,0 +1,320 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use super::capture::AudioConfig;
use super::device::{enumerate_audio_devices, AudioDeviceInfo};
use super::monitor::AudioHealthMonitor;
use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
use super::types::AudioControllerConfig;
use super::controller::AudioRecoveredCallback;
use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent};
const AUDIO_RECOVERY_RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(1);
pub(super) fn select_recovery_device(
devices: &[AudioDeviceInfo],
preferred: &str,
) -> Option<AudioDeviceInfo> {
if let Some(device) = devices
.iter()
.find(|d| !preferred.trim().is_empty() && d.name == preferred)
{
return Some(device.clone());
}
devices
.iter()
.find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2))
.or_else(|| {
devices
.iter()
.find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2))
})
.or_else(|| devices.first())
.cloned()
}
async fn publish_state(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
state: &str,
device: Option<String>,
reason: Option<&str>,
next_retry_ms: Option<u64>,
) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamStateChanged {
state: state.to_string(),
device,
reason: reason.map(str::to_string),
next_retry_ms,
});
bus.mark_device_info_dirty();
}
}
async fn publish_device_lost(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
reason: &str,
) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: device.to_string(),
reason: reason.to_string(),
});
}
}
async fn publish_reconnecting(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
attempt: u32,
) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamReconnecting {
device: device.to_string(),
attempt,
});
}
}
async fn publish_recovered(event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>, device: &str) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamRecovered {
device: device.to_string(),
});
}
}
fn spawn_stream_monitor_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
let mut state_rx = streamer.state_watch();
tokio::spawn(async move {
loop {
if state_rx.changed().await.is_err() {
return;
}
if *state_rx.borrow() != AudioStreamState::Error {
continue;
}
{
let current = streamer_slot.read().await;
if !current
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &streamer))
{
return;
}
}
let reason = format!("Audio device lost: {}", device);
monitor.report_error(&reason, "device_lost").await;
spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
device,
reason,
);
return;
}
});
}
fn spawn_recovery_task_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
if recovery_in_progress.swap(true, Ordering::SeqCst) {
debug!("Audio recovery already in progress");
return;
}
tokio::spawn(async move {
warn!("Audio recovery started for {}: {}", lost_device, reason);
publish_device_lost(&event_bus, &lost_device, &reason).await;
publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_device_lost"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
let mut attempt = 0u32;
loop {
if !recovery_in_progress.load(Ordering::SeqCst) {
debug!("Audio recovery canceled");
return;
}
if streamer_slot
.read()
.await
.as_ref()
.is_some_and(|s| s.is_running())
{
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
let cfg: AudioControllerConfig = config.read().await.clone();
if !cfg.enabled {
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
attempt = attempt.saturating_add(1);
publish_reconnecting(&event_bus, &lost_device, attempt).await;
publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_reconnecting"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await;
let devices = match enumerate_audio_devices() {
Ok(devices) => devices,
Err(e) => {
debug!(
"Audio recovery enumerate failed (attempt {}): {}",
attempt, e
);
continue;
}
};
let Some(device) = select_recovery_device(&devices, &cfg.device) else {
debug!("No audio devices found during recovery attempt {}", attempt);
continue;
};
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device.name.clone(),
..Default::default()
},
opus: cfg.quality.to_opus_config(),
};
let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config));
match new_streamer.start().await {
Ok(()) => {
{
let mut cfg = config.write().await;
cfg.device = device.name.clone();
}
*streamer_slot.write().await = Some(new_streamer.clone());
monitor.report_recovered().await;
publish_recovered(&event_bus, &device.name).await;
if let Some(callback) = recovered_callback.read().await.clone() {
callback();
}
publish_state(
&event_bus,
"streaming",
Some(device.name.clone()),
None,
None,
)
.await;
recovery_in_progress.store(false, Ordering::SeqCst);
info!(
"Audio device recovered with {} after {} attempts",
device.name, attempt
);
spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
new_streamer,
device.name,
);
return;
}
Err(e) => {
debug!(
"Audio recovery start failed with {} (attempt {}): {}",
device.name, attempt, e
);
}
}
}
});
}
pub(super) fn spawn_stream_monitor(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
streamer,
device,
);
}
pub(super) fn spawn_recovery_task(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
lost_device,
reason,
);
}

85
src/audio/types.rs Normal file
View File

@@ -0,0 +1,85 @@
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use super::encoder::OpusConfig;
use crate::error::AppError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
Voice,
#[default]
Balanced,
High,
}
impl AudioQuality {
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
AudioQuality::Balanced => 64000,
AudioQuality::High => 128000,
}
}
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
AudioQuality::Balanced => OpusConfig::default(),
AudioQuality::High => OpusConfig::music(),
}
}
}
impl FromStr for AudioQuality {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"voice" => Ok(Self::Voice),
"balanced" => Ok(Self::Balanced),
"high" => Ok(Self::High),
_ => Err(AppError::BadRequest(format!(
"invalid audio quality {:?} (expected voice, balanced, or high)",
s.trim()
))),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AudioQuality::Voice => write!(f, "voice"),
AudioQuality::Balanced => write!(f, "balanced"),
AudioQuality::High => write!(f, "high"),
}
}
}
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
pub enabled: bool,
pub device: String,
pub quality: AudioQuality,
}
impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: String::new(),
quality: AudioQuality::Balanced,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
pub enabled: bool,
pub streaming: bool,
pub device: Option<String>,
pub quality: AudioQuality,
pub subscriber_count: usize,
pub error: Option<String>,
}

View File

@@ -1,7 +1,11 @@
mod persistence;
mod schema;
mod store;
pub use persistence::ConfigChange;
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}
pub use schema::*;
pub use store::ConfigStore;

View File

@@ -1,5 +0,0 @@
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}

View File

@@ -1,827 +0,0 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
// Re-export domain config types that are embedded in AppConfig.
// These are simple data types defined in their respective modules;
// keeping the re-export here is acceptable since they flow inward.
pub use crate::extensions::ExtensionsConfig;
pub use crate::rustdesk::config::RustDeskConfig;
/// Bitrate preset for video encoding
///
/// Simplifies bitrate configuration by providing three intuitive presets
/// plus a custom option for advanced users.
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
#[derive(Default)]
pub enum BitratePreset {
/// Speed priority: 1 Mbps, lowest latency, smaller GOP
Speed,
/// Balanced: 4 Mbps, good quality/latency tradeoff
#[default]
Balanced,
/// Quality priority: 8 Mbps, best visual quality
Quality,
/// Custom bitrate in kbps (for advanced users)
Custom(u32),
}
impl BitratePreset {
/// Get bitrate value in kbps
pub fn bitrate_kbps(&self) -> u32 {
match self {
Self::Speed => 1000,
Self::Balanced => 4000,
Self::Quality => 8000,
Self::Custom(kbps) => *kbps,
}
}
/// Get recommended GOP size based on preset
pub fn gop_size(&self, fps: u32) -> u32 {
match self {
Self::Speed => (fps / 2).max(15),
Self::Balanced => fps,
Self::Quality => fps * 2,
Self::Custom(_) => fps,
}
}
/// Get quality preset name for encoder configuration
pub fn quality_level(&self) -> &'static str {
match self {
Self::Speed => "low",
Self::Balanced => "medium",
Self::Quality => "high",
Self::Custom(_) => "medium",
}
}
/// Create from kbps value, mapping to nearest preset or Custom
pub fn from_kbps(kbps: u32) -> Self {
match kbps {
0..=1500 => Self::Speed,
1501..=6000 => Self::Balanced,
6001..=10000 => Self::Quality,
_ => Self::Custom(kbps),
}
}
}
impl std::fmt::Display for BitratePreset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Speed => write!(f, "Speed (1 Mbps)"),
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
Self::Quality => write!(f, "Quality (8 Mbps)"),
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
}
}
}
/// Main application configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AppConfig {
/// Whether initial setup has been completed
pub initialized: bool,
/// Authentication settings
pub auth: AuthConfig,
/// Video capture settings
pub video: VideoConfig,
/// HID (keyboard/mouse) settings
pub hid: HidConfig,
/// Mass Storage Device settings
pub msd: MsdConfig,
/// ATX power control settings
pub atx: AtxConfig,
/// Audio settings
pub audio: AudioConfig,
/// Streaming settings
pub stream: StreamConfig,
/// Web server settings
pub web: WebConfig,
/// Extensions settings (ttyd, gostc, easytier)
pub extensions: ExtensionsConfig,
/// RustDesk remote access settings
pub rustdesk: RustDeskConfig,
/// RTSP streaming settings
pub rtsp: RtspConfig,
/// Redfish API settings
pub redfish: RedfishConfig,
}
/// Authentication configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AuthConfig {
/// Session timeout in seconds
pub session_timeout_secs: u32,
/// Allow multiple concurrent web sessions (single-user mode)
pub single_user_allow_multiple_sessions: bool,
/// Enable 2FA
pub totp_enabled: bool,
/// TOTP secret (encrypted)
pub totp_secret: Option<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
session_timeout_secs: 3600 * 24, // 24 hours
single_user_allow_multiple_sessions: false,
totp_enabled: false,
totp_secret: None,
}
}
}
/// Video capture configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct VideoConfig {
/// Video device path (e.g., /dev/video0)
pub device: Option<String>,
/// Video pixel format (e.g., "MJPEG", "YUYV", "NV12")
pub format: Option<String>,
/// Resolution width
pub width: u32,
/// Resolution height
pub height: u32,
/// Frame rate
pub fps: u32,
/// JPEG quality (1-100)
pub quality: u32,
}
impl Default for VideoConfig {
fn default() -> Self {
Self {
device: None,
format: None, // Auto-detect or use MJPEG as default
width: 1920,
height: 1080,
fps: 30,
quality: 80,
}
}
}
/// HID backend type
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackend {
/// USB OTG HID gadget
Otg,
/// CH9329 serial HID controller
Ch9329,
/// Disabled
#[default]
None,
}
/// OTG USB device descriptor configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OtgDescriptorConfig {
/// USB Vendor ID (e.g., 0x1d6b)
pub vendor_id: u16,
/// USB Product ID (e.g., 0x0104)
pub product_id: u16,
/// Manufacturer string
pub manufacturer: String,
/// Product string
pub product: String,
/// Serial number (optional, auto-generated if not set)
pub serial_number: Option<String>,
}
impl Default for OtgDescriptorConfig {
fn default() -> Self {
Self {
vendor_id: 0x1d6b, // Linux Foundation
product_id: 0x0104, // Multifunction Composite Gadget
manufacturer: "One-KVM".to_string(),
product: "One-KVM USB Device".to_string(),
serial_number: None,
}
}
}
/// OTG HID function profile
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgHidProfile {
/// Full HID device set (keyboard + relative mouse + absolute mouse + consumer control)
#[default]
#[serde(alias = "full_no_msd")]
Full,
/// Full HID device set without consumer control
#[serde(alias = "full_no_consumer_no_msd")]
FullNoConsumer,
/// Legacy profile: only keyboard
LegacyKeyboard,
/// Legacy profile: only relative mouse
LegacyMouseRelative,
/// Custom function selection
Custom,
}
/// OTG endpoint budget policy.
#[typeshare]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgEndpointBudget {
/// Derive a safe default from the selected UDC.
#[default]
Auto,
/// Limit OTG gadget functions to 5 endpoints.
Five,
/// Limit OTG gadget functions to 6 endpoints.
Six,
/// Do not impose a software endpoint budget.
Unlimited,
}
impl OtgEndpointBudget {
/// Resolve endpoint limit assuming a known budget variant (not Auto).
pub fn endpoint_limit_raw(&self) -> Option<u8> {
match self {
Self::Five => Some(5),
Self::Six => Some(6),
Self::Unlimited => None,
Self::Auto => None, // resolved via `HidConfig::resolved_otg_endpoint_limit`
}
}
}
/// OTG HID function selection (used when profile is Custom)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(default)]
pub struct OtgHidFunctions {
pub keyboard: bool,
pub mouse_relative: bool,
pub mouse_absolute: bool,
pub consumer: bool,
}
impl OtgHidFunctions {
pub fn full() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: true,
}
}
pub fn full_no_consumer() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: false,
}
}
pub fn legacy_keyboard() -> Self {
Self {
keyboard: true,
mouse_relative: false,
mouse_absolute: false,
consumer: false,
}
}
pub fn legacy_mouse_relative() -> Self {
Self {
keyboard: false,
mouse_relative: true,
mouse_absolute: false,
consumer: false,
}
}
pub fn is_empty(&self) -> bool {
!self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer
}
pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 {
let mut endpoints = 0;
if self.keyboard {
endpoints += 1;
if keyboard_leds {
endpoints += 1;
}
}
if self.mouse_relative {
endpoints += 1;
}
if self.mouse_absolute {
endpoints += 1;
}
if self.consumer {
endpoints += 1;
}
endpoints
}
}
impl Default for OtgHidFunctions {
fn default() -> Self {
Self::full()
}
}
impl OtgHidProfile {
pub fn from_legacy_str(value: &str) -> Option<Self> {
match value {
"full" | "full_no_msd" => Some(Self::Full),
"full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer),
"legacy_keyboard" => Some(Self::LegacyKeyboard),
"legacy_mouse_relative" => Some(Self::LegacyMouseRelative),
"custom" => Some(Self::Custom),
_ => None,
}
}
pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions {
match self {
Self::Full => OtgHidFunctions::full(),
Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(),
Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(),
Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(),
Self::Custom => custom.clone(),
}
}
}
/// HID configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct HidConfig {
/// HID backend type
pub backend: HidBackend,
/// OTG UDC (USB Device Controller) name
pub otg_udc: Option<String>,
/// OTG USB device descriptor configuration
#[serde(default)]
pub otg_descriptor: OtgDescriptorConfig,
/// OTG HID function profile
#[serde(default)]
pub otg_profile: OtgHidProfile,
/// OTG endpoint budget policy
#[serde(default)]
pub otg_endpoint_budget: OtgEndpointBudget,
/// OTG HID function selection (used when profile is Custom)
#[serde(default)]
pub otg_functions: OtgHidFunctions,
/// Enable keyboard LED/status feedback for OTG keyboard
#[serde(default)]
pub otg_keyboard_leds: bool,
/// CH9329 serial port
pub ch9329_port: String,
/// CH9329 baud rate
pub ch9329_baudrate: u32,
/// Mouse mode: absolute or relative
pub mouse_absolute: bool,
}
impl Default for HidConfig {
fn default() -> Self {
Self {
backend: HidBackend::None,
otg_udc: None,
otg_descriptor: OtgDescriptorConfig::default(),
otg_profile: OtgHidProfile::default(),
otg_endpoint_budget: OtgEndpointBudget::default(),
otg_functions: OtgHidFunctions::default(),
otg_keyboard_leds: false,
ch9329_port: "/dev/ttyUSB0".to_string(),
ch9329_baudrate: 9600,
mouse_absolute: true,
}
}
}
impl HidConfig {
/// Resolve effective OTG HID functions from profile + custom selection.
/// Pure logic, no external dependency.
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
self.otg_profile.resolve_functions(&self.otg_functions)
}
/// Whether keyboard LED feedback is effectively enabled.
pub fn effective_otg_keyboard_leds(&self) -> bool {
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
}
/// Effective HID functions after applying all constraints.
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
self.effective_otg_functions()
}
/// Calculate required endpoint count for the current function selection.
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
let functions = self.effective_otg_functions();
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
if msd_enabled {
endpoints += 2;
}
endpoints
}
/// Validate endpoint budget for the current OTG configuration (UDC-aware when budget is Auto).
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
if self.backend != HidBackend::Otg {
return Ok(());
}
let functions = self.effective_otg_functions();
if functions.is_empty() {
return Err(crate::error::AppError::BadRequest(
"OTG HID functions cannot be empty".to_string(),
));
}
let resolved_limit = self.resolved_otg_endpoint_limit();
let required = self.effective_otg_required_endpoints(msd_enabled);
if let Some(limit) = resolved_limit {
if required > limit {
return Err(crate::error::AppError::BadRequest(format!(
"OTG selection requires {} endpoints, but the configured limit is {}",
required, limit
)));
}
}
Ok(())
}
/// Effective OTG UDC name (for change detection / service).
#[inline]
pub fn resolved_otg_udc(&self) -> Option<String> {
if self.backend != HidBackend::Otg {
return None;
}
self.otg_udc
.as_ref()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.or_else(|| crate::otg::OtgGadgetManager::find_udc())
}
/// Resolved endpoint limit used for OTG gadget allocator / validation.
#[inline]
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
if self.backend != HidBackend::Otg {
return None;
}
match self.otg_endpoint_budget {
OtgEndpointBudget::Five => Some(5),
OtgEndpointBudget::Six => Some(6),
OtgEndpointBudget::Unlimited => None,
OtgEndpointBudget::Auto => {
let udc = self.resolved_otg_udc().unwrap_or_default();
if crate::otg::configfs::is_low_endpoint_udc(&udc) {
Some(5)
} else {
Some(6)
}
}
}
}
}
/// MSD configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct MsdConfig {
/// Enable MSD functionality
pub enabled: bool,
/// MSD base directory (absolute path)
pub msd_dir: String,
}
impl Default for MsdConfig {
fn default() -> Self {
Self {
enabled: true,
msd_dir: String::new(),
}
}
}
impl MsdConfig {
pub fn msd_dir_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.msd_dir)
}
pub fn images_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("images")
}
pub fn ventoy_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("ventoy")
}
pub fn drive_path(&self) -> std::path::PathBuf {
self.ventoy_dir().join("ventoy.img")
}
}
// Re-export ATX types from atx module for configuration
pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig};
/// ATX power control configuration
///
/// Each ATX action (power, reset) can be independently configured with its own
/// hardware binding using the four-tuple: (driver, device, pin, active_level).
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AtxConfig {
/// Enable ATX functionality
pub enabled: bool,
/// Power button configuration (used for both short and long press)
pub power: AtxKeyConfig,
/// Reset button configuration
pub reset: AtxKeyConfig,
/// LED sensing configuration (optional)
pub led: AtxLedConfig,
/// Network interface for WOL packets (empty = auto)
pub wol_interface: String,
}
impl AtxConfig {
/// Convert to AtxControllerConfig for the controller
pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig {
crate::atx::AtxControllerConfig {
enabled: self.enabled,
power: self.power.clone(),
reset: self.reset.clone(),
led: self.led.clone(),
}
}
}
/// Audio configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AudioConfig {
/// Enable audio capture
pub enabled: bool,
/// ALSA device name
pub device: String,
/// Audio quality preset: "voice", "balanced", "high"
pub quality: String,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: String::new(),
quality: "balanced".to_string(),
}
}
}
/// Stream mode
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum StreamMode {
/// WebRTC with H264/H265
WebRTC,
/// MJPEG over HTTP
#[default]
Mjpeg,
}
/// RTSP output codec
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum RtspCodec {
#[default]
H264,
H265,
}
/// RTSP configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RtspConfig {
/// Enable RTSP output
pub enabled: bool,
/// Bind IP address
pub bind: String,
/// RTSP TCP listen port
pub port: u16,
/// Stream path (without leading slash)
pub path: String,
/// Allow only one client connection at a time
pub allow_one_client: bool,
/// Output codec (H264/H265)
pub codec: RtspCodec,
/// Optional username for authentication
pub username: Option<String>,
/// Optional password for authentication
#[typeshare(skip)]
pub password: Option<String>,
}
impl Default for RtspConfig {
fn default() -> Self {
Self {
enabled: false,
bind: "0.0.0.0".to_string(),
port: 8554,
path: "live".to_string(),
allow_one_client: true,
codec: RtspCodec::H264,
username: None,
password: None,
}
}
}
/// Encoder type
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum EncoderType {
/// Auto-detect best encoder
#[default]
Auto,
/// Software encoder (libx264)
Software,
/// VAAPI hardware encoder
Vaapi,
/// NVIDIA NVENC hardware encoder
Nvenc,
/// Intel Quick Sync hardware encoder
Qsv,
/// AMD AMF hardware encoder
Amf,
/// Rockchip MPP hardware encoder
Rkmpp,
/// V4L2 M2M hardware encoder
V4l2m2m,
}
impl EncoderType {
/// Get display name for UI
pub fn display_name(&self) -> &'static str {
match self {
EncoderType::Auto => "Auto (Recommended)",
EncoderType::Software => "Software (CPU)",
EncoderType::Vaapi => "VAAPI",
EncoderType::Nvenc => "NVIDIA NVENC",
EncoderType::Qsv => "Intel Quick Sync",
EncoderType::Amf => "AMD AMF",
EncoderType::Rkmpp => "Rockchip MPP",
EncoderType::V4l2m2m => "V4L2 M2M",
}
}
}
/// Streaming configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
/// Stream mode
pub mode: StreamMode,
/// Encoder type for H264/H265
pub encoder: EncoderType,
/// Bitrate preset (Speed/Balanced/Quality)
pub bitrate_preset: BitratePreset,
/// Custom STUN server (e.g., "stun:stun.l.google.com:19302")
/// If empty, uses public ICE servers from secrets.toml
pub stun_server: Option<String>,
/// Custom TURN server (e.g., "turn:turn.example.com:3478")
/// If empty, uses public ICE servers from secrets.toml
pub turn_server: Option<String>,
/// TURN username
pub turn_username: Option<String>,
/// TURN password (stored encrypted in DB, not exposed via API)
pub turn_password: Option<String>,
/// Auto-pause when no clients connected
#[typeshare(skip)]
pub auto_pause_enabled: bool,
/// Auto-pause delay (seconds)
#[typeshare(skip)]
pub auto_pause_delay_secs: u64,
/// Client timeout for cleanup (seconds)
#[typeshare(skip)]
pub client_timeout_secs: u64,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
mode: StreamMode::Mjpeg,
encoder: EncoderType::Auto,
bitrate_preset: BitratePreset::Balanced,
// Empty means use public ICE servers (like RustDesk)
stun_server: None,
turn_server: None,
turn_username: None,
turn_password: None,
auto_pause_enabled: false,
auto_pause_delay_secs: 10,
client_timeout_secs: 30,
}
}
}
impl StreamConfig {
/// Whether built-in / public ICE is used (no custom STUN or TURN URL configured).
pub fn is_using_public_ice_servers(&self) -> bool {
let no_custom_stun = self
.stun_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
let no_custom_turn = self
.turn_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
no_custom_stun && no_custom_turn
}
}
/// Web server configuration persisted in the database (includes on-disk TLS paths).
///
/// The HTTP API for `/api/config/web` uses `WebConfigResponse` instead: no path fields, includes `has_custom_cert`.
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WebConfig {
/// HTTP port
pub http_port: u16,
/// HTTPS port
pub https_port: u16,
/// Bind addresses (preferred)
pub bind_addresses: Vec<String>,
/// Bind address (legacy)
pub bind_address: String,
/// Enable HTTPS
pub https_enabled: bool,
/// Custom SSL certificate path
pub ssl_cert_path: Option<String>,
/// Custom SSL key path
pub ssl_key_path: Option<String>,
}
impl Default for WebConfig {
fn default() -> Self {
Self {
http_port: 8080,
https_port: 8443,
bind_addresses: Vec::new(),
bind_address: "0.0.0.0".to_string(),
https_enabled: false,
ssl_cert_path: None,
ssl_key_path: None,
}
}
}
/// Redfish API configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RedfishConfig {
/// Enable Redfish API endpoint
pub enabled: bool,
}
impl Default for RedfishConfig {
fn default() -> Self {
Self { enabled: false }
}
}

28
src/config/schema/atx.rs Normal file
View File

@@ -0,0 +1,28 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig};
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AtxConfig {
pub enabled: bool,
pub power: AtxKeyConfig,
pub reset: AtxKeyConfig,
pub led: AtxLedConfig,
pub wol_interface: String,
}
impl AtxConfig {
pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig {
crate::atx::AtxControllerConfig {
enabled: self.enabled,
power: self.power.clone(),
reset: self.reset.clone(),
led: self.led.clone(),
}
}
}

View File

@@ -0,0 +1,64 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
#[derive(Default)]
pub enum BitratePreset {
Speed,
#[default]
Balanced,
Quality,
Custom(u32),
}
impl BitratePreset {
pub fn bitrate_kbps(&self) -> u32 {
match self {
Self::Speed => 1000,
Self::Balanced => 4000,
Self::Quality => 8000,
Self::Custom(kbps) => *kbps,
}
}
pub fn gop_size(&self, fps: u32) -> u32 {
match self {
Self::Speed => (fps / 2).max(15),
Self::Balanced => fps,
Self::Quality => fps * 2,
Self::Custom(_) => fps,
}
}
pub fn quality_level(&self) -> &'static str {
match self {
Self::Speed => "low",
Self::Balanced => "medium",
Self::Quality => "high",
Self::Custom(_) => "medium",
}
}
pub fn from_kbps(kbps: u32) -> Self {
match kbps {
0..=1500 => Self::Speed,
1501..=6000 => Self::Balanced,
6001..=10000 => Self::Quality,
_ => Self::Custom(kbps),
}
}
}
impl std::fmt::Display for BitratePreset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Speed => write!(f, "Speed (1 Mbps)"),
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
Self::Quality => write!(f, "Quality (8 Mbps)"),
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
}
}
}

309
src/config/schema/hid.rs Normal file
View File

@@ -0,0 +1,309 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackend {
Otg,
Ch9329,
#[default]
None,
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OtgDescriptorConfig {
pub vendor_id: u16,
pub product_id: u16,
pub manufacturer: String,
pub product: String,
pub serial_number: Option<String>,
}
impl Default for OtgDescriptorConfig {
fn default() -> Self {
Self {
vendor_id: 0x1d6b,
product_id: 0x0104,
manufacturer: "One-KVM".to_string(),
product: "One-KVM USB Device".to_string(),
serial_number: None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgHidProfile {
#[default]
#[serde(alias = "full_no_msd")]
Full,
#[serde(alias = "full_no_consumer_no_msd")]
FullNoConsumer,
LegacyKeyboard,
LegacyMouseRelative,
Custom,
}
#[typeshare]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgEndpointBudget {
#[default]
Auto,
Five,
Six,
Unlimited,
}
impl OtgEndpointBudget {
pub fn endpoint_limit_raw(&self) -> Option<u8> {
match self {
Self::Five => Some(5),
Self::Six => Some(6),
Self::Unlimited => None,
Self::Auto => None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(default)]
pub struct OtgHidFunctions {
pub keyboard: bool,
pub mouse_relative: bool,
pub mouse_absolute: bool,
pub consumer: bool,
}
impl OtgHidFunctions {
pub fn full() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: true,
}
}
pub fn full_no_consumer() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: false,
}
}
pub fn legacy_keyboard() -> Self {
Self {
keyboard: true,
mouse_relative: false,
mouse_absolute: false,
consumer: false,
}
}
pub fn legacy_mouse_relative() -> Self {
Self {
keyboard: false,
mouse_relative: true,
mouse_absolute: false,
consumer: false,
}
}
pub fn is_empty(&self) -> bool {
!self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer
}
pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 {
let mut endpoints = 0;
if self.keyboard {
endpoints += 1;
if keyboard_leds {
endpoints += 1;
}
}
if self.mouse_relative {
endpoints += 1;
}
if self.mouse_absolute {
endpoints += 1;
}
if self.consumer {
endpoints += 1;
}
endpoints
}
}
impl Default for OtgHidFunctions {
fn default() -> Self {
Self::full()
}
}
impl OtgHidProfile {
pub fn from_legacy_str(value: &str) -> Option<Self> {
match value {
"full" | "full_no_msd" => Some(Self::Full),
"full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer),
"legacy_keyboard" => Some(Self::LegacyKeyboard),
"legacy_mouse_relative" => Some(Self::LegacyMouseRelative),
"custom" => Some(Self::Custom),
_ => None,
}
}
pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions {
match self {
Self::Full => OtgHidFunctions::full(),
Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(),
Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(),
Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(),
Self::Custom => custom.clone(),
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct HidConfig {
pub backend: HidBackend,
pub otg_udc: Option<String>,
#[serde(default)]
pub otg_descriptor: OtgDescriptorConfig,
#[serde(default)]
pub otg_profile: OtgHidProfile,
#[serde(default)]
pub otg_endpoint_budget: OtgEndpointBudget,
#[serde(default)]
pub otg_functions: OtgHidFunctions,
#[serde(default)]
pub otg_keyboard_leds: bool,
pub ch9329_port: String,
pub ch9329_baudrate: u32,
pub mouse_absolute: bool,
}
impl Default for HidConfig {
fn default() -> Self {
Self {
backend: HidBackend::None,
otg_udc: None,
otg_descriptor: OtgDescriptorConfig::default(),
otg_profile: OtgHidProfile::default(),
otg_endpoint_budget: OtgEndpointBudget::default(),
otg_functions: OtgHidFunctions::default(),
otg_keyboard_leds: false,
ch9329_port: "/dev/ttyUSB0".to_string(),
ch9329_baudrate: 9600,
mouse_absolute: true,
}
}
}
impl HidConfig {
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
self.otg_profile.resolve_functions(&self.otg_functions)
}
pub fn effective_otg_keyboard_leds(&self) -> bool {
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
}
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
self.effective_otg_functions()
}
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
let functions = self.effective_otg_functions();
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
if msd_enabled {
endpoints += 2;
}
endpoints
}
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
if self.backend != HidBackend::Otg {
return Ok(());
}
let functions = self.effective_otg_functions();
if functions.is_empty() {
return Err(crate::error::AppError::BadRequest(
"OTG HID functions cannot be empty".to_string(),
));
}
let resolved_limit = self.resolved_otg_endpoint_limit();
let required = self.effective_otg_required_endpoints(msd_enabled);
if let Some(limit) = resolved_limit {
if required > limit {
return Err(crate::error::AppError::BadRequest(format!(
"OTG selection requires {} endpoints, but the configured limit is {}",
required, limit
)));
}
}
Ok(())
}
#[inline]
pub fn resolved_otg_udc(&self) -> Option<String> {
if self.backend != HidBackend::Otg {
return None;
}
self.otg_udc
.as_ref()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.or_else(|| {
#[cfg(unix)]
{
crate::otg::OtgGadgetManager::find_udc()
}
#[cfg(not(unix))]
{
None
}
})
}
#[inline]
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
if self.backend != HidBackend::Otg {
return None;
}
match self.otg_endpoint_budget {
OtgEndpointBudget::Five => Some(5),
OtgEndpointBudget::Six => Some(6),
OtgEndpointBudget::Unlimited => None,
OtgEndpointBudget::Auto => {
#[cfg(unix)]
let udc = self.resolved_otg_udc().unwrap_or_default();
#[cfg(unix)]
if crate::otg::configfs::is_low_endpoint_udc(&udc) {
Some(5)
} else {
Some(6)
}
#[cfg(not(unix))]
{
Some(6)
}
}
}
}
}

44
src/config/schema/mod.rs Normal file
View File

@@ -0,0 +1,44 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
pub use crate::extensions::ExtensionsConfig;
pub use crate::rustdesk::config::RustDeskConfig;
mod atx;
mod common;
mod hid;
mod stream;
mod web;
pub use atx::*;
pub use common::*;
pub use hid::*;
pub use stream::*;
pub use web::*;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AppConfig {
pub initialized: bool,
pub auth: AuthConfig,
pub video: VideoConfig,
pub hid: HidConfig,
pub msd: MsdConfig,
pub atx: AtxConfig,
pub audio: AudioConfig,
pub stream: StreamConfig,
pub web: WebConfig,
pub extensions: ExtensionsConfig,
pub rustdesk: RustDeskConfig,
pub rtsp: RtspConfig,
pub redfish: RedfishConfig,
}
impl AppConfig {
pub fn apply_platform_defaults(&mut self) {
crate::platform::defaults::apply(self);
}
}

149
src/config/schema/stream.rs Normal file
View File

@@ -0,0 +1,149 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
use super::BitratePreset;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum StreamMode {
WebRTC,
#[default]
Mjpeg,
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum RtspCodec {
#[default]
H264,
H265,
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RtspConfig {
pub enabled: bool,
pub bind: String,
pub port: u16,
pub path: String,
pub allow_one_client: bool,
pub codec: RtspCodec,
pub username: Option<String>,
#[typeshare(skip)]
pub password: Option<String>,
}
impl Default for RtspConfig {
fn default() -> Self {
Self {
enabled: false,
bind: "0.0.0.0".to_string(),
port: 8554,
path: "live".to_string(),
allow_one_client: true,
codec: RtspCodec::H264,
username: None,
password: None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum EncoderType {
#[default]
Auto,
Software,
Vaapi,
Nvenc,
Qsv,
Amf,
Rkmpp,
V4l2m2m,
}
impl EncoderType {
pub fn display_name(&self) -> &'static str {
match self {
EncoderType::Auto => "Auto (Recommended)",
EncoderType::Software => "Software (CPU)",
EncoderType::Vaapi => "VAAPI",
EncoderType::Nvenc => "NVIDIA NVENC",
EncoderType::Qsv => "Intel Quick Sync",
EncoderType::Amf => "AMD AMF",
EncoderType::Rkmpp => "Rockchip MPP",
EncoderType::V4l2m2m => "V4L2 M2M",
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
pub mode: StreamMode,
pub encoder: EncoderType,
pub bitrate_preset: BitratePreset,
pub stun_server: Option<String>,
pub turn_server: Option<String>,
pub turn_username: Option<String>,
pub turn_password: Option<String>,
#[typeshare(skip)]
pub auto_pause_enabled: bool,
#[typeshare(skip)]
pub auto_pause_delay_secs: u64,
#[typeshare(skip)]
pub client_timeout_secs: u64,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
mode: StreamMode::Mjpeg,
encoder: EncoderType::Auto,
bitrate_preset: BitratePreset::Balanced,
stun_server: None,
turn_server: None,
turn_username: None,
turn_password: None,
auto_pause_enabled: false,
auto_pause_delay_secs: 10,
client_timeout_secs: 30,
}
}
}
impl StreamConfig {
pub fn is_using_public_ice_servers(&self) -> bool {
let no_custom_stun = self
.stun_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
let no_custom_turn = self
.turn_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
no_custom_stun && no_custom_turn
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RedfishConfig {
pub enabled: bool,
}
impl Default for RedfishConfig {
fn default() -> Self {
Self { enabled: false }
}
}

129
src/config/schema/web.rs Normal file
View File

@@ -0,0 +1,129 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AuthConfig {
pub session_timeout_secs: u32,
pub single_user_allow_multiple_sessions: bool,
pub totp_enabled: bool,
pub totp_secret: Option<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
session_timeout_secs: 3600 * 24,
single_user_allow_multiple_sessions: false,
totp_enabled: false,
totp_secret: None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct VideoConfig {
pub device: Option<String>,
pub format: Option<String>,
pub width: u32,
pub height: u32,
pub fps: u32,
pub quality: u32,
}
impl Default for VideoConfig {
fn default() -> Self {
Self {
device: None,
format: None,
width: 1920,
height: 1080,
fps: 30,
quality: 80,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct MsdConfig {
pub enabled: bool,
pub msd_dir: String,
}
impl Default for MsdConfig {
fn default() -> Self {
Self {
enabled: true,
msd_dir: String::new(),
}
}
}
impl MsdConfig {
pub fn msd_dir_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.msd_dir)
}
pub fn images_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("images")
}
pub fn ventoy_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("ventoy")
}
pub fn drive_path(&self) -> std::path::PathBuf {
self.ventoy_dir().join("ventoy.img")
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AudioConfig {
pub enabled: bool,
pub device: String,
pub quality: String,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: String::new(),
quality: "balanced".to_string(),
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WebConfig {
pub http_port: u16,
pub https_port: u16,
pub bind_addresses: Vec<String>,
pub bind_address: String,
pub https_enabled: bool,
pub ssl_cert_path: Option<String>,
pub ssl_key_path: Option<String>,
}
impl Default for WebConfig {
fn default() -> Self {
Self {
http_port: 8080,
https_port: 8443,
bind_addresses: Vec::new(),
bind_address: "0.0.0.0".to_string(),
https_enabled: false,
ssl_cert_path: None,
ssl_key_path: None,
}
}
}

View File

@@ -4,29 +4,20 @@ use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use super::persistence::ConfigChange;
use super::AppConfig;
use super::ConfigChange;
use crate::error::{AppError, Result};
/// Configuration store backed by SQLite
///
/// Uses `ArcSwap` for lock-free reads, providing high performance
/// for frequent configuration access in hot paths.
#[derive(Clone)]
pub struct ConfigStore {
pool: Pool<Sqlite>,
/// Lock-free cache using ArcSwap for zero-cost reads
cache: Arc<ArcSwap<AppConfig>>,
change_tx: broadcast::Sender<ConfigChange>,
/// Serializes `set` / `update` so concurrent PATCH handlers cannot clobber each other
write_lock: Arc<Mutex<()>>,
}
impl ConfigStore {
/// Create a new configuration store
pub fn new(pool: Pool<Sqlite>) -> Result<Self> {
// Load or create default config synchronously wrapper
// (actual DB load is async, handled in init())
Ok(Self {
pool,
cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())),
@@ -35,14 +26,12 @@ impl ConfigStore {
})
}
/// Load configuration from database (call after new())
pub async fn load(&self) -> Result<()> {
let config = Self::load_config(&self.pool).await?;
self.cache.store(Arc::new(config));
Ok(())
}
/// Load configuration from database
async fn load_config(pool: &Pool<Sqlite>) -> Result<AppConfig> {
let row: Option<(String,)> =
sqlx::query_as("SELECT value FROM config WHERE key = 'app_config'")
@@ -54,7 +43,6 @@ impl ConfigStore {
serde_json::from_str(&json).map_err(|e| AppError::Config(e.to_string()))
}
None => {
// Create default config
let config = AppConfig::default();
Self::save_config_to_db(pool, &config).await?;
Ok(config)
@@ -62,7 +50,6 @@ impl ConfigStore {
}
}
/// Save configuration to database
async fn save_config_to_db(pool: &Pool<Sqlite>, config: &AppConfig) -> Result<()> {
let json = serde_json::to_string(config)?;
@@ -80,21 +67,15 @@ impl ConfigStore {
Ok(())
}
/// Get current configuration (lock-free, zero-copy)
///
/// Returns an `Arc<AppConfig>` for efficient sharing without cloning.
/// This is a lock-free operation with minimal overhead.
pub fn get(&self) -> Arc<AppConfig> {
self.cache.load_full()
}
/// Set entire configuration
pub async fn set(&self, config: AppConfig) -> Result<()> {
let _guard = self.write_lock.lock().await;
Self::save_config_to_db(&self.pool, &config).await?;
self.cache.store(Arc::new(config));
// Notify subscribers
let _ = self.change_tx.send(ConfigChange {
key: "app_config".to_string(),
});
@@ -102,27 +83,19 @@ impl ConfigStore {
Ok(())
}
/// Update configuration with a closure
///
/// Uses read-modify-write under a mutex so concurrent `update` / `set` calls are serialized
/// and merged correctly (each closure sees the latest stored config).
pub async fn update<F>(&self, f: F) -> Result<()>
where
F: FnOnce(&mut AppConfig),
{
let _guard = self.write_lock.lock().await;
// Load current config, clone it for modification
let current = self.cache.load();
let mut config = (**current).clone();
f(&mut config);
// Persist to database first
Self::save_config_to_db(&self.pool, &config).await?;
// Then update cache atomically
self.cache.store(Arc::new(config));
// Notify subscribers
let _ = self.change_tx.send(ConfigChange {
key: "app_config".to_string(),
});
@@ -130,12 +103,10 @@ impl ConfigStore {
Ok(())
}
/// Subscribe to configuration changes
pub fn subscribe(&self) -> broadcast::Receiver<ConfigChange> {
self.change_tx.subscribe()
}
/// Check if system is initialized (lock-free)
pub fn is_initialized(&self) -> bool {
self.cache.load().initialized
}
@@ -158,11 +129,9 @@ mod tests {
let store = ConfigStore::new(db.clone_pool()).unwrap();
store.load().await.unwrap();
// Check default config (now lock-free, no await needed)
let config = store.get();
assert!(!config.initialized);
// Update config
store
.update(|c| {
c.initialized = true;
@@ -171,12 +140,10 @@ mod tests {
.await
.unwrap();
// Verify update
let config = store.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);
// Create new store instance and verify persistence
let store2 = ConfigStore::new(db.clone_pool()).unwrap();
store2.load().await.unwrap();
let config = store2.get();

280
src/diagnostics/linux.rs Normal file
View File

@@ -0,0 +1,280 @@
use super::{DeviceInfo, DiskSpaceInfo, NetworkAddress};
use crate::error::{AppError, Result};
use crate::utils::hostname_uname;
pub fn get_disk_space(path: &std::path::Path) -> Result<DiskSpaceInfo> {
let stat = nix::sys::statvfs::statvfs(path)
.map_err(|e| AppError::Internal(format!("Failed to get disk space: {}", e)))?;
let block_size = stat.block_size() as u64;
let total = stat.blocks() as u64 * block_size;
let available = stat.blocks_available() as u64 * block_size;
let used = total - available;
Ok(DiskSpaceInfo {
total,
available,
used,
})
}
pub fn get_device_info() -> DeviceInfo {
let mem_info = get_meminfo();
DeviceInfo {
hostname: hostname_uname(),
cpu_model: get_cpu_model(),
cpu_usage: get_cpu_usage(),
memory_total: mem_info.total,
memory_used: mem_info.total.saturating_sub(mem_info.available),
network_addresses: get_network_addresses(),
serial_ports: crate::utils::list_serial_ports(),
}
}
fn get_cpu_model() -> String {
let cpuinfo = std::fs::read_to_string("/proc/cpuinfo").ok();
if let Some(model) = parse_cpu_model_from_cpuinfo_content(cpuinfo.as_deref()) {
return model;
}
if let Some(model) = read_device_tree_model() {
return model;
}
if let Some(content) = cpuinfo.as_deref() {
let cores = content
.lines()
.filter(|line| line.starts_with("processor"))
.count();
if cores > 0 {
return format!("{} {}C", std::env::consts::ARCH, cores);
}
}
std::env::consts::ARCH.to_string()
}
fn parse_cpu_model_from_cpuinfo_content(content: Option<&str>) -> Option<String> {
let content = content?;
content
.lines()
.find(|line| line.starts_with("model name") || line.starts_with("Model"))
.and_then(|line| line.split(':').nth(1))
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn read_device_tree_model() -> Option<String> {
std::fs::read("/proc/device-tree/model")
.ok()
.and_then(|bytes| parse_device_tree_model_bytes(bytes.as_slice()))
}
fn parse_device_tree_model_bytes(bytes: &[u8]) -> Option<String> {
let model = String::from_utf8_lossy(bytes)
.trim_matches(|c: char| c == '\0' || c.is_whitespace())
.to_string();
if model.is_empty() {
None
} else {
Some(model)
}
}
static CPU_PREV_STATS: std::sync::OnceLock<std::sync::Mutex<(u64, u64)>> =
std::sync::OnceLock::new();
fn get_cpu_usage() -> f32 {
let content = match std::fs::read_to_string("/proc/stat") {
Ok(c) => c,
Err(_) => return 0.0,
};
let cpu_line = match content.lines().next() {
Some(line) if line.starts_with("cpu ") => line,
_ => return 0.0,
};
let parts: Vec<u64> = cpu_line
.split_whitespace()
.skip(1)
.take(8)
.filter_map(|s| s.parse().ok())
.collect();
if parts.len() < 4 {
return 0.0;
}
let idle = parts[3] + parts.get(4).unwrap_or(&0);
let total: u64 = parts.iter().sum();
let prev_mutex = CPU_PREV_STATS.get_or_init(|| std::sync::Mutex::new((0, 0)));
let mut prev = prev_mutex.lock().unwrap();
let (prev_idle, prev_total) = *prev;
let idle_delta = idle.saturating_sub(prev_idle);
let total_delta = total.saturating_sub(prev_total);
*prev = (idle, total);
if total_delta == 0 {
return 0.0;
}
let usage = 100.0 * (1.0 - (idle_delta as f64 / total_delta as f64));
usage as f32
}
struct MemInfo {
total: u64,
available: u64,
}
fn get_meminfo() -> MemInfo {
let content = match std::fs::read_to_string("/proc/meminfo") {
Ok(c) => c,
Err(_) => {
return MemInfo {
total: 0,
available: 0,
}
}
};
let mut total = 0u64;
let mut available = 0u64;
for line in content.lines() {
if line.starts_with("MemTotal:") {
if let Some(kb) = line
.split_whitespace()
.nth(1)
.and_then(|v| v.parse::<u64>().ok())
{
total = kb * 1024;
}
} else if line.starts_with("MemAvailable:") {
if let Some(kb) = line
.split_whitespace()
.nth(1)
.and_then(|v| v.parse::<u64>().ok())
{
available = kb * 1024;
}
}
if total > 0 && available > 0 {
break;
}
}
MemInfo { total, available }
}
fn get_network_addresses() -> Vec<NetworkAddress> {
let all_addrs = match nix::ifaddrs::getifaddrs() {
Ok(addrs) => addrs,
Err(_) => return Vec::new(),
};
let mut up_ifaces = std::collections::HashSet::new();
let net_dir = match std::fs::read_dir("/sys/class/net") {
Ok(dir) => dir,
Err(_) => return Vec::new(),
};
for entry in net_dir.flatten() {
let iface_name = match entry.file_name().into_string() {
Ok(name) => name,
Err(_) => continue,
};
if iface_name == "lo" {
continue;
}
let operstate_path = entry.path().join("operstate");
let is_up = std::fs::read_to_string(&operstate_path)
.map(|s| s.trim() == "up")
.unwrap_or(false);
if is_up {
up_ifaces.insert(iface_name);
}
}
let mut addresses = Vec::new();
let mut seen = std::collections::HashSet::new();
for ifaddr in all_addrs {
let iface_name = &ifaddr.interface_name;
if iface_name == "lo" || !up_ifaces.contains(iface_name) {
continue;
}
if let Some(addr) = ifaddr.address {
if let Some(sockaddr_in) = addr.as_sockaddr_in() {
let ip = sockaddr_in.ip();
if ip.is_loopback() {
continue;
}
let ip_str = ip.to_string();
if seen.insert((iface_name.clone(), ip_str.clone())) {
addresses.push(NetworkAddress {
interface: iface_name.clone(),
ip: ip_str,
});
}
} else if let Some(sockaddr_in6) = addr.as_sockaddr_in6() {
let ip = sockaddr_in6.ip();
if ip.is_loopback() || ip.is_unspecified() || ip.is_unicast_link_local() {
continue;
}
let ip_str = ip.to_string();
if seen.insert((iface_name.clone(), ip_str.clone())) {
addresses.push(NetworkAddress {
interface: iface_name.clone(),
ip: ip_str,
});
}
}
}
}
addresses
}
#[cfg(test)]
mod tests {
use super::{parse_cpu_model_from_cpuinfo_content, parse_device_tree_model_bytes};
#[test]
fn parse_cpu_model_from_model_name_field() {
let input = "processor\t: 0\nmodel name\t: Intel(R) Xeon(R)\n";
assert_eq!(
parse_cpu_model_from_cpuinfo_content(input),
Some("Intel(R) Xeon(R)".to_string())
);
}
#[test]
fn parse_cpu_model_from_model_field() {
let input = "processor\t: 0\nModel\t\t: Raspberry Pi 4 Model B Rev 1.4\n";
assert_eq!(
parse_cpu_model_from_cpuinfo_content(input),
Some("Raspberry Pi 4 Model B Rev 1.4".to_string())
);
}
#[test]
fn parse_device_tree_model_trimmed() {
let input = b"Onething OEC Box\0\n";
assert_eq!(
parse_device_tree_model_bytes(input),
Some("Onething OEC Box".to_string())
);
}
}

47
src/diagnostics/mod.rs Normal file
View File

@@ -0,0 +1,47 @@
//! Host diagnostics used by the web status API.
use serde::Serialize;
use crate::error::Result;
#[derive(Debug, Clone, Serialize)]
pub struct DeviceInfo {
pub hostname: String,
pub cpu_model: String,
pub cpu_usage: f32,
pub memory_total: u64,
pub memory_used: u64,
pub network_addresses: Vec<NetworkAddress>,
pub serial_ports: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct NetworkAddress {
pub interface: String,
pub ip: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct DiskSpaceInfo {
pub total: u64,
pub available: u64,
pub used: u64,
}
#[cfg(unix)]
mod linux;
#[cfg(windows)]
mod windows;
#[cfg(unix)]
use linux as platform;
#[cfg(windows)]
use windows as platform;
pub fn get_disk_space(path: &std::path::Path) -> Result<DiskSpaceInfo> {
platform::get_disk_space(path)
}
pub fn get_device_info() -> DeviceInfo {
platform::get_device_info()
}

249
src/diagnostics/windows.rs Normal file
View File

@@ -0,0 +1,249 @@
use super::{DeviceInfo, DiskSpaceInfo, NetworkAddress};
use crate::error::{AppError, Result};
use crate::utils::hostname_uname;
use std::ffi::CStr;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::{Mutex, OnceLock};
use windows_sys::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, ERROR_SUCCESS, FILETIME};
use windows_sys::Win32::NetworkManagement::IpHelper::{
GetAdaptersAddresses, GAA_FLAG_SKIP_ANYCAST, GAA_FLAG_SKIP_DNS_SERVER, GAA_FLAG_SKIP_MULTICAST,
IP_ADAPTER_ADDRESSES_LH,
};
use windows_sys::Win32::NetworkManagement::Ndis::IfOperStatusUp;
use windows_sys::Win32::Networking::WinSock::{
AF_INET, AF_INET6, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6,
};
use windows_sys::Win32::System::SystemInformation::{
GetNativeSystemInfo, GlobalMemoryStatusEx, MEMORYSTATUSEX, PROCESSOR_ARCHITECTURE_AMD64,
PROCESSOR_ARCHITECTURE_ARM64, PROCESSOR_ARCHITECTURE_INTEL, SYSTEM_INFO,
};
use windows_sys::Win32::System::Threading::GetSystemTimes;
pub fn get_disk_space(_path: &std::path::Path) -> Result<DiskSpaceInfo> {
Err(AppError::Internal(
"Disk space reporting is unavailable on Windows".to_string(),
))
}
pub fn get_device_info() -> DeviceInfo {
let (memory_total, memory_used) = get_memory_usage();
DeviceInfo {
hostname: hostname_uname(),
cpu_model: get_cpu_model(),
cpu_usage: get_cpu_usage(),
memory_total,
memory_used,
network_addresses: get_network_addresses(),
serial_ports: crate::utils::list_serial_ports(),
}
}
fn get_cpu_model() -> String {
std::env::var("PROCESSOR_IDENTIFIER")
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(get_cpu_arch_label)
}
fn get_cpu_arch_label() -> String {
let mut info = std::mem::MaybeUninit::<SYSTEM_INFO>::zeroed();
unsafe {
GetNativeSystemInfo(info.as_mut_ptr());
let info = info.assume_init();
match info.Anonymous.Anonymous.wProcessorArchitecture {
PROCESSOR_ARCHITECTURE_AMD64 => "x86_64".to_string(),
PROCESSOR_ARCHITECTURE_ARM64 => "aarch64".to_string(),
PROCESSOR_ARCHITECTURE_INTEL => "x86".to_string(),
_ => std::env::consts::ARCH.to_string(),
}
}
}
fn get_memory_usage() -> (u64, u64) {
let mut status = MEMORYSTATUSEX {
dwLength: std::mem::size_of::<MEMORYSTATUSEX>() as u32,
..unsafe { std::mem::zeroed() }
};
let ok = unsafe { GlobalMemoryStatusEx(&mut status) };
if ok == 0 {
return (0, 0);
}
(
status.ullTotalPhys,
status.ullTotalPhys.saturating_sub(status.ullAvailPhys),
)
}
fn get_cpu_usage() -> f32 {
static LAST_SAMPLE: OnceLock<Mutex<Option<CpuTimes>>> = OnceLock::new();
let Some(current) = read_cpu_times() else {
return 0.0;
};
let sample = LAST_SAMPLE.get_or_init(|| Mutex::new(None));
let Ok(mut last) = sample.lock() else {
return 0.0;
};
let (previous, current) = if let Some(previous) = last.replace(current) {
(previous, current)
} else {
drop(last);
std::thread::sleep(std::time::Duration::from_millis(100));
let Some(next) = read_cpu_times() else {
return 0.0;
};
if let Ok(mut last) = sample.lock() {
*last = Some(next);
}
(current, next)
};
let idle = current.idle.saturating_sub(previous.idle);
let kernel = current.kernel.saturating_sub(previous.kernel);
let user = current.user.saturating_sub(previous.user);
let total = kernel.saturating_add(user);
if total == 0 {
return 0.0;
}
((total.saturating_sub(idle)) as f64 * 100.0 / total as f64).clamp(0.0, 100.0) as f32
}
#[derive(Clone, Copy)]
struct CpuTimes {
idle: u64,
kernel: u64,
user: u64,
}
fn read_cpu_times() -> Option<CpuTimes> {
let mut idle = FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let mut kernel = idle;
let mut user = idle;
let ok = unsafe { GetSystemTimes(&mut idle, &mut kernel, &mut user) };
if ok == 0 {
return None;
}
Some(CpuTimes {
idle: filetime_to_u64(idle),
kernel: filetime_to_u64(kernel),
user: filetime_to_u64(user),
})
}
fn filetime_to_u64(time: FILETIME) -> u64 {
((time.dwHighDateTime as u64) << 32) | time.dwLowDateTime as u64
}
fn get_network_addresses() -> Vec<NetworkAddress> {
let mut buffer_len = 15_000u32;
let flags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER;
for _ in 0..2 {
let mut buffer = vec![0u8; buffer_len as usize];
let ret = unsafe {
GetAdaptersAddresses(
0,
flags,
std::ptr::null_mut(),
buffer.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH,
&mut buffer_len,
)
};
if ret == ERROR_BUFFER_OVERFLOW {
continue;
}
if ret != ERROR_SUCCESS {
return Vec::new();
}
let mut addresses = Vec::new();
let mut adapter = buffer.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH;
while !adapter.is_null() {
let adapter_ref = unsafe { &*adapter };
if adapter_ref.OperStatus != IfOperStatusUp {
adapter = adapter_ref.Next;
continue;
}
let interface = adapter_name(adapter_ref);
let mut unicast = adapter_ref.FirstUnicastAddress;
while !unicast.is_null() {
let unicast_ref = unsafe { &*unicast };
if let Some(ip) = sockaddr_to_ip(unicast_ref.Address.lpSockaddr) {
addresses.push(NetworkAddress {
interface: interface.clone(),
ip,
});
}
unicast = unicast_ref.Next;
}
adapter = adapter_ref.Next;
}
addresses.sort_by(|a, b| a.interface.cmp(&b.interface).then(a.ip.cmp(&b.ip)));
addresses.dedup_by(|a, b| a.interface == b.interface && a.ip == b.ip);
return addresses;
}
Vec::new()
}
fn adapter_name(adapter: &IP_ADAPTER_ADDRESSES_LH) -> String {
unsafe {
if !adapter.FriendlyName.is_null() {
let mut len = 0usize;
while *adapter.FriendlyName.add(len) != 0 {
len += 1;
}
let name =
String::from_utf16_lossy(std::slice::from_raw_parts(adapter.FriendlyName, len));
if !name.trim().is_empty() {
return name;
}
}
if !adapter.AdapterName.is_null() {
return CStr::from_ptr(adapter.AdapterName.cast())
.to_string_lossy()
.into_owned();
}
}
"unknown".to_string()
}
fn sockaddr_to_ip(sockaddr: *const SOCKADDR) -> Option<String> {
if sockaddr.is_null() {
return None;
}
let family = unsafe { (*sockaddr).sa_family };
match family {
AF_INET => {
let addr = unsafe { *(sockaddr as *const SOCKADDR_IN) };
let bytes = unsafe { addr.sin_addr.S_un.S_addr.to_ne_bytes() };
Some(Ipv4Addr::from(bytes).to_string())
}
AF_INET6 => {
let addr = unsafe { *(sockaddr as *const SOCKADDR_IN6) };
let bytes = unsafe { addr.sin6_addr.u.Byte };
Some(Ipv6Addr::from(bytes).to_string())
}
_ => None,
}
}

View File

@@ -1,5 +1,4 @@
use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
@@ -13,8 +12,16 @@ use crate::events::EventBus;
const LOG_BUFFER_SIZE: usize = 200;
const LOG_BATCH_SIZE: usize = 16;
#[cfg(unix)]
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
#[cfg(windows)]
pub const TTYD_TCP_ADDR: &str = "127.0.0.1:7681";
#[cfg(windows)]
const TTYD_TCP_HOST: &str = "127.0.0.1";
#[cfg(windows)]
const TTYD_TCP_PORT: &str = "7681";
struct ExtensionProcess {
child: Child,
logs: Arc<RwLock<VecDeque<String>>>,
@@ -36,7 +43,7 @@ impl ExtensionManager {
pub fn new() -> Self {
let availability = ExtensionId::all()
.iter()
.map(|id| (*id, Path::new(id.binary_path()).exists()))
.map(|id| (*id, id.binary_path().exists()))
.collect();
Self {
@@ -64,6 +71,20 @@ impl ExtensionManager {
*self.availability.get(&id).unwrap_or(&false)
}
fn is_enabled_for_config(id: ExtensionId, config: &ExtensionsConfig) -> bool {
match id {
ExtensionId::Ttyd => config.ttyd.enabled,
ExtensionId::Gostc => {
config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
}
ExtensionId::Easytier => {
config.easytier.enabled && !config.easytier.network_name.is_empty()
}
}
}
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
if !self.check_available(id) {
return ExtensionStatus::Unavailable;
@@ -105,7 +126,11 @@ impl ExtensionManager {
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
if !self.check_available(id) {
return Err(format!("{} not found at {}", id, id.binary_path()));
return Err(format!(
"{} not found at {}",
id,
id.binary_path().display()
));
}
self.stop(id).await.ok();
@@ -115,7 +140,7 @@ impl ExtensionManager {
tracing::info!(
"Starting extension {}: {} {}",
id,
id.binary_path(),
id.binary_path().display(),
Self::redact_args_for_log(&args).join(" ")
);
@@ -232,15 +257,7 @@ impl ExtensionManager {
ExtensionId::Ttyd => {
let c = &config.ttyd;
Self::prepare_ttyd_socket().await?;
let mut args = vec![
"-i".to_string(),
TTYD_SOCKET_PATH.to_string(),
"-b".to_string(),
"/api/terminal".to_string(),
"-W".to_string(),
];
let mut args = Self::build_ttyd_listen_args().await?;
args.push(c.shell.clone());
Ok(args)
@@ -302,6 +319,43 @@ impl ExtensionManager {
}
}
#[cfg(unix)]
async fn build_ttyd_listen_args() -> Result<Vec<String>, String> {
Self::prepare_ttyd_socket().await?;
Ok(vec![
"-i".to_string(),
TTYD_SOCKET_PATH.to_string(),
"-b".to_string(),
"/api/terminal".to_string(),
"-W".to_string(),
])
}
#[cfg(windows)]
async fn build_ttyd_listen_args() -> Result<Vec<String>, String> {
let cwd = std::env::var("USERPROFILE")
.ok()
.filter(|path| !path.trim().is_empty())
.unwrap_or_else(|| {
std::env::current_dir()
.map(|path| path.to_string_lossy().to_string())
.unwrap_or_else(|_| ".".to_string())
});
Ok(vec![
"-i".to_string(),
TTYD_TCP_HOST.to_string(),
"-p".to_string(),
TTYD_TCP_PORT.to_string(),
"-b".to_string(),
"/api/terminal".to_string(),
"-w".to_string(),
cwd,
"-W".to_string(),
])
}
fn redact_args_for_log(args: &[String]) -> Vec<String> {
let mut redacted = Vec::with_capacity(args.len());
let mut redact_next = false;
@@ -330,8 +384,9 @@ impl ExtensionManager {
redacted
}
#[cfg(unix)]
async fn prepare_ttyd_socket() -> Result<(), String> {
let socket_path = Path::new(TTYD_SOCKET_PATH);
let socket_path = std::path::Path::new(TTYD_SOCKET_PATH);
if let Some(socket_dir) = socket_path.parent() {
if !socket_dir.exists() {
@@ -357,18 +412,7 @@ impl ExtensionManager {
let checks: Vec<_> = ExtensionId::all()
.iter()
.filter_map(|id| {
let should_run = match id {
ExtensionId::Ttyd => config.ttyd.enabled,
ExtensionId::Gostc => {
config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
}
ExtensionId::Easytier => {
config.easytier.enabled && !config.easytier.network_name.is_empty()
}
};
if should_run && self.check_available(*id) {
if Self::is_enabled_for_config(*id, config) && self.check_available(*id) {
Some(*id)
} else {
None
@@ -404,41 +448,15 @@ impl ExtensionManager {
}
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
use futures::Future;
use std::pin::Pin;
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
tracing::error!("Failed to start ttyd: {}", e);
let start_futures: Vec<_> = ExtensionId::all()
.iter()
.filter(|id| Self::is_enabled_for_config(**id, config) && self.check_available(**id))
.map(|id| async move {
if let Err(e) = self.start(*id, config).await {
tracing::error!("Failed to start {}: {}", id, e);
}
}));
}
if config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
&& self.check_available(ExtensionId::Gostc)
{
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Gostc, config).await {
tracing::error!("Failed to start gostc: {}", e);
}
}));
}
if config.easytier.enabled
&& !config.easytier.network_name.is_empty()
&& self.check_available(ExtensionId::Easytier)
{
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Easytier, config).await {
tracing::error!("Failed to start easytier: {}", e);
}
}));
}
})
.collect();
futures::future::join_all(start_futures).await;
}

View File

@@ -1,5 +1,10 @@
mod manager;
mod software;
mod types;
pub use manager::{ExtensionManager, TTYD_SOCKET_PATH};
pub use manager::ExtensionManager;
#[cfg(unix)]
pub use manager::TTYD_SOCKET_PATH;
#[cfg(windows)]
pub use manager::TTYD_TCP_ADDR;
pub use types::*;

View File

@@ -0,0 +1,15 @@
use std::path::PathBuf;
use super::ExtensionId;
#[cfg_attr(windows, path = "software_windows.rs")]
#[cfg_attr(not(windows), path = "software_linux.rs")]
mod platform;
pub fn binary_path(id: ExtensionId) -> PathBuf {
platform::binary_path(id)
}
pub fn default_ttyd_shell() -> &'static str {
platform::default_ttyd_shell()
}

View File

@@ -0,0 +1,19 @@
use std::path::PathBuf;
use super::ExtensionId;
pub fn default_binary_path(id: ExtensionId) -> &'static str {
match id {
ExtensionId::Ttyd => "/usr/bin/ttyd",
ExtensionId::Gostc => "/usr/bin/gostc",
ExtensionId::Easytier => "/usr/bin/easytier-core",
}
}
pub fn binary_path(id: ExtensionId) -> PathBuf {
PathBuf::from(default_binary_path(id))
}
pub fn default_ttyd_shell() -> &'static str {
"/bin/bash"
}

View File

@@ -0,0 +1,47 @@
use std::path::PathBuf;
use super::ExtensionId;
pub fn default_binary_path(id: ExtensionId) -> &'static str {
match id {
ExtensionId::Ttyd => "ttyd.win32.exe",
ExtensionId::Gostc => "gostc.exe",
ExtensionId::Easytier => "easytier-core.exe",
}
}
pub fn binary_path(id: ExtensionId) -> PathBuf {
if id == ExtensionId::Ttyd {
if let Some(path) = env_path("ONE_KVM_TTYD_PATH") {
return path;
}
}
find_in_app_dir(default_binary_path(id))
.unwrap_or_else(|| PathBuf::from(default_binary_path(id)))
}
pub fn default_ttyd_shell() -> &'static str {
"cmd"
}
fn env_path(name: &str) -> Option<PathBuf> {
std::env::var(name)
.ok()
.map(|path| path.trim().to_string())
.filter(|path| !path.is_empty())
.map(PathBuf::from)
}
fn find_in_app_dir(binary_name: &str) -> Option<PathBuf> {
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
let bundled = exe_dir.join(binary_name);
if bundled.exists() {
return Some(bundled);
}
}
}
None
}

View File

@@ -1,6 +1,8 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
use super::software;
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
@@ -11,12 +13,8 @@ pub enum ExtensionId {
}
impl ExtensionId {
pub fn binary_path(&self) -> &'static str {
match self {
Self::Ttyd => "/usr/bin/ttyd",
Self::Gostc => "/usr/bin/gostc",
Self::Easytier => "/usr/bin/easytier-core",
}
pub fn binary_path(&self) -> std::path::PathBuf {
software::binary_path(*self)
}
pub fn all() -> &'static [ExtensionId] {
@@ -54,7 +52,6 @@ pub enum ExtensionStatus {
Unavailable,
Stopped,
Running { pid: u32 },
Failed { error: String },
}
impl ExtensionStatus {
@@ -75,7 +72,7 @@ impl Default for TtydConfig {
fn default() -> Self {
Self {
enabled: false,
shell: "/bin/bash".to_string(),
shell: software::default_ttyd_shell().to_string(),
}
}
}

View File

@@ -10,29 +10,24 @@
use async_trait::async_trait;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::{Duration, Instant};
use tokio::sync::watch;
use tracing::{info, trace, warn};
use tracing::{info, trace};
use super::backend::{HidBackend, HidBackendRuntimeSnapshot};
use super::ch9329_proto::{
build_packet, cmd, expected_response_cmd, try_extract_response, ChipInfo, LedStatus, Response,
DEFAULT_ADDR, DEFAULT_BAUD_RATE, MAX_PACKET_SIZE,
};
use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType};
use crate::error::{AppError, Result};
use crate::events::LedState;
const PACKET_HEADER: [u8; 2] = [0x57, 0xAB];
const DEFAULT_ADDR: u8 = 0x00;
pub const DEFAULT_BAUD_RATE: u32 = 9600;
const RESPONSE_TIMEOUT_MS: u64 = 500;
const MAX_DATA_LEN: usize = 64;
const CH9329_MOUSE_RESOLUTION: u32 = 4096;
const PROBE_INTERVAL_MS: u64 = 100;
@@ -41,173 +36,6 @@ const RECONNECT_DELAY_MS: u64 = 2000;
const INIT_WAIT_MS: u64 = 3000;
pub mod cmd {
pub const GET_INFO: u8 = 0x01;
pub const SEND_KB_GENERAL_DATA: u8 = 0x02;
pub const SEND_KB_MEDIA_DATA: u8 = 0x03;
pub const SEND_MS_ABS_DATA: u8 = 0x04;
pub const SEND_MS_REL_DATA: u8 = 0x05;
pub const SEND_MY_HID_DATA: u8 = 0x06;
pub const SET_DEFAULT_CFG: u8 = 0x0C;
pub const RESET: u8 = 0x0F;
}
const RESPONSE_SUCCESS_MASK: u8 = 0x80;
const RESPONSE_ERROR_MASK: u8 = 0xC0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Ch9329Error {
Success = 0x00,
Timeout = 0xE1,
InvalidHeader = 0xE2,
InvalidCommand = 0xE3,
ChecksumError = 0xE4,
ParameterError = 0xE5,
OperationFailed = 0xE6,
}
impl From<u8> for Ch9329Error {
fn from(code: u8) -> Self {
match code {
0x00 => Ch9329Error::Success,
0xE1 => Ch9329Error::Timeout,
0xE2 => Ch9329Error::InvalidHeader,
0xE3 => Ch9329Error::InvalidCommand,
0xE4 => Ch9329Error::ChecksumError,
0xE5 => Ch9329Error::ParameterError,
0xE6 => Ch9329Error::OperationFailed,
_ => Ch9329Error::OperationFailed,
}
}
}
impl std::fmt::Display for Ch9329Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ch9329Error::Success => write!(f, "Success"),
Ch9329Error::Timeout => write!(f, "Serial receive timeout"),
Ch9329Error::InvalidHeader => write!(f, "Invalid packet header"),
Ch9329Error::InvalidCommand => write!(f, "Invalid command code"),
Ch9329Error::ChecksumError => write!(f, "Checksum mismatch"),
Ch9329Error::ParameterError => write!(f, "Parameter error"),
Ch9329Error::OperationFailed => write!(f, "Operation failed"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChipInfo {
pub version: String,
pub version_raw: u8,
pub usb_connected: bool,
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl ChipInfo {
pub fn from_response(data: &[u8]) -> Option<Self> {
if data.len() < 8 {
return None;
}
let version_raw = data[0];
let version = format!("V{}.{}", version_raw >> 4, version_raw & 0x0F);
let usb_connected = data[1] == 0x01;
let led_status = data[2];
Some(Self {
version,
version_raw,
usb_connected,
num_lock: (led_status & 0x01) != 0,
caps_lock: (led_status & 0x02) != 0,
scroll_lock: (led_status & 0x04) != 0,
})
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LedStatus {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl From<u8> for LedStatus {
fn from(byte: u8) -> Self {
Self {
num_lock: (byte & 0x01) != 0,
caps_lock: (byte & 0x02) != 0,
scroll_lock: (byte & 0x04) != 0,
}
}
}
#[derive(Debug)]
pub struct Response {
pub address: u8,
pub cmd: u8,
pub data: Vec<u8>,
pub is_error: bool,
pub error_code: Option<Ch9329Error>,
}
impl Response {
pub fn parse(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 6 {
return None;
}
if bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] {
return None;
}
let address = bytes[2];
let cmd = bytes[3];
let len = bytes[4] as usize;
if bytes.len() < 5 + len + 1 {
return None;
}
let expected_checksum = bytes[5 + len];
let calculated_checksum = bytes[..5 + len]
.iter()
.fold(0u8, |acc, &x| acc.wrapping_add(x));
if expected_checksum != calculated_checksum {
warn!(
"CH9329 checksum mismatch: expected {:02X}, got {:02X}",
expected_checksum, calculated_checksum
);
return None;
}
let data = bytes[5..5 + len].to_vec();
let is_error = (cmd & RESPONSE_ERROR_MASK) == RESPONSE_ERROR_MASK;
let error_code = if is_error && !data.is_empty() {
Some(Ch9329Error::from(data[0]))
} else {
None
};
Some(Self {
address,
cmd,
data,
is_error,
error_code,
})
}
pub fn is_success(&self) -> bool {
!self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8)
}
}
const MAX_PACKET_SIZE: usize = 70;
struct Ch9329RuntimeState {
initialized: AtomicBool,
@@ -331,6 +159,13 @@ impl Ch9329Backend {
}
pub fn check_port_exists(&self) -> bool {
#[cfg(windows)]
{
return crate::utils::list_serial_ports()
.iter()
.any(|port| port.eq_ignore_ascii_case(&self.port_path));
}
#[cfg(not(windows))]
std::path::Path::new(&self.port_path).exists()
}
@@ -358,39 +193,8 @@ impl Ch9329Backend {
}
#[inline]
fn calculate_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
}
#[inline]
fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) {
debug_assert!(
data.len() <= MAX_DATA_LEN,
"Data too long for CH9329 packet"
);
let len = data.len() as u8;
let packet_len = 6 + data.len();
let mut packet = [0u8; MAX_PACKET_SIZE];
packet[0] = PACKET_HEADER[0];
packet[1] = PACKET_HEADER[1];
packet[2] = address;
packet[3] = cmd;
packet[4] = len;
packet[5..5 + data.len()].copy_from_slice(data);
let checksum = Self::calculate_checksum(&packet[..5 + data.len()]);
packet[5 + data.len()] = checksum;
(packet, packet_len)
}
fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec<u8> {
let (buf, len) = Self::build_packet_buf(address, cmd, data);
buf[..len].to_vec()
}
fn open_port(port_path: &str, baud_rate: u32) -> Result<Box<dyn serialport::SerialPort>> {
#[cfg(not(windows))]
if !std::path::Path::new(port_path).exists() {
return Err(Self::backend_error(
format!("Serial port {} not found", port_path),
@@ -410,46 +214,13 @@ impl Ch9329Backend {
cmd: u8,
data: &[u8],
) -> Result<()> {
let packet = Self::build_packet(address, cmd, data);
let packet = build_packet(address, cmd, data);
port.write_all(&packet).map_err(|e| {
Self::backend_error(format!("Failed to write to CH9329: {}", e), "write_failed")
})?;
Ok(())
}
fn try_extract_response(buffer: &[u8]) -> Option<(Response, usize)> {
let mut offset = 0;
while offset + 6 <= buffer.len() {
if buffer[offset] != PACKET_HEADER[0] || buffer[offset + 1] != PACKET_HEADER[1] {
offset += 1;
continue;
}
let len = buffer[offset + 4] as usize;
let frame_len = 6 + len;
if offset + frame_len > buffer.len() {
return None;
}
let frame = &buffer[offset..offset + frame_len];
if let Some(response) = Response::parse(frame) {
return Some((response, offset + frame_len));
}
offset += 1;
}
None
}
fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 {
cmd | if is_error {
RESPONSE_ERROR_MASK
} else {
RESPONSE_SUCCESS_MASK
}
}
fn xfer_packet(
port: &mut dyn serialport::SerialPort,
address: u8,
@@ -460,8 +231,8 @@ impl Ch9329Backend {
let mut pending = Vec::with_capacity(128);
let deadline = Instant::now() + Duration::from_millis(RESPONSE_TIMEOUT_MS);
let expected_ok = Self::expected_response_cmd(cmd, false);
let expected_err = Self::expected_response_cmd(cmd, true);
let expected_ok = expected_response_cmd(cmd, false);
let expected_err = expected_response_cmd(cmd, true);
loop {
let mut chunk = [0u8; 128];
@@ -469,7 +240,7 @@ impl Ch9329Backend {
Ok(n) if n > 0 => {
pending.extend_from_slice(&chunk[..n]);
while let Some((response, consumed)) = Self::try_extract_response(&pending) {
while let Some((response, consumed)) = try_extract_response(&pending) {
pending.drain(..consumed);
if response.cmd == expected_ok || response.cmd == expected_err {
return Ok(response);
@@ -1023,7 +794,14 @@ impl HidBackend for Ch9329Backend {
let mut online = initialized && self.runtime.online.load(Ordering::Relaxed);
let mut error = self.runtime.last_error.read().clone();
if initialized && !self.check_port_exists() {
#[cfg(windows)]
let port_still_present = crate::utils::list_serial_ports()
.iter()
.any(|port| port.eq_ignore_ascii_case(&self.port_path));
#[cfg(not(windows))]
let port_still_present = self.check_port_exists();
if initialized && !port_still_present {
online = false;
error = Some((
format!("Serial port {} not found", self.port_path),
@@ -1066,14 +844,15 @@ impl HidBackend for Ch9329Backend {
#[cfg(test)]
mod tests {
use super::*;
use super::ch9329_proto::{build_packet, calculate_checksum};
#[test]
fn test_packet_building() {
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]);
let packet = build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]);
assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]);
let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data);
let packet = build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data);
assert_eq!(packet[0], 0x57); // Header
assert_eq!(packet[1], 0xAB); // Header
@@ -1090,7 +869,7 @@ mod tests {
#[test]
fn test_relative_mouse_packet() {
let data = [0x01, 0x00, 50u8, 0x00, 0x00];
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data);
let packet = build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data);
assert_eq!(packet[0], 0x57);
assert_eq!(packet[1], 0xAB);
@@ -1105,13 +884,13 @@ mod tests {
#[test]
fn test_checksum_calculation() {
let packet = [0x57u8, 0xAB, 0x00, 0x01, 0x00];
let checksum = Ch9329Backend::calculate_checksum(&packet);
let checksum = calculate_checksum(&packet);
assert_eq!(checksum, 0x03);
let packet = [
0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
];
let checksum = Ch9329Backend::calculate_checksum(&packet);
let checksum = calculate_checksum(&packet);
assert_eq!(checksum, 0x10);
}

225
src/hid/ch9329_proto.rs Normal file
View File

@@ -0,0 +1,225 @@
//! Shared CH9329 protocol types and packet helpers.
use serde::{Deserialize, Serialize};
const PACKET_HEADER: [u8; 2] = [0x57, 0xAB];
pub const RESPONSE_SUCCESS_MASK: u8 = 0x80;
pub const RESPONSE_ERROR_MASK: u8 = 0xC0;
pub const DEFAULT_ADDR: u8 = 0x00;
pub const DEFAULT_BAUD_RATE: u32 = 9600;
pub const MAX_DATA_LEN: usize = 64;
pub const MAX_PACKET_SIZE: usize = 70;
pub mod cmd {
pub const GET_INFO: u8 = 0x01;
pub const SEND_KB_GENERAL_DATA: u8 = 0x02;
pub const SEND_KB_MEDIA_DATA: u8 = 0x03;
pub const SEND_MS_ABS_DATA: u8 = 0x04;
pub const SEND_MS_REL_DATA: u8 = 0x05;
pub const RESET: u8 = 0x0F;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Ch9329Error {
Success = 0x00,
Timeout = 0xE1,
InvalidHeader = 0xE2,
InvalidCommand = 0xE3,
ChecksumError = 0xE4,
ParameterError = 0xE5,
OperationFailed = 0xE6,
}
impl From<u8> for Ch9329Error {
fn from(code: u8) -> Self {
match code {
0x00 => Ch9329Error::Success,
0xE1 => Ch9329Error::Timeout,
0xE2 => Ch9329Error::InvalidHeader,
0xE3 => Ch9329Error::InvalidCommand,
0xE4 => Ch9329Error::ChecksumError,
0xE5 => Ch9329Error::ParameterError,
0xE6 => Ch9329Error::OperationFailed,
_ => Ch9329Error::OperationFailed,
}
}
}
impl std::fmt::Display for Ch9329Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ch9329Error::Success => write!(f, "Success"),
Ch9329Error::Timeout => write!(f, "Serial receive timeout"),
Ch9329Error::InvalidHeader => write!(f, "Invalid packet header"),
Ch9329Error::InvalidCommand => write!(f, "Invalid command code"),
Ch9329Error::ChecksumError => write!(f, "Checksum mismatch"),
Ch9329Error::ParameterError => write!(f, "Parameter error"),
Ch9329Error::OperationFailed => write!(f, "Operation failed"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChipInfo {
pub version: String,
pub version_raw: u8,
pub usb_connected: bool,
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl ChipInfo {
pub fn from_response(data: &[u8]) -> Option<Self> {
if data.len() < 8 {
return None;
}
let version_raw = data[0];
let version = format!("V{}.{}", version_raw >> 4, version_raw & 0x0F);
let usb_connected = data[1] == 0x01;
let led_status = data[2];
Some(Self {
version,
version_raw,
usb_connected,
num_lock: (led_status & 0x01) != 0,
caps_lock: (led_status & 0x02) != 0,
scroll_lock: (led_status & 0x04) != 0,
})
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LedStatus {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl From<u8> for LedStatus {
fn from(byte: u8) -> Self {
Self {
num_lock: (byte & 0x01) != 0,
caps_lock: (byte & 0x02) != 0,
scroll_lock: (byte & 0x04) != 0,
}
}
}
#[derive(Debug)]
pub struct Response {
pub cmd: u8,
pub data: Vec<u8>,
pub is_error: bool,
pub error_code: Option<Ch9329Error>,
}
impl Response {
pub fn parse(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 6 || bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] {
return None;
}
let cmd = bytes[3];
let len = bytes[4] as usize;
if bytes.len() < 5 + len + 1 {
return None;
}
let expected_checksum = bytes[5 + len];
let calculated_checksum = bytes[..5 + len]
.iter()
.fold(0u8, |acc, &x| acc.wrapping_add(x));
if expected_checksum != calculated_checksum {
tracing::warn!(
"CH9329 checksum mismatch: expected {:02X}, got {:02X}",
expected_checksum,
calculated_checksum
);
return None;
}
let data = bytes[5..5 + len].to_vec();
let is_error = (cmd & RESPONSE_ERROR_MASK) == RESPONSE_ERROR_MASK;
let error_code = if is_error && !data.is_empty() {
Some(Ch9329Error::from(data[0]))
} else {
None
};
Some(Self {
cmd,
data,
is_error,
error_code,
})
}
}
#[inline]
pub fn calculate_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
}
#[inline]
pub fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) {
debug_assert!(data.len() <= MAX_DATA_LEN, "Data too long for CH9329 packet");
let len = data.len() as u8;
let packet_len = 6 + data.len();
let mut packet = [0u8; MAX_PACKET_SIZE];
packet[0] = PACKET_HEADER[0];
packet[1] = PACKET_HEADER[1];
packet[2] = address;
packet[3] = cmd;
packet[4] = len;
packet[5..5 + data.len()].copy_from_slice(data);
packet[5 + data.len()] = calculate_checksum(&packet[..5 + data.len()]);
(packet, packet_len)
}
#[inline]
pub fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec<u8> {
let (buf, len) = build_packet_buf(address, cmd, data);
buf[..len].to_vec()
}
#[inline]
pub fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 {
cmd | if is_error {
RESPONSE_ERROR_MASK
} else {
RESPONSE_SUCCESS_MASK
}
}
pub fn try_extract_response(buffer: &[u8]) -> Option<(Response, usize)> {
let mut offset = 0;
while offset + 6 <= buffer.len() {
if buffer[offset] != PACKET_HEADER[0] || buffer[offset + 1] != PACKET_HEADER[1] {
offset += 1;
continue;
}
let len = buffer[offset + 4] as usize;
let frame_len = 6 + len;
if offset + frame_len > buffer.len() {
return None;
}
let frame = &buffer[offset..offset + frame_len];
if let Some(response) = Response::parse(frame) {
return Some((response, offset + frame_len));
}
offset += 1;
}
None
}

80
src/hid/factory.rs Normal file
View File

@@ -0,0 +1,80 @@
use std::sync::Arc;
use tracing::{info, warn};
use super::{ch9329, HidBackend, HidBackendType};
use crate::error::{AppError, Result};
#[cfg(unix)]
use crate::otg::OtgService;
pub struct HidBackendFactory {
#[cfg(unix)]
otg_service: Option<Arc<OtgService>>,
}
impl HidBackendFactory {
#[cfg(unix)]
pub fn new(otg_service: Option<Arc<OtgService>>) -> Self {
Self { otg_service }
}
#[cfg(not(unix))]
pub fn new() -> Self {
Self {}
}
pub async fn create_initialized(
&self,
backend_type: &HidBackendType,
) -> Result<Option<Arc<dyn HidBackend>>> {
let backend = match self.create(backend_type).await? {
Some(backend) => backend,
None => return Ok(None),
};
backend.init().await?;
Ok(Some(backend))
}
async fn create(&self, backend_type: &HidBackendType) -> Result<Option<Arc<dyn HidBackend>>> {
match backend_type {
HidBackendType::Otg => self.create_otg_backend().await.map(Some),
HidBackendType::Ch9329 { port, baud_rate } => {
info!(
"Initializing CH9329 HID backend on {} @ {} baud",
port, baud_rate
);
Ok(Some(Arc::new(ch9329::Ch9329Backend::with_baud_rate(
port, *baud_rate,
)?)))
}
HidBackendType::None => {
warn!("HID backend disabled");
Ok(None)
}
}
}
#[cfg(unix)]
async fn create_otg_backend(&self) -> Result<Arc<dyn HidBackend>> {
let otg_service = self
.otg_service
.as_ref()
.ok_or_else(|| AppError::Config("OTG backend not available".to_string()))?;
let handles = otg_service
.hid_device_paths()
.await
.ok_or_else(|| AppError::Config("OTG HID paths are not available".to_string()))?;
info!("Creating OTG HID backend from device paths");
Ok(Arc::new(super::otg::OtgBackend::from_handles(handles)?))
}
#[cfg(not(unix))]
async fn create_otg_backend(&self) -> Result<Arc<dyn HidBackend>> {
Err(AppError::Config(
"OTG HID is only available on Linux".to_string(),
))
}
}

View File

@@ -1,11 +1,16 @@
//! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329.
pub mod backend;
mod ch9329_proto;
pub mod ch9329;
pub mod consumer;
pub mod datachannel;
mod factory;
pub mod keyboard;
#[cfg(unix)]
pub mod otg;
#[cfg(unix)]
mod otg_device;
pub mod types;
pub mod websocket;
@@ -95,7 +100,9 @@ use tracing::{info, warn};
use crate::error::{AppError, Result};
use crate::events::EventBus;
#[cfg(unix)]
use crate::otg::OtgService;
use factory::HidBackendFactory;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
@@ -112,7 +119,7 @@ enum QueuedHidEvent {
}
pub struct HidController {
otg_service: Option<Arc<OtgService>>,
backend_factory: HidBackendFactory,
backend: Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
backend_type: Arc<RwLock<HidBackendType>>,
events: Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
@@ -127,11 +134,33 @@ pub struct HidController {
}
impl HidController {
#[cfg(unix)]
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
Self {
otg_service,
backend: Arc::new(RwLock::new(None)),
backend_factory: HidBackendFactory::new(otg_service),
backend_type: Arc::new(RwLock::new(backend_type.clone())),
events: Arc::new(tokio::sync::RwLock::new(None)),
runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type(
&backend_type,
))),
hid_tx,
hid_rx: Mutex::new(Some(hid_rx)),
pending_move: Arc::new(parking_lot::Mutex::new(None)),
pending_move_flag: Arc::new(AtomicBool::new(false)),
hid_worker: Mutex::new(None),
runtime_worker: Mutex::new(None),
backend_available: Arc::new(AtomicBool::new(false)),
}
}
#[cfg(not(unix))]
pub fn new(backend_type: HidBackendType) -> Self {
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
Self {
backend: Arc::new(RwLock::new(None)),
backend_factory: HidBackendFactory::new(),
backend_type: Arc::new(RwLock::new(backend_type.clone())),
events: Arc::new(tokio::sync::RwLock::new(None)),
runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type(
@@ -153,51 +182,22 @@ impl HidController {
pub async fn init(&self) -> Result<()> {
let backend_type = self.backend_type.read().await.clone();
let backend: Arc<dyn HidBackend> = match backend_type {
HidBackendType::Otg => {
let otg_service = self
.otg_service
.as_ref()
.ok_or_else(|| AppError::Internal("OtgService not available".into()))?;
let handles = otg_service.hid_device_paths().await.ok_or_else(|| {
AppError::Config("OTG HID paths are not available".to_string())
})?;
info!("Creating OTG HID backend from device paths");
Arc::new(otg::OtgBackend::from_handles(handles)?)
}
HidBackendType::Ch9329 {
ref port,
baud_rate,
} => {
info!(
"Initializing CH9329 HID backend on {} @ {} baud",
port, baud_rate
);
Arc::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?)
}
HidBackendType::None => {
warn!("HID backend disabled");
return Ok(());
}
};
if let Err(e) = backend.init().await {
self.backend_available.store(false, Ordering::Release);
let error_state = {
let backend_type = self.backend_type.read().await.clone();
let backend = match self.backend_factory.create_initialized(&backend_type).await {
Ok(Some(backend)) => backend,
Ok(None) => return Ok(()),
Err(error) => {
self.backend_available.store(false, Ordering::Release);
let current = self.runtime_state.read().await.clone();
HidRuntimeState::with_error(
let error_state = HidRuntimeState::with_error(
&backend_type,
&current,
format!("Failed to initialize HID backend: {}", e),
format!("Failed to initialize HID backend: {}", error),
"init_failed",
)
};
self.apply_runtime_state(error_state).await;
return Err(e);
}
);
self.apply_runtime_state(error_state).await;
return Err(error);
}
};
*self.backend.write().await = Some(backend);
self.sync_runtime_state_from_backend().await;
@@ -298,73 +298,15 @@ impl HidController {
}
}
let new_backend: Option<Arc<dyn HidBackend>> = match new_backend_type {
HidBackendType::Otg => {
info!("Initializing OTG HID backend");
let otg_service = match self.otg_service.as_ref() {
Some(svc) => svc,
None => {
warn!("OTG backend requires OtgService, but it's not available");
return Err(AppError::Config(
"OTG backend not available (OtgService missing)".to_string(),
));
}
};
match otg_service.hid_device_paths().await {
Some(handles) => match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let backend = Arc::new(backend);
match backend.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(backend)
}
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
None
}
}
}
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
None
}
},
None => {
warn!("OTG HID paths are not available");
None
}
}
}
HidBackendType::Ch9329 {
ref port,
baud_rate,
} => {
info!(
"Initializing CH9329 HID backend on {} @ {} baud",
port, baud_rate
);
match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) {
Ok(b) => {
let backend = Arc::new(b);
match backend.init().await {
Ok(_) => Some(backend),
Err(e) => {
warn!("Failed to initialize CH9329 backend: {}", e);
None
}
}
}
Err(e) => {
warn!("Failed to create CH9329 backend: {}", e);
None
}
}
}
HidBackendType::None => {
warn!("HID backend disabled");
let new_backend = match self
.backend_factory
.create_initialized(&new_backend_type)
.await
{
Ok(backend) => backend,
Err(error) if matches!(&new_backend_type, HidBackendType::None) => return Err(error),
Err(error) => {
warn!("Failed to initialize HID backend: {}", error);
None
}
};

View File

@@ -4,12 +4,10 @@
//! Polled timed writes (JetKVM-style). Treat `ESHUTDOWN` (108) by closing handles and reopening; keep fd on `EAGAIN` (11). Host/gadget teardown during MSD resembles PiKVM. <https://github.com/raspberrypi/linux/issues/4373>
use async_trait::async_trait;
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use parking_lot::Mutex;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::AsFd;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc;
@@ -19,6 +17,7 @@ use tokio::sync::watch;
use tracing::{debug, info, trace, warn};
use super::backend::{HidBackend, HidBackendRuntimeSnapshot};
use super::otg_device::OtgDeviceIo;
use super::types::{
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType,
};
@@ -87,7 +86,6 @@ pub struct OtgBackend {
last_error: parking_lot::RwLock<Option<(String, String)>>,
last_error_log: parking_lot::Mutex<std::time::Instant>,
error_count: AtomicU8,
eagain_count: AtomicU8,
runtime_notify_tx: watch::Sender<()>,
runtime_worker_stop: Arc<AtomicBool>,
runtime_worker: Mutex<Option<thread::JoinHandle<()>>>,
@@ -119,7 +117,6 @@ impl OtgBackend {
last_error: parking_lot::RwLock::new(None),
last_error_log: parking_lot::Mutex::new(std::time::Instant::now()),
error_count: AtomicU8::new(0),
eagain_count: AtomicU8::new(0),
runtime_notify_tx,
runtime_worker_stop: Arc::new(AtomicBool::new(false)),
runtime_worker: Mutex::new(None),
@@ -179,34 +176,11 @@ impl OtgBackend {
fn reset_error_count(&self) {
self.error_count.store(0, Ordering::Relaxed);
self.eagain_count.store(0, Ordering::Relaxed);
}
/// Poll-based write with `HID_WRITE_TIMEOUT_MS`; timeout → drop (JetKVM-style).
fn write_with_timeout(&self, file: &mut File, data: &[u8]) -> std::io::Result<bool> {
let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)];
match poll(&mut pollfd, PollTimeout::from(HID_WRITE_TIMEOUT_MS as u16)) {
Ok(1) => {
if let Some(revents) = pollfd[0].revents() {
if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP)
{
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Device error or hangup",
));
}
}
file.write_all(data)?;
Ok(true)
}
Ok(0) => {
trace!("HID write timeout, dropping data");
Ok(false)
}
Ok(_) => Ok(false),
Err(e) => Err(std::io::Error::other(e)),
}
OtgDeviceIo::write_with_timeout(file, data, HID_WRITE_TIMEOUT_MS)
}
pub fn set_udc_name(&self, udc: &str) {
@@ -357,6 +331,32 @@ impl OtgBackend {
}
}
fn handle_write_error(
&self,
dev: &mut Option<File>,
err: std::io::Error,
operation: &str,
device_label: &str,
) -> Result<()> {
match err.raw_os_error() {
Some(108) => {
debug!("{} ESHUTDOWN, closing for recovery", device_label);
*dev = None;
self.record_error(format!("{}: {}", operation, err), "eshutdown");
Err(Self::io_error_to_hid_error(err, operation))
}
Some(11) => {
trace!("{} EAGAIN after poll, dropping", device_label);
Ok(())
}
_ => {
warn!("{} write error: {}", device_label, err);
self.record_error(format!("{}: {}", operation, err), Self::io_error_code(&err));
Err(Self::io_error_to_hid_error(err, operation))
}
}
}
pub fn check_devices_exist(&self) -> bool {
self.keyboard_path.as_ref().is_none_or(|p| p.exists())
&& self.mouse_rel_path.as_ref().is_none_or(|p| p.exists())
@@ -405,41 +405,12 @@ impl OtgBackend {
self.log_throttled_error("HID keyboard write timeout, dropped");
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Keyboard ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write keyboard report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write keyboard report",
))
}
Some(11) => {
trace!("Keyboard EAGAIN after poll, dropping");
Ok(())
}
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Keyboard write error: {}", e);
self.record_error(
format!("Failed to write keyboard report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write keyboard report",
))
}
}
}
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write keyboard report",
"Keyboard",
),
}
} else {
Err(AppError::HidError {
@@ -468,38 +439,12 @@ impl OtgBackend {
Ok(())
}
Ok(false) => Ok(()),
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Relative mouse ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write mouse report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
Some(11) => Ok(()),
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Relative mouse write error: {}", e);
self.record_error(
format!("Failed to write mouse report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
}
}
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write mouse report",
"Relative mouse",
),
}
} else {
Err(AppError::HidError {
@@ -534,38 +479,12 @@ impl OtgBackend {
Ok(())
}
Ok(false) => Ok(()),
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Absolute mouse ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write mouse report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
Some(11) => Ok(()),
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Absolute mouse write error: {}", e);
self.record_error(
format!("Failed to write mouse report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
}
}
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write mouse report",
"Absolute mouse",
),
}
} else {
Err(AppError::HidError {
@@ -597,35 +516,12 @@ impl OtgBackend {
Ok(())
}
Ok(false) => Ok(()),
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
debug!("Consumer control ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write consumer report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write consumer report",
))
}
Some(11) => Ok(()),
_ => {
warn!("Consumer control write error: {}", e);
self.record_error(
format!("Failed to write consumer report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write consumer report",
))
}
}
}
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write consumer report",
"Consumer control",
),
}
} else {
Err(AppError::HidError {

50
src/hid/otg_device.rs Normal file
View File

@@ -0,0 +1,50 @@
#[cfg(unix)]
use std::fs::{File, OpenOptions};
#[cfg(unix)]
use std::io::{Read, Write};
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
#[cfg(unix)]
use std::os::unix::io::AsFd;
#[cfg(unix)]
use std::path::PathBuf;
#[cfg(unix)]
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
#[cfg(unix)]
use tracing::trace;
#[cfg(unix)]
pub struct OtgDeviceIo;
#[cfg(unix)]
impl OtgDeviceIo {
pub fn write_with_timeout(
file: &mut File,
data: &[u8],
timeout_ms: i32,
) -> std::io::Result<bool> {
let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)];
match poll(&mut pollfd, PollTimeout::from(timeout_ms as u16)) {
Ok(1) => {
if let Some(revents) = pollfd[0].revents() {
if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP)
{
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Device error or hangup",
));
}
}
file.write_all(data)?;
Ok(true)
}
Ok(0) => {
trace!("HID write timeout, dropping data");
Ok(false)
}
Ok(_) => Ok(false),
Err(e) => Err(std::io::Error::other(e)),
}
}
}

View File

@@ -1,16 +1,23 @@
//! Core library for One-KVM (IPKVM: capture, HID, OTG, streaming, Web UI glue).
#[cfg(not(any(unix, windows)))]
compile_error!("One-KVM supports Linux and Windows targets only.");
pub mod atx;
pub mod audio;
pub mod auth;
pub mod config;
pub mod db;
pub mod diagnostics;
pub mod error;
pub mod events;
pub mod extensions;
pub mod hid;
#[cfg(unix)]
pub mod msd;
#[cfg(unix)]
pub mod otg;
pub mod platform;
pub mod redfish;
pub mod rtsp;
pub mod rustdesk;

View File

@@ -20,8 +20,11 @@ use one_kvm::db::DatabasePool;
use one_kvm::events::EventBus;
use one_kvm::extensions::ExtensionManager;
use one_kvm::hid::{HidBackendType, HidController};
#[cfg(unix)]
use one_kvm::msd::MsdController;
#[cfg(unix)]
use one_kvm::otg::OtgService;
use one_kvm::platform::PlatformCapabilities;
use one_kvm::rtsp::RtspService;
use one_kvm::rustdesk::RustDeskService;
use one_kvm::state::AppState;
@@ -78,7 +81,7 @@ struct CliArgs {
#[arg(long, value_name = "FILE", requires = "ssl_cert")]
ssl_key: Option<PathBuf>,
/// Data directory path (default: /etc/one-kvm)
/// Data directory path (default: /etc/one-kvm, or the executable directory on Windows)
#[arg(short = 'd', long, value_name = "DIR")]
data_dir: Option<PathBuf>,
@@ -119,6 +122,12 @@ async fn main() -> anyhow::Result<()> {
.expect("Failed to install rustls crypto provider");
tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION"));
let platform = PlatformCapabilities::current();
tracing::info!(
"Platform mode: {:?} ({})",
platform.mode,
platform.mode_label
);
let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir);
tracing::info!("Data directory: {}", data_dir.display());
@@ -128,37 +137,7 @@ async fn main() -> anyhow::Result<()> {
return Ok(());
}
tokio::fs::create_dir_all(&data_dir).await?;
let db = open_database_pool(&data_dir).await?;
let config_store = ConfigStore::new(db.clone_pool())?;
config_store.load().await?;
let mut config = (*config_store.get()).clone();
let mut msd_dir_updated = false;
if config.msd.msd_dir.trim().is_empty() {
let msd_dir = data_dir.join("msd");
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
} else if !PathBuf::from(&config.msd.msd_dir).is_absolute() {
let msd_dir = data_dir.join(&config.msd.msd_dir);
tracing::warn!(
"MSD directory is relative, rebasing to {}",
msd_dir.display()
);
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
}
if msd_dir_updated {
config_store.set(config.clone()).await?;
}
let msd_dir = PathBuf::from(&config.msd.msd_dir);
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
tracing::warn!("Failed to create MSD images directory: {}", e);
}
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("ventoy")).await {
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
}
let (db, config_store, mut config) = load_runtime_config(&data_dir).await?;
if let Some(addr) = args.address {
config.web.bind_address = addr.clone();
@@ -311,9 +290,12 @@ async fn main() -> anyhow::Result<()> {
};
tracing::info!("WebRTC streamer created");
#[cfg(unix)]
let otg_service = Arc::new(OtgService::new());
#[cfg(unix)]
tracing::info!("OTG Service created");
#[cfg(unix)]
if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await {
tracing::warn!("Failed to apply OTG config: {}", e);
}
@@ -326,12 +308,16 @@ async fn main() -> anyhow::Result<()> {
},
config::HidBackend::None => HidBackendType::None,
};
#[cfg(unix)]
let hid = Arc::new(HidController::new(hid_backend, Some(otg_service.clone())));
#[cfg(not(unix))]
let hid = Arc::new(HidController::new(hid_backend));
hid.set_event_bus(events.clone()).await;
if let Err(e) = hid.init().await {
tracing::warn!("Failed to initialize HID backend: {}", e);
}
#[cfg(unix)]
let msd = if config.msd.enabled {
let ventoy_resource_dir = data_dir.join("ventoy");
if ventoy_resource_dir.exists() {
@@ -544,10 +530,12 @@ async fn main() -> anyhow::Result<()> {
config_store.clone(),
session_store,
user_store,
#[cfg(unix)]
otg_service,
stream_manager,
webrtc_streamer.clone(),
hid,
#[cfg(unix)]
msd,
atx,
audio,
@@ -703,12 +691,12 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
};
let filter = match effective_level {
LogLevel::Error => "one_kvm=error,tower_http=error",
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
LogLevel::Info => "one_kvm=info,tower_http=info",
LogLevel::Verbose => "one_kvm=debug,tower_http=info",
LogLevel::Debug => "one_kvm=debug,tower_http=debug",
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
LogLevel::Error => "one_kvm=error,tower_http=error,webrtc_sctp=warn",
LogLevel::Warn => "one_kvm=warn,tower_http=warn,webrtc_sctp=warn",
LogLevel::Info => "one_kvm=info,tower_http=info,webrtc_sctp=warn",
LogLevel::Verbose => "one_kvm=debug,tower_http=info,webrtc_sctp=warn",
LogLevel::Debug => "one_kvm=debug,tower_http=debug,webrtc_sctp=warn",
LogLevel::Trace => "one_kvm=trace,tower_http=debug,webrtc_sctp=warn",
};
let env_filter =
@@ -728,6 +716,19 @@ fn get_data_dir() -> PathBuf {
return PathBuf::from(path);
}
#[cfg(windows)]
{
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
return exe_dir.join("one-kvm");
}
}
return std::env::current_dir()
.map(|dir| dir.join("one-kvm"))
.unwrap_or_else(|_| PathBuf::from("one-kvm"));
}
#[cfg(not(windows))]
PathBuf::from("/etc/one-kvm")
}
@@ -771,6 +772,64 @@ async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Resu
}
}
async fn load_runtime_config(
data_dir: &Path,
) -> anyhow::Result<(DatabasePool, ConfigStore, AppConfig)> {
tokio::fs::create_dir_all(data_dir).await?;
let db = open_database_pool(data_dir).await?;
let config_store = ConfigStore::new(db.clone_pool())?;
config_store.load().await?;
let mut config = (*config_store.get()).clone();
config.apply_platform_defaults();
prepare_linux_runtime_dirs(data_dir, &config_store, &mut config).await?;
Ok((db, config_store, config))
}
#[cfg(unix)]
async fn prepare_linux_runtime_dirs(
data_dir: &Path,
config_store: &ConfigStore,
config: &mut AppConfig,
) -> anyhow::Result<()> {
let mut msd_dir_updated = false;
if config.msd.msd_dir.trim().is_empty() {
let msd_dir = data_dir.join("msd");
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
} else if !PathBuf::from(&config.msd.msd_dir).is_absolute() {
let msd_dir = data_dir.join(&config.msd.msd_dir);
tracing::warn!(
"MSD directory is relative, rebasing to {}",
msd_dir.display()
);
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
}
if msd_dir_updated {
config_store.set(config.clone()).await?;
}
let msd_dir = PathBuf::from(&config.msd.msd_dir);
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
tracing::warn!("Failed to create MSD images directory: {}", e);
}
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("ventoy")).await {
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
}
Ok(())
}
#[cfg(not(unix))]
async fn prepare_linux_runtime_dirs(
_data_dir: &Path,
_config_store: &ConfigStore,
_config: &mut AppConfig,
) -> anyhow::Result<()> {
Ok(())
}
async fn run_user_action(
action: UserAction,
users: &UserStore,
@@ -1048,6 +1107,7 @@ async fn cleanup(state: &Arc<AppState>) {
tracing::warn!("Failed to shutdown HID: {}", e);
}
#[cfg(unix)]
if let Some(msd) = state.msd.write().await.as_mut() {
if let Err(e) = msd.shutdown().await {
tracing::warn!("Failed to shutdown MSD: {}", e);

View File

@@ -1,14 +1,37 @@
//! USB OTG composite gadget (HID + MSD).
#[cfg(unix)]
pub mod configfs;
pub mod endpoint;
#[cfg(unix)]
pub mod function;
#[cfg(unix)]
pub mod hid;
#[cfg(unix)]
pub mod manager;
#[cfg(unix)]
pub mod msd;
pub mod report_desc;
pub mod self_check;
#[cfg(unix)]
pub mod service;
#[cfg(unix)]
pub use manager::{wait_for_hid_devices, OtgGadgetManager};
#[cfg(unix)]
pub use msd::{MsdFunction, MsdLunConfig};
#[cfg(unix)]
pub use service::{HidDevicePaths, OtgService};
/// List USB Device Controller names exposed by sysfs.
pub fn list_udc_devices() -> Vec<String> {
let mut devices: Vec<String> = std::fs::read_dir("/sys/class/udc")
.ok()
.into_iter()
.flat_map(|entries| entries.filter_map(|entry| entry.ok()))
.filter_map(|entry| entry.file_name().to_str().map(str::to_owned))
.collect();
devices.sort();
devices
}

740
src/otg/self_check.rs Normal file
View File

@@ -0,0 +1,740 @@
use serde::Serialize;
use crate::utils::{list_dir_names, read_trimmed};
#[derive(Serialize, Clone, Copy, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum OtgSelfCheckLevel {
Info,
Warn,
Error,
}
#[derive(Serialize)]
pub struct OtgSelfCheckItem {
pub id: &'static str,
pub ok: bool,
pub level: OtgSelfCheckLevel,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub hint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
}
#[derive(Serialize)]
pub struct OtgSelfCheckResponse {
pub overall_ok: bool,
pub error_count: usize,
pub warning_count: usize,
pub hid_backend: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub selected_udc: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bound_udc: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub udc_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub udc_speed: Option<String>,
pub available_udcs: Vec<String>,
pub other_gadgets: Vec<String>,
pub checks: Vec<OtgSelfCheckItem>,
}
fn push_otg_check(
checks: &mut Vec<OtgSelfCheckItem>,
id: &'static str,
ok: bool,
level: OtgSelfCheckLevel,
message: impl Into<String>,
hint: Option<impl Into<String>>,
path: Option<impl Into<String>>,
) {
checks.push(OtgSelfCheckItem {
id,
ok,
level,
message: message.into(),
hint: hint.map(|v| v.into()),
path: path.map(|v| v.into()),
});
}
fn proc_modules_has(module_name: &str) -> bool {
std::fs::read_to_string("/proc/modules")
.ok()
.map(|content| {
content
.lines()
.filter_map(|line| line.split_whitespace().next())
.any(|name| name == module_name)
})
.unwrap_or(false)
}
fn modules_metadata_has(module_name: &str) -> bool {
let kernel_release = match read_trimmed(std::path::Path::new("/proc/sys/kernel/osrelease")) {
Some(value) if !value.is_empty() => value,
_ => return false,
};
let module_dir = std::path::Path::new("/lib/modules").join(kernel_release);
let candidates = ["modules.builtin", "modules.builtin.modinfo", "modules.dep"];
candidates.iter().any(|filename| {
let path = module_dir.join(filename);
std::fs::read_to_string(path)
.ok()
.map(|content| {
let module_token = format!("/{module_name}.ko");
content.lines().any(|line| {
line.contains(&module_token)
|| line.contains(module_name)
|| line.contains(&module_name.replace('_', "-"))
})
})
.unwrap_or(false)
})
}
fn kernel_config_option_enabled(option_name: &str) -> bool {
let kernel_release = match read_trimmed(std::path::Path::new("/proc/sys/kernel/osrelease")) {
Some(value) if !value.is_empty() => value,
_ => return false,
};
let config_paths = [
std::path::PathBuf::from(format!("/boot/config-{kernel_release}")),
std::path::PathBuf::from("/boot/config"),
std::path::PathBuf::from(format!("/lib/modules/{kernel_release}/build/.config")),
];
config_paths.iter().any(|path| {
std::fs::read_to_string(path)
.ok()
.map(|content| {
let enabled_y = format!("{option_name}=y");
let enabled_m = format!("{option_name}=m");
content
.lines()
.any(|line| line == enabled_y || line == enabled_m)
})
.unwrap_or(false)
})
}
fn detect_libcomposite_available(gadget_root: &std::path::Path) -> bool {
let sys_module = std::path::Path::new("/sys/module/libcomposite").exists();
if sys_module {
return true;
}
if proc_modules_has("libcomposite") {
return true;
}
if modules_metadata_has("libcomposite") {
return true;
}
if kernel_config_option_enabled("CONFIG_USB_LIBCOMPOSITE")
|| kernel_config_option_enabled("CONFIG_USB_CONFIGFS")
{
return true;
}
// Fallback: if usb_gadget path exists, libcomposite may be built-in and already active.
gadget_root.exists()
}
/// OTG self-check status for troubleshooting USB gadget issues
pub fn run(config: &crate::config::AppConfig) -> OtgSelfCheckResponse {
let hid_backend_is_otg = matches!(config.hid.backend, crate::config::HidBackend::Otg);
let mut checks = Vec::new();
let build_response = |checks: Vec<OtgSelfCheckItem>,
selected_udc: Option<String>,
bound_udc: Option<String>,
udc_state: Option<String>,
udc_speed: Option<String>,
available_udcs: Vec<String>,
other_gadgets: Vec<String>| {
let error_count = checks
.iter()
.filter(|item| item.level == OtgSelfCheckLevel::Error)
.count();
let warning_count = checks
.iter()
.filter(|item| item.level == OtgSelfCheckLevel::Warn)
.count();
OtgSelfCheckResponse {
overall_ok: error_count == 0,
error_count,
warning_count,
hid_backend: format!("{:?}", config.hid.backend).to_lowercase(),
selected_udc,
bound_udc,
udc_state,
udc_speed,
available_udcs,
other_gadgets,
checks,
}
};
let udc_root = std::path::Path::new("/sys/class/udc");
let available_udcs = list_dir_names(udc_root);
let selected_udc = config
.hid
.otg_udc
.clone()
.filter(|udc| !udc.trim().is_empty())
.or_else(|| available_udcs.first().cloned());
let mut udc_stage_ok = true;
if !udc_root.exists() {
udc_stage_ok = false;
push_otg_check(
&mut checks,
"udc_dir_exists",
false,
OtgSelfCheckLevel::Error,
"Check /sys/class/udc existence",
Some("Ensure UDC/OTG kernel drivers are enabled"),
Some("/sys/class/udc"),
);
} else if available_udcs.is_empty() {
udc_stage_ok = false;
push_otg_check(
&mut checks,
"udc_has_entries",
false,
OtgSelfCheckLevel::Error,
"Check available UDC entries",
Some("Ensure OTG controller is enabled in device tree"),
Some("/sys/class/udc"),
);
} else {
push_otg_check(
&mut checks,
"udc_has_entries",
true,
OtgSelfCheckLevel::Info,
"Check available UDC entries",
None::<String>,
Some("/sys/class/udc"),
);
}
let mut configured_udc_ok = true;
if let Some(config_udc) = config
.hid
.otg_udc
.clone()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
{
if available_udcs.iter().any(|item| item == &config_udc) {
push_otg_check(
&mut checks,
"configured_udc_valid",
true,
OtgSelfCheckLevel::Info,
"Check configured UDC validity",
None::<String>,
Some("/sys/class/udc"),
);
} else {
configured_udc_ok = false;
push_otg_check(
&mut checks,
"configured_udc_valid",
false,
OtgSelfCheckLevel::Error,
"Check configured UDC validity",
Some("Please reselect UDC in HID OTG settings"),
Some("/sys/class/udc"),
);
}
} else {
push_otg_check(
&mut checks,
"configured_udc_valid",
!available_udcs.is_empty(),
if available_udcs.is_empty() {
OtgSelfCheckLevel::Warn
} else {
OtgSelfCheckLevel::Info
},
"Check configured UDC validity",
Some(
"You can set hid_otg_udc in settings to avoid ambiguity in multi-controller setups",
),
Some("/sys/class/udc"),
);
}
if !udc_stage_ok || !configured_udc_ok {
return build_response(
checks,
selected_udc,
None,
None,
None,
available_udcs,
vec![],
);
}
let gadget_root = std::path::Path::new("/sys/kernel/config/usb_gadget");
let configfs_mounted = std::fs::read_to_string("/proc/mounts")
.ok()
.map(|mounts| {
mounts.lines().any(|line| {
let mut parts = line.split_whitespace();
let _src = parts.next();
let mount_point = parts.next();
let fs_type = parts.next();
mount_point == Some("/sys/kernel/config") && fs_type == Some("configfs")
})
})
.unwrap_or(false);
let mut gadget_config_ok = true;
if configfs_mounted {
push_otg_check(
&mut checks,
"configfs_mounted",
true,
OtgSelfCheckLevel::Info,
"Check configfs mount status",
None::<String>,
Some("/sys/kernel/config"),
);
} else {
gadget_config_ok = false;
push_otg_check(
&mut checks,
"configfs_mounted",
false,
OtgSelfCheckLevel::Error,
"Check configfs mount status",
Some("Try: mount -t configfs none /sys/kernel/config"),
Some("/sys/kernel/config"),
);
}
if gadget_root.exists() {
push_otg_check(
&mut checks,
"usb_gadget_dir_exists",
true,
OtgSelfCheckLevel::Info,
"Check /sys/kernel/config/usb_gadget access",
None::<String>,
Some("/sys/kernel/config/usb_gadget"),
);
} else {
gadget_config_ok = false;
push_otg_check(
&mut checks,
"usb_gadget_dir_exists",
false,
OtgSelfCheckLevel::Error,
"Check /sys/kernel/config/usb_gadget access",
Some("Ensure configfs and USB gadget support are enabled"),
Some("/sys/kernel/config/usb_gadget"),
);
}
let libcomposite_available = detect_libcomposite_available(gadget_root);
if libcomposite_available {
push_otg_check(
&mut checks,
"libcomposite_loaded",
true,
OtgSelfCheckLevel::Info,
"Check libcomposite module status",
None::<String>,
Some("/sys/module/libcomposite"),
);
} else {
gadget_config_ok = false;
push_otg_check(
&mut checks,
"libcomposite_loaded",
false,
OtgSelfCheckLevel::Error,
"Check libcomposite module status",
Some("Try: modprobe libcomposite"),
Some("/sys/module/libcomposite"),
);
}
if !gadget_config_ok {
return build_response(
checks,
selected_udc,
None,
None,
None,
available_udcs,
vec![],
);
}
let gadget_names = list_dir_names(gadget_root);
let one_kvm_path = gadget_root.join("one-kvm");
let one_kvm_exists = one_kvm_path.exists();
if one_kvm_exists {
push_otg_check(
&mut checks,
"one_kvm_gadget_exists",
true,
OtgSelfCheckLevel::Info,
"Check one-kvm gadget presence",
None::<String>,
Some(one_kvm_path.display().to_string()),
);
} else {
push_otg_check(
&mut checks,
"one_kvm_gadget_exists",
false,
if hid_backend_is_otg {
OtgSelfCheckLevel::Error
} else {
OtgSelfCheckLevel::Warn
},
"Check one-kvm gadget presence",
Some("Enable OTG HID or MSD to let one-kvm gadget be created automatically"),
Some(one_kvm_path.display().to_string()),
);
}
let other_gadgets = gadget_names
.iter()
.filter(|name| name.as_str() != "one-kvm")
.cloned()
.collect::<Vec<_>>();
if other_gadgets.is_empty() {
push_otg_check(
&mut checks,
"other_gadgets",
true,
OtgSelfCheckLevel::Info,
"Check for other gadget services",
None::<String>,
Some("/sys/kernel/config/usb_gadget"),
);
} else {
push_otg_check(
&mut checks,
"other_gadgets",
false,
OtgSelfCheckLevel::Warn,
"Check for other gadget services",
Some("Potential UDC contention with one-kvm; check other OTG services"),
Some("/sys/kernel/config/usb_gadget"),
);
}
let mut bound_udc = None;
if one_kvm_exists {
let one_kvm_udc_path = one_kvm_path.join("UDC");
let current_udc = read_trimmed(&one_kvm_udc_path).unwrap_or_default();
if current_udc.is_empty() {
push_otg_check(
&mut checks,
"one_kvm_bound_udc",
false,
OtgSelfCheckLevel::Warn,
"Check one-kvm UDC binding",
Some("Ensure HID/MSD is enabled and initialized successfully"),
Some(one_kvm_udc_path.display().to_string()),
);
} else {
push_otg_check(
&mut checks,
"one_kvm_bound_udc",
true,
OtgSelfCheckLevel::Info,
"Check one-kvm UDC binding",
None::<String>,
Some(one_kvm_udc_path.display().to_string()),
);
bound_udc = Some(current_udc);
}
let functions_path = one_kvm_path.join("functions");
let function_names = list_dir_names(&functions_path)
.into_iter()
.filter(|name| name.contains(".usb"))
.collect::<Vec<_>>();
let hid_functions = function_names
.iter()
.filter(|name| name.starts_with("hid.usb"))
.cloned()
.collect::<Vec<_>>();
if hid_functions.is_empty() {
push_otg_check(
&mut checks,
"hid_functions_present",
false,
if hid_backend_is_otg {
OtgSelfCheckLevel::Error
} else {
OtgSelfCheckLevel::Warn
},
"Check HID function creation",
Some("Check OTG HID config and enable at least one HID function"),
Some(functions_path.display().to_string()),
);
} else {
push_otg_check(
&mut checks,
"hid_functions_present",
true,
OtgSelfCheckLevel::Info,
"Check HID function creation",
None::<String>,
Some(functions_path.display().to_string()),
);
}
let config_path = one_kvm_path.join("configs/c.1");
if !config_path.exists() {
push_otg_check(
&mut checks,
"config_c1_exists",
false,
OtgSelfCheckLevel::Error,
"Check configs/c.1 structure",
Some("Gadget structure is incomplete; try restarting One-KVM"),
Some(config_path.display().to_string()),
);
} else {
push_otg_check(
&mut checks,
"config_c1_exists",
true,
OtgSelfCheckLevel::Info,
"Check configs/c.1 structure",
None::<String>,
Some(config_path.display().to_string()),
);
let linked_functions = list_dir_names(&config_path)
.into_iter()
.filter(|name| name.contains(".usb"))
.collect::<Vec<_>>();
let missing_links = function_names
.iter()
.filter(|func| !linked_functions.iter().any(|link| link == *func))
.cloned()
.collect::<Vec<_>>();
if missing_links.is_empty() {
push_otg_check(
&mut checks,
"function_links_ok",
true,
OtgSelfCheckLevel::Info,
"Check function links in configs/c.1",
None::<String>,
Some(config_path.display().to_string()),
);
} else {
push_otg_check(
&mut checks,
"function_links_ok",
false,
OtgSelfCheckLevel::Warn,
"Check function links in configs/c.1",
Some("Reinitialize OTG (toggle HID backend once or restart service)"),
Some(config_path.display().to_string()),
);
}
}
let missing_hid_devices = hid_functions
.iter()
.filter_map(|name| {
let index = name.strip_prefix("hid.usb")?.parse::<u8>().ok()?;
let dev_path = std::path::PathBuf::from(format!("/dev/hidg{}", index));
if dev_path.exists() {
None
} else {
Some(dev_path.display().to_string())
}
})
.collect::<Vec<_>>();
if !hid_functions.is_empty() {
if missing_hid_devices.is_empty() {
push_otg_check(
&mut checks,
"hid_device_nodes",
true,
OtgSelfCheckLevel::Info,
"Check /dev/hidg* device nodes",
None::<String>,
Some("/dev/hidg*"),
);
} else {
push_otg_check(
&mut checks,
"hid_device_nodes",
false,
OtgSelfCheckLevel::Warn,
"Check /dev/hidg* device nodes",
Some("Ensure gadget is bound and check kernel logs"),
Some("/dev/hidg*"),
);
}
}
}
if !other_gadgets.is_empty() {
let check_udc = bound_udc.clone().or_else(|| selected_udc.clone());
if let Some(target_udc) = check_udc {
let conflicting_gadgets = other_gadgets
.iter()
.filter_map(|name| {
let udc_file = gadget_root.join(name).join("UDC");
let udc = read_trimmed(&udc_file)?;
if udc == target_udc {
Some(name.clone())
} else {
None
}
})
.collect::<Vec<_>>();
if conflicting_gadgets.is_empty() {
push_otg_check(
&mut checks,
"udc_conflict",
true,
OtgSelfCheckLevel::Info,
"Check UDC binding conflicts",
None::<String>,
Some("/sys/kernel/config/usb_gadget/*/UDC"),
);
} else {
push_otg_check(
&mut checks,
"udc_conflict",
false,
OtgSelfCheckLevel::Error,
"Check UDC binding conflicts",
Some("Stop other OTG services or switch one-kvm to an idle UDC"),
Some("/sys/kernel/config/usb_gadget/*/UDC"),
);
}
}
}
let active_udc = bound_udc.clone().or_else(|| selected_udc.clone());
let mut udc_state = None;
let mut udc_speed = None;
if let Some(udc) = active_udc.clone() {
let state_path = udc_root.join(&udc).join("state");
match read_trimmed(&state_path) {
Some(state_name) if state_name.eq_ignore_ascii_case("configured") => {
udc_state = Some(state_name.clone());
push_otg_check(
&mut checks,
"udc_state",
true,
OtgSelfCheckLevel::Info,
"Check UDC connection state",
None::<String>,
Some(state_path.display().to_string()),
);
}
Some(state_name) => {
udc_state = Some(state_name.clone());
push_otg_check(
&mut checks,
"udc_state",
false,
OtgSelfCheckLevel::Warn,
"Check UDC connection state",
Some("Ensure target host is connected and has recognized the USB device"),
Some(state_path.display().to_string()),
);
}
None => {
push_otg_check(
&mut checks,
"udc_state",
false,
OtgSelfCheckLevel::Warn,
"Check UDC connection state",
Some("Ensure UDC name is valid and check kernel permissions"),
Some(state_path.display().to_string()),
);
}
}
let speed_path = udc_root.join(&udc).join("current_speed");
if let Some(speed) = read_trimmed(&speed_path) {
udc_speed = Some(speed.clone());
let is_unknown = speed.eq_ignore_ascii_case("unknown");
push_otg_check(
&mut checks,
"udc_speed",
!is_unknown,
if is_unknown {
OtgSelfCheckLevel::Warn
} else {
OtgSelfCheckLevel::Info
},
"Check UDC current link speed",
if is_unknown {
Some("Device may not be fully enumerated; try reconnecting USB".to_string())
} else {
None
},
Some(speed_path.display().to_string()),
);
}
} else {
push_otg_check(
&mut checks,
"udc_state",
false,
OtgSelfCheckLevel::Warn,
"Check UDC connection state",
Some("Ensure UDC is available and one-kvm gadget is bound first"),
Some("/sys/class/udc"),
);
}
let error_count = checks
.iter()
.filter(|item| item.level == OtgSelfCheckLevel::Error)
.count();
let warning_count = checks
.iter()
.filter(|item| item.level == OtgSelfCheckLevel::Warn)
.count();
OtgSelfCheckResponse {
overall_ok: error_count == 0,
error_count,
warning_count,
hid_backend: format!("{:?}", config.hid.backend).to_lowercase(),
selected_udc,
bound_udc,
udc_state,
udc_speed,
available_udcs,
other_gadgets,
checks,
}
}

View File

@@ -0,0 +1,89 @@
//! Runtime platform mode and feature capability reporting.
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PlatformMode {
Linux,
Windows,
}
impl PlatformMode {
pub const fn current() -> Self {
if cfg!(windows) {
Self::Windows
} else {
Self::Linux
}
}
pub const fn label(self) -> &'static str {
match self {
Self::Linux => "Linux",
Self::Windows => "Windows",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureCapability {
pub available: bool,
pub backends: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub selected_backend: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
impl FeatureCapability {
pub fn available(backends: impl IntoIterator<Item = impl Into<String>>) -> Self {
let backends = backends.into_iter().map(Into::into).collect();
Self {
available: true,
backends,
selected_backend: None,
reason: None,
}
}
pub fn unsupported(reason: impl Into<String>) -> Self {
Self {
available: false,
backends: Vec::new(),
selected_backend: None,
reason: Some(reason.into()),
}
}
pub fn with_selected_backend(mut self, backend: Option<String>) -> Self {
self.selected_backend = backend;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlatformCapabilities {
pub mode: PlatformMode,
pub mode_label: &'static str,
pub video_capture: FeatureCapability,
pub encoder: FeatureCapability,
pub hid: FeatureCapability,
pub atx: FeatureCapability,
pub msd: FeatureCapability,
pub otg: FeatureCapability,
pub audio: FeatureCapability,
pub rustdesk: FeatureCapability,
pub diagnostics: FeatureCapability,
pub extensions: FeatureCapability,
pub service_installation: FeatureCapability,
}
impl PlatformCapabilities {
pub fn current() -> Self {
match PlatformMode::current() {
PlatformMode::Linux => crate::platform::linux::capabilities(),
PlatformMode::Windows => crate::platform::windows::capabilities(),
}
}
}

62
src/platform/defaults.rs Normal file
View File

@@ -0,0 +1,62 @@
use crate::config::{AppConfig, AtxDriverType, HidBackend};
pub fn apply(config: &mut AppConfig) {
if cfg!(windows) {
apply_windows(config);
}
}
fn apply_windows(config: &mut AppConfig) {
config.msd.enabled = false;
config.hid.otg_udc = None;
if config.hid.backend == HidBackend::Otg {
config.hid.backend = HidBackend::None;
}
if config.hid.ch9329_port == "/dev/ttyUSB0" {
config.hid.ch9329_port = "COM3".to_string();
}
if !config.initialized {
config.audio.enabled = false;
config.audio.device.clear();
}
if matches!(
config.atx.power.driver,
AtxDriverType::Gpio | AtxDriverType::UsbRelay
) {
config.atx.power.driver = AtxDriverType::None;
}
if matches!(
config.atx.reset.driver,
AtxDriverType::Gpio | AtxDriverType::UsbRelay
) {
config.atx.reset.driver = AtxDriverType::None;
}
if !config.initialized
&& config.atx.power.driver == AtxDriverType::None
&& config.atx.power.device.is_empty()
{
config.atx.power.driver = AtxDriverType::Serial;
config.atx.power.device = "COM4".to_string();
config.atx.power.pin = 1;
config.atx.power.baud_rate = 9600;
}
if !config.initialized
&& config.atx.reset.driver == AtxDriverType::None
&& config.atx.reset.device.is_empty()
{
config.atx.reset.driver = AtxDriverType::Serial;
config.atx.reset.device = "COM4".to_string();
config.atx.reset.pin = 2;
config.atx.reset.baud_rate = 9600;
}
config
.video
.device
.get_or_insert_with(|| "auto".to_string());
config
.video
.format
.get_or_insert_with(|| "MJPEG".to_string());
}

23
src/platform/linux.rs Normal file
View File

@@ -0,0 +1,23 @@
//! Linux platform capabilities.
use super::{FeatureCapability, PlatformCapabilities, PlatformMode};
pub fn capabilities() -> PlatformCapabilities {
PlatformCapabilities {
mode: PlatformMode::Linux,
mode_label: PlatformMode::Linux.label(),
video_capture: FeatureCapability::available(["v4l2"]),
encoder: FeatureCapability::available([
"software", "vaapi", "nvenc", "qsv", "amf", "rkmpp", "v4l2m2m",
]),
hid: FeatureCapability::available(["otg", "ch9329", "none"]),
atx: FeatureCapability::available(["gpio", "usb_relay", "serial", "wol", "none"]),
msd: FeatureCapability::available(["configfs"]),
otg: FeatureCapability::available(["configfs"]),
audio: FeatureCapability::available(["alsa"]),
rustdesk: FeatureCapability::available(["builtin"]),
diagnostics: FeatureCapability::available(["linux"]),
extensions: FeatureCapability::available(["linux"]),
service_installation: FeatureCapability::available(["systemd"]),
}
}

10
src/platform/mod.rs Normal file
View File

@@ -0,0 +1,10 @@
//! Platform selection and capability reporting.
pub mod capabilities;
pub mod defaults;
pub mod linux;
#[cfg(unix)]
pub mod usb_reset;
pub mod windows;
pub use capabilities::{FeatureCapability, PlatformCapabilities, PlatformMode};

View File

@@ -1,14 +1,9 @@
//! USB device enumeration and reset via sysfs `authorized`.
//!
//! Provides APIs for the settings page to list and reset USB devices.
//! Requires write access to `/sys/bus/usb/devices/.../authorized` (typically root).
use std::io;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
/// Walk up from a V4L sysfs `device` link until we find a USB device node
/// (`busnum` + `devnum` present).
fn usb_device_dir_for_v4l_sysfs(device_link: &Path) -> io::Result<PathBuf> {
let mut p = device_link.canonicalize()?;
loop {
@@ -28,57 +23,38 @@ fn usb_device_dir_for_v4l_sysfs(device_link: &Path) -> io::Result<PathBuf> {
}
}
// ---------------------------------------------------------------------------
// USB device enumeration & reset-by-bus/dev (for the settings API)
// ---------------------------------------------------------------------------
use serde::Serialize;
/// Information about a single USB device, read from `/sys/bus/usb/devices/`.
#[derive(Debug, Serialize)]
pub struct UsbDeviceInfo {
/// USB bus number (`busnum` sysfs attribute).
pub bus_num: u32,
/// USB device number on the bus (`devnum` sysfs attribute).
pub dev_num: u32,
/// Vendor ID hex string, e.g. `"1d6b"`.
pub id_vendor: String,
/// Product ID hex string, e.g. `"0002"`.
pub id_product: String,
/// Product name from sysfs `product`.
#[serde(skip_serializing_if = "Option::is_none")]
pub product: Option<String>,
/// Manufacturer name from sysfs `manufacturer`.
#[serde(skip_serializing_if = "Option::is_none")]
pub manufacturer: Option<String>,
/// Speed in Mbps from sysfs `speed`, e.g. `"480"`.
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<String>,
/// `true` if authorized=1, `false` if authorized=0, `None` if no file.
#[serde(skip_serializing_if = "Option::is_none")]
pub authorized: Option<bool>,
/// Kernel driver bound to this device (from driver symlink).
#[serde(skip_serializing_if = "Option::is_none")]
pub driver: Option<String>,
/// Associated `/dev/videoN` node, if this USB device has a V4L2 child.
#[serde(skip_serializing_if = "Option::is_none")]
pub video_device: Option<String>,
}
/// Read a sysfs string attribute, trimming trailing newline.
fn read_sysfs_str(dir: &Path, attr: &str) -> Option<String> {
std::fs::read_to_string(dir.join(attr))
.ok()
.map(|s| s.trim_end().to_string())
}
/// Read a sysfs u32 attribute.
fn read_sysfs_u32(dir: &Path, attr: &str) -> Option<u32> {
read_sysfs_str(dir, attr).and_then(|s| s.parse().ok())
}
/// Build a map from USB sysfs dir → video device name by scanning
/// `/sys/class/video4linux/`.
fn build_usb_to_video_map() -> std::collections::HashMap<String, String> {
let mut map = std::collections::HashMap::new();
let v4l_class = Path::new("/sys/class/video4linux");
@@ -92,7 +68,6 @@ fn build_usb_to_video_map() -> std::collections::HashMap<String, String> {
Some(s) if s.starts_with("video") => s,
_ => continue,
};
// Resolve the device symlink and walk up to find the USB parent
let device_link = v4l_class.join(name_str).join("device");
if let Ok(usb_dir) = usb_device_dir_for_v4l_sysfs(&device_link) {
if let Some(key) = usb_dir.file_name().and_then(|k| k.to_str()) {
@@ -103,7 +78,6 @@ fn build_usb_to_video_map() -> std::collections::HashMap<String, String> {
map
}
/// List all USB devices visible in `/sys/bus/usb/devices/`.
pub fn list_usb_devices() -> Vec<UsbDeviceInfo> {
let usb_bus = Path::new("/sys/bus/usb/devices");
let entries = match std::fs::read_dir(usb_bus) {
@@ -117,7 +91,6 @@ pub fn list_usb_devices() -> Vec<UsbDeviceInfo> {
.flatten()
.filter_map(|entry| {
let dir = entry.path();
// Only consider entries that have busnum + devnum (actual devices, not interfaces)
let bus_num = read_sysfs_u32(&dir, "busnum")?;
let dev_num = read_sysfs_u32(&dir, "devnum")?;
@@ -158,13 +131,10 @@ pub fn list_usb_devices() -> Vec<UsbDeviceInfo> {
})
.collect();
// Sort by bus, then device number for stable ordering.
devices.sort_by(|a, b| (a.bus_num, a.dev_num).cmp(&(b.bus_num, b.dev_num)));
devices
}
/// Reset a USB device identified by bus/dev numbers via the `authorized` sysfs
/// attribute. After re-authorizing, waits for the device to reappear.
pub fn reset_usb_device(bus_num: u32, dev_num: u32) -> io::Result<()> {
let usb_bus = Path::new("/sys/bus/usb/devices");
let entries = std::fs::read_dir(usb_bus)?;
@@ -187,7 +157,6 @@ pub fn reset_usb_device(bus_num: u32, dev_num: u32) -> io::Result<()> {
std::thread::sleep(Duration::from_millis(300));
std::fs::write(&authorized, b"1")?;
// Wait for device to reappear
let wait_until = Instant::now() + Duration::from_secs(2);
while !dir.join("busnum").exists() {
if Instant::now() >= wait_until {

33
src/platform/windows.rs Normal file
View File

@@ -0,0 +1,33 @@
//! Windows platform capabilities.
use super::{FeatureCapability, PlatformCapabilities, PlatformMode};
pub fn capabilities() -> PlatformCapabilities {
let linux_only = "unsupported on Windows";
PlatformCapabilities {
mode: PlatformMode::Windows,
mode_label: PlatformMode::Windows.label(),
video_capture: FeatureCapability::available(["directshow_uvc", "mjpeg"])
.with_selected_backend(Some("directshow_uvc".to_string())),
encoder: FeatureCapability::available([
"ffmpeg_h264",
"ffmpeg_h265",
"ffmpeg_vp8",
"ffmpeg_vp9",
"software",
"mjpeg",
]),
hid: FeatureCapability::available(["ch9329", "none"])
.with_selected_backend(Some("ch9329".to_string())),
atx: FeatureCapability::available(["serial", "wol", "none"]),
msd: FeatureCapability::unsupported(linux_only),
otg: FeatureCapability::unsupported(linux_only),
audio: FeatureCapability::available(["wasapi", "opus"])
.with_selected_backend(Some("wasapi".to_string())),
rustdesk: FeatureCapability::available(["builtin", "tcp_direct", "relay"])
.with_selected_backend(Some("builtin".to_string())),
diagnostics: FeatureCapability::available(["windows"]),
extensions: FeatureCapability::available(["windows_safe"]),
service_installation: FeatureCapability::available(["windows_service"]),
}
}

View File

@@ -62,10 +62,8 @@ pub async fn redfish_auth_middleware(
}
fn is_redfish_public_endpoint(path: &str, method: &Method) -> bool {
matches!(
path,
"/" | "/v1" | "/v1/" | "/v1/odata"
) || path.starts_with("/v1/$metadata")
matches!(path, "/" | "/v1" | "/v1/" | "/v1/odata")
|| path.starts_with("/v1/$metadata")
|| (path == "/v1/SessionService/Sessions" && *method == Method::POST)
}

View File

@@ -8,29 +8,20 @@ use axum::{
use std::sync::Arc;
use super::{empty_collection, resource_not_found};
use super::super::schema::*;
use super::{empty_collection, resource_not_found};
use crate::state::AppState;
pub(crate) fn router(state: Arc<AppState>) -> Router<Arc<AppState>> {
Router::new()
.route("/v1/AccountService", get(account_service))
.route(
"/v1/AccountService/Accounts",
get(account_list),
)
.route("/v1/AccountService/Accounts", get(account_list))
.route(
"/v1/AccountService/Accounts/{account_id}",
get(account_detail),
)
.route(
"/v1/AccountService/Roles",
get(roles_stub),
)
.route(
"/v1/AccountService/Roles/{role_id}",
get(role_detail_stub),
)
.route("/v1/AccountService/Roles", get(roles_stub))
.route("/v1/AccountService/Roles/{role_id}", get(role_detail_stub))
.with_state(state)
}
@@ -77,7 +68,10 @@ async fn account_list(State(state): State<Arc<AppState>>) -> Response {
"/redfish/v1/$metadata#ManagerAccountCollection.ManagerAccountCollection",
"Accounts Collection",
"Collection of Accounts",
vec![odata_ref(&format!("/redfish/v1/AccountService/Accounts/{}", user.id))],
vec![odata_ref(&format!(
"/redfish/v1/AccountService/Accounts/{}",
user.id
))],
))
.into_response()
}

View File

@@ -7,8 +7,8 @@ use axum::{
use std::sync::Arc;
use super::{empty_collection, get_power_state, validate_id, RESOURCE_ID};
use super::super::schema::*;
use super::{empty_collection, get_power_state, validate_id, RESOURCE_ID};
use crate::state::AppState;
pub(crate) fn router(state: Arc<AppState>) -> Router<Arc<AppState>> {
@@ -64,9 +64,7 @@ async fn chassis_detail(
.into_response()
}
async fn chassis_power(
Path(chassis_id): Path<String>,
) -> Response {
async fn chassis_power(Path(chassis_id): Path<String>) -> Response {
if let Some(resp) = validate_id(&chassis_id) {
return resp;
}

View File

@@ -48,8 +48,7 @@ async fn event_service() -> Json<EventService> {
server_sent_event_uri: Some("/redfish/v1/EventService/SSE".to_string()),
actions: EventServiceActions {
submit_test_event: ActionTarget {
target: "/redfish/v1/EventService/Actions/EventService.SubmitTestEvent"
.to_string(),
target: "/redfish/v1/EventService/Actions/EventService.SubmitTestEvent".to_string(),
},
},
})

View File

@@ -7,8 +7,8 @@ use axum::{
use std::sync::Arc;
use super::{empty_collection, validate_id, RESOURCE_ID};
use super::super::schema::*;
use super::{empty_collection, validate_id, RESOURCE_ID};
use crate::state::AppState;
pub(crate) fn router(state: Arc<AppState>) -> Router<Arc<AppState>> {
@@ -86,7 +86,10 @@ async fn manager_detail(
manager_for_servers: vec![odata_ref(&format!("/redfish/v1/Systems/{}", RESOURCE_ID))],
manager_for_chassis: vec![odata_ref(&format!("/redfish/v1/Chassis/{}", RESOURCE_ID))],
},
network_protocol: odata_ref(&format!("/redfish/v1/Managers/{}/NetworkProtocol", manager_id)),
network_protocol: odata_ref(&format!(
"/redfish/v1/Managers/{}/NetworkProtocol",
manager_id
)),
})
.into_response()
}

View File

@@ -4,6 +4,7 @@ mod event;
mod managers;
mod session;
mod systems;
#[cfg(unix)]
mod virtual_media;
use axum::{
@@ -191,7 +192,6 @@ pub fn create_redfish_router(state: Arc<AppState>) -> Router {
.merge(systems::router(state.clone()))
.merge(chassis::router(state.clone()))
.merge(managers::router(state.clone()))
.merge(virtual_media::router(state.clone()))
.merge(session::router(state.clone()))
.merge(account::router(state.clone()))
.merge(event::router(state.clone()))
@@ -200,6 +200,9 @@ pub fn create_redfish_router(state: Arc<AppState>) -> Router {
redfish_auth_middleware,
));
#[cfg(unix)]
let redfish_routes = redfish_routes.merge(virtual_media::router(state.clone()));
Router::new()
.route("/redfish", get(service_root_redirect))
.nest("/redfish/", redfish_routes)

View File

@@ -9,8 +9,8 @@ use tracing::info;
use std::sync::Arc;
use super::empty_collection;
use super::super::schema::*;
use super::empty_collection;
use crate::state::AppState;
pub(crate) fn router(state: Arc<AppState>) -> Router<Arc<AppState>> {
@@ -56,7 +56,10 @@ async fn session_list(State(state): State<Arc<AppState>>) -> Response {
let mut members = Vec::new();
for id in &session_ids {
if state.sessions.get(id).await.ok().flatten().is_some() {
members.push(odata_ref(&format!("/redfish/v1/SessionService/Sessions/{}", id)));
members.push(odata_ref(&format!(
"/redfish/v1/SessionService/Sessions/{}",
id
)));
}
}

View File

@@ -8,8 +8,8 @@ use axum::{
use std::sync::Arc;
use tracing::info;
use super::{get_power_state, validate_id, service_unavailable, empty_collection, RESOURCE_ID};
use super::super::schema::*;
use super::{empty_collection, get_power_state, service_unavailable, validate_id, RESOURCE_ID};
use crate::state::AppState;
pub(crate) fn router(state: Arc<AppState>) -> Router<Arc<AppState>> {
@@ -208,13 +208,14 @@ async fn system_reset(
}
}
async fn system_set_default_boot_order(
Path(system_id): Path<String>,
) -> Response {
async fn system_set_default_boot_order(Path(system_id): Path<String>) -> Response {
if let Some(resp) = validate_id(&system_id) {
return resp;
}
info!("Redfish: SetDefaultBootOrder for system {} (accepted, no-op)", system_id);
info!(
"Redfish: SetDefaultBootOrder for system {} (accepted, no-op)",
system_id
);
StatusCode::NO_CONTENT.into_response()
}

View File

@@ -9,8 +9,8 @@ use tracing::{info, warn};
use std::sync::Arc;
use super::{empty_collection, validate_id, service_unavailable, resource_not_found, RESOURCE_ID};
use super::super::schema::*;
use super::{empty_collection, resource_not_found, service_unavailable, validate_id, RESOURCE_ID};
use crate::state::AppState;
pub(crate) fn router(state: Arc<AppState>) -> Router<Arc<AppState>> {
@@ -95,7 +95,10 @@ async fn virtual_media_detail(
Json(VirtualMedia {
odata_type: "#VirtualMedia.v1_6_2.VirtualMedia".to_string(),
odata_id: format!("/redfish/v1/Managers/{}/VirtualMedia/{}", manager_id, media_id),
odata_id: format!(
"/redfish/v1/Managers/{}/VirtualMedia/{}",
manager_id, media_id
),
odata_context: "/redfish/v1/$metadata#VirtualMedia.VirtualMedia".to_string(),
id: media_id.clone(),
name: "Virtual Media 1".to_string(),
@@ -153,7 +156,9 @@ async fn virtual_media_insert(
if msd.state().await.connected {
return (
StatusCode::CONFLICT,
Json(RedfishError::general_error("Virtual media already inserted")),
Json(RedfishError::general_error(
"Virtual media already inserted",
)),
)
.into_response();
}

View File

@@ -527,7 +527,10 @@ pub struct RedfishError {
pub struct RedfishErrorBody {
pub code: String,
pub message: String,
#[serde(rename = "@Message.ExtendedInfo", skip_serializing_if = "Vec::is_empty")]
#[serde(
rename = "@Message.ExtendedInfo",
skip_serializing_if = "Vec::is_empty"
)]
pub extended_info: Vec<RedfishExtendedInfo>,
}

View File

@@ -1,7 +1,7 @@
use bytes::Bytes;
use crate::video::encoder::registry::VideoEncoderType;
use crate::video::shared_video_pipeline::EncodedVideoFrame;
use crate::video::codec::registry::VideoEncoderType;
use crate::video::pipeline::EncodedVideoFrame;
use super::state::ParameterSets;

View File

@@ -1,5 +1,5 @@
use crate::config::RtspCodec;
use crate::video::encoder::VideoCodecType;
use crate::video::codec::VideoCodecType;
pub(crate) fn rtsp_codec_to_video(codec: RtspCodec) -> VideoCodecType {
match codec {

View File

@@ -2,8 +2,10 @@ use base64::Engine;
use sdp_types as sdp;
use crate::config::RtspConfig;
use crate::video::encoder::VideoCodecType;
use crate::webrtc::rtp::parse_profile_level_id_from_sps;
use crate::video::codec::h264_bitstream::{
parse_profile_level_id_from_sps, FALLBACK_WEBRTC_PROFILE_LEVEL_ID,
};
use crate::video::codec::VideoCodecType;
use super::state::ParameterSets;
@@ -15,7 +17,10 @@ pub(crate) fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> Strin
attrs.push(format!("profile-level-id={}", profile_level_id));
}
} else {
attrs.push("profile-level-id=42e01f".to_string());
attrs.push(format!(
"profile-level-id={}",
FALLBACK_WEBRTC_PROFILE_LEVEL_ID
));
}
if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) {

View File

@@ -11,8 +11,8 @@ use webrtc::util::{Marshal, MarshalSize};
use crate::config::RtspCodec;
use crate::error::{AppError, Result};
use crate::video::encoder::registry::VideoEncoderType;
use crate::video::shared_video_pipeline::EncodedVideoFrame;
use crate::video::codec::registry::VideoEncoderType;
use crate::video::pipeline::EncodedVideoFrame;
use crate::video::VideoStreamManager;
use crate::webrtc::h265_payloader::H265Payloader;

View File

@@ -16,11 +16,11 @@ use tracing::{debug, error, info, warn};
use crate::audio::AudioController;
use crate::hid::{CanonicalKey, HidController, KeyEventType, KeyboardEvent, KeyboardModifiers};
use crate::utils::hostname_from_etc;
use crate::video::codec::registry::{EncoderRegistry, VideoEncoderType};
use crate::video::codec::BitratePreset;
use crate::video::codec_constraints::{
encoder_codec_to_id, encoder_codec_to_video_codec, video_codec_to_encoder_codec,
};
use crate::video::encoder::registry::{EncoderRegistry, VideoEncoderType};
use crate::video::encoder::BitratePreset;
use crate::video::stream_manager::VideoStreamManager;
use super::bytes_codec::{read_frame, write_frame, write_frame_buffered};
@@ -637,22 +637,22 @@ impl Connection {
// Check availability in priority order
// H264 is preferred because it has the best hardware encoder support (RKMPP, VAAPI, etc.)
// and most RustDesk clients support H264 hardware decoding
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H264)
if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H264)
&& registry.is_codec_available(VideoEncoderType::H264)
{
return VideoEncoderType::H264;
}
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H265)
if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H265)
&& registry.is_codec_available(VideoEncoderType::H265)
{
return VideoEncoderType::H265;
}
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP8)
if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP8)
&& registry.is_codec_available(VideoEncoderType::VP8)
{
return VideoEncoderType::VP8;
}
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP9)
if constraints.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP9)
&& registry.is_codec_available(VideoEncoderType::VP9)
{
return VideoEncoderType::VP9;
@@ -1106,16 +1106,16 @@ impl Connection {
// Check which encoders are available (include software fallback)
let h264_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H264)
.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H264)
&& registry.is_codec_available(VideoEncoderType::H264);
let h265_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H265)
.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::H265)
&& registry.is_codec_available(VideoEncoderType::H265);
let vp8_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP8)
.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP8)
&& registry.is_codec_available(VideoEncoderType::VP8);
let vp9_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP9)
.is_webrtc_codec_allowed(crate::video::codec::VideoCodecType::VP9)
&& registry.is_codec_available(VideoEncoderType::VP9);
info!(
@@ -1574,7 +1574,7 @@ async fn run_video_streaming(
shutdown_tx: broadcast::Sender<()>,
negotiated_codec: VideoEncoderType,
) -> anyhow::Result<()> {
use crate::video::encoder::VideoCodecType;
use crate::video::codec::VideoCodecType;
// Convert VideoEncoderType to VideoCodecType for the pipeline
let webrtc_codec = match negotiated_codec {

View File

@@ -91,7 +91,7 @@ impl VideoFrameAdapter {
return data;
}
let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data);
let (sps, pps) = crate::video::codec::h264_bitstream::extract_sps_pps(&data);
let mut has_sps = false;
let mut has_pps = false;

View File

@@ -696,6 +696,7 @@ fn try_into_v4(addr: SocketAddr) -> SocketAddr {
addr
}
#[cfg(target_os = "linux")]
fn is_virtual_interface(name: &str) -> bool {
name.starts_with("docker")
|| name.starts_with("br-")

View File

@@ -12,7 +12,9 @@ use crate::events::{
};
use crate::extensions::{ExtensionId, ExtensionManager};
use crate::hid::HidController;
#[cfg(unix)]
use crate::msd::MsdController;
#[cfg(unix)]
use crate::otg::OtgService;
use crate::rtsp::RtspService;
use crate::rustdesk::RustDeskService;
@@ -51,10 +53,12 @@ pub struct AppState {
pub config: ConfigStore,
pub sessions: SessionStore,
pub users: UserStore,
#[cfg(unix)]
pub otg_service: Arc<OtgService>,
pub stream_manager: Arc<VideoStreamManager>,
pub webrtc: Arc<WebRtcStreamer>,
pub hid: Arc<HidController>,
#[cfg(unix)]
pub msd: Arc<RwLock<Option<MsdController>>>,
pub atx: Arc<RwLock<Option<AtxController>>>,
pub audio: Arc<AudioController>,
@@ -77,11 +81,11 @@ impl AppState {
config: ConfigStore,
sessions: SessionStore,
users: UserStore,
otg_service: Arc<OtgService>,
#[cfg(unix)] otg_service: Arc<OtgService>,
stream_manager: Arc<VideoStreamManager>,
webrtc: Arc<WebRtcStreamer>,
hid: Arc<HidController>,
msd: Option<MsdController>,
#[cfg(unix)] msd: Option<MsdController>,
atx: Option<AtxController>,
audio: Arc<AudioController>,
rustdesk: Option<Arc<RustDeskService>>,
@@ -99,10 +103,12 @@ impl AppState {
config,
sessions,
users,
#[cfg(unix)]
otg_service,
stream_manager,
webrtc,
hid,
#[cfg(unix)]
msd: Arc::new(RwLock::new(msd)),
atx: Arc::new(RwLock::new(atx)),
audio,
@@ -202,23 +208,30 @@ impl AppState {
}
async fn collect_msd_info(&self) -> Option<MsdDeviceInfo> {
let msd_guard = self.msd.read().await;
let msd = msd_guard.as_ref()?;
#[cfg(not(unix))]
{
None
}
#[cfg(unix)]
{
let msd_guard = self.msd.read().await;
let msd = msd_guard.as_ref()?;
let state = msd.state().await;
let error = msd.monitor().error_message().await;
Some(MsdDeviceInfo {
available: state.available,
mode: match state.mode {
crate::msd::MsdMode::None => "none",
crate::msd::MsdMode::Image => "image",
crate::msd::MsdMode::Drive => "drive",
}
.to_string(),
connected: state.connected,
image_id: state.current_image.map(|img| img.id),
error,
})
let state = msd.state().await;
let error = msd.monitor().error_message().await;
Some(MsdDeviceInfo {
available: state.available,
mode: match state.mode {
crate::msd::MsdMode::None => "none",
crate::msd::MsdMode::Image => "image",
crate::msd::MsdMode::Drive => "drive",
}
.to_string(),
connected: state.connected,
image_id: state.current_image.map(|img| img.id),
error,
})
}
}
async fn collect_atx_info(&self) -> Option<AtxDeviceInfo> {

View File

@@ -11,8 +11,8 @@ use tracing::{debug, info, warn};
/// Generation token paired with `client_id` so [`unregister_client`] ignores stale drops.
pub type ClientGeneration = u64;
use crate::video::encoder::traits::{Encoder, EncoderConfig};
use crate::video::encoder::JpegEncoder;
use crate::video::codec::traits::{Encoder, EncoderConfig};
use crate::video::codec::JpegEncoder;
use crate::video::format::PixelFormat;
use crate::video::VideoFrame;

View File

@@ -1,7 +1,7 @@
//! `EncoderType` → `EncoderBackend` (breaks config ↔ video import cycles).
use crate::config::EncoderType;
use crate::video::encoder::EncoderBackend;
use crate::video::codec::EncoderBackend;
/// `None` means “auto” in WebRTC / pipeline (same as `EncoderType::Auto`).
pub fn encoder_type_to_backend(encoder: EncoderType) -> Option<EncoderBackend> {

View File

@@ -9,7 +9,15 @@ pub fn hostname_from_etc() -> String {
/// Current kernel hostname (`gethostname`). Used for live device info in the UI.
pub fn hostname_uname() -> String {
nix::unistd::gethostname()
.map(|s| s.to_string_lossy().into_owned())
.unwrap_or_else(|_| "unknown".to_string())
#[cfg(unix)]
{
nix::unistd::gethostname()
.map(|s| s.to_string_lossy().into_owned())
.unwrap_or_else(|_| "unknown".to_string())
}
#[cfg(not(unix))]
{
std::env::var("COMPUTERNAME").unwrap_or_else(|_| "unknown".to_string())
}
}

View File

@@ -2,10 +2,16 @@
pub mod fs;
pub mod host;
#[cfg(unix)]
pub mod net;
#[cfg(not(unix))]
#[path = "net_disabled.rs"]
pub mod net;
pub mod serial;
pub mod throttle;
pub use fs::{list_dir_names, read_trimmed};
pub use host::{hostname_from_etc, hostname_uname};
pub use net::{bind_tcp_listener, bind_udp_socket};
pub use serial::list_serial_ports;
pub use throttle::LogThrottler;

14
src/utils/net_disabled.rs Normal file
View File

@@ -0,0 +1,14 @@
use std::io;
use std::net::{SocketAddr, TcpListener, UdpSocket};
pub fn bind_tcp_listener(addr: SocketAddr) -> io::Result<TcpListener> {
let listener = TcpListener::bind(addr)?;
listener.set_nonblocking(true)?;
Ok(listener)
}
pub fn bind_udp_socket(addr: SocketAddr) -> io::Result<UdpSocket> {
let socket = UdpSocket::bind(addr)?;
socket.set_nonblocking(true)?;
Ok(socket)
}

12
src/utils/serial.rs Normal file
View File

@@ -0,0 +1,12 @@
//! Cross-platform serial port discovery helpers.
/// Return serial port names that users can put directly into the config.
pub fn list_serial_ports() -> Vec<String> {
let mut ports: Vec<String> = serialport::available_ports()
.map(|ports| ports.into_iter().map(|port| port.port_name).collect())
.unwrap_or_default();
ports.sort();
ports.dedup();
ports
}

View File

@@ -22,9 +22,9 @@ use v4l2r::nix::errno::Errno;
use v4l2r::{Format as V4l2rFormat, PixelFormat as V4l2rPixelFormat, QueueType};
use crate::error::{AppError, Result};
use crate::video::csi_bridge::{self, CsiBridgeKind, ProbeResult};
use crate::video::device::bridge::{self as csi_bridge, CsiBridgeKind, ProbeResult};
use crate::video::format::{PixelFormat, Resolution};
use crate::video::SignalStatus;
use crate::video::signal::SignalStatus;
/// `io::Error` payload when the driver posts `V4L2_EVENT_SOURCE_CHANGE`.
pub const SOURCE_CHANGED_MARKER: &str = "v4l2_source_changed";
@@ -60,7 +60,7 @@ impl BridgeContext {
}
/// V4L2 capture stream backed by v4l2r ioctl.
pub struct V4l2rCaptureStream {
pub struct CaptureStream {
fd: File,
queue: QueueType,
resolution: Resolution,
@@ -72,7 +72,7 @@ pub struct V4l2rCaptureStream {
bridge_kind: Option<CsiBridgeKind>,
}
impl V4l2rCaptureStream {
impl CaptureStream {
/// UVC: uses `resolution`. CSI bridges: DV-probe first; may return `CaptureNoSignal`.
pub fn open(
device_path: impl AsRef<Path>,
@@ -538,7 +538,7 @@ impl V4l2rCaptureStream {
}
}
impl Drop for V4l2rCaptureStream {
impl Drop for CaptureStream {
fn drop(&mut self) {
// Release ordering matters on rkcif: a subsequent open()/S_FMT from a
// freshly-constructed stream returns EBUSY if the previous capture has
@@ -571,9 +571,9 @@ impl Drop for V4l2rCaptureStream {
}
/// Driver-name check for CSI/HDMI bridge devices (rk_hdmirx, rkcif, tc358743,
/// …) that expose DV timings. Kept in sync with `video::is_csi_hdmi_bridge`
/// …) that expose DV timings. Kept in sync with `video::device::is_csi_hdmi_bridge`
/// but queries the raw V4L2 driver string so we don't need a full
/// `VideoDeviceInfo` at `V4l2rCaptureStream::open` time.
/// `VideoDeviceInfo` at `CaptureStream::open` time.
fn is_csi_bridge_driver(driver: &str) -> bool {
let d = driver.to_ascii_lowercase();
d == "rk_hdmirx" || d == "rkcif" || d == "tc358743" || d.starts_with("rkcif")

12
src/video/capture/mod.rs Normal file
View File

@@ -0,0 +1,12 @@
//! Video capture implementations and capture-state helpers.
pub(crate) mod runtime;
pub(crate) mod status;
#[cfg(unix)]
mod linux;
#[cfg(windows)]
#[path = "windows.rs"]
mod linux;
pub use linux::*;

View File

@@ -0,0 +1,70 @@
use std::path::Path;
use std::time::Duration;
use crate::error::AppError;
use crate::video::capture::status::signal_status_from_capture_kind;
use crate::video::format::{PixelFormat, Resolution};
use crate::video::signal::SignalStatus;
use super::{BridgeContext, CaptureStream};
pub enum CaptureOpenResult {
Opened(CaptureStream),
NoSignal(SignalStatus),
DeviceLost(String),
Fatal,
}
pub fn open_capture_stream(
device_path: &Path,
resolution: Resolution,
format: PixelFormat,
fps: u32,
buffer_count: u32,
timeout: Duration,
bridge_ctx: BridgeContext,
) -> Result<CaptureStream, AppError> {
CaptureStream::open_with_bridge(
device_path,
resolution,
format,
fps,
buffer_count.max(1),
timeout,
bridge_ctx,
)
}
pub fn open_capture_stream_for_retry(
device_path: &Path,
resolution: Resolution,
format: PixelFormat,
fps: u32,
buffer_count: u32,
timeout: Duration,
bridge_ctx: BridgeContext,
is_device_lost_message: impl FnOnce(&str) -> bool,
) -> CaptureOpenResult {
match open_capture_stream(
device_path,
resolution,
format,
fps,
buffer_count,
timeout,
bridge_ctx,
) {
Ok(stream) => CaptureOpenResult::Opened(stream),
Err(AppError::CaptureNoSignal { kind }) => {
CaptureOpenResult::NoSignal(signal_status_from_capture_kind(&kind))
}
Err(error) => {
let reason = error.to_string();
if is_device_lost_message(&reason) {
CaptureOpenResult::DeviceLost(reason)
} else {
CaptureOpenResult::Fatal
}
}
}
}

View File

@@ -2,7 +2,7 @@
use std::io;
use crate::video::SignalStatus;
use crate::video::signal::SignalStatus;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureIoErrorKind {

View File

@@ -0,0 +1,181 @@
use std::io;
use std::path::{Path, PathBuf};
use std::time::Duration;
use crate::error::{AppError, Result};
use crate::video::device::bridge::{CsiBridgeKind, ProbeResult};
use crate::video::device::{directshow_display_name_from_path, normalize_windows_device_path};
use crate::video::format::{PixelFormat, Resolution};
pub const SOURCE_CHANGED_MARKER: &str = "dshow_source_changed";
pub fn is_source_changed_error(err: &io::Error) -> bool {
err.get_ref()
.map(|inner| inner.to_string() == SOURCE_CHANGED_MARKER)
.unwrap_or(false)
}
#[derive(Debug, Clone, Copy)]
pub struct CaptureMeta {
pub bytes_used: usize,
pub sequence: u64,
}
#[derive(Debug, Clone, Default)]
pub struct BridgeContext {
pub subdev_path: Option<PathBuf>,
pub kind: Option<CsiBridgeKind>,
}
impl BridgeContext {
pub fn from_parts(subdev_path: Option<PathBuf>, kind: Option<CsiBridgeKind>) -> Self {
Self { subdev_path, kind }
}
pub fn has_subdev(&self) -> bool {
false
}
}
pub struct CaptureStream {
capture: hwcodec::capture::DshowCapture,
resolution: Resolution,
format: PixelFormat,
stride: u32,
}
unsafe impl Send for CaptureStream {}
impl CaptureStream {
pub fn open(
device_path: impl AsRef<Path>,
resolution: Resolution,
format: PixelFormat,
fps: u32,
buffer_count: u32,
timeout: Duration,
) -> Result<Self> {
let _ = buffer_count;
let path = normalize_windows_device_path(device_path);
let display_name = directshow_display_name_from_path(&path).ok_or_else(|| {
AppError::VideoError(format!(
"Unsupported DirectShow device path: {}",
path.display()
))
})?;
let capture = hwcodec::capture::DshowCapture::open(
&display_name,
resolution.width as i32,
resolution.height as i32,
fps as i32,
map_pixel_format(format),
timeout.as_millis().clamp(1, i32::MAX as u128) as i32,
)
.map_err(|e| AppError::VideoError(format!("Failed to open DirectShow capture: {}", e)))?;
let info = capture.info().map_err(|e| {
AppError::VideoError(format!("Failed to query DirectShow capture: {}", e))
})?;
let actual_format = map_capture_format(info.pixel_format)?;
let actual_resolution =
Resolution::new(info.width.max(1) as u32, info.height.max(1) as u32);
Ok(Self {
capture,
resolution: actual_resolution,
format: actual_format,
stride: info.stride.max(0) as u32,
})
}
pub fn open_with_bridge(
device_path: impl AsRef<Path>,
resolution: Resolution,
format: PixelFormat,
fps: u32,
buffer_count: u32,
timeout: Duration,
bridge: BridgeContext,
) -> Result<Self> {
let _ = bridge;
Self::open(device_path, resolution, format, fps, buffer_count, timeout)
}
pub fn resolution(&self) -> Resolution {
self.resolution
}
pub fn format(&self) -> PixelFormat {
self.format
}
pub fn stride(&self) -> u32 {
self.stride
}
pub fn next_into(&mut self, dst: &mut Vec<u8>) -> io::Result<CaptureMeta> {
match self.capture.read_packet() {
Ok((packet, sequence)) => {
dst.clear();
dst.extend_from_slice(&packet);
Ok(CaptureMeta {
bytes_used: packet.len(),
sequence,
})
}
Err(err) => {
let kind = if err.code == -110 {
io::ErrorKind::TimedOut
} else {
io::ErrorKind::Other
};
Err(io::Error::new(kind, err.message))
}
}
}
pub fn probe_bridge_signal_with_timeout(&self, _limit: Duration) -> Option<ProbeResult> {
None
}
}
fn map_pixel_format(format: PixelFormat) -> hwcodec::capture::CapturePixelFormat {
match format {
PixelFormat::Mjpeg => hwcodec::capture::CapturePixelFormat::Mjpeg,
PixelFormat::Jpeg => hwcodec::capture::CapturePixelFormat::Jpeg,
PixelFormat::Yuyv => hwcodec::capture::CapturePixelFormat::Yuyv,
PixelFormat::Yvyu => hwcodec::capture::CapturePixelFormat::Yvyu,
PixelFormat::Uyvy => hwcodec::capture::CapturePixelFormat::Uyvy,
PixelFormat::Nv12 => hwcodec::capture::CapturePixelFormat::Nv12,
PixelFormat::Nv21 => hwcodec::capture::CapturePixelFormat::Nv21,
PixelFormat::Nv16 => hwcodec::capture::CapturePixelFormat::Nv16,
PixelFormat::Nv24 => hwcodec::capture::CapturePixelFormat::Nv24,
PixelFormat::Yuv420 => hwcodec::capture::CapturePixelFormat::Yuv420,
PixelFormat::Yvu420 => hwcodec::capture::CapturePixelFormat::Yvu420,
PixelFormat::Rgb24 => hwcodec::capture::CapturePixelFormat::Rgb24,
PixelFormat::Bgr24 => hwcodec::capture::CapturePixelFormat::Bgr24,
PixelFormat::Grey => hwcodec::capture::CapturePixelFormat::Grey,
PixelFormat::Rgb565 => hwcodec::capture::CapturePixelFormat::Unknown,
}
}
fn map_capture_format(format: hwcodec::capture::CapturePixelFormat) -> Result<PixelFormat> {
match format {
hwcodec::capture::CapturePixelFormat::Mjpeg => Ok(PixelFormat::Mjpeg),
hwcodec::capture::CapturePixelFormat::Jpeg => Ok(PixelFormat::Jpeg),
hwcodec::capture::CapturePixelFormat::Yuyv => Ok(PixelFormat::Yuyv),
hwcodec::capture::CapturePixelFormat::Yvyu => Ok(PixelFormat::Yvyu),
hwcodec::capture::CapturePixelFormat::Uyvy => Ok(PixelFormat::Uyvy),
hwcodec::capture::CapturePixelFormat::Nv12 => Ok(PixelFormat::Nv12),
hwcodec::capture::CapturePixelFormat::Nv21 => Ok(PixelFormat::Nv21),
hwcodec::capture::CapturePixelFormat::Nv16 => Ok(PixelFormat::Nv16),
hwcodec::capture::CapturePixelFormat::Nv24 => Ok(PixelFormat::Nv24),
hwcodec::capture::CapturePixelFormat::Yuv420 => Ok(PixelFormat::Yuv420),
hwcodec::capture::CapturePixelFormat::Yvu420 => Ok(PixelFormat::Yvu420),
hwcodec::capture::CapturePixelFormat::Rgb24 => Ok(PixelFormat::Rgb24),
hwcodec::capture::CapturePixelFormat::Bgr24 => Ok(PixelFormat::Bgr24),
hwcodec::capture::CapturePixelFormat::Grey => Ok(PixelFormat::Grey),
hwcodec::capture::CapturePixelFormat::Unknown => Err(AppError::ServiceUnavailable(
"DirectShow returned an unsupported pixel format".to_string(),
)),
}
}

View File

@@ -1,30 +0,0 @@
//! Shared tuning for V4L2 MJPEG capture paths (`Streamer` + `SharedVideoPipeline`).
/// Frames smaller than this are treated as incomplete / noise.
pub(crate) const MIN_CAPTURE_FRAME_SIZE: usize = 128;
/// After startup, validate JPEG header every N frames to limit CPU use.
pub(crate) const JPEG_VALIDATE_INTERVAL: u64 = 30;
/// Validate every MJPEG frame for the first N frames (UVC warm-up / bad headers).
pub(crate) const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3;
#[inline]
pub(crate) fn should_validate_jpeg_frame(validate_counter: u64) -> bool {
validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES
|| validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jpeg_validation_policy_startup_then_interval() {
assert!(should_validate_jpeg_frame(1));
assert!(should_validate_jpeg_frame(2));
assert!(should_validate_jpeg_frame(3));
assert!(!should_validate_jpeg_frame(4));
assert!(should_validate_jpeg_frame(30));
}
}

View File

@@ -0,0 +1,299 @@
//! H.264 Annex-B/AVCC bitstream helpers shared by WebRTC, RTSP and RustDesk.
pub const FALLBACK_WEBRTC_PROFILE_LEVEL_ID: &str = "42e01f";
pub fn webrtc_fmtp_line(profile_level_id: &str) -> String {
format!(
"level-asymmetry-allowed=1;packetization-mode=1;profile-level-id={}",
profile_level_id
)
}
pub fn fallback_webrtc_fmtp_line() -> String {
webrtc_fmtp_line(FALLBACK_WEBRTC_PROFILE_LEVEL_ID)
}
pub fn strip_aud_nal_units(data: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len());
let mut i = 0;
while i < data.len() {
let (start_code_pos, start_code_len) = if i + 4 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 0
&& data[i + 3] == 1
{
(i, 4)
} else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
(i, 3)
} else {
i += 1;
continue;
};
let nal_start = start_code_pos + start_code_len;
if nal_start >= data.len() {
break;
}
let nal_type = data[nal_start] & 0x1F;
let mut nal_end = data.len();
let mut j = nal_start + 1;
while j + 3 <= data.len() {
if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1)
|| (j + 4 <= data.len()
&& data[j] == 0
&& data[j + 1] == 0
&& data[j + 2] == 0
&& data[j + 3] == 1)
{
nal_end = j;
break;
}
j += 1;
}
if nal_type != 9 && nal_type != 12 {
result.extend_from_slice(&data[start_code_pos..nal_end]);
}
i = nal_end;
}
if result.is_empty() && !data.is_empty() {
return data.to_vec();
}
result
}
pub fn extract_sps_pps(data: &[u8]) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
let mut sps: Option<Vec<u8>> = None;
let mut pps: Option<Vec<u8>> = None;
let mut i = 0;
while i < data.len() {
let start_code_len = if i + 4 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 0
&& data[i + 3] == 1
{
4
} else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
3
} else {
i += 1;
continue;
};
let nal_start = i + start_code_len;
if nal_start >= data.len() {
break;
}
let nal_type = data[nal_start] & 0x1F;
let mut nal_end = data.len();
let mut j = nal_start + 1;
while j + 3 <= data.len() {
if (data[j] == 0 && data[j + 1] == 0 && data[j + 2] == 1)
|| (j + 4 <= data.len()
&& data[j] == 0
&& data[j + 1] == 0
&& data[j + 2] == 0
&& data[j + 3] == 1)
{
nal_end = j;
break;
}
j += 1;
}
match nal_type {
7 => {
sps = Some(data[nal_start..nal_end].to_vec());
}
8 => {
pps = Some(data[nal_start..nal_end].to_vec());
}
_ => {}
}
i = nal_end;
}
(sps, pps)
}
pub fn has_sps_pps(data: &[u8]) -> bool {
let mut has_sps = false;
let mut has_pps = false;
let mut i = 0;
while i < data.len() {
let start_code_len = if i + 4 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 0
&& data[i + 3] == 1
{
4
} else if i + 3 <= data.len() && data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
3
} else {
i += 1;
continue;
};
let nal_start = i + start_code_len;
if nal_start >= data.len() {
break;
}
let nal_type = data[nal_start] & 0x1F;
match nal_type {
7 => has_sps = true,
8 => has_pps = true,
_ => {}
}
if has_sps && has_pps {
return true;
}
i = nal_start + 1;
}
has_sps && has_pps
}
pub fn is_keyframe(data: &[u8]) -> bool {
let mut i = 0;
while i < data.len() {
if i + 3 < data.len() && data[i] == 0 && data[i + 1] == 0 {
let nal_start = if data[i + 2] == 1 {
i + 3
} else if i + 4 < data.len() && data[i + 2] == 0 && data[i + 3] == 1 {
i + 4
} else {
i += 1;
continue;
};
if nal_start < data.len() {
let nal_type = data[nal_start] & 0x1F;
if nal_type == 5 {
return true;
}
}
i = nal_start;
} else {
i += 1;
}
}
false
}
/// `profile-level-id` hex for SDP (`42001f` etc.); expects SPS NAL without start code.
pub fn parse_profile_level_id_from_sps(sps: &[u8]) -> Option<String> {
if sps.len() < 4 {
return None;
}
let profile_idc = sps[1];
let constraint_set_flags = sps[2];
let level_idc = sps[3];
Some(format!(
"{:02x}{:02x}{:02x}",
profile_idc, constraint_set_flags, level_idc
))
}
pub fn extract_profile_level_id(data: &[u8]) -> Option<String> {
let (sps, _) = extract_sps_pps(data);
sps.and_then(|sps_data| parse_profile_level_id_from_sps(&sps_data))
}
pub fn is_annex_b(data: &[u8]) -> bool {
data.starts_with(&[0, 0, 1]) || data.starts_with(&[0, 0, 0, 1])
}
pub fn avcc_to_annex_b(data: &[u8]) -> Option<Vec<u8>> {
let mut pos = 0;
let mut output = Vec::with_capacity(data.len() + 16);
let mut nalu_count = 0usize;
while pos + 4 <= data.len() {
let nalu_len =
u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
if nalu_len == 0 || pos + nalu_len > data.len() {
return None;
}
let nal_type = data[pos] & 0x1F;
if nal_type != 9 && nal_type != 12 {
output.extend_from_slice(&[0, 0, 0, 1]);
output.extend_from_slice(&data[pos..pos + nalu_len]);
}
nalu_count += 1;
pos += nalu_len;
}
if pos == data.len() && nalu_count > 0 && !output.is_empty() {
Some(output)
} else {
None
}
}
pub fn normalize_for_webrtc(data: &[u8]) -> Vec<u8> {
if is_annex_b(data) {
return strip_aud_nal_units(data);
}
if let Some(annex_b) = avcc_to_annex_b(data) {
return strip_aud_nal_units(&annex_b);
}
data.to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detects_h264_keyframes() {
let idr_frame = vec![0x00, 0x00, 0x00, 0x01, 0x65];
assert!(is_keyframe(&idr_frame));
let idr_frame_3 = vec![0x00, 0x00, 0x01, 0x65];
assert!(is_keyframe(&idr_frame_3));
let p_frame = vec![0x00, 0x00, 0x00, 0x01, 0x41];
assert!(!is_keyframe(&p_frame));
let sps = vec![0x00, 0x00, 0x00, 0x01, 0x67];
assert!(!is_keyframe(&sps));
let multi_nal = vec![
0x00, 0x00, 0x00, 0x01, 0x67, 0x42, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x01, 0x68, 0xce,
0x38, 0x80, 0x00, 0x00, 0x00, 0x01, 0x65, 0x88, 0x84,
];
assert!(is_keyframe(&multi_nal));
}
#[test]
fn parses_profile_level_id_from_sps() {
assert_eq!(
parse_profile_level_id_from_sps(&[0x67, 0x42, 0x40, 0x2a]),
Some("42402a".to_string())
);
}
}

Some files were not shown because too many files have changed in this diff Show More