Compare commits

...

40 Commits

Author SHA1 Message Date
mofeng-git
2e0ca89943 chore: bump version to v0.2.1 2026-05-20 00:09:31 +08:00
mofeng-git
1f7cfb373c fix: 修复设置页滚动和 HID 继电器识别 #252 2026-05-19 22:17:50 +08:00
mofeng-git
da05656a89 fix(web): 设置页按菜单加载并优化错误提示 2026-05-19 21:38:56 +08:00
mofeng-git
265852b312 fix: 避免 CH9329 配置保存时误触发 OTG reconcile 2026-05-19 20:50:33 +08:00
mofeng-git
02bf04ed7f fix: 修复 MSD 状态卡片 i18n 键名 2026-05-19 20:44:12 +08:00
mofeng-git
8915d36bcf fix: 升级 Vite、Rollup、PostCSS 等依赖清除 Github 安全漏洞提示 2026-05-19 11:45:45 +00:00
mofeng-git
3ea15e37a4 feat: 增加 CHINAMIRRO 构建环境变量 2026-05-19 11:22:04 +00:00
mofeng-git
cb0c66af96 ci: 调整 GitHub Actions 构建与发布流程 2026-05-19 18:01:53 +08:00
SilentWind
a3ebcded34 Merge pull request #261 from fcsha/fix/issue-260-msd-endpoint-budget
fix: 关闭 MSD 后保存 HID 配置时端点预算校验误判超限
2026-05-19 09:58:03 +08:00
mofeng-git
f7c2cd1b90 ci: 支持 GitHub Actions 构建 2026-05-19 09:54:54 +08:00
mofeng-git
e774210ae3 fix: 修复构建错误并清理未使用导入 2026-05-18 15:23:42 +00:00
mofeng-git
935fa823f2 feat: 初步增加 Windows 支持 2026-05-18 22:44:59 +08:00
Fucheng Sha
dd3f73ae54 fix: 关闭 MSD 后保存 HID 配置时先更新 MSD 状态再校验端点预算
saveConfig 中调换 updateMsd 和 updateHid 的调用顺序,确保 HID
校验端点预算时 MSD enabled 状态已是最新值,避免被误判为超限。

Fixes mofeng-git/One-KVM#260
2026-05-16 13:00:49 +08:00
SilentWind
0b9d94f53f docs: Update README with 贝塔网络 sponsorship 2026-05-13 22:38:58 +08:00
SilentWind
e5d6279a54 Merge pull request #257 from btzen/redfish
feat: 实现 Redfish API 标准接口;支持通过前端开关控制 Redfish 服务
2026-05-13 13:47:07 +08:00
Fucheng Sha
57d4091497 fix: 恢复被误删的 MSD Section 注释 2026-05-12 14:53:04 +08:00
Fucheng Sha
4e8c342905 feat: 实现 Redfish API 标准接口;支持通过前端开关控制 Redfish 服务 2026-05-12 10:53:26 +08:00
SilentWind
17cd74f64c Merge pull request #250 from arounyf/pr/audio-fix
fix: 修复 WebRTC 音频/视频接收器重启时破音问题
2026-05-05 12:12:17 +08:00
arounyf
9923670426 fix: 修复 WebRTC 音频/视频接收器重启时破音问题
start_audio_from_opus 和 start_from_video_frames 替换旧 handle 时先
abort 旧任务,防止新旧两个任务同时向同一个 track 写数据导致破音。
2026-05-05 05:11:04 +08:00
mofeng-git
3ee3df77b8 chroe: 不再配置 iceCandidatePoolSize,沿用浏览器默认 2026-05-05 01:19:17 +08:00
mofeng-git
8ec2f25e82 chore: bump version to v0.2.0 2026-05-05 00:59:16 +08:00
mofeng-git
c27d3a6703 fix:改进atx usb 继电器适配;修复 webrtc 无法建立连接问题;网页样式优化 2026-05-05 00:52:16 +08:00
mofeng-git
6723f432a3 feat: 允许通过环境变量手动指定前端资源路径,删除 debug 分支默认资源路径 2026-05-04 17:53:27 +08:00
mofeng-git
12a3f1c947 feat: 增加设备丢失自恢复机制
增加音频设备丢失自恢复机制,完善视频设备丢失自恢复机制

降级部分日志级别,GOSTC key打印脱敏

代码格式化
2026-05-02 10:55:05 +08:00
mofeng-git
52754c862b feat: 优化网页消息提醒样式 2026-05-01 21:46:32 +08:00
mofeng-git
e51d243324 feat: 增加 MSD 虚拟盘文件路径编码 2026-05-01 21:27:03 +08:00
mofeng-git
a1ebd34083 feat: 外部扩展程序输出日志级别修改为 info 级别 2026-05-01 21:21:54 +08:00
mofeng-git
89b19ea7dd refactor: 修改为同步请求 2026-05-01 20:06:22 +08:00
mofeng-git
0d47d8395d refactor: 重构视频采集状态与错误分类公共逻辑 2026-05-01 17:56:56 +08:00
mofeng-git
d82c863f40 refactor: 精简依赖 2026-05-01 17:41:11 +08:00
mofeng-git
d8e7de74a6 refactor: 删除部分多余的代码和注释 2026-05-01 17:31:04 +08:00
SilentWind
74035f8e12 Merge pull request #247 from tedaimengtech/main
Update: Add CQU Mirror Information
2026-04-29 14:13:25 +08:00
tedaimeng
8d45186eba Add Mirror Download Services to README
Added a section for Mirror Download Services and updated sponsors.
2026-04-29 13:12:35 +08:00
tedaimeng
c484580b8f Update README with CQUMirror details
Added CQUMirror information and links to the README.
2026-04-29 13:09:23 +08:00
SilentWind
56bce7937c Add GNU General Public License v3 2026-04-29 10:38:30 +08:00
mofeng-git
07b982d1d2 feat: 完善 USB UVC 设备异常处理,添加 USB 设备复位功能 2026-04-27 16:37:04 +08:00
mofeng-git
9065e01225 feat: 优化控制台页面状态工具栏在不同宽度网页下的自适应能力 2026-04-25 20:32:44 +08:00
mofeng-git
cc3cc15774 refactor: 删除部分多余的 Ventoy 逻辑 2026-04-20 14:07:28 +08:00
mofeng-git
fcb39c73fc refactor: 删除未使用的公共 STUN/TURN 逻辑 2026-04-20 10:15:53 +08:00
mofeng-git
7c703b8b4b feat: 深入适配 RK628D CSI 采集卡的设备识别、参数读取、自恢复和音频采集 2026-04-19 11:26:21 +08:00
288 changed files with 25282 additions and 20142 deletions

177
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,177 @@
name: Build
on:
pull_request:
workflow_dispatch:
inputs:
publish_release:
description: Publish GitHub Release
required: false
default: false
type: boolean
release_tag:
description: Release tag name when publishing
required: false
default: ""
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
jobs:
frontend:
runs-on: ubuntu-22.04
timeout-minutes: 30
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 24
cache: npm
cache-dependency-path: web/package-lock.json
- name: Build frontend
working-directory: web
run: |
npm ci
npm run build
- name: Upload frontend dist
uses: actions/upload-artifact@v4
with:
name: web-dist
path: web/dist
if-no-files-found: error
retention-days: 7
deb:
runs-on: ubuntu-22.04
needs: frontend
timeout-minutes: 120
steps:
- uses: actions/checkout@v4
- name: Download frontend dist
uses: actions/download-artifact@v4
with:
name: web-dist
path: web/dist
- uses: dtolnay/rust-toolchain@stable
- name: Install cross
run: cargo install cross --locked
- name: Build linux binary
run: bash build/build-images.sh
- name: Package deb
run: bash build/package-deb.sh
- name: Upload deb
uses: actions/upload-artifact@v4
with:
name: one-kvm-deb
path: target/debian/*.deb
if-no-files-found: error
retention-days: 7
windows:
runs-on: windows-2022
needs: frontend
timeout-minutes: 120
steps:
- uses: actions/checkout@v4
- name: Download frontend dist
uses: actions/download-artifact@v4
with:
name: web-dist
path: web/dist
- uses: dtolnay/rust-toolchain@stable
- name: Set up MSVC
uses: ilammy/msvc-dev-cmd@v1
- name: Prepare vcpkg and dependencies
shell: pwsh
run: |
$env:VCPKG_ROOT = "C:\vcpkg"
$env:VCPKG_DEFAULT_TRIPLET = "x64-windows-static"
$env:VCPKG_INSTALLED_DIR = Join-Path $pwd "vcpkg_installed"
if (-not (Test-Path $env:VCPKG_ROOT)) {
git clone https://github.com/microsoft/vcpkg $env:VCPKG_ROOT
}
& "$env:VCPKG_ROOT\bootstrap-vcpkg.bat" -disableMetrics
& "$env:VCPKG_ROOT\vcpkg.exe" install --triplet $env:VCPKG_DEFAULT_TRIPLET --x-install-root="$env:VCPKG_INSTALLED_DIR"
$tripletRoot = Join-Path $env:VCPKG_INSTALLED_DIR $env:VCPKG_DEFAULT_TRIPLET
$env:TURBOJPEG_SOURCE = "explicit"
$env:TURBOJPEG_LIB_DIR = Join-Path $tripletRoot "lib"
$env:TURBOJPEG_INCLUDE_DIR = Join-Path $tripletRoot "include"
"VCPKG_ROOT=$env:VCPKG_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append
"VCPKG_DEFAULT_TRIPLET=$env:VCPKG_DEFAULT_TRIPLET" | Out-File -FilePath $env:GITHUB_ENV -Append
"VCPKG_INSTALLED_DIR=$env:VCPKG_INSTALLED_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append
"TURBOJPEG_SOURCE=$env:TURBOJPEG_SOURCE" | Out-File -FilePath $env:GITHUB_ENV -Append
"TURBOJPEG_LIB_DIR=$env:TURBOJPEG_LIB_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append
"TURBOJPEG_INCLUDE_DIR=$env:TURBOJPEG_INCLUDE_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append
- name: Build Windows exe
shell: pwsh
run: .\build\windows\build.ps1 -Configuration release -Package
- name: Upload exe
uses: actions/upload-artifact@v4
with:
name: one-kvm-windows-exe
path: target/x86_64-pc-windows-msvc/release/one-kvm_*.exe
if-no-files-found: error
retention-days: 7
release:
runs-on: ubuntu-22.04
needs: [deb, windows]
if: ${{ github.event_name == 'workflow_dispatch' && inputs.publish_release }}
timeout-minutes: 30
permissions:
contents: write
steps:
- name: Validate release tag
run: |
if [ -z "${{ inputs.release_tag }}" ]; then
echo "release_tag is required when publish_release is true"
exit 1
fi
- name: Download deb artifact
uses: actions/download-artifact@v4
with:
name: one-kvm-deb
path: release-artifacts/deb
- name: Download exe artifact
uses: actions/download-artifact@v4
with:
name: one-kvm-windows-exe
path: release-artifacts/windows
- name: Publish GitHub Release
uses: softprops/action-gh-release@v2
with:
tag_name: ${{ inputs.release_tag }}
prerelease: true
generate_release_notes: true
files: |
release-artifacts/deb/*.deb
release-artifacts/windows/*.exe

View File

@@ -1,6 +1,6 @@
[package]
name = "one-kvm"
version = "0.1.9"
version = "0.2.1"
edition = "2021"
authors = ["SilentWind"]
description = "A open and lightweight IP-KVM solution written in Rust"
@@ -16,8 +16,8 @@ tokio-util = { version = "0.7", features = ["rt"] }
# Web framework
axum = { version = "0.8", features = ["ws", "multipart", "tokio"] }
axum-extra = { version = "0.12", features = ["typed-header", "cookie"] }
tower-http = { version = "0.6", features = ["fs", "cors", "trace", "compression-gzip"] }
axum-extra = { version = "0.12", features = ["cookie"] }
tower-http = { version = "0.6", features = ["cors", "trace", "set-header"] }
# Database - Use bundled SQLite for static linking
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
@@ -29,7 +29,6 @@ serde_json = "1"
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "tracing-log"] }
tracing-log = "0.2"
# Error handling
thiserror = "2"
@@ -41,7 +40,6 @@ rand = "0.9"
# Utilities
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
base64 = "0.22"
nix = { version = "0.30", features = ["fs", "net", "hostname", "poll"] }
@@ -51,7 +49,7 @@ reqwest = { version = "0.13", features = ["stream", "rustls", "json"], default-f
urlencoding = "2"
# Static file embedding
rust-embed = { version = "8", features = ["compression"] }
rust-embed = { version = "8", features = ["compression", "debug-embed"] }
mime_guess = "2"
# TLS/HTTPS
@@ -62,14 +60,8 @@ axum-server = { version = "0.8", features = ["tls-rustls"] }
# CLI argument parsing
clap = { version = "4", features = ["derive"] }
# Time
time = "0.3"
# Video capture (V4L2)
v4l2r = "0.0.7"
# JPEG encoding (libjpeg-turbo, SIMD accelerated)
turbojpeg = "1.3"
# Time (cookie max_age + RFC3339 timestamps)
time = { version = "0.3", features = ["serde", "formatting", "parsing"] }
# Bytes handling
bytes = "1"
@@ -95,11 +87,6 @@ rtp = "0.14"
rtsp-types = "0.1"
sdp-types = "0.1"
# Audio (ALSA capture + Opus encoding)
# Note: audiopus links to libopus.so (unavoidable for audio support)
alsa = "0.11"
audiopus = "0.2"
# HID (serial port for CH9329)
serialport = "4"
async-trait = "0.1"
@@ -108,30 +95,47 @@ libc = "0.2"
# Ventoy bootable image support
ventoy-img = { path = "libs/ventoy-img-rs" }
# ATX (GPIO control)
gpio-cdev = "0.6"
# H264 hardware/software encoding (hwcodec from rustdesk)
hwcodec = { path = "libs/hwcodec" }
# RustDesk protocol support
protobuf = { version = "3.7", features = ["with-bytes"] }
sodiumoxide = "0.2"
sha2 = "0.10"
# High-performance pixel format conversion (libyuv)
libyuv = { path = "res/vcpkg/libyuv" }
# TypeScript type generation
typeshare = "1.0"
[target.'cfg(any(unix, windows))'.dependencies]
# Video encoding/decoding (FFmpeg/libjpeg-turbo/libyuv; available on Windows and Linux)
hwcodec = { path = "libs/hwcodec" }
libyuv = { path = "res/vcpkg/libyuv" }
turbojpeg = "1.3"
# Note: audiopus links to libopus.so (unavoidable for audio support)
audiopus = "0.2"
[target.'cfg(unix)'.dependencies]
# Video capture (V4L2)
v4l2r = "0.0.7"
# Audio (ALSA capture)
alsa = "0.11"
# ATX (GPIO control)
gpio-cdev = "0.6"
[target.'cfg(windows)'.dependencies]
cpal = { version = "0.17", default-features = false }
windows-sys = { version = "0.61", features = [
"Win32_Foundation",
"Win32_NetworkManagement_IpHelper",
"Win32_NetworkManagement_Ndis",
"Win32_Networking_WinSock",
"Win32_System_SystemInformation",
"Win32_System_Threading",
] }
[dev-dependencies]
tokio-test = "0.4"
tempfile = "3"
[build-dependencies]
protobuf-codegen = "3.7"
toml = "0.9"
cc = "1"
[profile.release]
opt-level = 3

674
LICENSE Normal file
View File

@@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

View File

@@ -222,6 +222,9 @@ One-KVM builds on many great open-source projects; a lot of time goes into testi
### Sponsors
**Mirror Download Services:**
- **[Chongqing University Open Source Software Mirror](https://mirrors.cqu.edu.cn/)** — provides mirror download services
**File hosting**
- **[Huang1111 public-interest program](https://pan.huang1111.cn/s/mxkx3T1)** — login-free downloads

View File

@@ -218,6 +218,9 @@ One-KVM 已上架飞牛 **应用市场**,在 NAS 上直接搜索安装即可
本项目得到以下赞助商的支持:
**镜像下载服务:**
- **[重庆大学开源软件镜像站](https://mirrors.cqu.edu.cn/)** - 提供镜像站下载服务
**文件存储服务:**
- **[Huang1111公益计划](https://pan.huang1111.cn/s/mxkx3T1)** - 提供免登录下载服务
@@ -225,6 +228,12 @@ One-KVM 已上架飞牛 **应用市场**,在 NAS 上直接搜索安装即可
- **[林枫云](https://www.dkdun.cn)** - 赞助了本项目服务器
![林枫云](https://docs.one-kvm.cn/img/36076FEFF0898A80EBD5756D28F4076C.png)
<img height="128" alt="林枫云" src="https://docs.one-kvm.cn/img/36076FEFF0898A80EBD5756D28F4076C.png" />
林枫云主营国内外地域的精品线路业务服务器、高主频游戏服务器和大带宽服务器。
林枫云主营国内外地域的精品线路业务服务器、高主频游戏服务器和大带宽服务器。
- **[贝塔网络](https://my.beita.cc/?ref=github_onekvm)** - 赞助了本项目服务器
<img height="128" alt="BTBT" src="https://github.com/user-attachments/assets/c442d5f5-d72f-4a07-b9f4-400a6a0c3f1e" />
远程电脑、消费级GPU服务器、独服物理机全自动在线交付。

14
agents.md Normal file
View File

@@ -0,0 +1,14 @@
# Agents Notes
## Windows MSVC Build
Run from the repository root in PowerShell:
```powershell
$env:VCPKG_ROOT='C:\Users\mofen\code\vcpkg'
$env:TURBOJPEG_SOURCE='explicit'
$env:TURBOJPEG_LIB_DIR='C:\Users\mofen\code\vcpkg\installed\x64-windows-static\lib'
$env:TURBOJPEG_INCLUDE_DIR='C:\Users\mofen\code\vcpkg\installed\x64-windows-static\include'
cargo build --target x86_64-pc-windows-msvc
```

View File

@@ -64,25 +64,6 @@ fn generate_secrets() {
pub mod ice {
/// Google public STUN server URL (hardcoded)
pub const STUN_SERVER: &str = "stun:stun.l.google.com:19302";
/// TURN server URLs - not provided, users must configure their own
pub const TURN_URLS: &str = "";
/// TURN authentication username
pub const TURN_USERNAME: &str = "";
/// TURN authentication password
pub const TURN_PASSWORD: &str = "";
/// Always returns true since we have STUN
pub const fn is_configured() -> bool {
true
}
/// Always returns false since TURN is not provided
pub const fn has_turn() -> bool {
false
}
}
/// RustDesk public server configuration - NOT PROVIDED

View File

@@ -19,6 +19,23 @@ ARCH_MAP=(
build_arch() {
local rust_target="$1"
case "${CHINAMIRRO:-}" in
1|true|TRUE|yes|YES|on|ON)
local cross_build_opts="${CROSS_BUILD_OPTS:+$CROSS_BUILD_OPTS }--build-arg CHINAMIRRO=1"
echo "=== China mirror acceleration: enabled (Tsinghua) ==="
echo "=== Building: $rust_target (via cross with custom Dockerfile) ==="
env \
CROSS_BUILD_OPTS="$cross_build_opts" \
CARGO_SOURCE_CRATES_IO_REPLACE_WITH=tuna \
CARGO_SOURCE_TUNA_REGISTRY=sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/ \
CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse \
RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup \
RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup \
cross build --release --target "$rust_target"
return
;;
esac
echo "=== Building: $rust_target (via cross with custom Dockerfile) ==="
cross build --release --target "$rust_target"
}
@@ -49,6 +66,7 @@ case "${1:-all}" in
echo "Examples:"
echo " $0 # Build all"
echo " $0 x86_64 # Build x86_64 only"
echo " CHINAMIRRO=1 $0 arm64 # Build with Tsinghua mirrors"
exit 0
;;
*)

View File

@@ -6,16 +6,36 @@ FROM debian:11
# Linux headers used by v4l2r bindgen
ARG LINUX_HEADERS_VERSION=6.6
ARG LINUX_HEADERS_SHA256=
ARG CHINAMIRRO=0
# Set Rustup mirrors (Aliyun)
#ENV RUSTUP_UPDATE_ROOT=https://mirrors.aliyun.com/rustup/rustup \
# RUSTUP_DIST_SERVER=https://mirrors.aliyun.com/rustup
# Optionally use Tsinghua mirrors for builds in China.
RUN if [ "$CHINAMIRRO" = "1" ]; then \
sed -i \
-e 's|http://deb.debian.org/debian|http://mirrors.tuna.tsinghua.edu.cn/debian|g' \
-e 's|http://security.debian.org/debian-security|http://mirrors.tuna.tsinghua.edu.cn/debian-security|g' \
/etc/apt/sources.list; \
fi
# Install Rust toolchain
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
ca-certificates \
&& if [ "$CHINAMIRRO" = "1" ]; then \
export RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup; \
export RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup; \
fi \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
&& if [ "$CHINAMIRRO" = "1" ]; then \
mkdir -p /root/.cargo; \
printf '%s\n' \
'[source.crates-io]' \
"replace-with = 'tuna'" \
'[source.tuna]' \
'registry = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"' \
'[registries.tuna]' \
'index = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"' \
> /root/.cargo/config.toml; \
fi \
&& rm -rf /var/lib/apt/lists/*
ENV PATH="/root/.cargo/bin:${PATH}"
@@ -327,7 +347,11 @@ RUN mkdir -p /tmp/ffmpeg-build && cd /tmp/ffmpeg-build \
&& rm -rf /tmp/ffmpeg-build /tmp/aarch64-cross.txt /tmp/aarch64-pkg-config
# Add Rust target
RUN rustup target add aarch64-unknown-linux-gnu
RUN if [ "$CHINAMIRRO" = "1" ]; then \
export RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup; \
export RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup; \
fi \
&& rustup target add aarch64-unknown-linux-gnu
# Configure environment for cross-compilation
ENV CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc \

View File

@@ -6,16 +6,36 @@ FROM debian:11
# Linux headers used by v4l2r bindgen
ARG LINUX_HEADERS_VERSION=6.6
ARG LINUX_HEADERS_SHA256=
ARG CHINAMIRRO=0
# Set Rustup mirrors (Aliyun)
#ENV RUSTUP_UPDATE_ROOT=https://mirrors.aliyun.com/rustup/rustup \
# RUSTUP_DIST_SERVER=https://mirrors.aliyun.com/rustup
# Optionally use Tsinghua mirrors for builds in China.
RUN if [ "$CHINAMIRRO" = "1" ]; then \
sed -i \
-e 's|http://deb.debian.org/debian|http://mirrors.tuna.tsinghua.edu.cn/debian|g' \
-e 's|http://security.debian.org/debian-security|http://mirrors.tuna.tsinghua.edu.cn/debian-security|g' \
/etc/apt/sources.list; \
fi
# Install Rust toolchain
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
ca-certificates \
&& if [ "$CHINAMIRRO" = "1" ]; then \
export RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup; \
export RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup; \
fi \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
&& if [ "$CHINAMIRRO" = "1" ]; then \
mkdir -p /root/.cargo; \
printf '%s\n' \
'[source.crates-io]' \
"replace-with = 'tuna'" \
'[source.tuna]' \
'registry = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"' \
'[registries.tuna]' \
'index = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"' \
> /root/.cargo/config.toml; \
fi \
&& rm -rf /var/lib/apt/lists/*
ENV PATH="/root/.cargo/bin:${PATH}"
@@ -316,7 +336,11 @@ RUN mkdir -p /tmp/ffmpeg-build && cd /tmp/ffmpeg-build \
&& rm -rf /tmp/ffmpeg-build /tmp/armhf-cross.txt /tmp/armhf-pkg-config
# Add Rust target
RUN rustup target add armv7-unknown-linux-gnueabihf
RUN if [ "$CHINAMIRRO" = "1" ]; then \
export RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup; \
export RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup; \
fi \
&& rustup target add armv7-unknown-linux-gnueabihf
# Configure environment for cross-compilation
ENV CARGO_TARGET_ARMV7_UNKNOWN_LINUX_GNUEABIHF_LINKER=arm-linux-gnueabihf-gcc \

View File

@@ -6,16 +6,36 @@ FROM debian:11
# Linux headers used by v4l2r bindgen
ARG LINUX_HEADERS_VERSION=6.6
ARG LINUX_HEADERS_SHA256=
ARG CHINAMIRRO=0
# Set Rustup mirrors (Aliyun)
#ENV RUSTUP_UPDATE_ROOT=https://mirrors.aliyun.com/rustup/rustup \
# RUSTUP_DIST_SERVER=https://mirrors.aliyun.com/rustup
# Optionally use Tsinghua mirrors for builds in China.
RUN if [ "$CHINAMIRRO" = "1" ]; then \
sed -i \
-e 's|http://deb.debian.org/debian|http://mirrors.tuna.tsinghua.edu.cn/debian|g' \
-e 's|http://security.debian.org/debian-security|http://mirrors.tuna.tsinghua.edu.cn/debian-security|g' \
/etc/apt/sources.list; \
fi
# Install Rust toolchain
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
ca-certificates \
&& if [ "$CHINAMIRRO" = "1" ]; then \
export RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup; \
export RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup; \
fi \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable \
&& if [ "$CHINAMIRRO" = "1" ]; then \
mkdir -p /root/.cargo; \
printf '%s\n' \
'[source.crates-io]' \
"replace-with = 'tuna'" \
'[source.tuna]' \
'registry = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"' \
'[registries.tuna]' \
'index = "sparse+https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/"' \
> /root/.cargo/config.toml; \
fi \
&& rm -rf /var/lib/apt/lists/*
ENV PATH="/root/.cargo/bin:${PATH}"
@@ -221,7 +241,11 @@ RUN mkdir -p /tmp/ffmpeg-build && cd /tmp/ffmpeg-build \
&& rm -rf /tmp/ffmpeg-build
# Add Rust target
RUN rustup target add x86_64-unknown-linux-gnu
RUN if [ "$CHINAMIRRO" = "1" ]; then \
export RUSTUP_DIST_SERVER=https://mirrors.tuna.tsinghua.edu.cn/rustup; \
export RUSTUP_UPDATE_ROOT=https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup; \
fi \
&& rustup target add x86_64-unknown-linux-gnu
# Configure environment for static linking
ENV PKG_CONFIG_ALLOW_CROSS=1\

87
build/windows/build.ps1 Normal file
View File

@@ -0,0 +1,87 @@
param(
[string]$Configuration = "debug",
[string]$Target = "x86_64-pc-windows-msvc",
[string]$Triplet = "x64-windows-static",
[string]$VcpkgRoot = $env:VCPKG_ROOT,
[string]$VcpkgInstalledRoot = $env:VCPKG_INSTALLED_DIR,
[switch]$NoDefaultFeatures,
[string[]]$Features = @(),
[switch]$Package,
[Parameter(ValueFromRemainingArguments = $true)]
[string[]]$CargoArgs = @()
)
$ErrorActionPreference = "Stop"
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..\..")
Set-Location $repoRoot
if ([string]::IsNullOrWhiteSpace($VcpkgRoot)) {
$VcpkgRoot = Join-Path (Split-Path $repoRoot -Parent) "vcpkg"
}
$VcpkgRoot = [System.IO.Path]::GetFullPath($VcpkgRoot)
if ([string]::IsNullOrWhiteSpace($VcpkgInstalledRoot)) {
$VcpkgInstalledRoot = Join-Path $VcpkgRoot "installed"
}
$VcpkgInstalledRoot = [System.IO.Path]::GetFullPath($VcpkgInstalledRoot)
$vcpkgTripletRoot = Join-Path $VcpkgInstalledRoot $Triplet
$turbojpegLibDir = Join-Path $vcpkgTripletRoot "lib"
$turbojpegIncludeDir = Join-Path $vcpkgTripletRoot "include"
if (-not (Test-Path $VcpkgRoot)) {
throw "VCPKG_ROOT does not exist: $VcpkgRoot. Run build/windows/bootstrap-vcpkg.ps1 first."
}
if (-not (Test-Path $turbojpegLibDir) -or -not (Test-Path $turbojpegIncludeDir)) {
throw "vcpkg triplet is not installed at $vcpkgTripletRoot. Run build/windows/bootstrap-vcpkg.ps1 first."
}
$env:VCPKG_ROOT = $VcpkgRoot
$env:VCPKG_DEFAULT_TRIPLET = $Triplet
$env:VCPKG_INSTALLED_DIR = $VcpkgInstalledRoot
$env:TURBOJPEG_SOURCE = "explicit"
$env:TURBOJPEG_LIB_DIR = $turbojpegLibDir
$env:TURBOJPEG_INCLUDE_DIR = $turbojpegIncludeDir
$cargoCommand = @("build", "--target", $Target)
if ($Configuration -eq "release") {
$cargoCommand += "--release"
} elseif ($Configuration -ne "debug") {
throw "Unsupported configuration '$Configuration'. Use 'debug' or 'release'."
}
if ($NoDefaultFeatures) {
$cargoCommand += "--no-default-features"
}
if ($Features.Count -gt 0) {
$cargoCommand += "--features"
$cargoCommand += ($Features -join ",")
}
$cargoCommand += $CargoArgs
cargo @cargoCommand
if ($Package) {
$metadata = cargo metadata --no-deps --format-version 1 | ConvertFrom-Json
$packageInfo = $metadata.packages | Where-Object { $_.name -eq "one-kvm" } | Select-Object -First 1
if ($null -eq $packageInfo -or [string]::IsNullOrWhiteSpace($packageInfo.version)) {
throw "Failed to resolve version from Cargo metadata"
}
$sourcePath = Join-Path $repoRoot "target/$Target/release/one-kvm.exe"
$targetName = "one-kvm_{0}_amd64.exe" -f $packageInfo.version
$targetPath = Join-Path $repoRoot "target/$Target/release/$targetName"
if (-not (Test-Path $sourcePath)) {
throw "Windows binary not found: $sourcePath"
}
Copy-Item $sourcePath $targetPath
Write-Host $targetPath
}

View File

@@ -4,6 +4,9 @@ version = "0.8.0"
edition = "2021"
description = "Hardware video codec for IP-KVM (Windows/Linux)"
[package.metadata.cargo-machete]
ignored = ["serde"]
[features]
default = []
rkmpp = []
@@ -17,6 +20,3 @@ serde_json = "1.0"
[build-dependencies]
cc = "1.0"
bindgen = "0.59"
[dev-dependencies]
env_logger = "0.10"

View File

@@ -34,7 +34,9 @@ fn build_common(builder: &mut Build) {
// system
#[cfg(windows)]
{
["d3d11", "dxgi"].map(|lib| println!("cargo:rustc-link-lib={}", lib));
for lib in ["d3d11", "dxgi"] {
println!("cargo:rustc-link-lib={}", lib);
}
}
builder.include(&common_dir);
@@ -89,8 +91,8 @@ mod ffmpeg {
ffmpeg_ffi();
// Try VCPKG first, fallback to system FFmpeg via pkg-config
if let Ok(vcpkg_root) = std::env::var("VCPKG_ROOT") {
link_vcpkg(builder, vcpkg_root.into());
if let Some(vcpkg_installed) = vcpkg_installed_root() {
link_vcpkg(builder, vcpkg_installed);
} else {
// Use system FFmpeg via pkg-config
link_system_ffmpeg(builder);
@@ -99,6 +101,23 @@ mod ffmpeg {
link_os();
build_ffmpeg_ram(builder);
build_ffmpeg_hw(builder);
build_ffmpeg_capture(builder);
}
fn vcpkg_installed_root() -> Option<PathBuf> {
println!("cargo:rerun-if-env-changed=VCPKG_INSTALLED_DIR");
println!("cargo:rerun-if-env-changed=VCPKG_ROOT");
if let Ok(path) = std::env::var("VCPKG_INSTALLED_DIR") {
if !path.trim().is_empty() {
return Some(PathBuf::from(path));
}
}
std::env::var("VCPKG_ROOT")
.ok()
.filter(|path| !path.trim().is_empty())
.map(|path| PathBuf::from(path).join("installed"))
}
/// Link system FFmpeg using pkg-config or custom path
@@ -271,7 +290,6 @@ mod ffmpeg {
target = target.replace("x64", "x86");
}
println!("cargo:info={}", target);
path.push("installed");
path.push(target);
println!(
@@ -282,15 +300,26 @@ mod ffmpeg {
)
);
{
// Only need avcodec and avutil for encoding
// avdevice/avformat are needed by the Windows DirectShow capture bridge.
let mut static_libs = vec!["avcodec", "avutil"];
if target_os == "windows" {
static_libs.push("libmfx");
static_libs.extend([
"avformat",
"avdevice",
"avfilter",
"swresample",
"swscale",
"vpx",
"libx264",
"x265-static",
]);
}
for lib in static_libs {
println!("cargo:rustc-link-lib=static={}", lib);
}
if target_os == "windows" {
link_windows_qsv_lib(&path.join("lib"));
}
static_libs
.iter()
.map(|lib| println!("cargo:rustc-link-lib=static={}", lib))
.count();
}
let include = path.join("include");
@@ -299,12 +328,25 @@ mod ffmpeg {
include
}
fn link_windows_qsv_lib(lib_dir: &Path) {
if lib_dir.join("libmfx.lib").exists() {
println!("cargo:rustc-link-lib=static=libmfx");
println!("cargo:info=Using Windows QSV support library libmfx.lib");
return;
}
println!("cargo:warning=Windows QSV support library not found in {}", lib_dir.display());
}
fn link_os() {
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap();
let target_arch = std::env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let dyn_libs: Vec<&str> = if target_os == "windows" {
["User32", "bcrypt", "ole32", "advapi32"].to_vec()
[
"User32", "bcrypt", "ole32", "advapi32", "mfuuid", "strmiids",
]
.to_vec()
} else if target_os == "linux" {
// Base libraries for all Linux platforms
let mut v = vec!["drm", "stdc++"];
@@ -375,6 +417,34 @@ mod ffmpeg {
}
}
fn build_ffmpeg_capture(builder: &mut Build) {
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
if target_os != "windows" {
return;
}
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let capture_header = manifest_dir
.join("cpp")
.join("ffmpeg_capture_ffi.h")
.to_string_lossy()
.to_string();
bindgen::builder()
.header(capture_header)
.rustified_enum("*")
.generate()
.unwrap()
.write_to_file(
Path::new(&env::var_os("OUT_DIR").unwrap()).join("ffmpeg_capture_ffi.rs"),
)
.unwrap();
builder.file(manifest_dir.join("cpp").join("ffmpeg_capture.cpp"));
println!("cargo:rustc-link-lib=strmiids");
println!("cargo:rustc-link-lib=oleaut32");
println!("cargo:rustc-link-lib=quartz");
}
fn build_ffmpeg_hw(builder: &mut Build) {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let ffmpeg_hw_dir = manifest_dir.join("cpp").join("ffmpeg_hw");

View File

@@ -1,4 +1,5 @@
extern "C" {
#include <libavcodec/avcodec.h>
#include <libavutil/opt.h>
}
@@ -99,13 +100,12 @@ void set_av_codec_ctx(AVCodecContext *c, const std::string &name, int kbs,
c->color_primaries = AVCOL_PRI_SMPTE170M;
c->color_trc = AVCOL_TRC_SMPTE170M;
// Profile selection: use BASELINE for software H264 (faster, simpler)
if (is_software_h264(name)) {
c->profile = FF_PROFILE_H264_BASELINE; // Simpler profile for real-time
} else if (name.find("h264") != std::string::npos) {
c->profile = FF_PROFILE_H264_HIGH;
// WebRTC SDP advertises constrained baseline. Keep hardware and software
// encoders on the same browser-friendly H264 profile.
if (name.find("h264") != std::string::npos) {
c->profile = AV_PROFILE_H264_CONSTRAINED_BASELINE;
} else if (name.find("hevc") != std::string::npos) {
c->profile = FF_PROFILE_HEVC_MAIN;
c->profile = AV_PROFILE_HEVC_MAIN;
}
}
@@ -120,8 +120,7 @@ bool set_lantency_free(void *priv_data, const std::string &name) {
}
if (name.find("amf") != std::string::npos) {
if ((ret = av_opt_set(priv_data, "query_timeout", "1000", 0)) < 0) {
LOG_ERROR(std::string("amf set_lantency_free failed, ret = ") + av_err2str(ret));
return false;
LOG_WARN(std::string("amf query_timeout option is unavailable, ret = ") + av_err2str(ret));
}
}
if (name.find("qsv") != std::string::npos) {

View File

@@ -0,0 +1,879 @@
#define NOMINMAX
#include "ffmpeg_capture_ffi.h"
#include <Windows.h>
#include <dshow.h>
#include <dvdmedia.h>
extern "C" {
#include <libavcodec/codec_id.h>
#include <libavdevice/avdevice.h>
#include <libavformat/avformat.h>
#include <libavutil/avutil.h>
#include <libavutil/error.h>
#include <libavutil/pixfmt.h>
}
#include <atomic>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>
#pragma comment(lib, "strmiids")
thread_local std::string g_last_error;
struct HwcodecDshowCaptureContext {
AVFormatContext* format_ctx = nullptr;
int stream_index = -1;
int width = 0;
int height = 0;
int pixel_format = HWCODEC_CAPTURE_FMT_UNKNOWN;
int stride = 0;
int timeout_ms = 2000;
std::atomic<long long> deadline_ms{0};
std::atomic<int> timed_out{0};
uint64_t sequence = 0;
};
namespace {
struct DshowCapabilityEntry {
std::string format;
int width = 0;
int height = 0;
std::vector<int> fps;
};
const char* requested_pixel_format_name(int requested_format);
void set_last_error(const std::string& message) {
g_last_error = message;
}
std::string ffmpeg_error(int errnum) {
char buffer[AV_ERROR_MAX_STRING_SIZE] = {0};
av_make_error_string(buffer, sizeof(buffer), errnum);
return std::string(buffer);
}
long long now_ms() {
return static_cast<long long>(GetTickCount64());
}
std::string wide_to_utf8(const wchar_t* value) {
if (!value) {
return std::string();
}
int size = WideCharToMultiByte(CP_UTF8, 0, value, -1, nullptr, 0, nullptr, nullptr);
if (size <= 1) {
return std::string();
}
std::string result(static_cast<size_t>(size - 1), '\0');
WideCharToMultiByte(
CP_UTF8,
0,
value,
-1,
result.empty() ? nullptr : &result[0],
size,
nullptr,
nullptr);
return result;
}
void add_fps_candidate(std::vector<int>* fps, LONGLONG interval_100ns) {
if (!fps || interval_100ns <= 0) {
return;
}
double fps_value = 10000000.0 / static_cast<double>(interval_100ns);
int rounded = static_cast<int>(fps_value + 0.5);
if (rounded <= 0) {
return;
}
if (std::find(fps->begin(), fps->end(), rounded) == fps->end()) {
fps->push_back(rounded);
}
}
void normalize_fps(std::vector<int>* fps) {
if (!fps) {
return;
}
std::sort(fps->begin(), fps->end(), std::greater<int>());
fps->erase(std::unique(fps->begin(), fps->end()), fps->end());
}
const char* media_subtype_to_format(const GUID& subtype) {
if (subtype == MEDIASUBTYPE_MJPG) {
return "MJPEG";
}
if (subtype == MEDIASUBTYPE_YUY2) {
return "YUYV";
}
if (subtype == MEDIASUBTYPE_UYVY) {
return "UYVY";
}
if (subtype == MEDIASUBTYPE_YVYU) {
return "YVYU";
}
if (subtype == MEDIASUBTYPE_NV12) {
return "NV12";
}
if (subtype == MEDIASUBTYPE_RGB24) {
return "RGB24";
}
if (subtype == MEDIASUBTYPE_RGB32) {
return "BGR24";
}
if (subtype == MEDIASUBTYPE_IYUV) {
return "YUV420";
}
if (subtype == MEDIASUBTYPE_YV12) {
return "YVU420";
}
return nullptr;
}
void free_media_type(AM_MEDIA_TYPE* media_type) {
if (!media_type) {
return;
}
if (media_type->cbFormat != 0) {
CoTaskMemFree(media_type->pbFormat);
media_type->cbFormat = 0;
media_type->pbFormat = nullptr;
}
if (media_type->pUnk != nullptr) {
media_type->pUnk->Release();
media_type->pUnk = nullptr;
}
CoTaskMemFree(media_type);
}
bool fill_capability_entry(
const AM_MEDIA_TYPE* media_type,
const VIDEO_STREAM_CONFIG_CAPS* caps,
DshowCapabilityEntry* out_entry) {
if (!media_type || !out_entry) {
return false;
}
const char* format = media_subtype_to_format(media_type->subtype);
if (!format) {
return false;
}
LONG width = 0;
LONG height = 0;
REFERENCE_TIME avg_time_per_frame = 0;
if (media_type->formattype == FORMAT_VideoInfo && media_type->pbFormat &&
media_type->cbFormat >= sizeof(VIDEOINFOHEADER)) {
const auto* info = reinterpret_cast<const VIDEOINFOHEADER*>(media_type->pbFormat);
width = info->bmiHeader.biWidth;
height = std::abs(info->bmiHeader.biHeight);
avg_time_per_frame = info->AvgTimePerFrame;
} else if (media_type->formattype == FORMAT_VideoInfo2 && media_type->pbFormat &&
media_type->cbFormat >= sizeof(VIDEOINFOHEADER2)) {
const auto* info = reinterpret_cast<const VIDEOINFOHEADER2*>(media_type->pbFormat);
width = info->bmiHeader.biWidth;
height = std::abs(info->bmiHeader.biHeight);
avg_time_per_frame = info->AvgTimePerFrame;
}
if ((width <= 0 || height <= 0) && caps) {
width = std::max<LONG>(caps->InputSize.cx, caps->MinOutputSize.cx);
height = std::max<LONG>(caps->InputSize.cy, caps->MinOutputSize.cy);
if (width <= 0 || height <= 0) {
width = caps->MaxOutputSize.cx;
height = caps->MaxOutputSize.cy;
}
}
if (width <= 0 || height <= 0) {
return false;
}
out_entry->format = format;
out_entry->width = static_cast<int>(width);
out_entry->height = static_cast<int>(height);
out_entry->fps.clear();
add_fps_candidate(&out_entry->fps, avg_time_per_frame);
if (caps) {
add_fps_candidate(&out_entry->fps, caps->MinFrameInterval);
add_fps_candidate(&out_entry->fps, caps->MaxFrameInterval);
}
normalize_fps(&out_entry->fps);
return true;
}
void append_stream_capabilities(IAMStreamConfig* stream_config, std::vector<DshowCapabilityEntry>* entries) {
if (!stream_config || !entries) {
return;
}
int cap_count = 0;
int cap_size = 0;
HRESULT hr = stream_config->GetNumberOfCapabilities(&cap_count, &cap_size);
if (FAILED(hr) || cap_count <= 0 || cap_size < static_cast<int>(sizeof(VIDEO_STREAM_CONFIG_CAPS))) {
return;
}
std::vector<BYTE> caps_buffer(static_cast<size_t>(cap_size));
for (int index = 0; index < cap_count; ++index) {
AM_MEDIA_TYPE* media_type = nullptr;
hr = stream_config->GetStreamCaps(index, &media_type, caps_buffer.data());
if (FAILED(hr) || !media_type) {
continue;
}
DshowCapabilityEntry entry;
const auto* caps = reinterpret_cast<const VIDEO_STREAM_CONFIG_CAPS*>(caps_buffer.data());
if (fill_capability_entry(media_type, caps, &entry)) {
entries->push_back(std::move(entry));
}
free_media_type(media_type);
}
}
bool find_device_filter(const std::string& device_name, IBaseFilter** out_filter) {
if (!out_filter) {
return false;
}
*out_filter = nullptr;
ICreateDevEnum* dev_enum = nullptr;
IEnumMoniker* enum_moniker = nullptr;
HRESULT hr = CoCreateInstance(
CLSID_SystemDeviceEnum,
nullptr,
CLSCTX_INPROC_SERVER,
IID_ICreateDevEnum,
reinterpret_cast<void**>(&dev_enum));
if (FAILED(hr) || !dev_enum) {
return false;
}
hr = dev_enum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enum_moniker, 0);
dev_enum->Release();
if (hr != S_OK || !enum_moniker) {
return false;
}
bool found = false;
IMoniker* moniker = nullptr;
ULONG fetched = 0;
while (!found && enum_moniker->Next(1, &moniker, &fetched) == S_OK) {
IPropertyBag* bag = nullptr;
hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast<void**>(&bag));
if (SUCCEEDED(hr) && bag) {
VARIANT name;
VariantInit(&name);
if (SUCCEEDED(bag->Read(L"FriendlyName", &name, nullptr)) && name.vt == VT_BSTR) {
auto utf8_name = wide_to_utf8(name.bstrVal);
if (utf8_name == device_name) {
hr = moniker->BindToObject(nullptr, nullptr, IID_IBaseFilter, reinterpret_cast<void**>(out_filter));
found = SUCCEEDED(hr) && *out_filter != nullptr;
}
}
VariantClear(&name);
bag->Release();
}
moniker->Release();
}
enum_moniker->Release();
return found;
}
std::string build_capabilities_payload(const std::vector<DshowCapabilityEntry>& entries) {
std::string payload;
for (size_t i = 0; i < entries.size(); ++i) {
const auto& entry = entries[i];
payload += entry.format;
payload.push_back('|');
payload += std::to_string(entry.width);
payload.push_back('|');
payload += std::to_string(entry.height);
payload.push_back('|');
for (size_t fps_index = 0; fps_index < entry.fps.size(); ++fps_index) {
payload += std::to_string(entry.fps[fps_index]);
if (fps_index + 1 < entry.fps.size()) {
payload.push_back(',');
}
}
if (i + 1 < entries.size()) {
payload.push_back('\n');
}
}
return payload;
}
char* copy_payload(const std::string& payload) {
char* out = reinterpret_cast<char*>(std::malloc(payload.size() + 1));
if (!out) {
set_last_error("Failed to allocate capture payload buffer");
return nullptr;
}
std::memcpy(out, payload.c_str(), payload.size() + 1);
return out;
}
int open_dshow_input_with_options(
AVFormatContext** format_ctx,
const AVInputFormat* input,
const std::string& device_name,
int width,
int height,
int fps,
int requested_format,
bool use_video_size,
bool use_framerate,
bool use_pixel_format,
std::string* attempt_desc) {
if (!format_ctx || !input) {
return AVERROR(EINVAL);
}
AVDictionary* options = nullptr;
std::vector<std::string> parts;
if (use_video_size && width > 0 && height > 0) {
std::string video_size = std::to_string(width) + "x" + std::to_string(height);
av_dict_set(&options, "video_size", video_size.c_str(), 0);
parts.push_back("video_size=" + video_size);
}
if (use_framerate && fps > 0) {
std::string framerate = std::to_string(fps);
av_dict_set(&options, "framerate", framerate.c_str(), 0);
parts.push_back("framerate=" + framerate);
}
av_dict_set(&options, "rtbufsize", "64M", 0);
parts.push_back("rtbufsize=64M");
const char* pixel_format_name = requested_pixel_format_name(requested_format);
if (use_pixel_format && pixel_format_name) {
av_dict_set(&options, "pixel_format", pixel_format_name, 0);
parts.push_back(std::string("pixel_format=") + pixel_format_name);
}
if (attempt_desc) {
*attempt_desc = parts.empty() ? "default options" : "options{";
if (!parts.empty()) {
for (size_t i = 0; i < parts.size(); ++i) {
if (i > 0) {
attempt_desc->append(", ");
}
attempt_desc->append(parts[i]);
}
attempt_desc->append("}");
}
}
std::string input_name = "video=" + device_name;
int ret = avformat_open_input(format_ctx, input_name.c_str(), input, &options);
av_dict_free(&options);
return ret;
}
class ScopedComInit {
public:
ScopedComInit() {
HRESULT hr = CoInitializeEx(nullptr, COINIT_MULTITHREADED);
initialized_ = hr == S_OK || hr == S_FALSE;
}
~ScopedComInit() {
if (initialized_) {
CoUninitialize();
}
}
private:
bool initialized_ = false;
};
int capture_stride(int pixel_format, int width) {
switch (pixel_format) {
case HWCODEC_CAPTURE_FMT_YUYV:
case HWCODEC_CAPTURE_FMT_YVYU:
case HWCODEC_CAPTURE_FMT_UYVY:
return width * 2;
case HWCODEC_CAPTURE_FMT_RGB24:
case HWCODEC_CAPTURE_FMT_BGR24:
return width * 3;
case HWCODEC_CAPTURE_FMT_NV24:
return width * 2;
case HWCODEC_CAPTURE_FMT_NV12:
case HWCODEC_CAPTURE_FMT_NV21:
case HWCODEC_CAPTURE_FMT_NV16:
case HWCODEC_CAPTURE_FMT_YUV420:
case HWCODEC_CAPTURE_FMT_YVU420:
case HWCODEC_CAPTURE_FMT_GREY:
case HWCODEC_CAPTURE_FMT_MJPEG:
case HWCODEC_CAPTURE_FMT_JPEG:
default:
return width;
}
}
int map_raw_pixfmt(int format) {
switch (format) {
case AV_PIX_FMT_YUYV422:
return HWCODEC_CAPTURE_FMT_YUYV;
case AV_PIX_FMT_UYVY422:
return HWCODEC_CAPTURE_FMT_UYVY;
#ifdef AV_PIX_FMT_YVYU422
case AV_PIX_FMT_YVYU422:
return HWCODEC_CAPTURE_FMT_YVYU;
#endif
case AV_PIX_FMT_NV12:
return HWCODEC_CAPTURE_FMT_NV12;
case AV_PIX_FMT_NV21:
return HWCODEC_CAPTURE_FMT_NV21;
#ifdef AV_PIX_FMT_NV16
case AV_PIX_FMT_NV16:
return HWCODEC_CAPTURE_FMT_NV16;
#endif
#ifdef AV_PIX_FMT_NV24
case AV_PIX_FMT_NV24:
return HWCODEC_CAPTURE_FMT_NV24;
#endif
case AV_PIX_FMT_YUV420P:
return HWCODEC_CAPTURE_FMT_YUV420;
#ifdef AV_PIX_FMT_YVU420P
case AV_PIX_FMT_YVU420P:
return HWCODEC_CAPTURE_FMT_YVU420;
#endif
case AV_PIX_FMT_RGB24:
return HWCODEC_CAPTURE_FMT_RGB24;
case AV_PIX_FMT_BGR24:
return HWCODEC_CAPTURE_FMT_BGR24;
case AV_PIX_FMT_GRAY8:
return HWCODEC_CAPTURE_FMT_GREY;
default:
return HWCODEC_CAPTURE_FMT_UNKNOWN;
}
}
int map_codec_to_capture_format(const AVCodecParameters* codecpar) {
if (!codecpar) {
return HWCODEC_CAPTURE_FMT_UNKNOWN;
}
switch (codecpar->codec_id) {
case AV_CODEC_ID_MJPEG:
return HWCODEC_CAPTURE_FMT_MJPEG;
case AV_CODEC_ID_JPEG2000:
return HWCODEC_CAPTURE_FMT_JPEG;
case AV_CODEC_ID_RAWVIDEO:
return map_raw_pixfmt(codecpar->format);
default:
return HWCODEC_CAPTURE_FMT_UNKNOWN;
}
}
int interrupt_callback(void* opaque) {
auto* ctx = reinterpret_cast<HwcodecDshowCaptureContext*>(opaque);
if (!ctx) {
return 0;
}
auto deadline = ctx->deadline_ms.load();
if (deadline <= 0) {
return 0;
}
if (now_ms() > deadline) {
ctx->timed_out.store(1);
return 1;
}
return 0;
}
const char* requested_pixel_format_name(int requested_format) {
switch (requested_format) {
case HWCODEC_CAPTURE_FMT_YUYV:
return "yuyv422";
case HWCODEC_CAPTURE_FMT_UYVY:
return "uyvy422";
case HWCODEC_CAPTURE_FMT_NV12:
return "nv12";
case HWCODEC_CAPTURE_FMT_NV21:
return "nv21";
case HWCODEC_CAPTURE_FMT_RGB24:
return "rgb24";
case HWCODEC_CAPTURE_FMT_BGR24:
return "bgr24";
case HWCODEC_CAPTURE_FMT_GREY:
return "gray";
default:
return nullptr;
}
}
} // namespace
extern "C" const char* hwcodec_capture_last_error(void) {
return g_last_error.c_str();
}
extern "C" char* hwcodec_dshow_list_video_devices(void) {
ScopedComInit com;
ICreateDevEnum* dev_enum = nullptr;
IEnumMoniker* enum_moniker = nullptr;
HRESULT hr = CoCreateInstance(
CLSID_SystemDeviceEnum,
nullptr,
CLSCTX_INPROC_SERVER,
IID_ICreateDevEnum,
reinterpret_cast<void**>(&dev_enum));
if (FAILED(hr)) {
set_last_error("Failed to create DirectShow device enumerator");
return nullptr;
}
hr = dev_enum->CreateClassEnumerator(CLSID_VideoInputDeviceCategory, &enum_moniker, 0);
dev_enum->Release();
if (hr != S_OK || !enum_moniker) {
char* out = reinterpret_cast<char*>(std::malloc(1));
if (out) {
out[0] = '\0';
}
return out;
}
std::vector<std::string> devices;
IMoniker* moniker = nullptr;
ULONG fetched = 0;
while (enum_moniker->Next(1, &moniker, &fetched) == S_OK) {
IPropertyBag* bag = nullptr;
hr = moniker->BindToStorage(nullptr, nullptr, IID_IPropertyBag, reinterpret_cast<void**>(&bag));
if (SUCCEEDED(hr) && bag) {
VARIANT name;
VariantInit(&name);
if (SUCCEEDED(bag->Read(L"FriendlyName", &name, nullptr)) && name.vt == VT_BSTR) {
auto utf8_name = wide_to_utf8(name.bstrVal);
if (!utf8_name.empty()) {
devices.push_back(utf8_name);
}
}
VariantClear(&name);
bag->Release();
}
moniker->Release();
}
enum_moniker->Release();
std::string payload;
for (size_t i = 0; i < devices.size(); ++i) {
payload += devices[i];
if (i + 1 < devices.size()) {
payload.push_back('\n');
}
}
return copy_payload(payload);
}
extern "C" char* hwcodec_dshow_list_device_capabilities(const char* device_name) {
if (!device_name || device_name[0] == '\0') {
set_last_error("DirectShow device name is empty");
return nullptr;
}
ScopedComInit com;
IBaseFilter* filter = nullptr;
if (!find_device_filter(device_name, &filter) || !filter) {
set_last_error("Failed to find DirectShow device filter");
return nullptr;
}
std::vector<DshowCapabilityEntry> entries;
IEnumPins* enum_pins = nullptr;
HRESULT hr = filter->EnumPins(&enum_pins);
if (SUCCEEDED(hr) && enum_pins) {
IPin* pin = nullptr;
ULONG fetched = 0;
while (enum_pins->Next(1, &pin, &fetched) == S_OK) {
PIN_DIRECTION direction = PINDIR_INPUT;
if (SUCCEEDED(pin->QueryDirection(&direction)) && direction == PINDIR_OUTPUT) {
IAMStreamConfig* stream_config = nullptr;
if (SUCCEEDED(pin->QueryInterface(IID_IAMStreamConfig, reinterpret_cast<void**>(&stream_config))) &&
stream_config) {
append_stream_capabilities(stream_config, &entries);
stream_config->Release();
}
}
pin->Release();
}
enum_pins->Release();
}
filter->Release();
std::sort(entries.begin(), entries.end(), [](const DshowCapabilityEntry& left, const DshowCapabilityEntry& right) {
if (left.format != right.format) {
return left.format < right.format;
}
if (left.width != right.width) {
return left.width < right.width;
}
if (left.height != right.height) {
return left.height < right.height;
}
return left.fps > right.fps;
});
entries.erase(
std::unique(entries.begin(), entries.end(), [](const DshowCapabilityEntry& left, const DshowCapabilityEntry& right) {
return left.format == right.format && left.width == right.width && left.height == right.height && left.fps == right.fps;
}),
entries.end());
return copy_payload(build_capabilities_payload(entries));
}
extern "C" void hwcodec_capture_string_free(char* ptr) {
if (ptr) {
std::free(ptr);
}
}
extern "C" HwcodecDshowCaptureContext* hwcodec_dshow_capture_open(
const char* device_name,
int width,
int height,
int fps,
int requested_format,
int timeout_ms) {
if (!device_name || device_name[0] == '\0') {
set_last_error("Device name is empty");
return nullptr;
}
avdevice_register_all();
const AVInputFormat* input = av_find_input_format("dshow");
if (!input) {
set_last_error("FFmpeg dshow input format is unavailable");
return nullptr;
}
auto* ctx = new HwcodecDshowCaptureContext();
ctx->timeout_ms = timeout_ms > 0 ? timeout_ms : 2000;
ctx->format_ctx = avformat_alloc_context();
if (!ctx->format_ctx) {
delete ctx;
set_last_error("Failed to allocate FFmpeg format context");
return nullptr;
}
ctx->format_ctx->interrupt_callback.callback = interrupt_callback;
ctx->format_ctx->interrupt_callback.opaque = ctx;
std::string open_attempt;
int ret = open_dshow_input_with_options(
&ctx->format_ctx,
input,
device_name,
width,
height,
fps,
requested_format,
true,
true,
true,
&open_attempt);
if (ret < 0) {
avformat_free_context(ctx->format_ctx);
ctx->format_ctx = avformat_alloc_context();
if (!ctx->format_ctx) {
delete ctx;
set_last_error("Failed to allocate FFmpeg format context for fallback open");
return nullptr;
}
ctx->format_ctx->interrupt_callback.callback = interrupt_callback;
ctx->format_ctx->interrupt_callback.opaque = ctx;
std::string fallback_attempt;
ret = open_dshow_input_with_options(
&ctx->format_ctx,
input,
device_name,
width,
height,
fps,
requested_format,
true,
false,
true,
&fallback_attempt);
if (ret >= 0) {
open_attempt = fallback_attempt;
}
}
if (ret < 0) {
avformat_free_context(ctx->format_ctx);
ctx->format_ctx = avformat_alloc_context();
if (!ctx->format_ctx) {
delete ctx;
set_last_error("Failed to allocate FFmpeg format context for final fallback open");
return nullptr;
}
ctx->format_ctx->interrupt_callback.callback = interrupt_callback;
ctx->format_ctx->interrupt_callback.opaque = ctx;
std::string fallback_attempt;
ret = open_dshow_input_with_options(
&ctx->format_ctx,
input,
device_name,
width,
height,
fps,
requested_format,
false,
false,
false,
&fallback_attempt);
if (ret >= 0) {
open_attempt = fallback_attempt;
}
}
if (ret < 0) {
set_last_error("Failed to open dshow input (" + open_attempt + "): " + ffmpeg_error(ret));
avformat_free_context(ctx->format_ctx);
delete ctx;
return nullptr;
}
ret = avformat_find_stream_info(ctx->format_ctx, nullptr);
if (ret < 0) {
set_last_error("Failed to read stream info: " + ffmpeg_error(ret));
avformat_close_input(&ctx->format_ctx);
delete ctx;
return nullptr;
}
for (unsigned int i = 0; i < ctx->format_ctx->nb_streams; ++i) {
AVStream* stream = ctx->format_ctx->streams[i];
if (stream && stream->codecpar && stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
ctx->stream_index = static_cast<int>(i);
ctx->width = stream->codecpar->width > 0 ? stream->codecpar->width : width;
ctx->height = stream->codecpar->height > 0 ? stream->codecpar->height : height;
ctx->pixel_format = map_codec_to_capture_format(stream->codecpar);
ctx->stride = capture_stride(ctx->pixel_format, ctx->width);
break;
}
}
if (ctx->stream_index < 0) {
set_last_error("No video stream found on DirectShow device");
avformat_close_input(&ctx->format_ctx);
delete ctx;
return nullptr;
}
if (ctx->pixel_format == HWCODEC_CAPTURE_FMT_UNKNOWN) {
set_last_error("DirectShow stream format is unsupported in current Windows backend");
avformat_close_input(&ctx->format_ctx);
delete ctx;
return nullptr;
}
return ctx;
}
extern "C" int hwcodec_dshow_capture_info(
HwcodecDshowCaptureContext* ctx,
HwcodecCaptureStreamInfo* out_info) {
if (!ctx || !out_info) {
set_last_error("Invalid capture context");
return -1;
}
out_info->width = ctx->width;
out_info->height = ctx->height;
out_info->pixel_format = ctx->pixel_format;
out_info->stride = ctx->stride;
return 0;
}
extern "C" int hwcodec_dshow_capture_read(
HwcodecDshowCaptureContext* ctx,
uint8_t** out_data,
int* out_len,
uint64_t* out_sequence) {
if (!ctx || !out_data || !out_len || !out_sequence) {
set_last_error("Invalid capture read arguments");
return -1;
}
*out_data = nullptr;
*out_len = 0;
*out_sequence = 0;
AVPacket packet;
av_init_packet(&packet);
packet.data = nullptr;
packet.size = 0;
while (true) {
ctx->timed_out.store(0);
ctx->deadline_ms.store(now_ms() + ctx->timeout_ms);
int ret = av_read_frame(ctx->format_ctx, &packet);
ctx->deadline_ms.store(0);
if (ret < 0) {
if (ctx->timed_out.load() != 0) {
set_last_error("Timed out waiting for frame");
return -110;
}
set_last_error("Failed to read frame: " + ffmpeg_error(ret));
return ret;
}
if (packet.stream_index != ctx->stream_index) {
av_packet_unref(&packet);
continue;
}
if (packet.size <= 0 || !packet.data) {
av_packet_unref(&packet);
continue;
}
auto* buffer = reinterpret_cast<uint8_t*>(std::malloc(static_cast<size_t>(packet.size)));
if (!buffer) {
av_packet_unref(&packet);
set_last_error("Failed to allocate packet buffer");
return -12;
}
std::memcpy(buffer, packet.data, static_cast<size_t>(packet.size));
*out_data = buffer;
*out_len = packet.size;
*out_sequence = ctx->sequence++;
av_packet_unref(&packet);
return 0;
}
}
extern "C" void hwcodec_dshow_capture_packet_free(uint8_t* data) {
if (data) {
std::free(data);
}
}
extern "C" void hwcodec_dshow_capture_close(HwcodecDshowCaptureContext* ctx) {
if (!ctx) {
return;
}
if (ctx->format_ctx) {
avformat_close_input(&ctx->format_ctx);
}
delete ctx;
}

View File

@@ -0,0 +1,64 @@
#ifndef HWCODEC_FFMPEG_CAPTURE_FFI_H
#define HWCODEC_FFMPEG_CAPTURE_FFI_H
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef struct HwcodecDshowCaptureContext HwcodecDshowCaptureContext;
enum HwcodecCapturePixelFormat {
HWCODEC_CAPTURE_FMT_UNKNOWN = 0,
HWCODEC_CAPTURE_FMT_MJPEG = 1,
HWCODEC_CAPTURE_FMT_JPEG = 2,
HWCODEC_CAPTURE_FMT_YUYV = 3,
HWCODEC_CAPTURE_FMT_YVYU = 4,
HWCODEC_CAPTURE_FMT_UYVY = 5,
HWCODEC_CAPTURE_FMT_NV12 = 6,
HWCODEC_CAPTURE_FMT_NV21 = 7,
HWCODEC_CAPTURE_FMT_NV16 = 8,
HWCODEC_CAPTURE_FMT_NV24 = 9,
HWCODEC_CAPTURE_FMT_YUV420 = 10,
HWCODEC_CAPTURE_FMT_YVU420 = 11,
HWCODEC_CAPTURE_FMT_RGB24 = 12,
HWCODEC_CAPTURE_FMT_BGR24 = 13,
HWCODEC_CAPTURE_FMT_GREY = 14,
};
typedef struct HwcodecCaptureStreamInfo {
int width;
int height;
int pixel_format;
int stride;
} HwcodecCaptureStreamInfo;
const char* hwcodec_capture_last_error(void);
char* hwcodec_dshow_list_video_devices(void);
char* hwcodec_dshow_list_device_capabilities(const char* device_name);
void hwcodec_capture_string_free(char* ptr);
HwcodecDshowCaptureContext* hwcodec_dshow_capture_open(
const char* device_name,
int width,
int height,
int fps,
int requested_format,
int timeout_ms);
int hwcodec_dshow_capture_info(
HwcodecDshowCaptureContext* ctx,
HwcodecCaptureStreamInfo* out_info);
int hwcodec_dshow_capture_read(
HwcodecDshowCaptureContext* ctx,
uint8_t** out_data,
int* out_len,
uint64_t* out_sequence);
void hwcodec_dshow_capture_packet_free(uint8_t* data);
void hwcodec_dshow_capture_close(HwcodecDshowCaptureContext* ctx);
#ifdef __cplusplus
}
#endif
#endif

297
libs/hwcodec/src/capture.rs Normal file
View File

@@ -0,0 +1,297 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use std::ffi::{CStr, CString};
use std::os::raw::c_int;
include!(concat!(env!("OUT_DIR"), "/ffmpeg_capture_ffi.rs"));
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CapturePixelFormat {
Unknown,
Mjpeg,
Jpeg,
Yuyv,
Yvyu,
Uyvy,
Nv12,
Nv21,
Nv16,
Nv24,
Yuv420,
Yvu420,
Rgb24,
Bgr24,
Grey,
}
impl CapturePixelFormat {
pub fn to_ffi(self) -> c_int {
match self {
Self::Unknown => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UNKNOWN as c_int,
Self::Mjpeg => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_MJPEG as c_int,
Self::Jpeg => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_JPEG as c_int,
Self::Yuyv => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUYV as c_int,
Self::Yvyu => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVYU as c_int,
Self::Uyvy => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UYVY as c_int,
Self::Nv12 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV12 as c_int,
Self::Nv21 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV21 as c_int,
Self::Nv16 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV16 as c_int,
Self::Nv24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV24 as c_int,
Self::Yuv420 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUV420 as c_int,
Self::Yvu420 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVU420 as c_int,
Self::Rgb24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_RGB24 as c_int,
Self::Bgr24 => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_BGR24 as c_int,
Self::Grey => HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_GREY as c_int,
}
}
pub fn from_ffi(value: c_int) -> Self {
match value {
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_MJPEG as c_int => Self::Mjpeg,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_JPEG as c_int => Self::Jpeg,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUYV as c_int => Self::Yuyv,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVYU as c_int => Self::Yvyu,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UYVY as c_int => Self::Uyvy,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV12 as c_int => Self::Nv12,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV21 as c_int => Self::Nv21,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV16 as c_int => Self::Nv16,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_NV24 as c_int => Self::Nv24,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YUV420 as c_int => {
Self::Yuv420
}
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_YVU420 as c_int => {
Self::Yvu420
}
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_RGB24 as c_int => Self::Rgb24,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_BGR24 as c_int => Self::Bgr24,
x if x == HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_GREY as c_int => Self::Grey,
_ => Self::Unknown,
}
}
pub fn from_name(name: &str) -> Option<Self> {
match name.trim().to_ascii_uppercase().as_str() {
"MJPEG" | "MJPG" => Some(Self::Mjpeg),
"JPEG" => Some(Self::Jpeg),
"YUYV" => Some(Self::Yuyv),
"YVYU" => Some(Self::Yvyu),
"UYVY" => Some(Self::Uyvy),
"NV12" => Some(Self::Nv12),
"NV21" => Some(Self::Nv21),
"NV16" => Some(Self::Nv16),
"NV24" => Some(Self::Nv24),
"YUV420" | "I420" | "IYUV" => Some(Self::Yuv420),
"YVU420" | "YV12" => Some(Self::Yvu420),
"RGB24" => Some(Self::Rgb24),
"BGR24" => Some(Self::Bgr24),
"GREY" | "GRAY" | "Y800" => Some(Self::Grey),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct DshowCapability {
pub format: CapturePixelFormat,
pub width: u32,
pub height: u32,
pub fps: Vec<u32>,
}
#[derive(Debug, Clone, Copy)]
pub struct CaptureStreamInfo {
pub width: i32,
pub height: i32,
pub pixel_format: CapturePixelFormat,
pub stride: i32,
}
#[derive(Debug)]
pub struct CaptureError {
pub code: i32,
pub message: String,
}
impl std::fmt::Display for CaptureError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for CaptureError {}
fn last_error_message() -> String {
unsafe {
let ptr = hwcodec_capture_last_error();
if ptr.is_null() {
return String::new();
}
CStr::from_ptr(ptr).to_string_lossy().to_string()
}
}
pub fn list_dshow_video_devices() -> Result<Vec<String>, CaptureError> {
unsafe {
let ptr = hwcodec_dshow_list_video_devices();
if ptr.is_null() {
return Err(CaptureError {
code: -1,
message: last_error_message(),
});
}
let payload = CStr::from_ptr(ptr).to_string_lossy().to_string();
hwcodec_capture_string_free(ptr as *mut _);
Ok(payload
.lines()
.map(str::trim)
.filter(|line| !line.is_empty())
.map(ToOwned::to_owned)
.collect())
}
}
pub fn list_dshow_device_capabilities(device_name: &str) -> Result<Vec<DshowCapability>, CaptureError> {
let device_name = CString::new(device_name).map_err(|_| CaptureError {
code: -1,
message: "device name contains NUL byte".to_string(),
})?;
unsafe {
let ptr = hwcodec_dshow_list_device_capabilities(device_name.as_ptr());
if ptr.is_null() {
return Err(CaptureError {
code: -1,
message: last_error_message(),
});
}
let payload = CStr::from_ptr(ptr).to_string_lossy().to_string();
hwcodec_capture_string_free(ptr as *mut _);
let capabilities = payload
.lines()
.filter_map(parse_dshow_capability_line)
.collect();
Ok(capabilities)
}
}
fn parse_dshow_capability_line(line: &str) -> Option<DshowCapability> {
let mut parts = line.split('|');
let format = CapturePixelFormat::from_name(parts.next()?.trim())?;
let width = parts.next()?.trim().parse::<u32>().ok()?;
let height = parts.next()?.trim().parse::<u32>().ok()?;
let fps = parts
.next()
.unwrap_or_default()
.split(',')
.filter_map(|value| value.trim().parse::<u32>().ok())
.filter(|value| *value > 0)
.collect::<Vec<_>>();
Some(DshowCapability {
format,
width,
height,
fps,
})
}
pub struct DshowCapture {
ctx: *mut HwcodecDshowCaptureContext,
}
unsafe impl Send for DshowCapture {}
impl DshowCapture {
pub fn open(
device_name: &str,
width: i32,
height: i32,
fps: i32,
requested_format: CapturePixelFormat,
timeout_ms: i32,
) -> Result<Self, CaptureError> {
let device_name = CString::new(device_name).map_err(|_| CaptureError {
code: -1,
message: "device name contains NUL byte".to_string(),
})?;
unsafe {
let ctx = hwcodec_dshow_capture_open(
device_name.as_ptr(),
width,
height,
fps,
requested_format.to_ffi(),
timeout_ms,
);
if ctx.is_null() {
return Err(CaptureError {
code: -1,
message: last_error_message(),
});
}
Ok(Self { ctx })
}
}
pub fn info(&self) -> Result<CaptureStreamInfo, CaptureError> {
unsafe {
let mut info = HwcodecCaptureStreamInfo {
width: 0,
height: 0,
pixel_format: HwcodecCapturePixelFormat::HWCODEC_CAPTURE_FMT_UNKNOWN as c_int,
stride: 0,
};
let ret = hwcodec_dshow_capture_info(self.ctx, &mut info);
if ret != 0 {
return Err(CaptureError {
code: ret,
message: last_error_message(),
});
}
Ok(CaptureStreamInfo {
width: info.width,
height: info.height,
pixel_format: CapturePixelFormat::from_ffi(info.pixel_format),
stride: info.stride,
})
}
}
pub fn read_packet(&mut self) -> Result<(Vec<u8>, u64), CaptureError> {
unsafe {
let mut data = std::ptr::null_mut();
let mut len = 0;
let mut sequence = 0u64;
let ret = hwcodec_dshow_capture_read(self.ctx, &mut data, &mut len, &mut sequence);
if ret != 0 {
return Err(CaptureError {
code: ret,
message: last_error_message(),
});
}
if data.is_null() || len <= 0 {
return Err(CaptureError {
code: -1,
message: "empty packet returned by capture backend".to_string(),
});
}
let slice = std::slice::from_raw_parts(data, len as usize);
let vec = slice.to_vec();
hwcodec_dshow_capture_packet_free(data);
Ok((vec, sequence))
}
}
}
impl Drop for DshowCapture {
fn drop(&mut self) {
unsafe {
hwcodec_dshow_capture_close(self.ctx);
}
self.ctx = std::ptr::null_mut();
}
}

View File

@@ -257,7 +257,13 @@ struct ProbePolicy {
impl ProbePolicy {
fn for_codec(codec_name: &str) -> Self {
if codec_name.contains("v4l2m2m") {
if codec_name.contains("amf") {
Self {
max_attempts: 5,
request_keyframe: true,
accept_any_output: true,
}
} else if codec_name.contains("v4l2m2m") {
Self {
max_attempts: 5,
request_keyframe: true,

View File

@@ -1,3 +1,5 @@
#[cfg(windows)]
pub mod capture;
pub mod common;
pub mod ffmpeg;
#[cfg(any(target_arch = "aarch64", target_arch = "arm", feature = "rkmpp"))]

View File

@@ -45,4 +45,4 @@ pub use error::{Result, VentoyError};
pub use exfat::FileInfo;
pub use image::VentoyImage;
pub use partition::{parse_size, PartitionLayout};
pub use resources::{get_resource_dir, init_resources, is_initialized, required_files};
pub use resources::{init_resources, is_initialized, required_files};

View File

@@ -5,7 +5,7 @@
use crate::error::{Result, VentoyError};
use std::fs;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::OnceLock;
/// Resource file names
@@ -151,13 +151,6 @@ pub fn get_ventoy_disk_img() -> Result<&'static [u8]> {
})
}
/// Get the resource directory path for a given data directory
///
/// Returns `{data_dir}/ventoy`
pub fn get_resource_dir(data_dir: &Path) -> PathBuf {
data_dir.join("ventoy")
}
/// List required resource files
pub fn required_files() -> &'static [&'static str] {
&[BOOT_IMG_NAME, CORE_IMG_NAME, VENTOY_DISK_IMG_NAME]
@@ -166,22 +159,6 @@ pub fn required_files() -> &'static [&'static str] {
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn create_test_resources(dir: &Path) {
// Create boot.img (512 bytes)
let mut boot = std::fs::File::create(dir.join(BOOT_IMG_NAME)).unwrap();
boot.write_all(&[0u8; 512]).unwrap();
// Create core.img (fake, 1KB)
let mut core = std::fs::File::create(dir.join(CORE_IMG_NAME)).unwrap();
core.write_all(&[0u8; 1024]).unwrap();
// Create ventoy.disk.img (fake, 1KB)
let mut ventoy = std::fs::File::create(dir.join(VENTOY_DISK_IMG_NAME)).unwrap();
ventoy.write_all(&[0u8; 1024]).unwrap();
}
#[test]
fn test_required_files() {
@@ -191,11 +168,4 @@ mod tests {
assert!(files.contains(&"core.img"));
assert!(files.contains(&"ventoy.disk.img"));
}
#[test]
fn test_get_resource_dir() {
let data_dir = Path::new("/var/lib/one-kvm");
let resource_dir = get_resource_dir(data_dir);
assert_eq!(resource_dir, PathBuf::from("/var/lib/one-kvm/ventoy"));
}
}

View File

@@ -8,5 +8,4 @@ license = "BSD-3-Clause"
[dependencies]
[build-dependencies]
cc = "1.0"
bindgen = "0.59"

View File

@@ -82,8 +82,8 @@ fn generate_bindings(cpp_dir: &Path) {
fn link_libyuv() {
// Try vcpkg first
if let Ok(vcpkg_root) = env::var("VCPKG_ROOT") {
if link_vcpkg(vcpkg_root.into()) {
if let Some(vcpkg_installed) = vcpkg_installed_root() {
if link_vcpkg(vcpkg_installed) {
return;
}
}
@@ -109,6 +109,22 @@ fn link_libyuv() {
);
}
fn vcpkg_installed_root() -> Option<PathBuf> {
println!("cargo:rerun-if-env-changed=VCPKG_INSTALLED_DIR");
println!("cargo:rerun-if-env-changed=VCPKG_ROOT");
if let Ok(path) = env::var("VCPKG_INSTALLED_DIR") {
if !path.trim().is_empty() {
return Some(PathBuf::from(path));
}
}
env::var("VCPKG_ROOT")
.ok()
.filter(|path| !path.trim().is_empty())
.map(|path| PathBuf::from(path).join("installed"))
}
fn link_vcpkg(mut path: PathBuf) -> bool {
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default();
let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
@@ -130,7 +146,6 @@ fn link_vcpkg(mut path: PathBuf) -> bool {
}
};
path.push("installed");
path.push(triplet);
let include_path = path.join("include");
@@ -154,11 +169,13 @@ fn link_vcpkg(mut path: PathBuf) -> bool {
if use_static && static_lib.exists() {
// Static linking (for deb packaging)
println!("cargo:rustc-link-lib=static=yuv");
#[cfg(target_os = "linux")]
println!("cargo:rustc-link-lib=stdc++");
println!("cargo:info=Using libyuv from vcpkg (static linking)");
} else {
// Dynamic linking (default for development)
println!("cargo:rustc-link-lib=yuv");
#[cfg(target_os = "linux")]
println!("cargo:rustc-link-lib=stdc++");
println!("cargo:info=Using libyuv from vcpkg (dynamic linking)");
}

View File

@@ -11,20 +11,14 @@ use super::led::LedSensor;
use super::types::{AtxAction, AtxKeyConfig, AtxLedConfig, AtxState, PowerStatus};
use crate::error::{AppError, Result};
/// ATX power control configuration
#[derive(Debug, Clone, Default)]
pub struct AtxControllerConfig {
/// Whether ATX is enabled
pub enabled: bool,
/// Power button configuration (used for both short and long press)
pub power: AtxKeyConfig,
/// Reset button configuration
pub reset: AtxKeyConfig,
/// LED sensing configuration
pub led: AtxLedConfig,
}
/// Internal state holding all ATX components
/// Grouped together to reduce lock acquisitions
struct AtxInner {
config: AtxControllerConfig,
@@ -33,12 +27,9 @@ struct AtxInner {
led_sensor: Option<LedSensor>,
}
/// ATX Controller
///
/// Manages ATX power control through independent executors for each action.
/// Supports hot-reload of configuration.
pub struct AtxController {
/// Single lock for all internal state to reduce lock contention
inner: RwLock<AtxInner>,
}
@@ -53,6 +44,24 @@ impl AtxController {
&& power.baud_rate == reset.baud_rate
}
async fn init_key_executor(
warn_label: &str,
info_label: &str,
config: AtxKeyConfig,
mut executor: AtxKeyExecutor,
) -> Option<AtxKeyExecutor> {
if let Err(e) = executor.init().await {
warn!("Failed to initialize {} executor: {}", warn_label, e);
return None;
}
info!(
"{} executor initialized: {:?} on {} pin {}",
info_label, config.driver, config.device, config.pin
);
Some(executor)
}
async fn init_components(inner: &mut AtxInner) {
if Self::should_share_serial_device(&inner.config.power, &inner.config.reset) {
match AtxKeyExecutor::open_shared_serial(
@@ -60,36 +69,28 @@ impl AtxController {
inner.config.power.baud_rate,
) {
Ok(shared_serial) => {
let mut power_executor = AtxKeyExecutor::new_with_shared_serial(
inner.config.power.clone(),
shared_serial.clone(),
);
if let Err(e) = power_executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver,
inner.config.power.device,
inner.config.power.pin
for (slot, warn_label, info_label, config, serial) in [
(
&mut inner.power_executor,
"power",
"Power",
inner.config.power.clone(),
shared_serial.clone(),
),
(
&mut inner.reset_executor,
"reset",
"Reset",
inner.config.reset.clone(),
shared_serial,
),
] {
let executor = AtxKeyExecutor::new_with_shared_serial(
config.clone(),
serial,
);
inner.power_executor = Some(power_executor);
}
let mut reset_executor = AtxKeyExecutor::new_with_shared_serial(
inner.config.reset.clone(),
shared_serial,
);
if let Err(e) = reset_executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver,
inner.config.reset.device,
inner.config.reset.pin
);
inner.reset_executor = Some(reset_executor);
*slot = Self::init_key_executor(warn_label, info_label, config, executor)
.await;
}
}
Err(e) => {
@@ -100,40 +101,18 @@ impl AtxController {
}
}
} else {
// Initialize power executor
if inner.config.power.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.power.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver,
inner.config.power.device,
inner.config.power.pin
);
inner.power_executor = Some(executor);
}
}
// Initialize reset executor
if inner.config.reset.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.reset.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver,
inner.config.reset.device,
inner.config.reset.pin
);
inner.reset_executor = Some(executor);
for (slot, warn_label, info_label, config) in [
(&mut inner.power_executor, "power", "Power", inner.config.power.clone()),
(&mut inner.reset_executor, "reset", "Reset", inner.config.reset.clone()),
] {
if config.is_configured() {
let executor = AtxKeyExecutor::new(config.clone());
*slot = Self::init_key_executor(warn_label, info_label, config, executor)
.await;
}
}
}
// Initialize LED sensor
if inner.config.led.is_configured() {
let mut sensor = LedSensor::new(inner.config.led.clone());
if let Err(e) = sensor.init().await {
@@ -149,19 +128,17 @@ impl AtxController {
}
async fn shutdown_components(inner: &mut AtxInner) {
if let Some(executor) = inner.power_executor.as_mut() {
if let Err(e) = executor.shutdown().await {
warn!("Failed to shutdown power executor: {}", e);
for (slot, label) in [
(&mut inner.power_executor, "power"),
(&mut inner.reset_executor, "reset"),
] {
if let Some(executor) = slot.as_mut() {
if let Err(e) = executor.shutdown().await {
warn!("Failed to shutdown {} executor: {}", label, e);
}
}
*slot = None;
}
inner.power_executor = None;
if let Some(executor) = inner.reset_executor.as_mut() {
if let Err(e) = executor.shutdown().await {
warn!("Failed to shutdown reset executor: {}", e);
}
}
inner.reset_executor = None;
if let Some(sensor) = inner.led_sensor.as_mut() {
if let Err(e) = sensor.shutdown().await {
@@ -171,7 +148,20 @@ impl AtxController {
inner.led_sensor = None;
}
/// Create a new ATX controller with the specified configuration
async fn read_power_status(sensor: Option<&LedSensor>) -> PowerStatus {
let Some(sensor) = sensor else {
return PowerStatus::Unknown;
};
match sensor.read().await {
Ok(status) => status,
Err(e) => {
debug!("Failed to read ATX LED sensor: {}", e);
PowerStatus::Unknown
}
}
}
pub fn new(config: AtxControllerConfig) -> Self {
Self {
inner: RwLock::new(AtxInner {
@@ -183,12 +173,10 @@ impl AtxController {
}
}
/// Create a disabled ATX controller
pub fn disabled() -> Self {
Self::new(AtxControllerConfig::default())
}
/// Initialize the ATX controller and its executors
pub async fn init(&self) -> Result<()> {
let mut inner = self.inner.write().await;
@@ -204,7 +192,6 @@ impl AtxController {
Ok(())
}
/// Reload ATX controller configuration
pub async fn reload(&self, config: AtxControllerConfig) -> Result<()> {
let mut inner = self.inner.write().await;
@@ -225,7 +212,6 @@ impl AtxController {
Ok(())
}
/// Shutdown ATX controller and release all resources
pub async fn shutdown(&self) -> Result<()> {
let mut inner = self.inner.write().await;
Self::shutdown_components(&mut inner).await;
@@ -233,86 +219,48 @@ impl AtxController {
Ok(())
}
/// Trigger a power action (short/long/reset)
pub async fn trigger_power_action(&self, action: AtxAction) -> Result<()> {
let inner = self.inner.read().await;
match action {
AtxAction::Short | AtxAction::Long => {
if let Some(executor) = &inner.power_executor {
let duration = match action {
AtxAction::Short => timing::SHORT_PRESS,
AtxAction::Long => timing::LONG_PRESS,
_ => unreachable!(),
};
executor.pulse(duration).await?;
} else {
return Err(AppError::Config(
"Power button not configured for ATX controller".to_string(),
));
}
}
AtxAction::Reset => {
if let Some(executor) = &inner.reset_executor {
executor.pulse(timing::RESET_PRESS).await?;
} else {
return Err(AppError::Config(
"Reset button not configured for ATX controller".to_string(),
));
}
}
}
let (executor, duration) = match action {
AtxAction::Short => (inner.power_executor.as_ref(), timing::SHORT_PRESS),
AtxAction::Long => (inner.power_executor.as_ref(), timing::LONG_PRESS),
AtxAction::Reset => (inner.reset_executor.as_ref(), timing::RESET_PRESS),
};
let Some(executor) = executor else {
return Err(AppError::Config(match action {
AtxAction::Reset => "Reset button not configured for ATX controller",
_ => "Power button not configured for ATX controller",
}
.to_string()));
};
executor.pulse(duration).await?;
Ok(())
}
/// Trigger a short power button press
pub async fn power_short(&self) -> Result<()> {
self.trigger_power_action(AtxAction::Short).await
}
/// Trigger a long power button press
pub async fn power_long(&self) -> Result<()> {
self.trigger_power_action(AtxAction::Long).await
}
/// Trigger a reset button press
pub async fn reset(&self) -> Result<()> {
self.trigger_power_action(AtxAction::Reset).await
}
/// Get the current power status using the LED sensor (if configured)
pub async fn power_status(&self) -> PowerStatus {
let inner = self.inner.read().await;
if let Some(sensor) = &inner.led_sensor {
match sensor.read().await {
Ok(status) => status,
Err(e) => {
debug!("Failed to read ATX LED sensor: {}", e);
PowerStatus::Unknown
}
}
} else {
PowerStatus::Unknown
}
Self::read_power_status(inner.led_sensor.as_ref()).await
}
/// Get a snapshot of the ATX state for API responses
pub async fn state(&self) -> AtxState {
let inner = self.inner.read().await;
let power_status = if let Some(sensor) = &inner.led_sensor {
match sensor.read().await {
Ok(status) => status,
Err(e) => {
debug!("Failed to read ATX LED sensor: {}", e);
PowerStatus::Unknown
}
}
} else {
PowerStatus::Unknown
};
let power_status = Self::read_power_status(inner.led_sensor.as_ref()).await;
AtxState {
available: inner.config.enabled,

34
src/atx/disabled_key.rs Normal file
View File

@@ -0,0 +1,34 @@
use async_trait::async_trait;
use std::time::Duration;
use super::traits::AtxKeyBackend;
use crate::error::{AppError, Result};
pub struct DisabledAtxKeyBackend {
reason: &'static str,
}
impl DisabledAtxKeyBackend {
pub fn new(reason: &'static str) -> Self {
Self { reason }
}
}
#[async_trait]
impl AtxKeyBackend for DisabledAtxKeyBackend {
async fn init(&mut self) -> Result<()> {
Err(AppError::Internal(self.reason.to_string()))
}
async fn pulse(&self, _duration: Duration) -> Result<()> {
Err(AppError::Internal(self.reason.to_string()))
}
async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
fn is_initialized(&self) -> bool {
false
}
}

34
src/atx/disabled_led.rs Normal file
View File

@@ -0,0 +1,34 @@
#![allow(dead_code)]
use super::types::{AtxLedConfig, PowerStatus};
use crate::error::Result;
pub struct LedSensor {
config: AtxLedConfig,
}
impl LedSensor {
pub fn new(config: AtxLedConfig) -> Self {
Self { config }
}
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
pub fn is_initialized(&self) -> bool {
false
}
pub async fn init(&mut self) -> Result<()> {
Ok(())
}
pub async fn read(&self) -> Result<PowerStatus> {
Ok(PowerStatus::Unknown)
}
pub async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
}

View File

@@ -1,441 +1,150 @@
//! ATX Key Executor
//!
//! Lightweight executor for a single ATX key operation.
//! Each executor handles one button (power or reset) with its own hardware binding.
//! ATX key executor backend selector.
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use serialport::SerialPort;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use tracing::debug;
use super::types::{ActiveLevel, AtxDriverType, AtxKeyConfig};
use super::serial_relay::SerialRelayBackend;
use super::traits::{AtxKeyBackend, AtxKeyBackendContext, SharedSerialHandle};
use super::types::{AtxDriverType, AtxKeyConfig};
use crate::error::{AppError, Result};
pub type SharedSerialHandle = Arc<Mutex<Box<dyn SerialPort>>>;
/// Timing constants for ATX operations
pub mod timing {
use std::time::Duration;
/// Short press duration (power on/graceful shutdown)
pub const SHORT_PRESS: Duration = Duration::from_millis(500);
/// Long press duration (force power off)
pub const LONG_PRESS: Duration = Duration::from_millis(5000);
/// Reset press duration
pub const RESET_PRESS: Duration = Duration::from_millis(500);
}
/// Executor for a single ATX key operation
///
/// Each executor manages one hardware button (power or reset).
/// It handles both GPIO and USB relay backends.
pub struct AtxKeyExecutor {
config: AtxKeyConfig,
gpio_handle: Mutex<Option<LineHandle>>,
/// Cached USB relay file handle to avoid repeated open/close syscalls
usb_relay_handle: Mutex<Option<File>>,
/// Cached Serial port handle (can be shared across power/reset executors)
serial_handle: Mutex<Option<SharedSerialHandle>>,
initialized: AtomicBool,
backend: Option<Box<dyn AtxKeyBackend>>,
}
impl AtxKeyExecutor {
/// Create a new executor with the given configuration
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
gpio_handle: Mutex::new(None),
usb_relay_handle: Mutex::new(None),
serial_handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
Self::with_context(config, AtxKeyBackendContext::Standalone)
}
/// Create a new executor with a pre-opened shared serial handle.
pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self {
Self {
config,
gpio_handle: Mutex::new(None),
usb_relay_handle: Mutex::new(None),
serial_handle: Mutex::new(Some(serial_handle)),
initialized: AtomicBool::new(false),
}
Self::with_context(config, AtxKeyBackendContext::SharedSerial(serial_handle))
}
/// Open a serial relay device and wrap it for shared use.
pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result<SharedSerialHandle> {
let port = serialport::new(device, baud_rate)
.timeout(Duration::from_millis(100))
.open()
.map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?;
Ok(Arc::new(Mutex::new(port)))
SerialRelayBackend::open_shared_serial(device, baud_rate)
}
fn with_context(config: AtxKeyConfig, context: AtxKeyBackendContext) -> Self {
let backend = build_backend(&config, context);
Self { config, backend }
}
/// Check if this executor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
/// Check if this executor is initialized
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
/// Initialize the executor
pub async fn init(&mut self) -> Result<()> {
if !self.config.is_configured() {
debug!("ATX key executor not configured, skipping init");
return Ok(());
}
self.validate_runtime_config()?;
match self.config.driver {
AtxDriverType::Gpio => self.init_gpio().await?,
AtxDriverType::UsbRelay => self.init_usb_relay().await?,
AtxDriverType::Serial => self.init_serial().await?,
AtxDriverType::None => {}
}
self.initialized.store(true, Ordering::Relaxed);
Ok(())
}
fn validate_runtime_config(&self) -> Result<()> {
match self.config.driver {
AtxDriverType::Serial => {
if self.config.pin == 0 {
return Err(AppError::Config(
"Serial ATX channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > u8::MAX as u32 {
return Err(AppError::Config(format!(
"Serial ATX channel must be <= {}",
u8::MAX
)));
}
if self.config.baud_rate == 0 {
return Err(AppError::Config(
"Serial ATX baud_rate must be greater than 0".to_string(),
));
}
}
AtxDriverType::UsbRelay => {
if self.config.pin > u8::MAX as u32 {
return Err(AppError::Config(format!(
"USB relay channel must be <= {}",
u8::MAX
)));
}
}
AtxDriverType::Gpio | AtxDriverType::None => {}
}
Ok(())
}
/// Initialize GPIO backend
async fn init_gpio(&mut self) -> Result<()> {
info!(
"Initializing GPIO ATX executor on {} pin {}",
self.config.device, self.config.pin
);
let mut chip = Chip::new(&self.config.device)
.map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?;
let line = chip.get_line(self.config.pin).map_err(|e| {
AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e))
let backend = self.backend.as_mut().ok_or_else(|| {
AppError::Internal(format!(
"ATX backend {:?} is unsupported on this platform",
self.config.driver
))
})?;
// Initial value depends on active level (start in inactive state)
let initial_value = match self.config.active_level {
ActiveLevel::High => 0, // Inactive = low
ActiveLevel::Low => 1, // Inactive = high
};
let handle = line
.request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx")
.map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?;
*self.gpio_handle.lock().unwrap() = Some(handle);
debug!("GPIO pin {} configured successfully", self.config.pin);
Ok(())
backend.init().await
}
/// Initialize USB relay backend
async fn init_usb_relay(&self) -> Result<()> {
info!(
"Initializing USB relay ATX executor on {} channel {}",
self.config.device, self.config.pin
);
// Open and cache the device handle
let device = OpenOptions::new()
.read(true)
.write(true)
.open(&self.config.device)
.map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?;
*self.usb_relay_handle.lock().unwrap() = Some(device);
// Ensure relay is off initially
self.send_usb_relay_command(false)?;
debug!(
"USB relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
/// Initialize Serial relay backend
async fn init_serial(&self) -> Result<()> {
info!(
"Initializing Serial relay ATX executor on {} channel {}",
self.config.device, self.config.pin
);
let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned();
if existing_handle.is_none() {
let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?;
*self.serial_handle.lock().unwrap() = Some(shared);
}
// Ensure relay is off initially
self.send_serial_relay_command(false)?;
debug!(
"Serial relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
/// Pulse the button for the specified duration
pub async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_configured() {
return Err(AppError::Internal("ATX key not configured".to_string()));
}
if !self.is_initialized() {
let backend = self.backend.as_ref().ok_or_else(|| {
AppError::Internal(format!(
"ATX backend {:?} is unsupported on this platform",
self.config.driver
))
})?;
if !backend.is_initialized() {
return Err(AppError::Internal("ATX key not initialized".to_string()));
}
match self.config.driver {
AtxDriverType::Gpio => self.pulse_gpio(duration).await,
AtxDriverType::UsbRelay => self.pulse_usb_relay(duration).await,
AtxDriverType::Serial => self.pulse_serial(duration).await,
AtxDriverType::None => Ok(()),
}
backend.pulse(duration).await
}
/// Pulse GPIO pin
async fn pulse_gpio(&self, duration: Duration) -> Result<()> {
let (active, inactive) = match self.config.active_level {
ActiveLevel::High => (1u8, 0u8),
ActiveLevel::Low => (0u8, 1u8),
};
// Set to active state
{
let guard = self.gpio_handle.lock().unwrap();
let handle = guard
.as_ref()
.ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?;
handle
.set_value(active)
.map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?;
}
// Wait for duration (no lock held)
sleep(duration).await;
// Set to inactive state
{
let guard = self.gpio_handle.lock().unwrap();
if let Some(handle) = guard.as_ref() {
handle.set_value(inactive).ok();
}
}
Ok(())
}
/// Pulse USB relay
async fn pulse_usb_relay(&self, duration: Duration) -> Result<()> {
// Turn relay on
self.send_usb_relay_command(true)?;
// Wait for duration
sleep(duration).await;
// Turn relay off
self.send_usb_relay_command(false)?;
Ok(())
}
/// Send USB relay command using cached handle
fn send_usb_relay_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"USB relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
// Standard HID relay command format
let cmd = if on {
[0x00, channel + 1, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00]
} else {
[0x00, channel + 1, 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00]
};
let mut guard = self.usb_relay_handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
device
.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("USB relay write failed: {}", e)))?;
Ok(())
}
/// Pulse Serial relay
async fn pulse_serial(&self, duration: Duration) -> Result<()> {
info!(
"Pulse serial relay on {} pin {}",
self.config.device, self.config.pin
);
// Turn relay on
self.send_serial_relay_command(true)?;
// Wait for duration
sleep(duration).await;
// Turn relay off
self.send_serial_relay_command(false)?;
Ok(())
}
/// Send Serial relay command using cached handle
fn send_serial_relay_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"Serial relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"Serial relay channel must be 1-based (>= 1)".to_string(),
));
}
// LCUS-Type Protocol
// Frame: [StopByte(A0), Channel, State, Checksum]
// Checksum = A0 + channel + state
let state = if on { 1 } else { 0 };
let checksum = 0xA0u8.wrapping_add(channel).wrapping_add(state);
// Example for Channel 1:
// ON: A0 01 01 A2
// OFF: A0 01 00 A1
let cmd = [0xA0, channel, state, checksum];
let serial_handle = self
.serial_handle
.lock()
.unwrap()
.as_ref()
.cloned()
.ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?;
let mut port = serial_handle.lock().unwrap();
port.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?;
Ok(())
}
/// Shutdown the executor
pub async fn shutdown(&mut self) -> Result<()> {
if !self.is_initialized() {
return Ok(());
if let Some(backend) = self.backend.as_mut() {
backend.shutdown().await?;
}
match self.config.driver {
AtxDriverType::Gpio => {
// Release GPIO handle
*self.gpio_handle.lock().unwrap() = None;
}
AtxDriverType::UsbRelay => {
// Ensure relay is off before closing handle
let _ = self.send_usb_relay_command(false);
// Release USB relay handle
*self.usb_relay_handle.lock().unwrap() = None;
}
AtxDriverType::Serial => {
// Ensure relay is off before closing handle
let _ = self.send_serial_relay_command(false);
// Release Serial relay handle
*self.serial_handle.lock().unwrap() = None;
}
AtxDriverType::None => {}
}
self.initialized.store(false, Ordering::Relaxed);
debug!("ATX key executor shutdown complete");
Ok(())
}
}
impl Drop for AtxKeyExecutor {
fn drop(&mut self) {
// Ensure GPIO lines are released
*self.gpio_handle.lock().unwrap() = None;
// Ensure USB relay is off and handle released
if self.config.driver == AtxDriverType::UsbRelay && self.is_initialized() {
let _ = self.send_usb_relay_command(false);
}
*self.usb_relay_handle.lock().unwrap() = None;
// Ensure Serial relay is off and handle released
if self.config.driver == AtxDriverType::Serial && self.is_initialized() {
let _ = self.send_serial_relay_command(false);
}
*self.serial_handle.lock().unwrap() = None;
fn build_backend(
config: &AtxKeyConfig,
context: AtxKeyBackendContext,
) -> Option<Box<dyn AtxKeyBackend>> {
match config.driver {
AtxDriverType::Serial => Some(match context {
AtxKeyBackendContext::Standalone => Box::new(SerialRelayBackend::new(config.clone())),
AtxKeyBackendContext::SharedSerial(handle) => Box::new(
SerialRelayBackend::new_with_shared_serial(config.clone(), handle),
),
}),
AtxDriverType::Gpio => build_gpio_backend(config),
AtxDriverType::UsbRelay => build_hidraw_backend(config),
AtxDriverType::None => None,
}
}
#[cfg(unix)]
fn build_gpio_backend(config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::gpio_linux::GpioLinuxBackend::new(
config.clone(),
)))
}
#[cfg(not(unix))]
fn build_gpio_backend(_config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::disabled_key::DisabledAtxKeyBackend::new(
"GPIO ATX backend is only available on Linux",
)))
}
#[cfg(unix)]
fn build_hidraw_backend(config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::hidraw_linux::HidrawLinuxRelayBackend::new(
config.clone(),
)))
}
#[cfg(not(unix))]
fn build_hidraw_backend(_config: &AtxKeyConfig) -> Option<Box<dyn AtxKeyBackend>> {
Some(Box::new(super::disabled_key::DisabledAtxKeyBackend::new(
"USB hidraw relay backend is only available on Linux",
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::atx::ActiveLevel;
#[test]
fn test_executor_creation() {
fn executor_creation() {
let config = AtxKeyConfig::default();
let executor = AtxKeyExecutor::new(config);
assert!(!executor.is_configured());
assert!(!executor.is_initialized());
}
#[test]
fn test_executor_with_gpio_config() {
fn executor_with_gpio_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::Gpio,
device: "/dev/gpiochip0".to_string(),
@@ -445,16 +154,15 @@ mod tests {
};
let executor = AtxKeyExecutor::new(config);
assert!(executor.is_configured());
assert!(!executor.is_initialized());
}
#[test]
fn test_executor_with_usb_relay_config() {
fn executor_with_usb_relay_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
active_level: ActiveLevel::High, // Ignored for USB relay
pin: 1,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let executor = AtxKeyExecutor::new(config);
@@ -462,12 +170,12 @@ mod tests {
}
#[test]
fn test_executor_with_serial_config() {
fn executor_with_serial_config() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 1,
active_level: ActiveLevel::High, // Ignored
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let executor = AtxKeyExecutor::new(config);
@@ -475,51 +183,9 @@ mod tests {
}
#[test]
fn test_timing_constants() {
fn timing_constants() {
assert_eq!(timing::SHORT_PRESS.as_millis(), 500);
assert_eq!(timing::LONG_PRESS.as_millis(), 5000);
assert_eq!(timing::RESET_PRESS.as_millis(), 500);
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_zero() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 0,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_overflow() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 256,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_zero_serial_baud_rate() {
let config = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 1,
active_level: ActiveLevel::High,
baud_rate: 0,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
}

106
src/atx/gpio_linux.rs Normal file
View File

@@ -0,0 +1,106 @@
use async_trait::async_trait;
use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::traits::AtxKeyBackend;
use super::types::{ActiveLevel, AtxKeyConfig};
use crate::error::{AppError, Result};
pub struct GpioLinuxBackend {
config: AtxKeyConfig,
handle: Mutex<Option<LineHandle>>,
initialized: AtomicBool,
}
impl GpioLinuxBackend {
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
}
#[async_trait]
impl AtxKeyBackend for GpioLinuxBackend {
async fn init(&mut self) -> Result<()> {
info!(
"Initializing GPIO ATX backend on {} pin {}",
self.config.device, self.config.pin
);
let mut chip = Chip::new(&self.config.device)
.map_err(|e| AppError::Internal(format!("GPIO chip open failed: {}", e)))?;
let line = chip.get_line(self.config.pin).map_err(|e| {
AppError::Internal(format!("GPIO line {} failed: {}", self.config.pin, e))
})?;
let initial_value = match self.config.active_level {
ActiveLevel::High => 0,
ActiveLevel::Low => 1,
};
let handle = line
.request(LineRequestFlags::OUTPUT, initial_value, "one-kvm-atx")
.map_err(|e| AppError::Internal(format!("GPIO request failed: {}", e)))?;
*self.handle.lock().unwrap() = Some(handle);
self.initialized.store(true, Ordering::Relaxed);
debug!("GPIO pin {} configured successfully", self.config.pin);
Ok(())
}
async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_initialized() {
return Err(AppError::Internal("GPIO not initialized".to_string()));
}
let (active, inactive) = match self.config.active_level {
ActiveLevel::High => (1u8, 0u8),
ActiveLevel::Low => (0u8, 1u8),
};
{
let guard = self.handle.lock().unwrap();
let handle = guard
.as_ref()
.ok_or_else(|| AppError::Internal("GPIO not initialized".to_string()))?;
handle
.set_value(active)
.map_err(|e| AppError::Internal(format!("GPIO set failed: {}", e)))?;
}
sleep(duration).await;
{
let guard = self.handle.lock().unwrap();
if let Some(handle) = guard.as_ref() {
handle.set_value(inactive).ok();
}
}
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
}
impl Drop for GpioLinuxBackend {
fn drop(&mut self) {
*self.handle.lock().unwrap() = None;
}
}

190
src/atx/hidraw_linux.rs Normal file
View File

@@ -0,0 +1,190 @@
use async_trait::async_trait;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::os::fd::AsRawFd;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::traits::AtxKeyBackend;
use super::types::AtxKeyConfig;
use crate::error::{AppError, Result};
const USB_RELAY_MAX_CHANNEL: u8 = 8;
const USB_RELAY_REPORT_LEN: usize = 9;
const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806;
pub struct HidrawLinuxRelayBackend {
config: AtxKeyConfig,
handle: Mutex<Option<File>>,
initialized: AtomicBool,
}
impl HidrawLinuxRelayBackend {
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
fn validate_config(&self) -> Result<()> {
if self.config.pin == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
Ok(())
}
fn send_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"USB relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if channel > USB_RELAY_MAX_CHANNEL {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
let cmd = Self::build_command(channel, on);
let mut guard = self.handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
if let Err(feature_err) = Self::send_feature_report(device, &cmd) {
debug!(
"USB relay feature report failed ({}), falling back to hidraw write",
feature_err
);
device.write_all(&cmd).map_err(|write_err| {
AppError::Internal(format!(
"USB relay feature report failed: {}; raw write failed: {}",
feature_err, write_err
))
})?;
device
.flush()
.map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?;
}
Ok(())
}
pub fn build_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] {
let mut cmd = [0x00; USB_RELAY_REPORT_LEN];
cmd[1] = if on { 0xFF } else { 0xFD };
cmd[2] = channel;
cmd
}
fn send_feature_report(
device: &File,
report: &[u8; USB_RELAY_REPORT_LEN],
) -> std::io::Result<()> {
let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) };
if rc < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
}
#[async_trait]
impl AtxKeyBackend for HidrawLinuxRelayBackend {
async fn init(&mut self) -> Result<()> {
self.validate_config()?;
info!(
"Initializing USB relay ATX backend on {} channel {}",
self.config.device, self.config.pin
);
let device = OpenOptions::new()
.read(true)
.write(true)
.open(&self.config.device)
.map_err(|e| AppError::Internal(format!("USB relay device open failed: {}", e)))?;
*self.handle.lock().unwrap() = Some(device);
self.send_command(false)?;
self.initialized.store(true, Ordering::Relaxed);
debug!(
"USB relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_initialized() {
return Err(AppError::Internal("USB relay not initialized".to_string()));
}
self.send_command(true)?;
sleep(duration).await;
self.send_command(false)?;
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
if self.is_initialized() {
let _ = self.send_command(false);
}
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
}
impl Drop for HidrawLinuxRelayBackend {
fn drop(&mut self) {
if self.is_initialized() {
let _ = self.send_command(false);
}
*self.handle.lock().unwrap() = None;
}
}
#[cfg(test)]
mod tests {
use super::HidrawLinuxRelayBackend;
#[test]
fn usb_relay_command_format() {
assert_eq!(
HidrawLinuxRelayBackend::build_command(1, true),
[0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
HidrawLinuxRelayBackend::build_command(1, false),
[0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
}
}

View File

@@ -10,9 +10,6 @@ use tracing::{debug, info};
use super::types::{AtxLedConfig, PowerStatus};
use crate::error::{AppError, Result};
/// LED sensor for reading power status
///
/// Uses GPIO to read the power LED state and determine if the system is on or off.
pub struct LedSensor {
config: AtxLedConfig,
handle: Mutex<Option<LineHandle>>,
@@ -20,7 +17,6 @@ pub struct LedSensor {
}
impl LedSensor {
/// Create a new LED sensor with the given configuration
pub fn new(config: AtxLedConfig) -> Self {
Self {
config,
@@ -29,17 +25,6 @@ impl LedSensor {
}
}
/// Check if the sensor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
}
/// Check if the sensor is initialized
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
/// Initialize the LED sensor
pub async fn init(&mut self) -> Result<()> {
if !self.config.is_configured() {
debug!("LED sensor not configured, skipping init");
@@ -72,9 +57,8 @@ impl LedSensor {
Ok(())
}
/// Read the current power status
pub async fn read(&self) -> Result<PowerStatus> {
if !self.is_configured() || !self.is_initialized() {
if !self.config.is_configured() || !self.initialized.load(Ordering::Relaxed) {
return Ok(PowerStatus::Unknown);
}
@@ -85,11 +69,10 @@ impl LedSensor {
.get_value()
.map_err(|e| AppError::Internal(format!("LED read failed: {}", e)))?;
// Apply inversion if configured
let is_on = if self.config.inverted {
value == 0 // Active low: 0 means on
value == 0
} else {
value == 1 // Active high: 1 means on
value == 1
};
Ok(if is_on {
@@ -102,7 +85,6 @@ impl LedSensor {
}
}
/// Shutdown the LED sensor
pub async fn shutdown(&mut self) -> Result<()> {
*self.handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
@@ -125,8 +107,8 @@ mod tests {
fn test_led_sensor_creation() {
let config = AtxLedConfig::default();
let sensor = LedSensor::new(config);
assert!(!sensor.is_configured());
assert!(!sensor.is_initialized());
assert!(!sensor.config.is_configured());
assert!(!sensor.initialized.load(Ordering::Relaxed));
}
#[test]
@@ -138,8 +120,8 @@ mod tests {
inverted: false,
};
let sensor = LedSensor::new(config);
assert!(sensor.is_configured());
assert!(!sensor.is_initialized());
assert!(sensor.config.is_configured());
assert!(!sensor.initialized.load(Ordering::Relaxed));
}
#[test]
@@ -151,7 +133,6 @@ mod tests {
inverted: true,
};
let sensor = LedSensor::new(config);
assert!(sensor.is_configured());
assert!(sensor.config.inverted);
}
}

View File

@@ -2,52 +2,22 @@
//!
//! Provides ATX power management functionality for IP-KVM.
//! Supports flexible hardware binding with independent configuration for each action.
//!
//! # Features
//!
//! - Power button control (short press for on/graceful shutdown, long press for force off)
//! - Reset button control
//! - Power status monitoring via LED sensing (GPIO only)
//! - Independent hardware binding for each action (GPIO or USB relay)
//! - Hot-reload configuration support
//!
//! # Hardware Support
//!
//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control
//! - **USB Relay**: Uses HID USB relay modules for isolated switching
//!
//! # Example
//!
//! ```ignore
//! use one_kvm::atx::{AtxController, AtxControllerConfig, AtxKeyConfig, AtxDriverType, ActiveLevel};
//!
//! let config = AtxControllerConfig {
//! enabled: true,
//! power: AtxKeyConfig {
//! driver: AtxDriverType::Gpio,
//! device: "/dev/gpiochip0".to_string(),
//! pin: 5,
//! active_level: ActiveLevel::High,
//! baud_rate: 9600,
//! },
//! reset: AtxKeyConfig {
//! driver: AtxDriverType::UsbRelay,
//! device: "/dev/hidraw0".to_string(),
//! pin: 0,
//! active_level: ActiveLevel::High,
//! baud_rate: 9600,
//! },
//! led: Default::default(),
//! };
//!
//! let controller = AtxController::new(config);
//! controller.init().await?;
//! controller.power_short().await?; // Turn on or graceful shutdown
//! ```
mod controller;
#[cfg(not(unix))]
mod disabled_key;
mod executor;
#[cfg(unix)]
mod gpio_linux;
#[cfg(unix)]
mod hidraw_linux;
#[cfg(unix)]
mod led;
#[cfg(not(unix))]
#[path = "disabled_led.rs"]
mod led;
mod serial_relay;
mod traits;
mod types;
mod wol;
@@ -57,22 +27,45 @@ pub use types::{
ActiveLevel, AtxAction, AtxDevices, AtxDriverType, AtxKeyConfig, AtxLedConfig, AtxPowerRequest,
AtxState, PowerStatus,
};
pub use wol::send_wol;
pub use wol::{list_wol_history, record_wol_history, send_wol};
#[cfg(any(unix, test))]
fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool {
let upper = uevent.to_ascii_uppercase();
upper.contains("000016C0:000005DF")
|| upper.contains("00005131:00002007")
|| upper.contains("16C0:05DF")
|| upper.contains("5131:2007")
|| upper.contains("PRODUCT=16C0/5DF")
|| upper.contains("PRODUCT=5131/2007")
|| upper.contains("USBRELAY")
|| upper.contains("USB RELAY")
}
#[cfg(unix)]
fn is_usb_relay_hidraw(name: &str) -> bool {
let uevent_path = format!("/sys/class/hidraw/{}/device/uevent", name);
std::fs::read_to_string(uevent_path)
.map(|uevent| hidraw_uevent_is_usb_relay(&uevent))
.unwrap_or(false)
}
/// Discover available ATX devices on the system
///
/// Scans for GPIO chips and USB HID relay devices in a single pass.
/// Scans for GPIO chips, LCUS USB HID relay devices, and serial relay ports.
pub fn discover_devices() -> AtxDevices {
let mut devices = AtxDevices::default();
// Single pass through /dev directory
devices.serial_ports = crate::utils::list_serial_ports();
#[cfg(unix)]
if let Ok(entries) = std::fs::read_dir("/dev") {
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with("gpiochip") {
devices.gpio_chips.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("hidraw") {
} else if name_str.starts_with("hidraw") && is_usb_relay_hidraw(&name_str) {
devices.usb_relays.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("ttyUSB") || name_str.starts_with("ttyACM") {
devices.serial_ports.push(format!("/dev/{}", name_str));
@@ -83,6 +76,7 @@ pub fn discover_devices() -> AtxDevices {
devices.gpio_chips.sort();
devices.usb_relays.sort();
devices.serial_ports.sort();
devices.serial_ports.dedup();
devices
}
@@ -96,9 +90,30 @@ mod tests {
let _devices = discover_devices();
}
#[test]
fn test_hidraw_uevent_detects_usb_relay_id() {
assert!(hidraw_uevent_is_usb_relay(
"HID_ID=0003:000016C0:000005DF\nHID_NAME=www.dcttech.com USBRelay2\n"
));
}
#[test]
fn test_hidraw_uevent_detects_5131_usb_relay_id() {
assert!(hidraw_uevent_is_usb_relay(
"HID_ID=0003:00005131:00002007\n"
));
assert!(hidraw_uevent_is_usb_relay("PRODUCT=5131/2007/100"));
}
#[test]
fn test_hidraw_uevent_rejects_unrelated_hid() {
assert!(!hidraw_uevent_is_usb_relay(
"HID_ID=0003:0000046D:0000C534\nHID_NAME=Logitech USB Receiver\n"
));
}
#[test]
fn test_module_exports() {
// Verify all public exports are accessible
let _: AtxDriverType = AtxDriverType::None;
let _: ActiveLevel = ActiveLevel::High;
let _: AtxKeyConfig = AtxKeyConfig::default();

141
src/atx/serial_relay.rs Normal file
View File

@@ -0,0 +1,141 @@
use async_trait::async_trait;
use std::io::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
use super::traits::{validate_serial_config, AtxKeyBackend, SharedSerialHandle};
use super::types::AtxKeyConfig;
use crate::error::{AppError, Result};
pub struct SerialRelayBackend {
config: AtxKeyConfig,
serial_handle: Mutex<Option<SharedSerialHandle>>,
initialized: AtomicBool,
}
impl SerialRelayBackend {
pub fn new(config: AtxKeyConfig) -> Self {
Self {
config,
serial_handle: Mutex::new(None),
initialized: AtomicBool::new(false),
}
}
pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self {
Self {
config,
serial_handle: Mutex::new(Some(serial_handle)),
initialized: AtomicBool::new(false),
}
}
pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result<SharedSerialHandle> {
let port = serialport::new(device, baud_rate)
.timeout(Duration::from_millis(100))
.open()
.map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?;
Ok(Arc::new(Mutex::new(port)))
}
fn send_command(&self, on: bool) -> Result<()> {
let channel = u8::try_from(self.config.pin).map_err(|_| {
AppError::Config(format!(
"Serial relay channel {} exceeds max {}",
self.config.pin,
u8::MAX
))
})?;
let state = if on { 1 } else { 0 };
let checksum = 0xA0u8.wrapping_add(channel).wrapping_add(state);
let cmd = [0xA0, channel, state, checksum];
let serial_handle = self
.serial_handle
.lock()
.unwrap()
.as_ref()
.cloned()
.ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?;
let mut port = serial_handle.lock().unwrap();
port.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?;
port.flush()
.map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?;
Ok(())
}
}
#[async_trait]
impl AtxKeyBackend for SerialRelayBackend {
async fn init(&mut self) -> Result<()> {
validate_serial_config(&self.config)?;
info!(
"Initializing Serial relay ATX backend on {} channel {}",
self.config.device, self.config.pin
);
let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned();
if existing_handle.is_none() {
let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?;
*self.serial_handle.lock().unwrap() = Some(shared);
}
self.send_command(false)?;
self.initialized.store(true, Ordering::Relaxed);
debug!(
"Serial relay channel {} configured successfully",
self.config.pin
);
Ok(())
}
async fn pulse(&self, duration: Duration) -> Result<()> {
if !self.is_initialized() {
return Err(AppError::Internal(
"Serial relay not initialized".to_string(),
));
}
info!(
"Pulse serial relay on {} pin {}",
self.config.device, self.config.pin
);
self.send_command(true)?;
sleep(duration).await;
self.send_command(false)?;
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
if !self.is_initialized() {
return Ok(());
}
let _ = self.send_command(false);
*self.serial_handle.lock().unwrap() = None;
self.initialized.store(false, Ordering::Relaxed);
Ok(())
}
fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::Relaxed)
}
}
impl Drop for SerialRelayBackend {
fn drop(&mut self) {
if self.is_initialized() {
let _ = self.send_command(false);
}
*self.serial_handle.lock().unwrap() = None;
}
}

51
src/atx/traits.rs Normal file
View File

@@ -0,0 +1,51 @@
use async_trait::async_trait;
use serialport::SerialPort;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use super::types::AtxKeyConfig;
use crate::error::Result;
pub type SharedSerialHandle = Arc<Mutex<Box<dyn SerialPort>>>;
#[async_trait]
pub trait AtxKeyBackend: Send + Sync {
async fn init(&mut self) -> Result<()>;
async fn pulse(&self, duration: Duration) -> Result<()>;
async fn shutdown(&mut self) -> Result<()>;
fn is_initialized(&self) -> bool;
}
#[derive(Debug, Clone)]
pub enum AtxKeyBackendContext {
Standalone,
SharedSerial(SharedSerialHandle),
}
pub fn validate_serial_config(config: &AtxKeyConfig) -> Result<()> {
if config.device.trim().is_empty() {
return Err(crate::error::AppError::Config(
"Serial ATX device cannot be empty".to_string(),
));
}
if config.pin == 0 {
return Err(crate::error::AppError::Config(
"Serial ATX channel must be 1-based (>= 1)".to_string(),
));
}
if config.pin > u8::MAX as u32 {
return Err(crate::error::AppError::Config(format!(
"Serial ATX channel must be <= {}",
u8::MAX
)));
}
if config.baud_rate == 0 {
return Err(crate::error::AppError::Config(
"Serial ATX baud_rate must be greater than 0".to_string(),
));
}
Ok(())
}

View File

@@ -6,67 +6,43 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Power status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PowerStatus {
/// Power is on
On,
/// Power is off
Off,
/// Power status unknown (no LED connected)
#[default]
Unknown,
}
/// Driver type for ATX key operations
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AtxDriverType {
/// GPIO control via Linux character device
Gpio,
/// USB HID relay module
UsbRelay,
/// Serial/COM port relay (taobao LCUS type)
Serial,
/// Disabled / Not configured
#[default]
None,
}
/// Active level for GPIO pins
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ActiveLevel {
/// Active high (default for most cases)
#[default]
High,
/// Active low (inverted)
Low,
}
/// Configuration for a single ATX key (power or reset)
/// This is the "four-tuple" configuration: (driver, device, pin/channel, level)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct AtxKeyConfig {
/// Driver type (GPIO or USB Relay)
pub driver: AtxDriverType,
/// Device path:
/// - For GPIO: /dev/gpiochipX
/// - For USB Relay: /dev/hidrawX
pub device: String,
/// Pin or channel number:
/// - For GPIO: GPIO pin number
/// - For USB Relay: relay channel (0-based)
/// - For Serial Relay (LCUS): relay channel (1-based)
pub pin: u32,
/// Active level (only applicable to GPIO, ignored for USB Relay)
pub active_level: ActiveLevel,
/// Baud rate for serial relay (start with 9600)
pub baud_rate: u32,
}
@@ -83,77 +59,54 @@ impl Default for AtxKeyConfig {
}
impl AtxKeyConfig {
/// Check if this key is configured
pub fn is_configured(&self) -> bool {
self.driver != AtxDriverType::None && !self.device.is_empty()
}
}
/// LED sensing configuration (optional)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct AtxLedConfig {
/// Whether LED sensing is enabled
pub enabled: bool,
/// GPIO chip for LED sensing
pub gpio_chip: String,
/// GPIO pin for LED input
pub gpio_pin: u32,
/// Whether LED is active low (inverted logic)
pub inverted: bool,
}
impl AtxLedConfig {
/// Check if LED sensing is configured
pub fn is_configured(&self) -> bool {
self.enabled && !self.gpio_chip.is_empty()
}
}
/// ATX state information
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AtxState {
/// Whether ATX feature is available/enabled
pub available: bool,
/// Whether power button is configured
pub power_configured: bool,
/// Whether reset button is configured
pub reset_configured: bool,
/// Current power status
pub power_status: PowerStatus,
/// Whether power LED sensing is supported
pub led_supported: bool,
}
/// ATX power action request
#[derive(Debug, Clone, Deserialize)]
pub struct AtxPowerRequest {
/// Action to perform: "short", "long", "reset"
pub action: AtxAction,
}
/// ATX power action
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AtxAction {
/// Short press power button (turn on or graceful shutdown)
Short,
/// Long press power button (force power off)
Long,
/// Press reset button
Reset,
}
/// Available ATX devices for discovery
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDevices {
/// Available GPIO chips (/dev/gpiochip*)
pub gpio_chips: Vec<String>,
/// Available USB HID relay devices (/dev/hidraw*)
pub usb_relays: Vec<String>,
/// Available Serial ports (/dev/ttyUSB*)
pub serial_ports: Vec<String>,
}
@@ -201,13 +154,13 @@ mod tests {
assert!(!config.is_configured());
config.driver = AtxDriverType::Gpio;
assert!(!config.is_configured()); // device still empty
assert!(!config.is_configured());
config.device = "/dev/gpiochip0".to_string();
assert!(config.is_configured());
config.driver = AtxDriverType::None;
assert!(!config.is_configured()); // driver is None
assert!(!config.is_configured());
}
#[test]
@@ -224,7 +177,7 @@ mod tests {
assert!(!config.is_configured());
config.enabled = true;
assert!(!config.is_configured()); // gpio_chip still empty
assert!(!config.is_configured());
config.gpio_chip = "/dev/gpiochip0".to_string();
assert!(config.is_configured());

View File

@@ -3,18 +3,14 @@
//! Sends magic packets to wake up remote machines.
use std::net::{SocketAddr, UdpSocket};
use tracing::{debug, info};
use tracing::info;
use crate::error::{AppError, Result};
/// WOL magic packet structure:
/// - 6 bytes of 0xFF
/// - 16 repetitions of the target MAC address (6 bytes each)
/// Total: 6 + 16 * 6 = 102 bytes
const WOL_HISTORY_MAX_ENTRIES: i64 = 50;
const MAGIC_PACKET_SIZE: usize = 102;
/// Parse MAC address string into bytes
/// Supports formats: "AA:BB:CC:DD:EE:FF" or "AA-BB-CC-DD-EE-FF"
fn parse_mac_address(mac: &str) -> Result<[u8; 6]> {
let mac = mac.trim().to_uppercase();
let parts: Vec<&str> = if mac.contains(':') {
@@ -44,16 +40,13 @@ fn parse_mac_address(mac: &str) -> Result<[u8; 6]> {
Ok(bytes)
}
/// Build WOL magic packet
fn build_magic_packet(mac: &[u8; 6]) -> [u8; MAGIC_PACKET_SIZE] {
let mut packet = [0u8; MAGIC_PACKET_SIZE];
// First 6 bytes are 0xFF
for byte in packet.iter_mut().take(6) {
*byte = 0xFF;
}
// Next 96 bytes are 16 repetitions of the MAC address
for i in 0..16 {
let offset = 6 + i * 6;
packet[offset..offset + 6].copy_from_slice(mac);
@@ -73,16 +66,13 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
info!("Sending WOL packet to {} via {:?}", mac_address, interface);
// Create UDP socket
let socket = UdpSocket::bind("0.0.0.0:0")
.map_err(|e| AppError::Internal(format!("Failed to create UDP socket: {}", e)))?;
// Enable broadcast
socket
.set_broadcast(true)
.map_err(|e| AppError::Internal(format!("Failed to enable broadcast: {}", e)))?;
// Bind to specific interface if specified
#[cfg(target_os = "linux")]
if let Some(iface) = interface {
if !iface.is_empty() {
@@ -90,8 +80,7 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
let fd = socket.as_raw_fd();
let iface_bytes = iface.as_bytes();
// SO_BINDTODEVICE requires interface name as null-terminated string
let mut iface_buf = [0u8; 16]; // IFNAMSIZ is typically 16
let mut iface_buf = [0u8; 16];
let len = iface_bytes.len().min(15);
iface_buf[..len].copy_from_slice(&iface_bytes[..len]);
@@ -112,18 +101,16 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
iface, err
)));
}
debug!("Bound to interface: {}", iface);
tracing::debug!("Bound to interface: {}", iface);
}
}
// Send to broadcast address on port 9 (discard protocol, commonly used for WOL)
let broadcast_addr: SocketAddr = "255.255.255.255:9".parse().unwrap();
socket
.send_to(&packet, broadcast_addr)
.map_err(|e| AppError::Internal(format!("Failed to send WOL packet: {}", e)))?;
// Also try sending to port 7 (echo protocol, alternative WOL port)
let broadcast_addr_7: SocketAddr = "255.255.255.255:7".parse().unwrap();
let _ = socket.send_to(&packet, broadcast_addr_7);
@@ -131,6 +118,55 @@ pub fn send_wol(mac_address: &str, interface: Option<&str>) -> Result<()> {
Ok(())
}
pub async fn record_wol_history(pool: &sqlx::Pool<sqlx::Sqlite>, mac_address: &str) -> Result<()> {
sqlx::query(
r#"
INSERT INTO wol_history (mac_address, updated_at)
VALUES (?1, CAST(strftime('%s', 'now') AS INTEGER))
ON CONFLICT(mac_address) DO UPDATE SET
updated_at = excluded.updated_at
"#,
)
.bind(mac_address)
.execute(pool)
.await?;
sqlx::query(
r#"
DELETE FROM wol_history
WHERE mac_address NOT IN (
SELECT mac_address FROM wol_history
ORDER BY updated_at DESC
LIMIT ?1
)
"#,
)
.bind(WOL_HISTORY_MAX_ENTRIES)
.execute(pool)
.await?;
Ok(())
}
pub async fn list_wol_history(
pool: &sqlx::Pool<sqlx::Sqlite>,
limit: usize,
) -> Result<Vec<(String, i64)>> {
let rows = sqlx::query_as(
r#"
SELECT mac_address, updated_at
FROM wol_history
ORDER BY updated_at DESC
LIMIT ?1
"#,
)
.bind(limit as i64)
.fetch_all(pool)
.await?;
Ok(rows)
}
#[cfg(test)]
mod tests {
use super::*;
@@ -159,12 +195,10 @@ mod tests {
let mac = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF];
let packet = build_magic_packet(&mac);
// Check header (6 bytes of 0xFF)
for byte in packet.iter().take(6) {
assert_eq!(*byte, 0xFF);
}
// Check MAC repetitions
for i in 0..16 {
let offset = 6 + i * 6;
assert_eq!(&packet[offset..offset + 6], &mac);

View File

@@ -1,390 +1,9 @@
//! ALSA audio capture implementation
#[cfg(unix)]
#[path = "capture_linux.rs"]
mod imp;
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, info};
#[cfg(windows)]
#[path = "capture_windows.rs"]
mod imp;
use super::device::AudioDeviceInfo;
use crate::error::{AppError, Result};
use crate::utils::LogThrottler;
use crate::{error_throttled, warn_throttled};
/// Audio capture configuration
#[derive(Debug, Clone)]
pub struct AudioConfig {
/// ALSA device name (e.g., "hw:0,0" or "default")
pub device_name: String,
/// Sample rate in Hz
pub sample_rate: u32,
/// Number of channels (1 = mono, 2 = stereo)
pub channels: u32,
/// Samples per frame (for Opus, typically 480 for 10ms at 48kHz)
pub frame_size: u32,
/// Buffer size in frames
pub buffer_frames: u32,
/// Period size in frames
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: "default".to_string(),
sample_rate: 48000,
channels: 2,
frame_size: 960, // 20ms at 48kHz (good for Opus)
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
/// Create config for a specific device
pub fn for_device(device: &AudioDeviceInfo) -> Self {
let sample_rate = if device.sample_rates.contains(&48000) {
48000
} else {
*device.sample_rates.first().unwrap_or(&48000)
};
let channels = if device.channels.contains(&2) {
2
} else {
*device.channels.first().unwrap_or(&2)
};
Self {
device_name: device.name.clone(),
sample_rate,
channels,
frame_size: sample_rate / 50, // 20ms
..Default::default()
}
}
/// Bytes per sample (16-bit signed)
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
/// Bytes per frame
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
/// Audio frame data
#[derive(Debug, Clone)]
pub struct AudioFrame {
/// Raw PCM data (S16LE interleaved)
pub data: Bytes,
/// Sample rate
pub sample_rate: u32,
/// Number of channels
pub channels: u32,
/// Number of samples per channel
pub samples: u32,
/// Frame sequence number
pub sequence: u64,
/// Capture timestamp
pub timestamp: Instant,
}
impl AudioFrame {
/// One capture block: `sample_rate` must be the **hardware** rate (e.g. ALSA `actual_rate`).
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / bps,
data,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
/// Audio capture state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
/// ALSA audio capturer
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
/// Log throttler to prevent log flooding
log_throttler: LogThrottler,
}
impl AudioCapturer {
/// Create a new audio capturer
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency
Self {
config,
state: Arc::new(state_tx),
state_rx,
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
log_throttler: LogThrottler::with_secs(5),
}
}
/// Get current state
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
/// Subscribe to audio frames
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
/// Start capturing
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
info!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
capture_loop(config, state, frame_tx, stop_flag, sequence, log_throttler);
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
/// Stop capturing
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
/// Check if running
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
/// Main capture loop
fn capture_loop(
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
log_throttler: LogThrottler,
) {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
// Open ALSA device
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
config.device_name, e
))
})?;
// Configure hardware parameters
{
let hwp = HwParams::any(&pcm)
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
hwp.set_channels(config.channels)
.map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?;
hwp.set_rate(config.sample_rate, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?;
hwp.set_format(Format::s16())
.map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?;
hwp.set_access(Access::RWInterleaved)
.map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?;
hwp.set_buffer_size_near(config.buffer_frames as Frames)
.map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?;
hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?;
pcm.hw_params(&hwp)
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
}
// Get actual configuration
let actual_rate = pcm
.hw_params_current()
.map(|h| h.get_rate().unwrap_or(config.sample_rate))
.unwrap_or(config.sample_rate);
if actual_rate != config.sample_rate {
info!(
"ALSA sample rate differs from requested ({}Hz vs {}Hz); streamer will resample to 48000Hz for Opus",
actual_rate, config.sample_rate
);
} else {
info!(
"Audio capture configured: {}Hz {}ch (requested {}Hz)",
actual_rate, config.channels, config.sample_rate
);
}
// Prepare for capture
pcm.prepare()
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
let _ = state.send(CaptureState::Running);
// Sized from actual period — `readi` may return up to ~one period of frames per call.
let period_frames = pcm
.hw_params_current()
.ok()
.and_then(|h| h.get_period_size().ok())
.map(|f| f as usize)
.unwrap_or(1024)
.max(256);
let buf_frames = period_frames.saturating_mul(4).max(2048);
let bytes_per_frame = (config.channels as usize) * 2;
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
// Capture loop
while !stop_flag.load(Ordering::Relaxed) {
// Check PCM state
match pcm.state() {
State::XRun => {
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
let _ = pcm.prepare();
continue;
}
State::Suspended => {
warn_throttled!(
log_throttler,
"suspended",
"Audio device suspended, recovering"
);
let _ = pcm.resume();
continue;
}
_ => {}
}
// Get IO handle and read audio data directly as bytes
// Note: Use io() instead of io_checked() because USB audio devices
// typically don't support mmap, which io_checked() requires
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
Ok(frames_read) => {
if frames_read == 0 {
continue;
}
// Calculate actual byte count
let byte_count = frames_read * config.channels as usize * 2;
// Directly use the buffer slice (already in correct byte format)
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(&buffer[..byte_count]),
config.channels,
actual_rate,
seq,
);
// Send to subscribers
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
}
Err(e) => {
// Check for buffer overrun (EPIPE = 32 on Linux)
let desc = e.to_string();
if desc.contains("EPIPE") || desc.contains("Broken pipe") {
// Buffer overrun
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
let _ = pcm.prepare();
} else if desc.contains("No such device") || desc.contains("ENODEV") {
// Device disconnected - use longer throttle for this
error_throttled!(log_throttler, "no_device", "Audio read error: {}", e);
} else {
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
}
}
}
}
info!("Audio capture stopped");
Ok(())
}
pub use imp::*;

334
src/audio/capture_linux.rs Normal file
View File

@@ -0,0 +1,334 @@
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, info};
use crate::audio::device::AudioDeviceInfo;
use crate::error::{AppError, Result};
use crate::utils::LogThrottler;
use crate::{error_throttled, warn_throttled};
#[derive(Debug, Clone)]
pub struct AudioConfig {
pub device_name: String,
pub sample_rate: u32,
pub channels: u32,
pub frame_size: u32,
pub buffer_frames: u32,
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
pub fn for_device(device: &AudioDeviceInfo) -> Self {
Self {
device_name: device.name.clone(),
..Default::default()
}
}
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
#[derive(Debug, Clone)]
pub struct AudioFrame {
pub data: Bytes,
pub sample_rate: u32,
pub channels: u32,
pub samples: u32,
pub sequence: u64,
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / bps,
data,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
log_throttler: LogThrottler,
}
impl AudioCapturer {
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
state: Arc::new(state_tx),
state_rx,
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
log_throttler: LogThrottler::with_secs(5),
}
}
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
debug!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
config.device_name, e
))
})?;
{
let hwp = HwParams::any(&pcm)
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
hwp.set_channels(config.channels)
.map_err(|e| AppError::AudioError(format!("Failed to set channels: {}", e)))?;
hwp.set_rate(config.sample_rate, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set sample rate: {}", e)))?;
hwp.set_format(Format::s16())
.map_err(|e| AppError::AudioError(format!("Failed to set format: {}", e)))?;
hwp.set_access(Access::RWInterleaved)
.map_err(|e| AppError::AudioError(format!("Failed to set access: {}", e)))?;
hwp.set_buffer_size_near(config.buffer_frames as Frames)
.map_err(|e| AppError::AudioError(format!("Failed to set buffer size: {}", e)))?;
hwp.set_period_size_near(config.period_frames as Frames, ValueOr::Nearest)
.map_err(|e| AppError::AudioError(format!("Failed to set period size: {}", e)))?;
pcm.hw_params(&hwp)
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
}
let hw_now = pcm.hw_params_current().map_err(|e| {
AppError::AudioError(format!("Failed to read hw_params after apply: {}", e))
})?;
let actual_rate = hw_now
.get_rate()
.map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?;
let actual_ch = hw_now
.get_channels()
.map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?;
if actual_rate != 48_000 {
return Err(AppError::AudioError(format!(
"Audio capture requires 48000 Hz; device is {} Hz",
actual_rate
)));
}
if actual_ch != 2 {
return Err(AppError::AudioError(format!(
"Audio capture requires 2 channels (stereo); device has {}",
actual_ch
)));
}
debug!("Audio capture: 48000 Hz, 2 ch");
pcm.prepare()
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
let _ = state.send(CaptureState::Running);
let period_frames = pcm
.hw_params_current()
.ok()
.and_then(|h| h.get_period_size().ok())
.map(|f| f as usize)
.unwrap_or(1024)
.max(256);
let buf_frames = period_frames.saturating_mul(4).max(2048);
let bytes_per_frame = (config.channels as usize) * 2;
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
while !stop_flag.load(Ordering::Relaxed) {
match pcm.state() {
State::XRun => {
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
let _ = pcm.prepare();
continue;
}
State::Suspended => {
warn_throttled!(
log_throttler,
"suspended",
"Audio device suspended, recovering"
);
let _ = pcm.resume();
continue;
}
_ => {}
}
// io_bytes: USB capture often lacks mmap (io_checked requires it).
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
Ok(frames_read) => {
if frames_read == 0 {
continue;
}
let byte_count = frames_read * config.channels as usize * 2;
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(&buffer[..byte_count]),
config.channels,
48_000,
seq,
);
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
}
Err(e) => {
let desc = e.to_string();
if is_device_lost_error(&desc) {
return Err(AppError::AudioError(format!(
"Audio device lost while reading {}: {}",
config.device_name, e
)));
} else if desc.contains("EPIPE") || desc.contains("Broken pipe") {
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
let _ = pcm.prepare();
} else {
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
}
}
}
}
info!("Audio capture stopped");
Ok(())
}
fn is_device_lost_error(desc: &str) -> bool {
desc.contains("No such device")
|| desc.contains("ENODEV")
|| desc.contains("ENXIO")
|| desc.contains("ESHUTDOWN")
}

View File

@@ -0,0 +1,516 @@
use bytes::Bytes;
use cpal::traits::{DeviceTrait, StreamTrait};
use cpal::{BufferSize, SampleFormat, StreamConfig};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, watch, Mutex};
use tracing::{debug, info};
use crate::audio::device::{find_wasapi_device, AudioDeviceInfo};
use crate::error::{AppError, Result};
use crate::error_throttled;
use crate::utils::LogThrottler;
#[derive(Debug, Clone)]
pub struct AudioConfig {
pub device_name: String,
pub sample_rate: u32,
pub channels: u32,
pub frame_size: u32,
pub buffer_frames: u32,
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
}
}
impl AudioConfig {
pub fn for_device(device: &AudioDeviceInfo) -> Self {
Self {
device_name: device.name.clone(),
..Default::default()
}
}
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
#[derive(Debug, Clone)]
pub struct AudioFrame {
pub data: Bytes,
pub sample_rate: u32,
pub channels: u32,
pub samples: u32,
pub sequence: u64,
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / bps,
data,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
Running,
Error,
}
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
state_rx: watch::Receiver<CaptureState>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
log_throttler: LogThrottler,
}
impl AudioCapturer {
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
state: Arc::new(state_tx),
state_rx,
frame_tx,
stop_flag: Arc::new(AtomicBool::new(false)),
sequence: Arc::new(AtomicU64::new(0)),
capture_handle: Mutex::new(None),
log_throttler: LogThrottler::with_secs(5),
}
}
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
debug!(
"Starting WASAPI audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
self.stop_flag.store(false, Ordering::SeqCst);
let config = self.config.clone();
let state = self.state.clone();
let frame_tx = self.frame_tx.clone();
let stop_flag = self.stop_flag.clone();
let sequence = self.sequence.clone();
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(
log_throttler,
"capture_error",
"WASAPI audio capture error: {}",
e
);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
info!("Stopping WASAPI audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
if let Some(handle) = self.capture_handle.lock().await.take() {
let _ = handle.await;
}
let _ = self.state.send(CaptureState::Stopped);
Ok(())
}
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
frame_tx: &broadcast::Sender<AudioFrame>,
stop_flag: &AtomicBool,
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
let device = find_wasapi_device(&config.device_name)?;
let device_label = device_label(&device);
let supported = select_input_config(&device, config)?;
let sample_format = supported.sample_format();
let input_channels = supported.channels() as u32;
let input_rate = supported.sample_rate();
let stream_config = StreamConfig {
channels: supported.channels(),
sample_rate: supported.sample_rate(),
buffer_size: BufferSize::Fixed(config.period_frames.max(128)),
};
debug!(
"WASAPI capture selected: {} @ {}Hz {}ch {:?}",
device_label, input_rate, input_channels, sample_format
);
let (tx, rx) = mpsc::sync_channel::<Vec<i16>>(8);
let (err_tx, err_rx) = mpsc::sync_channel::<String>(1);
let callback_stop = Arc::new(AtomicBool::new(false));
let stream = match sample_format {
SampleFormat::F32 => build_stream::<f32>(
&device,
&stream_config,
input_channels,
input_rate,
tx.clone(),
err_tx.clone(),
callback_stop.clone(),
),
SampleFormat::I16 => build_stream::<i16>(
&device,
&stream_config,
input_channels,
input_rate,
tx.clone(),
err_tx.clone(),
callback_stop.clone(),
),
SampleFormat::U16 => build_stream::<u16>(
&device,
&stream_config,
input_channels,
input_rate,
tx.clone(),
err_tx.clone(),
callback_stop.clone(),
),
other => {
return Err(AppError::AudioError(format!(
"Unsupported WASAPI sample format: {:?}",
other
)));
}
}?;
stream
.play()
.map_err(|e| AppError::AudioError(format!("Failed to start WASAPI stream: {}", e)))?;
let _ = state.send(CaptureState::Running);
while !stop_flag.load(Ordering::Relaxed) {
if let Ok(err) = err_rx.try_recv() {
return Err(AppError::AudioError(format!(
"WASAPI stream error for {}: {}",
device_label, err
)));
}
match rx.recv_timeout(Duration::from_millis(100)) {
Ok(samples) => {
if samples.is_empty() {
continue;
}
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(bytemuck::cast_slice(&samples)),
2,
48_000,
seq,
);
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
}
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {}
Err(mpsc::RecvTimeoutError::Disconnected) => {
return Err(AppError::AudioError(format!(
"WASAPI capture callback stopped for {}",
device_label
)));
}
}
}
callback_stop.store(true, Ordering::SeqCst);
drop(stream);
info!("WASAPI audio capture stopped");
let _ = log_throttler;
Ok(())
}
fn select_input_config(
device: &cpal::Device,
config: &AudioConfig,
) -> Result<cpal::SupportedStreamConfig> {
let requested_rate = config.sample_rate;
let mut fallback = None;
let configs = device.supported_input_configs().map_err(|e| {
AppError::AudioError(format!("Failed to query WASAPI input configs: {}", e))
})?;
for range in configs {
let sample_format = range.sample_format();
if !matches!(
sample_format,
SampleFormat::F32 | SampleFormat::I16 | SampleFormat::U16
) {
continue;
}
if fallback
.as_ref()
.is_none_or(|best: &cpal::SupportedStreamConfigRange| {
range.cmp_default_heuristics(best).is_gt()
})
{
fallback = Some(range);
}
if range.channels() >= 2
&& range.min_sample_rate() <= requested_rate
&& requested_rate <= range.max_sample_rate()
{
return Ok(range.with_sample_rate(requested_rate));
}
}
if let Some(range) = fallback {
let rate = if range.min_sample_rate() <= requested_rate
&& requested_rate <= range.max_sample_rate()
{
requested_rate
} else {
range.with_max_sample_rate().sample_rate()
};
return Ok(range.with_sample_rate(rate));
}
device.default_input_config().map_err(|e| {
AppError::AudioError(format!(
"No supported WASAPI input format found, and default config failed: {}",
e
))
})
}
fn build_stream<T>(
device: &cpal::Device,
config: &StreamConfig,
input_channels: u32,
input_rate: u32,
tx: mpsc::SyncSender<Vec<i16>>,
err_tx: mpsc::SyncSender<String>,
stop_flag: Arc<AtomicBool>,
) -> Result<cpal::Stream>
where
T: cpal::SizedSample + SampleToI16,
{
let mut converter = PcmConverter::new(input_channels, input_rate, 2, 48_000);
let data_tx = tx.clone();
let stream = device
.build_input_stream(
config,
move |data: &[T], _| {
if stop_flag.load(Ordering::Relaxed) {
return;
}
let pcm = converter.convert(data);
if !pcm.is_empty() {
let _ = data_tx.try_send(pcm);
}
},
move |err| {
let _ = err_tx.try_send(err.to_string());
},
Some(Duration::from_secs(2)),
)
.map_err(|e| AppError::AudioError(format!("Failed to build WASAPI input stream: {}", e)))?;
Ok(stream)
}
trait SampleToI16: Copy + Send + 'static {
fn to_i16_sample(self) -> i16;
}
impl SampleToI16 for i16 {
fn to_i16_sample(self) -> i16 {
self
}
}
impl SampleToI16 for u16 {
fn to_i16_sample(self) -> i16 {
(self as i32 - 32768).clamp(i16::MIN as i32, i16::MAX as i32) as i16
}
}
impl SampleToI16 for f32 {
fn to_i16_sample(self) -> i16 {
(self.clamp(-1.0, 1.0) * i16::MAX as f32).round() as i16
}
}
struct PcmConverter {
input_channels: usize,
input_rate: u32,
output_channels: usize,
output_rate: u32,
input_position: u64,
next_output_position: u64,
}
impl PcmConverter {
fn new(input_channels: u32, input_rate: u32, output_channels: u32, output_rate: u32) -> Self {
Self {
input_channels: input_channels.max(1) as usize,
input_rate: input_rate.max(1),
output_channels: output_channels.max(1) as usize,
output_rate: output_rate.max(1),
input_position: 0,
next_output_position: 0,
}
}
fn convert<T: SampleToI16>(&mut self, input: &[T]) -> Vec<i16> {
let frames = input.len() / self.input_channels;
if frames == 0 {
return Vec::new();
}
if self.input_rate == self.output_rate {
self.input_position = self.input_position.saturating_add(frames as u64);
return self.convert_channels(input, frames);
}
let start = self.input_position;
let end = start.saturating_add(frames as u64);
let mut out = Vec::with_capacity(
((frames as u64 * self.output_rate as u64 / self.input_rate as u64 + 2) as usize)
* self.output_channels,
);
while self.source_position_for_output(self.next_output_position) < end {
let src = self.source_position_for_output(self.next_output_position);
if src >= start {
let local = (src - start) as usize;
self.push_frame(input, local.min(frames - 1), &mut out);
}
self.next_output_position = self.next_output_position.saturating_add(1);
}
self.input_position = end;
out
}
fn source_position_for_output(&self, output_position: u64) -> u64 {
output_position.saturating_mul(self.input_rate as u64) / self.output_rate as u64
}
fn convert_channels<T: SampleToI16>(&self, input: &[T], frames: usize) -> Vec<i16> {
let mut out = Vec::with_capacity(frames * self.output_channels);
for frame in 0..frames {
self.push_frame(input, frame, &mut out);
}
out
}
fn push_frame<T: SampleToI16>(&self, input: &[T], frame: usize, out: &mut Vec<i16>) {
let base = frame * self.input_channels;
let left = input
.get(base)
.copied()
.map(SampleToI16::to_i16_sample)
.unwrap_or(0);
let right = if self.input_channels > 1 {
input
.get(base + 1)
.copied()
.map(SampleToI16::to_i16_sample)
.unwrap_or(left)
} else {
left
};
out.push(left);
if self.output_channels > 1 {
out.push(right);
}
}
}
fn device_label(device: &cpal::Device) -> String {
device
.description()
.map(|desc| desc.to_string())
.or_else(|_| {
#[allow(deprecated)]
device.name()
})
.unwrap_or_else(|_| "Unknown WASAPI capture device".to_string())
}

View File

@@ -1,159 +1,85 @@
//! Audio controller for high-level audio management
//!
//! Provides device enumeration, selection, quality control, and streaming management.
//! Device selection, quality presets, streaming.
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::info;
use tracing::{debug, info};
use super::capture::AudioConfig;
use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo};
use super::encoder::{OpusConfig, OpusFrame};
use super::monitor::{AudioHealthMonitor, AudioHealthStatus};
use super::device::{enumerate_audio_devices_with_current, find_best_audio_device, AudioDeviceInfo};
use super::encoder::OpusFrame;
use super::monitor::AudioHealthMonitor;
use super::streamer::{AudioStreamer, AudioStreamerConfig};
use super::recovery;
use super::types::{AudioControllerConfig, AudioQuality, AudioStatus};
use crate::error::{AppError, Result};
use crate::events::EventBus;
/// Audio quality presets
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
/// Low bandwidth voice (32kbps)
Voice,
/// Balanced quality (64kbps) - default
#[default]
Balanced,
/// High quality audio (128kbps)
High,
}
pub(super) type AudioRecoveredCallback = Arc<dyn Fn() + Send + Sync>;
impl AudioQuality {
/// Get the bitrate for this quality level
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
AudioQuality::Balanced => 64000,
AudioQuality::High => 128000,
}
}
/// Parse from string
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"voice" | "low" => AudioQuality::Voice,
"high" | "music" => AudioQuality::High,
_ => AudioQuality::Balanced,
}
}
/// Convert to OpusConfig
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
AudioQuality::Balanced => OpusConfig::default(),
AudioQuality::High => OpusConfig::music(),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AudioQuality::Voice => write!(f, "voice"),
AudioQuality::Balanced => write!(f, "balanced"),
AudioQuality::High => write!(f, "high"),
}
}
}
/// Audio controller configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
/// Whether audio is enabled
pub enabled: bool,
/// Selected device name
pub device: String,
/// Audio quality preset
pub quality: AudioQuality,
}
impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
quality: AudioQuality::Balanced,
}
}
}
/// Current audio status
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
/// Whether audio feature is enabled
pub enabled: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Currently selected device
pub device: Option<String>,
/// Current quality preset
pub quality: AudioQuality,
/// Number of connected subscribers
pub subscriber_count: usize,
/// Error message if any
pub error: Option<String>,
}
/// Audio controller
///
/// High-level interface for audio management, providing:
/// - Device enumeration and selection
/// - Quality control
/// - Stream start/stop
/// - Status reporting
pub struct AudioController {
config: RwLock<AudioControllerConfig>,
streamer: RwLock<Option<Arc<AudioStreamer>>>,
devices: RwLock<Vec<AudioDeviceInfo>>,
event_bus: RwLock<Option<Arc<EventBus>>>,
last_error: RwLock<Option<String>>,
/// Health monitor for error tracking and recovery
config: Arc<RwLock<AudioControllerConfig>>,
streamer: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
devices: Arc<RwLock<Vec<AudioDeviceInfo>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
}
impl AudioController {
/// Create a new audio controller with configuration
pub fn new(config: AudioControllerConfig) -> Self {
Self {
config: RwLock::new(config),
streamer: RwLock::new(None),
devices: RwLock::new(Vec::new()),
event_bus: RwLock::new(None),
last_error: RwLock::new(None),
monitor: Arc::new(AudioHealthMonitor::with_defaults()),
config: Arc::new(RwLock::new(config)),
streamer: Arc::new(RwLock::new(None)),
devices: Arc::new(RwLock::new(Vec::new())),
event_bus: Arc::new(RwLock::new(None)),
monitor: Arc::new(AudioHealthMonitor::new()),
recovery_in_progress: Arc::new(AtomicBool::new(false)),
recovered_callback: Arc::new(RwLock::new(None)),
}
}
/// Set event bus for internal state notifications.
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus);
}
/// Mark the device-info snapshot as stale.
pub async fn set_recovered_callback(&self, callback: Arc<dyn Fn() + Send + Sync>) {
*self.recovered_callback.write().await = Some(callback);
}
async fn mark_device_info_dirty(&self) {
if let Some(ref bus) = *self.event_bus.read().await {
if let Some(bus) = self.event_bus.read().await.as_ref() {
bus.mark_device_info_dirty();
}
}
fn spawn_recovery_task(&self, lost_device: String, reason: String) {
recovery::spawn_recovery_task(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
self.monitor.clone(),
self.recovery_in_progress.clone(),
self.recovered_callback.clone(),
lost_device,
reason,
);
}
fn spawn_stream_monitor(&self, streamer: Arc<AudioStreamer>, device: String) {
recovery::spawn_stream_monitor(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
self.monitor.clone(),
self.recovery_in_progress.clone(),
self.recovered_callback.clone(),
streamer,
device,
);
}
/// List available audio capture devices
pub async fn list_devices(&self) -> Result<Vec<AudioDeviceInfo>> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
@@ -165,41 +91,23 @@ impl AudioController {
Ok(devices)
}
/// Refresh device list and cache it
pub async fn refresh_devices(&self) -> Result<()> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
None
};
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
*self.devices.write().await = devices;
Ok(())
}
/// Get cached device list
pub async fn get_cached_devices(&self) -> Vec<AudioDeviceInfo> {
self.devices.read().await.clone()
}
/// Select audio device
pub async fn select_device(&self, device: &str) -> Result<()> {
// Validate device exists
let devices = self.list_devices().await?;
let found = devices
.iter()
.any(|d| d.name == device || d.description.contains(device));
if !found && device != "default" {
if !found {
return Err(AppError::AudioError(format!(
"Audio device not found: {}",
device
)));
}
// Update config
{
let mut config = self.config.write().await;
config.device = device.to_string();
@@ -207,7 +115,6 @@ impl AudioController {
info!("Audio device selected: {}", device);
// If streaming, restart with new device
if self.is_streaming().await {
self.stop_streaming().await?;
self.start_streaming().await?;
@@ -216,16 +123,13 @@ impl AudioController {
Ok(())
}
/// Set audio quality
pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> {
// Update config
{
let mut config = self.config.write().await;
config.quality = quality;
}
// Update streamer if running
if let Some(ref streamer) = *self.streamer.read().await {
if let Some(streamer) = self.streamer.read().await.as_ref() {
streamer.set_bitrate(quality.bitrate()).await?;
}
@@ -237,124 +141,133 @@ impl AudioController {
Ok(())
}
/// Start audio streaming
pub async fn start_streaming(&self) -> Result<()> {
let config = self.config.read().await.clone();
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
{
let config = self.config.read().await;
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
}
}
// Check if already streaming
if self.is_streaming().await {
return Ok(());
}
info!("Starting audio streaming with device: {}", config.device);
// Clear any previous error
*self.last_error.write().await = None;
// Create streamer config (fixed 48kHz stereo)
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: config.device.clone(),
..Default::default()
},
opus: config.quality.to_opus_config(),
let mut select_error = None;
let (device_name, quality) = {
let mut cfg = self.config.write().await;
if cfg.device.trim().is_empty() {
match find_best_audio_device() {
Ok(best) => cfg.device = best.name,
Err(e) => {
select_error = Some(format!("Failed to select audio device: {}", e));
}
}
}
(cfg.device.clone(), cfg.quality)
};
if let Some(error_msg) = select_error {
self.monitor.report_error(&error_msg, "start_failed").await;
self.spawn_recovery_task("auto".to_string(), error_msg.clone());
self.mark_device_info_dirty().await;
return Err(AppError::AudioError(error_msg));
}
debug!("Starting audio streaming with device: {}", device_name);
self.monitor.prepare_retry_attempt();
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device_name.clone(),
..Default::default()
},
opus: quality.to_opus_config(),
};
// Create and start streamer
let streamer = Arc::new(AudioStreamer::with_config(streamer_config));
if let Err(e) = streamer.start().await {
let error_msg = format!("Failed to start audio: {}", e);
*self.last_error.write().await = Some(error_msg.clone());
// Report error to health monitor
self.monitor
.report_error(Some(&config.device), &error_msg, "start_failed")
.await;
self.monitor.report_error(&error_msg, "start_failed").await;
self.spawn_recovery_task(device_name.clone(), error_msg.clone());
self.mark_device_info_dirty().await;
return Err(AppError::AudioError(error_msg));
}
let streamer_for_monitor = streamer.clone();
*self.streamer.write().await = Some(streamer);
self.spawn_stream_monitor(streamer_for_monitor, device_name.clone());
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered(Some(&config.device)).await;
self.monitor.report_recovered().await;
}
self.recovery_in_progress.store(false, Ordering::SeqCst);
self.mark_device_info_dirty().await;
info!("Audio streaming started");
Ok(())
}
/// Stop audio streaming
pub async fn stop_streaming(&self) -> Result<()> {
self.recovery_in_progress.store(false, Ordering::SeqCst);
if let Some(streamer) = self.streamer.write().await.take() {
streamer.stop().await?;
}
self.monitor.reset().await;
self.mark_device_info_dirty().await;
info!("Audio streaming stopped");
Ok(())
}
/// Check if currently streaming
pub async fn is_streaming(&self) -> bool {
if let Some(ref streamer) = *self.streamer.read().await {
streamer.is_running()
} else {
false
}
self.streamer
.read()
.await
.as_ref()
.is_some_and(|streamer| streamer.is_running())
}
/// Get current status
pub async fn status(&self) -> AudioStatus {
let config = self.config.read().await;
let streaming = self.is_streaming().await;
let error = self.last_error.read().await.clone();
let (enabled, device_str, quality) = {
let c = self.config.read().await;
(c.enabled, c.device.clone(), c.quality)
};
let error = self.monitor.error_message().await;
let subscriber_count = if let Some(ref streamer) = *self.streamer.read().await {
streamer.stats().await.subscriber_count
let (streaming, subscriber_count) = if let Some(ref streamer) = *self.streamer.read().await
{
let streaming = streamer.is_running();
let subscriber_count = streamer.stats().subscriber_count;
(streaming, subscriber_count)
} else {
0
(false, 0)
};
AudioStatus {
enabled: config.enabled,
enabled,
streaming,
device: if streaming || config.enabled {
Some(config.device.clone())
device: if streaming || enabled {
Some(device_str)
} else {
None
},
quality: config.quality,
quality,
subscriber_count,
error,
}
}
/// Subscribe to Opus frames (for WebSocket clients)
pub fn subscribe_opus(&self) -> Option<tokio::sync::watch::Receiver<Option<Arc<OpusFrame>>>> {
// Use try_read to avoid blocking - this is called from sync context sometimes
if let Ok(guard) = self.streamer.try_read() {
guard.as_ref().map(|s| s.subscribe_opus())
} else {
None
}
}
/// Subscribe to Opus frames (async version)
pub async fn subscribe_opus_async(
&self,
) -> Option<tokio::sync::watch::Receiver<Option<Arc<OpusFrame>>>> {
pub async fn subscribe_opus(&self) -> Option<tokio::sync::mpsc::Receiver<Arc<OpusFrame>>> {
self.streamer
.read()
.await
@@ -362,7 +275,6 @@ impl AudioController {
.map(|s| s.subscribe_opus())
}
/// Enable or disable audio
pub async fn set_enabled(&self, enabled: bool) -> Result<()> {
{
let mut config = self.config.write().await;
@@ -377,21 +289,15 @@ impl AudioController {
Ok(())
}
/// Update full configuration
pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> {
let was_streaming = self.is_streaming().await;
// Stop streaming if running (device/quality/enabled may all change)
if was_streaming {
self.stop_streaming().await?;
}
// Update config
*self.config.write().await = new_config.clone();
// Start whenever audio is enabled — not only when we were already streaming.
// Otherwise PATCH /config/audio alone leaves enabled=true with no capture until
// POST /audio/start, which races WebRTC reconnect and matches "apply twice" reports.
if new_config.enabled {
self.start_streaming().await?;
}
@@ -399,25 +305,9 @@ impl AudioController {
Ok(())
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
self.stop_streaming().await
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<AudioHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> AudioHealthStatus {
self.monitor.status().await
}
/// Check if the audio is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Default for AudioController {
@@ -439,12 +329,23 @@ mod tests {
#[test]
fn test_audio_quality_from_str() {
assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced);
assert_eq!(AudioQuality::from_str("high"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("music"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced);
assert_eq!(
"voice".parse::<AudioQuality>().unwrap(),
AudioQuality::Voice
);
assert_eq!(
"balanced".parse::<AudioQuality>().unwrap(),
AudioQuality::Balanced
);
assert_eq!("high".parse::<AudioQuality>().unwrap(), AudioQuality::High);
}
#[test]
fn test_audio_quality_from_str_rejects_aliases_and_unknown() {
assert!("low".parse::<AudioQuality>().is_err());
assert!("music".parse::<AudioQuality>().is_err());
assert!("unknown".parse::<AudioQuality>().is_err());
assert!("".parse::<AudioQuality>().is_err());
}
#[tokio::test]

View File

@@ -1,271 +1,9 @@
//! Audio device enumeration using ALSA
#[cfg(unix)]
#[path = "device_linux.rs"]
mod imp;
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
use tracing::{debug, info, warn};
#[cfg(windows)]
#[path = "device_windows.rs"]
mod imp;
use crate::error::{AppError, Result};
/// Audio device information
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
/// Device name (e.g., "hw:0,0" or "default")
pub name: String,
/// Human-readable description
pub description: String,
/// Card index
pub card_index: i32,
/// Device index
pub device_index: i32,
/// Supported sample rates
pub sample_rates: Vec<u32>,
/// Supported channel counts
pub channels: Vec<u32>,
/// Is this a capture device
pub is_capture: bool,
/// Is this an HDMI audio device (likely from capture card)
pub is_hdmi: bool,
/// USB bus info for matching with video devices (e.g., "1-1" from USB path)
pub usb_bus: Option<String>,
}
impl AudioDeviceInfo {
/// Get ALSA device name
pub fn alsa_name(&self) -> String {
format!("hw:{},{}", self.card_index, self.device_index)
}
}
/// Get USB bus info for an audio card by reading sysfs
/// Returns the USB port path like "1-1" or "1-2.3"
fn get_usb_bus_info(card_index: i32) -> Option<String> {
if card_index < 0 {
return None;
}
// Read the device symlink: /sys/class/sound/cardX/device -> ../../usb1/1-1/1-1:1.0
let device_path = format!("/sys/class/sound/card{}/device", card_index);
let link_target = std::fs::read_link(&device_path).ok()?;
let link_str = link_target.to_string_lossy();
// Extract USB port from path like "../../usb1/1-1/1-1:1.0" or "../../1-1/1-1:1.0"
// We want the "1-1" part (USB bus-port)
for component in link_str.split('/') {
// Match patterns like "1-1", "1-2", "1-1.2", "2-1.3.1"
if component.contains('-') && !component.contains(':') {
// Verify it looks like a USB port (starts with digit)
if component
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
return Some(component.to_string());
}
}
}
None
}
/// Enumerate available audio capture devices
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
/// Enumerate available audio capture devices, with option to include a currently-in-use device
///
/// # Arguments
/// * `current_device` - Optional device name that is currently in use. This device will be
/// included in the list even if it cannot be opened (because it's already open by us).
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
// Try to enumerate cards
let cards = alsa::card::Iter::new();
for card_result in cards {
let card = match card_result {
Ok(c) => c,
Err(e) => {
debug!("Error iterating card: {}", e);
continue;
}
};
let card_index = card.get_index();
let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string());
let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone());
debug!("Found audio card {}: {}", card_index, card_longname);
// Check if this looks like an HDMI capture device
let is_hdmi = card_longname.to_lowercase().contains("hdmi")
|| card_longname.to_lowercase().contains("capture")
|| card_longname.to_lowercase().contains("usb");
// Get USB bus info for this card
let usb_bus = get_usb_bus_info(card_index);
// Try to open each device on this card for capture
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
// Check if this is the currently-in-use device
let is_current_device = current_device == Some(device_name.as_str());
// Try to open for capture
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
// Query capabilities
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
devices.push(AudioDeviceInfo {
name: device_name,
description: format!("{} - Device {}", card_longname, device_index),
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
}
}
Err(_) => {
// Device doesn't exist or can't be opened for capture
// But if it's the current device, include it anyway (it's busy because we're using it)
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
devices.push(AudioDeviceInfo {
name: device_name,
description: format!(
"{} - Device {} (in use)",
card_longname, device_index
),
card_index,
device_index,
// Use common default capabilities for HDMI capture devices
sample_rates: vec![44100, 48000],
channels: vec![2],
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
}
continue;
}
}
}
}
// Also check for "default" device
if let Ok(pcm) = PCM::new("default", Direction::Capture, false) {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() {
devices.insert(
0,
AudioDeviceInfo {
name: "default".to_string(),
description: "Default Audio Device".to_string(),
card_index: -1,
device_index: -1,
sample_rates,
channels,
is_capture: true,
is_hdmi: false,
usb_bus: None,
},
);
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
/// Query device capabilities
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
// Common sample rates to check
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
for rate in &common_rates {
if hwp.test_rate(*rate).is_ok() {
supported_rates.push(*rate);
}
}
// Check channel counts
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
supported_channels.push(ch);
}
}
(supported_rates, supported_channels)
}
/// Find the best audio device for capture
/// Prefers HDMI/capture devices over built-in microphones
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError(
"No audio capture devices found".to_string(),
));
}
// First, look for HDMI/capture card devices that support 48kHz stereo
for device in &devices {
if device.is_hdmi && device.sample_rates.contains(&48000) && device.channels.contains(&2) {
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
}
// Then look for any device supporting 48kHz stereo
for device in &devices {
if device.sample_rates.contains(&48000) && device.channels.contains(&2) {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
}
// Fall back to first device
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
device.description
);
Ok(device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enumerate_devices() {
// This test may not find devices in CI environment
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
// Just verify it doesn't panic
assert!(result.is_ok());
}
}
pub use imp::*;

201
src/audio/device_linux.rs Normal file
View File

@@ -0,0 +1,201 @@
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
pub name: String,
pub description: String,
pub card_index: i32,
pub device_index: i32,
pub sample_rates: Vec<u32>,
pub channels: Vec<u32>,
pub is_capture: bool,
pub is_hdmi: bool,
pub usb_bus: Option<String>,
}
fn get_usb_bus_info(card_index: i32) -> Option<String> {
if card_index < 0 {
return None;
}
let device_path = format!("/sys/class/sound/card{}/device", card_index);
let link_target = std::fs::read_link(&device_path).ok()?;
let link_str = link_target.to_string_lossy();
for component in link_str.split('/') {
if component.contains('-') && !component.contains(':') {
if component
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
return Some(component.to_string());
}
}
}
None
}
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
let cards = alsa::card::Iter::new();
for card_result in cards {
let card = match card_result {
Ok(c) => c,
Err(e) => {
debug!("Error iterating card: {}", e);
continue;
}
};
let card_index = card.get_index();
let card_name = card.get_name().unwrap_or_else(|_| "Unknown".to_string());
let card_longname = card.get_longname().unwrap_or_else(|_| card_name.clone());
debug!("Found audio card {}: {}", card_index, card_longname);
let long_lower = card_longname.to_lowercase();
let is_hdmi = long_lower.contains("hdmi")
|| long_lower.contains("capture")
|| long_lower.contains("usb");
let usb_bus = get_usb_bus_info(card_index);
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
let is_current_device = current_device == Some(device_name.as_str());
let mut push_info =
|sample_rates: Vec<u32>, channels: Vec<u32>, description: String| {
devices.push(AudioDeviceInfo {
name: device_name.clone(),
description,
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
};
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
push_info(
sample_rates,
channels,
format!("{} - Device {}", card_longname, device_index),
);
}
}
Err(_) => {
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
push_info(
vec![44100, 48000],
vec![2],
format!("{} - Device {} (in use)", card_longname, device_index),
);
}
}
}
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
for rate in &common_rates {
if hwp.test_rate(*rate).is_ok() {
supported_rates.push(*rate);
}
}
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
supported_channels.push(ch);
}
}
(supported_rates, supported_channels)
}
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError(
"No audio capture devices found".to_string(),
));
}
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
device.description
);
Ok(device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enumerate_devices() {
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
assert!(result.is_ok());
}
}

232
src/audio/device_windows.rs Normal file
View File

@@ -0,0 +1,232 @@
use cpal::traits::{DeviceTrait, HostTrait};
use cpal::DeviceId;
use serde::Serialize;
use std::str::FromStr;
use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
pub name: String,
pub description: String,
pub card_index: i32,
pub device_index: i32,
pub sample_rates: Vec<u32>,
pub channels: Vec<u32>,
pub is_capture: bool,
pub is_hdmi: bool,
pub usb_bus: Option<String>,
}
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let host = cpal::default_host();
let devices = host
.input_devices()
.map_err(|e| AppError::AudioError(format!("Failed to enumerate WASAPI devices: {}", e)))?;
let mut result = Vec::new();
for (index, device) in devices.enumerate() {
let labels = device_labels(&device);
let id = device
.id()
.map(|id| id.to_string())
.unwrap_or_else(|_| format!("wasapi-index:{}", index));
let (sample_rates, channels) = query_device_caps(&device);
if sample_rates.is_empty() || channels.is_empty() {
debug!(
"Skipping WASAPI endpoint without usable input caps: {}",
labels.search_text
);
continue;
}
let is_current =
current_device == Some(id.as_str()) || current_device == Some(labels.display.as_str());
let description = if is_current {
format!("{} (in use)", labels.display)
} else {
labels.display.clone()
};
let lower = labels.search_text.to_lowercase();
let is_hdmi = lower.contains("hdmi")
|| lower.contains("capture")
|| lower.contains("usb")
|| lower.contains("digital");
result.push(AudioDeviceInfo {
name: id,
description,
card_index: index as i32,
device_index: 0,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: None,
});
}
info!("Found {} WASAPI audio capture devices", result.len());
Ok(result)
}
fn query_device_caps(device: &cpal::Device) -> (Vec<u32>, Vec<u32>) {
let mut sample_rates = Vec::new();
let mut channels = Vec::new();
if let Ok(configs) = device.supported_input_configs() {
for cfg in configs {
for rate in [8000, 16000, 22050, 44100, 48000, 96000] {
if cfg.min_sample_rate() <= rate
&& rate <= cfg.max_sample_rate()
&& !sample_rates.contains(&rate)
{
sample_rates.push(rate);
}
}
let ch = cfg.channels() as u32;
if !channels.contains(&ch) {
channels.push(ch);
}
}
}
if (sample_rates.is_empty() || channels.is_empty()) && device.default_input_config().is_ok() {
if let Ok(default_cfg) = device.default_input_config() {
if !sample_rates.contains(&default_cfg.sample_rate()) {
sample_rates.push(default_cfg.sample_rate());
}
let ch = default_cfg.channels() as u32;
if !channels.contains(&ch) {
channels.push(ch);
}
}
}
sample_rates.sort_unstable();
channels.sort_unstable();
(sample_rates, channels)
}
struct DeviceLabels {
display: String,
search_text: String,
}
fn device_labels(device: &cpal::Device) -> DeviceLabels {
match device.description() {
Ok(desc) => {
let formatted = desc.to_string();
let display = desc
.extended()
.first()
.cloned()
.unwrap_or_else(|| formatted.clone());
let mut parts = vec![formatted, desc.name().to_string(), display.clone()];
parts.extend(desc.extended().iter().cloned());
DeviceLabels {
display,
search_text: parts.join(" "),
}
}
Err(_) => {
#[allow(deprecated)]
let display = device
.name()
.unwrap_or_else(|_| "Unknown WASAPI capture device".to_string());
DeviceLabels {
display: display.clone(),
search_text: display,
}
}
}
}
pub(crate) fn find_wasapi_device(requested_device: &str) -> Result<cpal::Device> {
let host = cpal::default_host();
let trimmed = requested_device.trim();
if trimmed.is_empty()
|| trimmed.eq_ignore_ascii_case("auto")
|| trimmed.eq_ignore_ascii_case("default")
{
return host.default_input_device().ok_or_else(|| {
AppError::AudioError("No default WASAPI input device found".to_string())
});
}
if let Ok(id) = DeviceId::from_str(trimmed) {
if let Some(device) = host.device_by_id(&id) {
return Ok(device);
}
}
let needle = trimmed.to_lowercase();
let devices = host
.input_devices()
.map_err(|e| AppError::AudioError(format!("Failed to enumerate WASAPI devices: {}", e)))?;
for device in devices {
let id_match = device
.id()
.map(|id| id.to_string() == trimmed)
.unwrap_or(false);
let labels = device_labels(&device);
if id_match || labels.search_text.to_lowercase().contains(&needle) {
return Ok(device);
}
}
Err(AppError::AudioError(format!(
"WASAPI audio device not found: {}",
requested_device
)))
}
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
if devices.is_empty() {
return Err(AppError::AudioError(
"No WASAPI audio capture devices found".to_string(),
));
}
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected WASAPI capture device: {}", device.description);
return Ok(device.clone());
}
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected WASAPI capture device: {}", device.description);
return Ok(device.clone());
}
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback WASAPI audio device: {} (will resample if needed)",
device.description
);
Ok(device)
}

View File

@@ -1,26 +1,19 @@
//! Opus audio encoder for WebRTC
//! Opus encoder.
use audiopus::coder::GenericCtl;
use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate};
use bytes::Bytes;
use std::time::Instant;
use tracing::info;
use tracing::debug;
use super::capture::AudioFrame;
use crate::error::{AppError, Result};
/// Opus encoder configuration
#[derive(Debug, Clone)]
pub struct OpusConfig {
/// Sample rate (must be 8000, 12000, 16000, 24000, or 48000)
pub sample_rate: u32,
/// Channels (1 or 2)
pub channels: u32,
/// Target bitrate in bps
pub bitrate: u32,
/// Application mode
pub application: OpusApplication,
/// Enable forward error correction
pub fec: bool,
}
@@ -29,7 +22,7 @@ impl Default for OpusConfig {
Self {
sample_rate: 48000,
channels: 2,
bitrate: 64000, // 64 kbps
bitrate: 64000,
application: OpusApplication::Audio,
fec: true,
}
@@ -37,7 +30,6 @@ impl Default for OpusConfig {
}
impl OpusConfig {
/// Create config for voice (lower latency)
pub fn voice() -> Self {
Self {
application: OpusApplication::Voip,
@@ -46,7 +38,6 @@ impl OpusConfig {
}
}
/// Create config for music (higher quality)
pub fn music() -> Self {
Self {
application: OpusApplication::Audio,
@@ -82,30 +73,18 @@ impl OpusConfig {
}
}
/// Opus application mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpusApplication {
/// Voice over IP
Voip,
/// General audio
Audio,
/// Low delay mode
LowDelay,
}
/// Encoded Opus frame
#[derive(Debug, Clone)]
pub struct OpusFrame {
/// Encoded Opus data
pub data: Bytes,
/// Duration in milliseconds
pub duration_ms: u32,
/// Sequence number
pub sequence: u64,
/// Timestamp
pub timestamp: Instant,
/// RTP timestamp (samples)
pub rtp_timestamp: u32,
}
impl OpusFrame {
@@ -118,20 +97,14 @@ impl OpusFrame {
}
}
/// Opus encoder
pub struct OpusEncoder {
config: OpusConfig,
encoder: Encoder,
/// Output buffer
output_buffer: Vec<u8>,
/// Frame counter for RTP timestamp
frame_count: u64,
/// Samples per frame
samples_per_frame: u32,
}
impl OpusEncoder {
/// Create a new Opus encoder
pub fn new(config: OpusConfig) -> Result<Self> {
let sample_rate = config.to_audiopus_sample_rate();
let channels = config.to_audiopus_channels();
@@ -140,7 +113,6 @@ impl OpusEncoder {
let mut encoder = Encoder::new(sample_rate, channels, application)
.map_err(|e| AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)))?;
// Configure encoder
encoder
.set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32))
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
@@ -151,10 +123,7 @@ impl OpusEncoder {
.map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?;
}
// Calculate samples per frame (20ms at sample_rate)
let samples_per_frame = config.sample_rate / 50;
info!(
debug!(
"Opus encoder created: {}Hz {}ch {}bps",
config.sample_rate, config.channels, config.bitrate
);
@@ -162,18 +131,11 @@ impl OpusEncoder {
Ok(Self {
config,
encoder,
output_buffer: vec![0u8; 4000], // Max Opus frame size
output_buffer: vec![0u8; 4000],
frame_count: 0,
samples_per_frame,
})
}
/// Create with default configuration
pub fn default_config() -> Result<Self> {
Self::new(OpusConfig::default())
}
/// Encode PCM audio data (S16LE interleaved)
pub fn encode(&mut self, pcm_data: &[i16]) -> Result<OpusFrame> {
let encoded_len = self
.encoder
@@ -182,7 +144,6 @@ impl OpusEncoder {
let samples = pcm_data.len() as u32 / self.config.channels;
let duration_ms = (samples * 1000) / self.config.sample_rate;
let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32;
self.frame_count += 1;
@@ -190,27 +151,18 @@ impl OpusEncoder {
data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]),
duration_ms,
sequence: self.frame_count - 1,
timestamp: Instant::now(),
rtp_timestamp,
})
}
/// Encode from AudioFrame
///
/// Uses zero-copy conversion from bytes to i16 samples via bytemuck.
pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result<OpusFrame> {
// Zero-copy: directly cast bytes to i16 slice
// AudioFrame.data is S16LE format, which matches native little-endian i16
let samples: &[i16] = bytemuck::cast_slice(&frame.data);
self.encode(samples)
}
/// Get encoder configuration
pub fn config(&self) -> &OpusConfig {
&self.config
}
/// Reset encoder state
pub fn reset(&mut self) -> Result<()> {
self.encoder
.reset_state()
@@ -219,7 +171,6 @@ impl OpusEncoder {
Ok(())
}
/// Set bitrate dynamically
pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
self.encoder
.set_bitrate(Bitrate::BitsPerSecond(bitrate as i32))
@@ -228,15 +179,6 @@ impl OpusEncoder {
}
}
/// Audio encoder statistics
#[derive(Debug, Clone, Default)]
pub struct EncoderStats {
pub frames_encoded: u64,
pub bytes_output: u64,
pub avg_frame_size: usize,
pub current_bitrate: u32,
}
#[cfg(test)]
mod tests {
use super::*;
@@ -261,13 +203,12 @@ mod tests {
let config = OpusConfig::default();
let mut encoder = OpusEncoder::new(config).unwrap();
// 20ms of stereo silence at 48kHz
let silence = vec![0i16; 960 * 2];
let result = encoder.encode(&silence);
assert!(result.is_ok());
let frame = result.unwrap();
assert!(!frame.is_empty());
assert!(frame.len() < silence.len() * 2); // Should be compressed
assert!(frame.len() < silence.len() * 2);
}
}

View File

@@ -1,24 +1,21 @@
//! Audio capture and encoding module
//!
//! This module provides:
//! - ALSA audio capture
//! - Opus encoding for WebRTC
//! - Audio device enumeration
//! - Audio streaming pipeline
//! - High-level audio controller
//! - Device health monitoring
//! Platform audio capture, Opus encode, device enumeration, streaming, controller, health monitor.
#[cfg(any(unix, windows))]
pub mod capture;
pub mod controller;
#[cfg(any(unix, windows))]
pub mod device;
#[cfg(any(unix, windows))]
pub mod encoder;
pub mod monitor;
pub mod resample;
pub mod recovery;
pub mod streamer;
pub mod types;
pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
pub use controller::AudioController;
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus};
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
pub use types::{AudioControllerConfig, AudioQuality, AudioStatus};

View File

@@ -1,114 +1,58 @@
//! Audio device health monitoring
//!
//! This module provides health monitoring for audio capture devices, including:
//! - Device connectivity checks
//! - Automatic reconnection on failure
//! - Error tracking
//! - Log throttling to prevent log flooding
//! Audio device health and logging throttle for repeated failures.
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::utils::LogThrottler;
/// Audio health status
const LOG_THROTTLE_SECS: u64 = 5;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum AudioHealthStatus {
/// Device is healthy and operational
#[default]
Healthy,
/// Device has an error, attempting recovery
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
/// Number of recovery attempts made
retry_count: u32,
},
/// Device is disconnected or not available
Disconnected,
}
/// Audio health monitor configuration
#[derive(Debug, Clone)]
pub struct AudioMonitorConfig {
/// Retry interval when device is lost (milliseconds)
pub retry_interval_ms: u64,
/// Maximum retry attempts before giving up (0 = infinite)
pub max_retries: u32,
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for AudioMonitorConfig {
fn default() -> Self {
Self {
retry_interval_ms: 1000,
max_retries: 0, // infinite retry
log_throttle_secs: 5,
}
}
}
/// Audio health monitor
///
/// Monitors audio device health and manages error recovery.
pub struct AudioHealthMonitor {
/// Current health status
status: RwLock<AudioHealthStatus>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
config: AudioMonitorConfig,
/// Current retry count
retry_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
/// Hide `error_message` while a new capture attempt is in flight (internal error state unchanged).
suppress_display: AtomicBool,
}
impl AudioHealthMonitor {
/// Create a new audio health monitor with the specified configuration
pub fn new(config: AudioMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
pub fn new() -> Self {
Self {
status: RwLock::new(AudioHealthStatus::Healthy),
throttler: LogThrottler::with_secs(throttle_secs),
config,
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
retry_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
suppress_display: AtomicBool::new(false),
}
}
/// Create a new audio health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(AudioMonitorConfig::default())
/// Clears the error string exposed via [`Self::error_message`] until the next outcome (`report_error` or recovery).
pub fn prepare_retry_attempt(&self) {
self.suppress_display.store(true, Ordering::Relaxed);
}
/// Report an error from audio operations
///
/// This method is called when an audio operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Updates in-memory error state
///
/// # Arguments
///
/// * `device` - The audio device name (if known)
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, _device: Option<&str>, reason: &str, error_code: &str) {
pub async fn report_error(&self, reason: &str, error_code: &str) {
self.suppress_display.store(false, Ordering::Relaxed);
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("audio_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
@@ -117,34 +61,22 @@ impl AudioHealthMonitor {
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = AudioHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
retry_count: count,
};
}
/// Report that the device has recovered
///
/// This method is called when the audio device successfully reconnects.
/// It resets the error state.
///
/// # Arguments
///
/// * `device` - The audio device name
pub async fn report_recovered(&self, _device: Option<&str>) {
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != AudioHealthStatus::Healthy {
let retry_count = self.retry_count.load(Ordering::Relaxed);
info!("Audio recovered after {} retries", retry_count);
// Reset state
self.suppress_display.store(false, Ordering::Relaxed);
self.retry_count.store(0, Ordering::Relaxed);
self.throttler.clear("audio_");
*self.last_error_code.write().await = None;
@@ -152,58 +84,30 @@ impl AudioHealthMonitor {
}
}
/// Get the current health status
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Get the current retry count
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.suppress_display.store(false, Ordering::Relaxed);
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = AudioHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the configuration
pub fn config(&self) -> &AudioMonitorConfig {
&self.config
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Check if we should continue retrying
///
/// Returns `false` if max_retries is set and we've exceeded it.
pub fn should_retry(&self) -> bool {
if self.config.max_retries == 0 {
return true; // Infinite retry
}
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Get the retry interval
pub fn retry_interval(&self) -> Duration {
Duration::from_millis(self.config.retry_interval_ms)
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
if self.suppress_display.load(Ordering::Relaxed) {
return None;
}
match &*self.status.read().await {
AudioHealthStatus::Error { reason, .. } => Some(reason.clone()),
_ => None,
@@ -213,7 +117,7 @@ impl AudioHealthMonitor {
impl Default for AudioHealthMonitor {
fn default() -> Self {
Self::with_defaults()
Self::new()
}
}
@@ -223,32 +127,25 @@ mod tests {
#[tokio::test]
async fn test_initial_status() {
let monitor = AudioHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
let monitor = AudioHealthMonitor::new();
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
monitor
.report_error(Some("hw:0,0"), "Device not found", "device_disconnected")
.report_error("Device not found", "device_disconnected")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.retry_count(), 1);
if let AudioHealthStatus::Error {
reason,
error_code,
retry_count,
} = monitor.status().await
{
if let AudioHealthStatus::Error { reason, error_code } = monitor.status().await {
assert_eq!(reason, "Device not found");
assert_eq!(error_code, "device_disconnected");
assert_eq!(retry_count, 1);
} else {
panic!("Expected Error status");
}
@@ -256,39 +153,52 @@ mod tests {
#[tokio::test]
async fn test_report_recovered() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
// First report an error
monitor
.report_error(Some("default"), "Capture failed", "capture_error")
.report_error("Capture failed", "capture_error")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered(Some("default")).await;
assert!(monitor.is_healthy().await);
monitor.report_recovered().await;
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_retry_count_increments() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
for i in 1..=5 {
monitor.report_error(None, "Error", "io_error").await;
monitor.report_error("Error", "io_error").await;
assert_eq!(monitor.retry_count(), i);
}
}
#[tokio::test]
async fn test_reset() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
monitor.report_error(None, "Error", "io_error").await;
monitor.report_error("Error", "io_error").await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_prepare_retry_hides_error_until_next_failure() {
let monitor = AudioHealthMonitor::new();
monitor.report_error("bad", "e").await;
assert_eq!(monitor.error_message().await.as_deref(), Some("bad"));
monitor.prepare_retry_attempt();
assert!(monitor.is_error().await);
assert!(monitor.error_message().await.is_none());
monitor.report_error("still bad", "e").await;
assert_eq!(monitor.error_message().await.as_deref(), Some("still bad"));
}
}

320
src/audio/recovery.rs Normal file
View File

@@ -0,0 +1,320 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use super::capture::AudioConfig;
use super::device::{enumerate_audio_devices, AudioDeviceInfo};
use super::monitor::AudioHealthMonitor;
use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
use super::types::AudioControllerConfig;
use super::controller::AudioRecoveredCallback;
use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent};
const AUDIO_RECOVERY_RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(1);
pub(super) fn select_recovery_device(
devices: &[AudioDeviceInfo],
preferred: &str,
) -> Option<AudioDeviceInfo> {
if let Some(device) = devices
.iter()
.find(|d| !preferred.trim().is_empty() && d.name == preferred)
{
return Some(device.clone());
}
devices
.iter()
.find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2))
.or_else(|| {
devices
.iter()
.find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2))
})
.or_else(|| devices.first())
.cloned()
}
async fn publish_state(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
state: &str,
device: Option<String>,
reason: Option<&str>,
next_retry_ms: Option<u64>,
) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamStateChanged {
state: state.to_string(),
device,
reason: reason.map(str::to_string),
next_retry_ms,
});
bus.mark_device_info_dirty();
}
}
async fn publish_device_lost(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
reason: &str,
) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: device.to_string(),
reason: reason.to_string(),
});
}
}
async fn publish_reconnecting(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
attempt: u32,
) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamReconnecting {
device: device.to_string(),
attempt,
});
}
}
async fn publish_recovered(event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>, device: &str) {
if let Some(bus) = event_bus.read().await.as_ref() {
bus.publish(SystemEvent::StreamRecovered {
device: device.to_string(),
});
}
}
fn spawn_stream_monitor_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
let mut state_rx = streamer.state_watch();
tokio::spawn(async move {
loop {
if state_rx.changed().await.is_err() {
return;
}
if *state_rx.borrow() != AudioStreamState::Error {
continue;
}
{
let current = streamer_slot.read().await;
if !current
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &streamer))
{
return;
}
}
let reason = format!("Audio device lost: {}", device);
monitor.report_error(&reason, "device_lost").await;
spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
device,
reason,
);
return;
}
});
}
fn spawn_recovery_task_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
if recovery_in_progress.swap(true, Ordering::SeqCst) {
debug!("Audio recovery already in progress");
return;
}
tokio::spawn(async move {
warn!("Audio recovery started for {}: {}", lost_device, reason);
publish_device_lost(&event_bus, &lost_device, &reason).await;
publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_device_lost"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
let mut attempt = 0u32;
loop {
if !recovery_in_progress.load(Ordering::SeqCst) {
debug!("Audio recovery canceled");
return;
}
if streamer_slot
.read()
.await
.as_ref()
.is_some_and(|s| s.is_running())
{
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
let cfg: AudioControllerConfig = config.read().await.clone();
if !cfg.enabled {
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
attempt = attempt.saturating_add(1);
publish_reconnecting(&event_bus, &lost_device, attempt).await;
publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_reconnecting"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await;
let devices = match enumerate_audio_devices() {
Ok(devices) => devices,
Err(e) => {
debug!(
"Audio recovery enumerate failed (attempt {}): {}",
attempt, e
);
continue;
}
};
let Some(device) = select_recovery_device(&devices, &cfg.device) else {
debug!("No audio devices found during recovery attempt {}", attempt);
continue;
};
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device.name.clone(),
..Default::default()
},
opus: cfg.quality.to_opus_config(),
};
let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config));
match new_streamer.start().await {
Ok(()) => {
{
let mut cfg = config.write().await;
cfg.device = device.name.clone();
}
*streamer_slot.write().await = Some(new_streamer.clone());
monitor.report_recovered().await;
publish_recovered(&event_bus, &device.name).await;
if let Some(callback) = recovered_callback.read().await.clone() {
callback();
}
publish_state(
&event_bus,
"streaming",
Some(device.name.clone()),
None,
None,
)
.await;
recovery_in_progress.store(false, Ordering::SeqCst);
info!(
"Audio device recovered with {} after {} attempts",
device.name, attempt
);
spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
new_streamer,
device.name,
);
return;
}
Err(e) => {
debug!(
"Audio recovery start failed with {} (attempt {}): {}",
device.name, attempt, e
);
}
}
}
});
}
pub(super) fn spawn_stream_monitor(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
streamer,
device,
);
}
pub(super) fn spawn_recovery_task(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
lost_device,
reason,
);
}

View File

@@ -1,202 +0,0 @@
//! Resample capture PCM to 48 kHz stereo for Opus (fixed 20 ms / 960×2 samples).
const OUT_RATE: f64 = 48000.0;
const OPUS_STEREO_SAMPLES: usize = 960 * 2;
enum PipelineState {
/// Native 48 kHz interleaved stereo: only buffer and slice into 20 ms blocks (no float work).
Stereo48kPassthrough,
/// Other rates / mono: linear interpolation to 48 kHz stereo.
Resample {
in_rate: u32,
in_channels: u32,
next_out_frame: u64,
buffer_start_frame: u64,
},
}
/// Converts incoming interleaved PCM to 48 kHz stereo, then exposes fixed 960×2-sample chunks.
pub struct Opus48kPcmBuffer {
state: PipelineState,
pending: Vec<i16>,
}
impl Opus48kPcmBuffer {
pub fn new(in_rate: u32, in_channels: u32) -> Self {
let ch = in_channels.max(1);
let rate = in_rate.max(1);
let state = if rate == 48000 && ch == 2 {
PipelineState::Stereo48kPassthrough
} else {
PipelineState::Resample {
in_rate: rate,
in_channels: ch,
next_out_frame: 0,
buffer_start_frame: 0,
}
};
Self {
state,
pending: Vec::new(),
}
}
/// True when input is already 48 kHz stereo (no interpolation loop).
#[cfg(test)]
pub fn is_passthrough(&self) -> bool {
matches!(self.state, PipelineState::Stereo48kPassthrough)
}
/// Append one capture block (`sample_rate` must match the rate this buffer was built for).
pub fn push_interleaved(&mut self, data: &[i16]) {
self.pending.extend_from_slice(data);
}
/// Drain as many 960×2 stereo S16LE samples (20 ms @ 48 kHz) as possible.
pub fn pop_opus_frames(&mut self, out: &mut Vec<i16>) {
match &mut self.state {
PipelineState::Stereo48kPassthrough => {
while self.pending.len() >= OPUS_STEREO_SAMPLES {
out.extend_from_slice(&self.pending[..OPUS_STEREO_SAMPLES]);
self.pending.drain(..OPUS_STEREO_SAMPLES);
}
}
PipelineState::Resample {
in_rate,
in_channels,
next_out_frame,
buffer_start_frame,
} => {
let ch = *in_channels as usize;
if ch == 0 {
return;
}
loop {
let batch_start = *next_out_frame;
let mut block = Vec::with_capacity(OPUS_STEREO_SAMPLES);
let mut complete = true;
for i in 0u64..960 {
let k = batch_start + i;
let p_abs = (k as f64) * (*in_rate as f64) / OUT_RATE;
let f_abs = p_abs.floor() as u64;
let frac = p_abs - f_abs as f64;
let f_rel = f_abs.saturating_sub(*buffer_start_frame) as usize;
if f_rel + 1 >= self.pending.len() / ch {
complete = false;
break;
}
let base0 = f_rel * ch;
let base1 = (f_rel + 1) * ch;
let (l, r) = if *in_channels >= 2 {
let l0 = self.pending[base0] as f64;
let l1 = self.pending[base1] as f64;
let r0 = self.pending[base0 + 1] as f64;
let r1 = self.pending[base1 + 1] as f64;
(l0 + frac * (l1 - l0), r0 + frac * (r1 - r0))
} else {
let m0 = self.pending[base0] as f64;
let m1 = self.pending[base1] as f64;
let v = m0 + frac * (m1 - m0);
(v, v)
};
block.push(clamp_f64_to_i16(l));
block.push(clamp_f64_to_i16(r));
}
if !complete || block.len() != OPUS_STEREO_SAMPLES {
break;
}
out.extend_from_slice(&block);
*next_out_frame = batch_start + 960;
trim_resample_prefix(
&mut self.pending,
*in_rate,
*next_out_frame,
buffer_start_frame,
ch,
);
}
}
}
}
}
fn trim_resample_prefix(
pending: &mut Vec<i16>,
in_rate: u32,
next_out_frame: u64,
buffer_start_frame: &mut u64,
ch: usize,
) {
if pending.is_empty() {
return;
}
let p_next = (next_out_frame as f64) * (in_rate as f64) / OUT_RATE;
let need_abs = p_next.floor() as u64;
let keep_from_abs = need_abs.saturating_sub(1);
if keep_from_abs <= *buffer_start_frame {
return;
}
let drop_frames = (keep_from_abs - *buffer_start_frame) as usize;
let drop_samples = drop_frames.saturating_mul(ch).min(pending.len());
if drop_samples > 0 {
pending.drain(0..drop_samples);
*buffer_start_frame += drop_frames as u64;
}
}
#[inline]
fn clamp_f64_to_i16(v: f64) -> i16 {
v.round().clamp(i16::MIN as f64, i16::MAX as f64) as i16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn passthrough_48k_identity_tone_length() {
let mut buf = Opus48kPcmBuffer::new(48000, 2);
assert!(buf.is_passthrough());
let mut chunk = vec![0i16; 960 * 2];
for i in 0..960 {
let s = (i as f32 * 0.1).sin() * 3000.0;
chunk[2 * i] = s as i16;
chunk[2 * i + 1] = s as i16;
}
buf.push_interleaved(&chunk);
let mut out = Vec::new();
buf.pop_opus_frames(&mut out);
assert_eq!(out.len(), 960 * 2);
}
#[test]
fn upsample_44k_to_48k_chunk() {
let mut buf = Opus48kPcmBuffer::new(44100, 2);
assert!(!buf.is_passthrough());
let mut chunk = vec![0i16; 882 * 2];
for i in 0..882 {
chunk[2 * i] = (i as i16).wrapping_mul(10);
chunk[2 * i + 1] = (i as i16).wrapping_mul(-7);
}
buf.push_interleaved(&chunk);
let mut out = Vec::new();
buf.pop_opus_frames(&mut out);
assert_eq!(out.len(), 960 * 2, "expected one 20ms Opus block");
}
#[test]
fn mono_48k_not_passthrough() {
let buf = Opus48kPcmBuffer::new(48000, 1);
assert!(!buf.is_passthrough());
}
}

View File

@@ -1,46 +1,36 @@
//! Audio streaming pipeline
//!
//! Coordinates audio capture and Opus encoding, distributing encoded
//! frames to multiple subscribers via broadcast channel.
//! ALSA 48 kHz stereo → Opus 20 ms frames, fan-out per subscriber.
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex, RwLock};
use tracing::{error, info, warn};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::{broadcast, mpsc, watch, Mutex as AsyncMutex, RwLock};
use tracing::{debug, error, info, warn};
use super::capture::{AudioCapturer, AudioConfig, AudioFrame, CaptureState};
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
use super::resample::Opus48kPcmBuffer;
use crate::error::{AppError, Result};
use bytemuck;
use bytes::Bytes;
use std::time::Duration;
/// 48 kHz stereo: 20 ms = 960 × 2 samples (S16LE).
const OPUS_STEREO_SAMPLES: usize = 960 * 2;
/// Audio stream state
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AudioStreamState {
/// Stream is stopped
#[default]
Stopped,
/// Stream is starting up
Starting,
/// Stream is running
Running,
/// Stream encountered an error
Error,
}
/// Audio streamer configuration
#[derive(Debug, Clone, Default)]
pub struct AudioStreamerConfig {
/// Audio capture configuration
pub capture: AudioConfig,
/// Opus encoder configuration
pub opus: OpusConfig,
}
impl AudioStreamerConfig {
/// Create config for a specific device with default quality
pub fn for_device(device_name: &str) -> Self {
Self {
capture: AudioConfig {
@@ -51,90 +41,75 @@ impl AudioStreamerConfig {
}
}
/// Create config with specified bitrate
pub fn with_bitrate(mut self, bitrate: u32) -> Self {
self.opus.bitrate = bitrate;
self
}
}
/// Audio stream statistics
#[derive(Debug, Clone, Default)]
pub struct AudioStreamStats {
/// Frames encoded to Opus
/// Number of active subscribers
pub subscriber_count: usize,
}
/// Audio streamer
///
/// Manages the audio capture -> encode -> broadcast pipeline.
pub struct AudioStreamer {
config: RwLock<AudioStreamerConfig>,
state: watch::Sender<AudioStreamState>,
state_rx: watch::Receiver<AudioStreamState>,
capturer: RwLock<Option<Arc<AudioCapturer>>>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: watch::Sender<Option<Arc<OpusFrame>>>,
stats: Arc<Mutex<AudioStreamStats>>,
sequence: AtomicU64,
stream_start_time: RwLock<Option<Instant>>,
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
stop_flag: Arc<AtomicBool>,
}
impl AudioStreamer {
/// Create a new audio streamer with default configuration
pub fn new() -> Self {
Self::with_config(AudioStreamerConfig::default())
}
/// Create a new audio streamer with specified configuration
pub fn with_config(config: AudioStreamerConfig) -> Self {
let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped);
let (opus_tx, _opus_rx) = watch::channel(None);
Self {
config: RwLock::new(config),
state: state_tx,
state_rx,
capturer: RwLock::new(None),
encoder: Arc::new(Mutex::new(None)),
opus_tx,
stats: Arc::new(Mutex::new(AudioStreamStats::default())),
sequence: AtomicU64::new(0),
stream_start_time: RwLock::new(None),
encoder: Arc::new(AsyncMutex::new(None)),
opus_subscribers: Arc::new(Mutex::new(Vec::new())),
stop_flag: Arc::new(AtomicBool::new(false)),
}
}
/// Get current state
pub fn state(&self) -> AudioStreamState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<AudioStreamState> {
self.state_rx.clone()
}
/// Subscribe to Opus frames
pub fn subscribe_opus(&self) -> watch::Receiver<Option<Arc<OpusFrame>>> {
self.opus_tx.subscribe()
pub fn subscribe_opus(&self) -> mpsc::Receiver<Arc<OpusFrame>> {
let (tx, rx) = mpsc::channel::<Arc<OpusFrame>>(128);
self.opus_subscribers.lock().unwrap().push(tx);
rx
}
/// Get number of active subscribers
pub fn subscriber_count(&self) -> usize {
self.opus_tx.receiver_count()
self.opus_subscribers
.lock()
.unwrap()
.iter()
.filter(|s| !s.is_closed())
.count()
}
/// Get current statistics
pub async fn stats(&self) -> AudioStreamStats {
let mut stats = self.stats.lock().await.clone();
stats.subscriber_count = self.subscriber_count();
stats
pub fn stats(&self) -> AudioStreamStats {
AudioStreamStats {
subscriber_count: self.subscriber_count(),
}
}
/// Update configuration (only when stopped)
pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> {
if self.state() != AudioStreamState::Stopped {
return Err(AppError::AudioError(
@@ -145,12 +120,9 @@ impl AudioStreamer {
Ok(())
}
/// Update bitrate dynamically (can be done while streaming)
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
// Update config
self.config.write().await.opus.bitrate = bitrate;
// Update encoder if running
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate)?;
}
@@ -159,7 +131,6 @@ impl AudioStreamer {
Ok(())
}
/// Start the audio stream
pub async fn start(&self) -> Result<()> {
if self.state() == AudioStreamState::Running {
return Ok(());
@@ -178,42 +149,77 @@ impl AudioStreamer {
config.opus.bitrate
);
// Create capturer
let capturer = Arc::new(AudioCapturer::new(config.capture.clone()));
*self.capturer.write().await = Some(capturer.clone());
// Create encoder
let encoder = OpusEncoder::new(config.opus.clone())?;
*self.encoder.lock().await = Some(encoder);
// Start capture
capturer.start().await?;
// Reset stats
{
let mut stats = self.stats.lock().await;
*stats = AudioStreamStats::default();
let mut capture_state = capturer.state_watch();
let startup_result = tokio::time::timeout(Duration::from_secs(2), async {
loop {
let current_state = *capture_state.borrow();
match current_state {
CaptureState::Running => return Ok(()),
CaptureState::Error => {
return Err(AppError::AudioError(
"Audio capture failed to start".to_string(),
))
}
CaptureState::Stopped => {
if capture_state.changed().await.is_err() {
return Err(AppError::AudioError(
"Audio capture stopped during startup".to_string(),
));
}
}
}
}
})
.await;
match startup_result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
let _ = capturer.stop().await;
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
let _ = self.state.send(AudioStreamState::Error);
return Err(e);
}
Err(_) => {
let _ = capturer.stop().await;
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
let _ = self.state.send(AudioStreamState::Error);
return Err(AppError::AudioError(
"Timed out waiting for audio capture to start".to_string(),
));
}
}
// Record start time
*self.stream_start_time.write().await = Some(Instant::now());
self.sequence.store(0, Ordering::SeqCst);
// Start encoding task
let capturer_for_task = capturer.clone();
let encoder = self.encoder.clone();
let opus_tx = self.opus_tx.clone();
let opus_subscribers = self.opus_subscribers.clone();
let state = self.state.clone();
let stop_flag = self.stop_flag.clone();
tokio::spawn(async move {
Self::stream_task(capturer_for_task, encoder, opus_tx, state, stop_flag).await;
Self::stream_task(
capturer_for_task,
encoder,
opus_subscribers,
state,
stop_flag,
)
.await;
});
Ok(())
}
/// Stop the audio stream
pub async fn stop(&self) -> Result<()> {
if self.state() == AudioStreamState::Stopped {
return Ok(());
@@ -221,74 +227,82 @@ impl AudioStreamer {
info!("Stopping audio stream");
// Signal stop
self.stop_flag.store(true, Ordering::SeqCst);
// Stop capturer
if let Some(ref capturer) = *self.capturer.read().await {
capturer.stop().await?;
}
// Clear resources
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
*self.stream_start_time.write().await = None;
self.opus_subscribers.lock().unwrap().clear();
let _ = self.state.send(AudioStreamState::Stopped);
info!("Audio stream stopped");
Ok(())
}
/// Check if streaming
pub fn is_running(&self) -> bool {
self.state() == AudioStreamState::Running
}
/// Internal streaming task
async fn fanout_opus(
subscribers: &Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
frame: Arc<OpusFrame>,
) {
let txs: Vec<_> = {
let g = subscribers.lock().unwrap();
if g.is_empty() {
return;
}
g.clone()
};
for tx in &txs {
let _ = tx.send(frame.clone()).await;
}
if txs.iter().any(|tx| tx.is_closed()) {
let mut g = subscribers.lock().unwrap();
g.retain(|tx| !tx.is_closed());
}
}
async fn stream_task(
capturer: Arc<AudioCapturer>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: watch::Sender<Option<Arc<OpusFrame>>>,
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
state: watch::Sender<AudioStreamState>,
stop_flag: Arc<AtomicBool>,
) {
let mut pcm_rx = capturer.subscribe();
let _ = state.send(AudioStreamState::Running);
info!("Audio stream task started");
debug!("Audio stream task started (48 kHz stereo → Opus, mpsc fan-out)");
let mut to_48k: Option<Opus48kPcmBuffer> = None;
let mut queued_48k: Vec<i16> = Vec::new();
let mut pending: Vec<i16> = Vec::new();
loop {
// Check stop flag (atomic, no async lock needed)
if stop_flag.load(Ordering::Relaxed) {
break;
}
// Check capturer state
if capturer.state() == CaptureState::Error {
error!("Audio capture error, stopping stream");
let _ = state.send(AudioStreamState::Error);
break;
}
// Receive PCM frame with timeout
let recv_result =
tokio::time::timeout(std::time::Duration::from_secs(2), pcm_rx.recv()).await;
match recv_result {
Ok(Ok(audio_frame)) => {
if to_48k.is_none() {
to_48k = Some(Opus48kPcmBuffer::new(
audio_frame.sample_rate,
audio_frame.channels,
));
if audio_frame.sample_rate != 48_000 || audio_frame.channels != 2 {
warn!(
"Skip non48 kHz/stereo PCM ({} Hz, {} ch)",
audio_frame.sample_rate, audio_frame.channels
);
continue;
}
let pipeline = match to_48k.as_mut() {
Some(p) => p,
None => continue,
};
let samples: &[i16] = match bytemuck::try_cast_slice(&audio_frame.data) {
Ok(s) => s,
@@ -298,16 +312,16 @@ impl AudioStreamer {
}
};
if !samples.is_empty() {
pipeline.push_interleaved(samples);
pending.extend_from_slice(samples);
}
pipeline.pop_opus_frames(&mut queued_48k);
while queued_48k.len() >= 960 * 2 {
let pcm_20ms =
Bytes::copy_from_slice(bytemuck::cast_slice(&queued_48k[..960 * 2]));
queued_48k.drain(..960 * 2);
while pending.len() >= OPUS_STEREO_SAMPLES {
let pcm_20ms = Bytes::copy_from_slice(bytemuck::cast_slice(
&pending[..OPUS_STEREO_SAMPLES],
));
pending.drain(..OPUS_STEREO_SAMPLES);
let frame_48k = AudioFrame::new_interleaved(pcm_20ms, 2, 48000, 0);
let frame_48k = AudioFrame::new_interleaved(pcm_20ms, 2, 48_000, 0);
let opus_result = {
let mut enc_guard = encoder.lock().await;
@@ -318,9 +332,7 @@ impl AudioStreamer {
match opus_result {
Some(Ok(opus_frame)) => {
if opus_tx.receiver_count() > 0 {
let _ = opus_tx.send(Some(Arc::new(opus_frame)));
}
Self::fanout_opus(&opus_subscribers, Arc::new(opus_frame)).await;
}
Some(Err(e)) => {
error!("Opus encode error: {}", e);
@@ -337,19 +349,23 @@ impl AudioStreamer {
break;
}
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
warn!("Audio receiver lagged by {} frames", n);
warn!("PCM receiver lagged by {} frames", n);
}
Err(_) => {
// Timeout - check if still capturing
if capturer.state() != CaptureState::Running {
info!("Audio capture stopped, ending stream task");
let _ = state.send(AudioStreamState::Error);
break;
}
}
}
}
let _ = state.send(AudioStreamState::Stopped);
if stop_flag.load(Ordering::Relaxed) {
let _ = state.send(AudioStreamState::Stopped);
} else {
opus_subscribers.lock().unwrap().clear();
}
info!("Audio stream task ended");
}
}

85
src/audio/types.rs Normal file
View File

@@ -0,0 +1,85 @@
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use super::encoder::OpusConfig;
use crate::error::AppError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
Voice,
#[default]
Balanced,
High,
}
impl AudioQuality {
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
AudioQuality::Balanced => 64000,
AudioQuality::High => 128000,
}
}
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
AudioQuality::Balanced => OpusConfig::default(),
AudioQuality::High => OpusConfig::music(),
}
}
}
impl FromStr for AudioQuality {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"voice" => Ok(Self::Voice),
"balanced" => Ok(Self::Balanced),
"high" => Ok(Self::High),
_ => Err(AppError::BadRequest(format!(
"invalid audio quality {:?} (expected voice, balanced, or high)",
s.trim()
))),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AudioQuality::Voice => write!(f, "voice"),
AudioQuality::Balanced => write!(f, "balanced"),
AudioQuality::High => write!(f, "high"),
}
}
}
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
pub enabled: bool,
pub device: String,
pub quality: AudioQuality,
}
impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: String::new(),
quality: AudioQuality::Balanced,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
pub enabled: bool,
pub streaming: bool,
pub device: Option<String>,
pub quality: AudioQuality,
pub subscriber_count: usize,
pub error: Option<String>,
}

View File

@@ -8,20 +8,16 @@ use axum::{
use axum_extra::extract::CookieJar;
use std::sync::Arc;
use crate::error::ErrorResponse;
use crate::state::AppState;
use crate::web::ErrorResponse;
/// Session cookie name
pub const SESSION_COOKIE: &str = "one_kvm_session";
/// Extract session ID from request
pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option<String> {
// First try cookie
if let Some(cookie) = cookies.get(SESSION_COOKIE) {
return Some(cookie.value().to_string());
}
// Then try Authorization header (Bearer token)
if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
@@ -33,7 +29,6 @@ pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap)
None
}
/// Authentication middleware
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
cookies: CookieJar,
@@ -41,29 +36,23 @@ pub async fn auth_middleware(
next: Next,
) -> Result<Response, StatusCode> {
let raw_path = request.uri().path();
// When this middleware is mounted under /api, Axum strips the prefix for the inner router.
// Normalize the path so checks work whether it is mounted or not.
// Mounted under /api: inner path may lack prefix; normalize for whitelist checks.
let path = raw_path.strip_prefix("/api").unwrap_or(raw_path);
// Check if system is initialized
if !state.config.is_initialized() {
// Allow only setup-related endpoints when not initialized
if is_setup_public_endpoint(path) {
return Ok(next.run(request).await);
}
}
// Public endpoints that don't require auth
if is_public_endpoint(path) {
return Ok(next.run(request).await);
}
// Extract session ID
let session_id = extract_session_id(&cookies, request.headers());
if let Some(session_id) = session_id {
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
// Add session to request extensions
request.extensions_mut().insert(session);
return Ok(next.run(request).await);
}
@@ -87,9 +76,7 @@ fn unauthorized_response(message: &str) -> Response {
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
}
/// Check if endpoint is public (no auth required)
fn is_public_endpoint(path: &str) -> bool {
// Note: paths here are relative to /api since middleware is applied within the nested router
matches!(
path,
"/" | "/auth/login" | "/health" | "/setup" | "/setup/init"
@@ -102,7 +89,6 @@ fn is_public_endpoint(path: &str) -> bool {
|| path.ends_with(".svg")
}
/// Setup-only endpoints allowed before initialization.
fn is_setup_public_endpoint(path: &str) -> bool {
matches!(
path,

View File

@@ -5,7 +5,6 @@ use argon2::{
use crate::error::{AppError, Result};
/// Hash a password using Argon2
pub fn hash_password(password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
@@ -16,7 +15,6 @@ pub fn hash_password(password: &str) -> Result<String> {
.map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e)))
}
/// Verify a password against a hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?;

View File

@@ -1,156 +1,104 @@
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use std::collections::HashMap;
use std::sync::Arc;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::error::Result;
/// Session data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user_id: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub expires_at: OffsetDateTime,
pub data: Option<serde_json::Value>,
}
impl Session {
/// Check if session is expired
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
OffsetDateTime::now_utc() > self.expires_at
}
}
/// Session store backed by SQLite
#[derive(Clone)]
pub struct SessionStore {
pool: Pool<Sqlite>,
inner: Arc<RwLock<HashMap<String, Session>>>,
default_ttl: Duration,
}
impl SessionStore {
/// Create a new session store
pub fn new(pool: Pool<Sqlite>, ttl_secs: i64) -> Self {
pub fn new(ttl_secs: i64) -> Self {
Self {
pool,
inner: Arc::new(RwLock::new(HashMap::new())),
default_ttl: Duration::seconds(ttl_secs),
}
}
/// Create a new session
pub async fn create(&self, user_id: &str) -> Result<Session> {
let now = OffsetDateTime::now_utc();
let session = Session {
id: Uuid::new_v4().to_string(),
user_id: user_id.to_string(),
created_at: Utc::now(),
expires_at: Utc::now() + self.default_ttl,
created_at: now,
expires_at: now + self.default_ttl,
data: None,
};
sqlx::query(
r#"
INSERT INTO sessions (id, user_id, created_at, expires_at, data)
VALUES (?1, ?2, ?3, ?4, ?5)
"#,
)
.bind(&session.id)
.bind(&session.user_id)
.bind(session.created_at.to_rfc3339())
.bind(session.expires_at.to_rfc3339())
.bind(session.data.as_ref().map(|d| d.to_string()))
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
guard.insert(session.id.clone(), session.clone());
Ok(session)
}
/// Get a session by ID
pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
let row: Option<(String, String, String, String, Option<String>)> = sqlx::query_as(
"SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
match row {
Some((id, user_id, created_at, expires_at, data)) => {
let session = Session {
id,
user_id,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
expires_at: DateTime::parse_from_rfc3339(&expires_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
data: data.and_then(|d| serde_json::from_str(&d).ok()),
};
if session.is_expired() {
self.delete(&session.id).await?;
Ok(None)
} else {
Ok(Some(session))
}
}
None => Ok(None),
let mut guard = self.inner.write().await;
let Some(session) = guard.get(session_id).cloned() else {
return Ok(None);
};
if session.is_expired() {
guard.remove(session_id);
return Ok(None);
}
Ok(Some(session))
}
/// Delete a session
pub async fn delete(&self, session_id: &str) -> Result<()> {
sqlx::query("DELETE FROM sessions WHERE id = ?1")
.bind(session_id)
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
guard.remove(session_id);
Ok(())
}
/// Delete all expired sessions
pub async fn cleanup_expired(&self) -> Result<u64> {
let now = Utc::now().to_rfc3339();
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?1")
.bind(now)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
let mut guard = self.inner.write().await;
let before = guard.len();
guard.retain(|_, s| !s.is_expired());
Ok((before - guard.len()) as u64)
}
/// Delete all sessions
pub async fn delete_all(&self) -> Result<u64> {
let result = sqlx::query("DELETE FROM sessions")
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
let mut guard = self.inner.write().await;
let n = guard.len() as u64;
guard.clear();
Ok(n)
}
/// Delete all sessions for a specific user
pub async fn delete_by_user_id(&self, user_id: &str) -> Result<u64> {
let result = sqlx::query("DELETE FROM sessions WHERE user_id = ?1")
.bind(user_id)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
/// List all session IDs
pub async fn list_ids(&self) -> Result<Vec<String>> {
let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions")
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(|(id,)| id).collect())
let guard = self.inner.read().await;
Ok(guard.keys().cloned().collect())
}
/// Extend session expiration
pub async fn extend(&self, session_id: &str) -> Result<()> {
let new_expires = Utc::now() + self.default_ttl;
sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2")
.bind(new_expires.to_rfc3339())
.bind(session_id)
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
if let Some(session) = guard.get_mut(session_id) {
if session.is_expired() {
guard.remove(session_id);
} else {
session.expires_at = OffsetDateTime::now_utc() + self.default_ttl;
}
}
Ok(())
}
}

View File

@@ -1,123 +1,99 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;
use uuid::Uuid;
use super::password::{hash_password, verify_password};
use crate::error::{AppError, Result};
/// User row type from database
type UserRow = (String, String, String, String, String);
type UserRow = (String, String, String);
/// User data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
#[serde(skip_serializing)]
pub password_hash: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl User {
/// Convert from database row to User
fn from_row(row: UserRow) -> Self {
let (id, username, password_hash, created_at, updated_at) = row;
let (id, username, password_hash) = row;
Self {
id,
username,
password_hash,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
updated_at: DateTime::parse_from_rfc3339(&updated_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
}
}
}
/// User store backed by SQLite
#[derive(Clone)]
pub struct UserStore {
pool: Pool<Sqlite>,
}
impl UserStore {
/// Create a new user store
pub fn new(pool: Pool<Sqlite>) -> Self {
Self { pool }
}
/// Create a new user
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
// Check if username already exists
if self.get_by_username(username).await?.is_some() {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
username
)));
/// The single local user, or `None` if none exists. Errors if more than one row is present.
pub async fn single_user(&self) -> Result<Option<User>> {
let mut rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash FROM users ORDER BY rowid ASC LIMIT 2",
)
.fetch_all(&self.pool)
.await?;
match rows.len() {
0 => Ok(None),
1 => Ok(Some(User::from_row(rows.remove(0)))),
_ => Err(AppError::Internal(
"Multiple user accounts in database; this build supports only one".to_string(),
)),
}
}
pub async fn create_first_user(&self, username: &str, password: &str) -> Result<User> {
if self.single_user().await?.is_some() {
return Err(AppError::BadRequest(
"A user account already exists".to_string(),
));
}
let password_hash = hash_password(password)?;
let now = Utc::now();
let user = User {
id: Uuid::new_v4().to_string(),
username: username.to_string(),
password_hash,
created_at: now,
updated_at: now,
};
sqlx::query(
r#"
INSERT INTO users (id, username, password_hash, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)
INSERT INTO users (id, username, password_hash)
VALUES (?1, ?2, ?3)
"#,
)
.bind(&user.id)
.bind(&user.username)
.bind(&user.password_hash)
.bind(user.created_at.to_rfc3339())
.bind(user.updated_at.to_rfc3339())
.execute(&self.pool)
.await?;
Ok(user)
}
/// Get user by ID
pub async fn get(&self, user_id: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE id = ?1",
)
.bind(user_id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Get user by username
pub async fn get_by_username(&self, username: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE username = ?1",
)
.bind(username)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Verify user credentials
pub async fn verify(&self, username: &str, password: &str) -> Result<Option<User>> {
let user = match self.get_by_username(username).await? {
Some(user) => user,
let user = match self.single_user().await? {
Some(u) => u,
None => return Ok(None),
};
if user.username != username {
return Ok(None);
}
if verify_password(password, &user.password_hash)? {
Ok(Some(user))
} else {
@@ -125,15 +101,23 @@ impl UserStore {
}
}
/// Update user password
pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> {
let user = self
.single_user()
.await?
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
if user.id != user_id {
return Err(AppError::AuthError("Invalid session".to_string()));
}
let password_hash = hash_password(new_password)?;
let now = Utc::now();
let now = OffsetDateTime::now_utc();
let result =
sqlx::query("UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
.bind(&password_hash)
.bind(now.to_rfc3339())
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
.bind(user_id)
.execute(&self.pool)
.await?;
@@ -145,21 +129,24 @@ impl UserStore {
Ok(())
}
/// Update username
pub async fn update_username(&self, user_id: &str, new_username: &str) -> Result<()> {
if let Some(existing) = self.get_by_username(new_username).await? {
if existing.id != user_id {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
new_username
)));
}
let user = self
.single_user()
.await?
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
if user.id != user_id {
return Err(AppError::AuthError("Invalid session".to_string()));
}
let now = Utc::now();
if new_username == user.username {
return Ok(());
}
let now = OffsetDateTime::now_utc();
let result = sqlx::query("UPDATE users SET username = ?1, updated_at = ?2 WHERE id = ?3")
.bind(new_username)
.bind(now.to_rfc3339())
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
.bind(user_id)
.execute(&self.pool)
.await?;
@@ -170,37 +157,4 @@ impl UserStore {
Ok(())
}
/// List all users
pub async fn list(&self) -> Result<Vec<User>> {
let rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users ORDER BY created_at",
)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(User::from_row).collect())
}
/// Delete user by ID
pub async fn delete(&self, user_id: &str) -> Result<()> {
let result = sqlx::query("DELETE FROM users WHERE id = ?1")
.bind(user_id)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("User not found".to_string()));
}
Ok(())
}
/// Check if any users exist
pub async fn has_users(&self) -> Result<bool> {
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
.fetch_one(&self.pool)
.await?;
Ok(count.0 > 0)
}
}

View File

@@ -1,5 +1,11 @@
mod schema;
mod store;
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}
pub use schema::*;
pub use store::ConfigStore;

View File

@@ -1,741 +0,0 @@
use crate::video::encoder::BitratePreset;
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
// Re-export ExtensionsConfig from extensions module
pub use crate::extensions::ExtensionsConfig;
// Re-export RustDeskConfig from rustdesk module
pub use crate::rustdesk::config::RustDeskConfig;
/// Main application configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AppConfig {
/// Whether initial setup has been completed
pub initialized: bool,
/// Authentication settings
pub auth: AuthConfig,
/// Video capture settings
pub video: VideoConfig,
/// HID (keyboard/mouse) settings
pub hid: HidConfig,
/// Mass Storage Device settings
pub msd: MsdConfig,
/// ATX power control settings
pub atx: AtxConfig,
/// Audio settings
pub audio: AudioConfig,
/// Streaming settings
pub stream: StreamConfig,
/// Web server settings
pub web: WebConfig,
/// Extensions settings (ttyd, gostc, easytier)
pub extensions: ExtensionsConfig,
/// RustDesk remote access settings
pub rustdesk: RustDeskConfig,
/// RTSP streaming settings
pub rtsp: RtspConfig,
}
/// Authentication configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AuthConfig {
/// Session timeout in seconds
pub session_timeout_secs: u32,
/// Allow multiple concurrent web sessions (single-user mode)
pub single_user_allow_multiple_sessions: bool,
/// Enable 2FA
pub totp_enabled: bool,
/// TOTP secret (encrypted)
pub totp_secret: Option<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
session_timeout_secs: 3600 * 24, // 24 hours
single_user_allow_multiple_sessions: false,
totp_enabled: false,
totp_secret: None,
}
}
}
/// Video capture configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct VideoConfig {
/// Video device path (e.g., /dev/video0)
pub device: Option<String>,
/// Video pixel format (e.g., "MJPEG", "YUYV", "NV12")
pub format: Option<String>,
/// Resolution width
pub width: u32,
/// Resolution height
pub height: u32,
/// Frame rate
pub fps: u32,
/// JPEG quality (1-100)
pub quality: u32,
}
impl Default for VideoConfig {
fn default() -> Self {
Self {
device: None,
format: None, // Auto-detect or use MJPEG as default
width: 1920,
height: 1080,
fps: 30,
quality: 80,
}
}
}
/// HID backend type
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackend {
/// USB OTG HID gadget
Otg,
/// CH9329 serial HID controller
Ch9329,
/// Disabled
#[default]
None,
}
/// OTG USB device descriptor configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OtgDescriptorConfig {
/// USB Vendor ID (e.g., 0x1d6b)
pub vendor_id: u16,
/// USB Product ID (e.g., 0x0104)
pub product_id: u16,
/// Manufacturer string
pub manufacturer: String,
/// Product string
pub product: String,
/// Serial number (optional, auto-generated if not set)
pub serial_number: Option<String>,
}
impl Default for OtgDescriptorConfig {
fn default() -> Self {
Self {
vendor_id: 0x1d6b, // Linux Foundation
product_id: 0x0104, // Multifunction Composite Gadget
manufacturer: "One-KVM".to_string(),
product: "One-KVM USB Device".to_string(),
serial_number: None,
}
}
}
/// OTG HID function profile
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgHidProfile {
/// Full HID device set (keyboard + relative mouse + absolute mouse + consumer control)
#[default]
#[serde(alias = "full_no_msd")]
Full,
/// Full HID device set without consumer control
#[serde(alias = "full_no_consumer_no_msd")]
FullNoConsumer,
/// Legacy profile: only keyboard
LegacyKeyboard,
/// Legacy profile: only relative mouse
LegacyMouseRelative,
/// Custom function selection
Custom,
}
/// OTG endpoint budget policy.
#[typeshare]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgEndpointBudget {
/// Derive a safe default from the selected UDC.
#[default]
Auto,
/// Limit OTG gadget functions to 5 endpoints.
Five,
/// Limit OTG gadget functions to 6 endpoints.
Six,
/// Do not impose a software endpoint budget.
Unlimited,
}
impl OtgEndpointBudget {
pub fn default_for_udc_name(udc: Option<&str>) -> Self {
if udc.is_some_and(crate::otg::configfs::is_low_endpoint_udc) {
Self::Five
} else {
Self::Six
}
}
pub fn resolved(self, udc: Option<&str>) -> Self {
match self {
Self::Auto => Self::default_for_udc_name(udc),
other => other,
}
}
pub fn endpoint_limit(self, udc: Option<&str>) -> Option<u8> {
match self.resolved(udc) {
Self::Five => Some(5),
Self::Six => Some(6),
Self::Unlimited => None,
Self::Auto => unreachable!("auto budget must be resolved before use"),
}
}
}
/// OTG HID function selection (used when profile is Custom)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(default)]
pub struct OtgHidFunctions {
pub keyboard: bool,
pub mouse_relative: bool,
pub mouse_absolute: bool,
pub consumer: bool,
}
impl OtgHidFunctions {
pub fn full() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: true,
}
}
pub fn full_no_consumer() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: false,
}
}
pub fn legacy_keyboard() -> Self {
Self {
keyboard: true,
mouse_relative: false,
mouse_absolute: false,
consumer: false,
}
}
pub fn legacy_mouse_relative() -> Self {
Self {
keyboard: false,
mouse_relative: true,
mouse_absolute: false,
consumer: false,
}
}
pub fn is_empty(&self) -> bool {
!self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer
}
pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 {
let mut endpoints = 0;
if self.keyboard {
endpoints += 1;
if keyboard_leds {
endpoints += 1;
}
}
if self.mouse_relative {
endpoints += 1;
}
if self.mouse_absolute {
endpoints += 1;
}
if self.consumer {
endpoints += 1;
}
endpoints
}
}
impl Default for OtgHidFunctions {
fn default() -> Self {
Self::full()
}
}
impl OtgHidProfile {
pub fn from_legacy_str(value: &str) -> Option<Self> {
match value {
"full" | "full_no_msd" => Some(Self::Full),
"full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer),
"legacy_keyboard" => Some(Self::LegacyKeyboard),
"legacy_mouse_relative" => Some(Self::LegacyMouseRelative),
"custom" => Some(Self::Custom),
_ => None,
}
}
pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions {
match self {
Self::Full => OtgHidFunctions::full(),
Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(),
Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(),
Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(),
Self::Custom => custom.clone(),
}
}
}
/// HID configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct HidConfig {
/// HID backend type
pub backend: HidBackend,
/// OTG UDC (USB Device Controller) name
pub otg_udc: Option<String>,
/// OTG USB device descriptor configuration
#[serde(default)]
pub otg_descriptor: OtgDescriptorConfig,
/// OTG HID function profile
#[serde(default)]
pub otg_profile: OtgHidProfile,
/// OTG endpoint budget policy
#[serde(default)]
pub otg_endpoint_budget: OtgEndpointBudget,
/// OTG HID function selection (used when profile is Custom)
#[serde(default)]
pub otg_functions: OtgHidFunctions,
/// Enable keyboard LED/status feedback for OTG keyboard
#[serde(default)]
pub otg_keyboard_leds: bool,
/// CH9329 serial port
pub ch9329_port: String,
/// CH9329 baud rate
pub ch9329_baudrate: u32,
/// Mouse mode: absolute or relative
pub mouse_absolute: bool,
}
impl Default for HidConfig {
fn default() -> Self {
Self {
backend: HidBackend::None,
otg_udc: None,
otg_descriptor: OtgDescriptorConfig::default(),
otg_profile: OtgHidProfile::default(),
otg_endpoint_budget: OtgEndpointBudget::default(),
otg_functions: OtgHidFunctions::default(),
otg_keyboard_leds: false,
ch9329_port: "/dev/ttyUSB0".to_string(),
ch9329_baudrate: 9600,
mouse_absolute: true,
}
}
}
impl HidConfig {
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
self.otg_profile.resolve_functions(&self.otg_functions)
}
pub fn resolved_otg_udc(&self) -> Option<String> {
crate::otg::configfs::resolve_udc_name(self.otg_udc.as_deref())
}
pub fn resolved_otg_endpoint_budget(&self) -> OtgEndpointBudget {
self.otg_endpoint_budget
.resolved(self.resolved_otg_udc().as_deref())
}
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
self.otg_endpoint_budget
.endpoint_limit(self.resolved_otg_udc().as_deref())
}
pub fn effective_otg_keyboard_leds(&self) -> bool {
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
}
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
self.effective_otg_functions()
}
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
let functions = self.effective_otg_functions();
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
if msd_enabled {
endpoints += 2;
}
endpoints
}
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
if self.backend != HidBackend::Otg {
return Ok(());
}
let functions = self.effective_otg_functions();
if functions.is_empty() {
return Err(crate::error::AppError::BadRequest(
"OTG HID functions cannot be empty".to_string(),
));
}
let required = self.effective_otg_required_endpoints(msd_enabled);
if let Some(limit) = self.resolved_otg_endpoint_limit() {
if required > limit {
return Err(crate::error::AppError::BadRequest(format!(
"OTG selection requires {} endpoints, but the configured limit is {}",
required, limit
)));
}
}
Ok(())
}
}
/// MSD configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct MsdConfig {
/// Enable MSD functionality
pub enabled: bool,
/// MSD base directory (absolute path)
pub msd_dir: String,
}
impl Default for MsdConfig {
fn default() -> Self {
Self {
enabled: true,
msd_dir: String::new(),
}
}
}
impl MsdConfig {
pub fn msd_dir_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.msd_dir)
}
pub fn images_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("images")
}
pub fn ventoy_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("ventoy")
}
pub fn drive_path(&self) -> std::path::PathBuf {
self.ventoy_dir().join("ventoy.img")
}
}
// Re-export ATX types from atx module for configuration
pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig};
/// ATX power control configuration
///
/// Each ATX action (power, reset) can be independently configured with its own
/// hardware binding using the four-tuple: (driver, device, pin, active_level).
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AtxConfig {
/// Enable ATX functionality
pub enabled: bool,
/// Power button configuration (used for both short and long press)
pub power: AtxKeyConfig,
/// Reset button configuration
pub reset: AtxKeyConfig,
/// LED sensing configuration (optional)
pub led: AtxLedConfig,
/// Network interface for WOL packets (empty = auto)
pub wol_interface: String,
}
impl AtxConfig {
/// Convert to AtxControllerConfig for the controller
pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig {
crate::atx::AtxControllerConfig {
enabled: self.enabled,
power: self.power.clone(),
reset: self.reset.clone(),
led: self.led.clone(),
}
}
}
/// Audio configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AudioConfig {
/// Enable audio capture
pub enabled: bool,
/// ALSA device name
pub device: String,
/// Audio quality preset: "voice", "balanced", "high"
pub quality: String,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
quality: "balanced".to_string(),
}
}
}
/// Stream mode
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum StreamMode {
/// WebRTC with H264/H265
WebRTC,
/// MJPEG over HTTP
#[default]
Mjpeg,
}
/// RTSP output codec
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum RtspCodec {
#[default]
H264,
H265,
}
/// RTSP configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RtspConfig {
/// Enable RTSP output
pub enabled: bool,
/// Bind IP address
pub bind: String,
/// RTSP TCP listen port
pub port: u16,
/// Stream path (without leading slash)
pub path: String,
/// Allow only one client connection at a time
pub allow_one_client: bool,
/// Output codec (H264/H265)
pub codec: RtspCodec,
/// Optional username for authentication
pub username: Option<String>,
/// Optional password for authentication
#[typeshare(skip)]
pub password: Option<String>,
}
impl Default for RtspConfig {
fn default() -> Self {
Self {
enabled: false,
bind: "0.0.0.0".to_string(),
port: 8554,
path: "live".to_string(),
allow_one_client: true,
codec: RtspCodec::H264,
username: None,
password: None,
}
}
}
/// Encoder type
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum EncoderType {
/// Auto-detect best encoder
#[default]
Auto,
/// Software encoder (libx264)
Software,
/// VAAPI hardware encoder
Vaapi,
/// NVIDIA NVENC hardware encoder
Nvenc,
/// Intel Quick Sync hardware encoder
Qsv,
/// AMD AMF hardware encoder
Amf,
/// Rockchip MPP hardware encoder
Rkmpp,
/// V4L2 M2M hardware encoder
V4l2m2m,
}
impl EncoderType {
/// Convert to EncoderBackend for registry queries
pub fn to_backend(&self) -> Option<crate::video::encoder::registry::EncoderBackend> {
use crate::video::encoder::registry::EncoderBackend;
match self {
EncoderType::Auto => None,
EncoderType::Software => Some(EncoderBackend::Software),
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
EncoderType::Qsv => Some(EncoderBackend::Qsv),
EncoderType::Amf => Some(EncoderBackend::Amf),
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
}
}
/// Get display name for UI
pub fn display_name(&self) -> &'static str {
match self {
EncoderType::Auto => "Auto (Recommended)",
EncoderType::Software => "Software (CPU)",
EncoderType::Vaapi => "VAAPI",
EncoderType::Nvenc => "NVIDIA NVENC",
EncoderType::Qsv => "Intel Quick Sync",
EncoderType::Amf => "AMD AMF",
EncoderType::Rkmpp => "Rockchip MPP",
EncoderType::V4l2m2m => "V4L2 M2M",
}
}
}
/// Streaming configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
/// Stream mode
pub mode: StreamMode,
/// Encoder type for H264/H265
pub encoder: EncoderType,
/// Bitrate preset (Speed/Balanced/Quality)
pub bitrate_preset: BitratePreset,
/// Custom STUN server (e.g., "stun:stun.l.google.com:19302")
/// If empty, uses public ICE servers from secrets.toml
pub stun_server: Option<String>,
/// Custom TURN server (e.g., "turn:turn.example.com:3478")
/// If empty, uses public ICE servers from secrets.toml
pub turn_server: Option<String>,
/// TURN username
pub turn_username: Option<String>,
/// TURN password (stored encrypted in DB, not exposed via API)
pub turn_password: Option<String>,
/// Auto-pause when no clients connected
#[typeshare(skip)]
pub auto_pause_enabled: bool,
/// Auto-pause delay (seconds)
#[typeshare(skip)]
pub auto_pause_delay_secs: u64,
/// Client timeout for cleanup (seconds)
#[typeshare(skip)]
pub client_timeout_secs: u64,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
mode: StreamMode::Mjpeg,
encoder: EncoderType::Auto,
bitrate_preset: BitratePreset::Balanced,
// Empty means use public ICE servers (like RustDesk)
stun_server: None,
turn_server: None,
turn_username: None,
turn_password: None,
auto_pause_enabled: false,
auto_pause_delay_secs: 10,
client_timeout_secs: 30,
}
}
}
impl StreamConfig {
/// Check if using public ICE servers (user left fields empty)
pub fn is_using_public_ice_servers(&self) -> bool {
use crate::webrtc::config::public_ice;
self.stun_server
.as_ref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& self
.turn_server
.as_ref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& public_ice::is_configured()
}
}
/// Web server configuration persisted in the database (includes on-disk TLS paths).
///
/// The HTTP API for `/api/config/web` uses `WebConfigResponse` instead: no path fields, includes `has_custom_cert`.
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WebConfig {
/// HTTP port
pub http_port: u16,
/// HTTPS port
pub https_port: u16,
/// Bind addresses (preferred)
pub bind_addresses: Vec<String>,
/// Bind address (legacy)
pub bind_address: String,
/// Enable HTTPS
pub https_enabled: bool,
/// Custom SSL certificate path
pub ssl_cert_path: Option<String>,
/// Custom SSL key path
pub ssl_key_path: Option<String>,
}
impl Default for WebConfig {
fn default() -> Self {
Self {
http_port: 8080,
https_port: 8443,
bind_addresses: Vec::new(),
bind_address: "0.0.0.0".to_string(),
https_enabled: false,
ssl_cert_path: None,
ssl_key_path: None,
}
}
}

28
src/config/schema/atx.rs Normal file
View File

@@ -0,0 +1,28 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
pub use crate::atx::{ActiveLevel, AtxDriverType, AtxKeyConfig, AtxLedConfig};
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AtxConfig {
pub enabled: bool,
pub power: AtxKeyConfig,
pub reset: AtxKeyConfig,
pub led: AtxLedConfig,
pub wol_interface: String,
}
impl AtxConfig {
pub fn to_controller_config(&self) -> crate::atx::AtxControllerConfig {
crate::atx::AtxControllerConfig {
enabled: self.enabled,
power: self.power.clone(),
reset: self.reset.clone(),
led: self.led.clone(),
}
}
}

View File

@@ -0,0 +1,64 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
#[derive(Default)]
pub enum BitratePreset {
Speed,
#[default]
Balanced,
Quality,
Custom(u32),
}
impl BitratePreset {
pub fn bitrate_kbps(&self) -> u32 {
match self {
Self::Speed => 1000,
Self::Balanced => 4000,
Self::Quality => 8000,
Self::Custom(kbps) => *kbps,
}
}
pub fn gop_size(&self, fps: u32) -> u32 {
match self {
Self::Speed => (fps / 2).max(15),
Self::Balanced => fps,
Self::Quality => fps * 2,
Self::Custom(_) => fps,
}
}
pub fn quality_level(&self) -> &'static str {
match self {
Self::Speed => "low",
Self::Balanced => "medium",
Self::Quality => "high",
Self::Custom(_) => "medium",
}
}
pub fn from_kbps(kbps: u32) -> Self {
match kbps {
0..=1500 => Self::Speed,
1501..=6000 => Self::Balanced,
6001..=10000 => Self::Quality,
_ => Self::Custom(kbps),
}
}
}
impl std::fmt::Display for BitratePreset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Speed => write!(f, "Speed (1 Mbps)"),
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
Self::Quality => write!(f, "Quality (8 Mbps)"),
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
}
}
}

309
src/config/schema/hid.rs Normal file
View File

@@ -0,0 +1,309 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackend {
Otg,
Ch9329,
#[default]
None,
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OtgDescriptorConfig {
pub vendor_id: u16,
pub product_id: u16,
pub manufacturer: String,
pub product: String,
pub serial_number: Option<String>,
}
impl Default for OtgDescriptorConfig {
fn default() -> Self {
Self {
vendor_id: 0x1d6b,
product_id: 0x0104,
manufacturer: "One-KVM".to_string(),
product: "One-KVM USB Device".to_string(),
serial_number: None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgHidProfile {
#[default]
#[serde(alias = "full_no_msd")]
Full,
#[serde(alias = "full_no_consumer_no_msd")]
FullNoConsumer,
LegacyKeyboard,
LegacyMouseRelative,
Custom,
}
#[typeshare]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgEndpointBudget {
#[default]
Auto,
Five,
Six,
Unlimited,
}
impl OtgEndpointBudget {
pub fn endpoint_limit_raw(&self) -> Option<u8> {
match self {
Self::Five => Some(5),
Self::Six => Some(6),
Self::Unlimited => None,
Self::Auto => None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(default)]
pub struct OtgHidFunctions {
pub keyboard: bool,
pub mouse_relative: bool,
pub mouse_absolute: bool,
pub consumer: bool,
}
impl OtgHidFunctions {
pub fn full() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: true,
}
}
pub fn full_no_consumer() -> Self {
Self {
keyboard: true,
mouse_relative: true,
mouse_absolute: true,
consumer: false,
}
}
pub fn legacy_keyboard() -> Self {
Self {
keyboard: true,
mouse_relative: false,
mouse_absolute: false,
consumer: false,
}
}
pub fn legacy_mouse_relative() -> Self {
Self {
keyboard: false,
mouse_relative: true,
mouse_absolute: false,
consumer: false,
}
}
pub fn is_empty(&self) -> bool {
!self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer
}
pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 {
let mut endpoints = 0;
if self.keyboard {
endpoints += 1;
if keyboard_leds {
endpoints += 1;
}
}
if self.mouse_relative {
endpoints += 1;
}
if self.mouse_absolute {
endpoints += 1;
}
if self.consumer {
endpoints += 1;
}
endpoints
}
}
impl Default for OtgHidFunctions {
fn default() -> Self {
Self::full()
}
}
impl OtgHidProfile {
pub fn from_legacy_str(value: &str) -> Option<Self> {
match value {
"full" | "full_no_msd" => Some(Self::Full),
"full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer),
"legacy_keyboard" => Some(Self::LegacyKeyboard),
"legacy_mouse_relative" => Some(Self::LegacyMouseRelative),
"custom" => Some(Self::Custom),
_ => None,
}
}
pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions {
match self {
Self::Full => OtgHidFunctions::full(),
Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(),
Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(),
Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(),
Self::Custom => custom.clone(),
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct HidConfig {
pub backend: HidBackend,
pub otg_udc: Option<String>,
#[serde(default)]
pub otg_descriptor: OtgDescriptorConfig,
#[serde(default)]
pub otg_profile: OtgHidProfile,
#[serde(default)]
pub otg_endpoint_budget: OtgEndpointBudget,
#[serde(default)]
pub otg_functions: OtgHidFunctions,
#[serde(default)]
pub otg_keyboard_leds: bool,
pub ch9329_port: String,
pub ch9329_baudrate: u32,
pub mouse_absolute: bool,
}
impl Default for HidConfig {
fn default() -> Self {
Self {
backend: HidBackend::None,
otg_udc: None,
otg_descriptor: OtgDescriptorConfig::default(),
otg_profile: OtgHidProfile::default(),
otg_endpoint_budget: OtgEndpointBudget::default(),
otg_functions: OtgHidFunctions::default(),
otg_keyboard_leds: false,
ch9329_port: "/dev/ttyUSB0".to_string(),
ch9329_baudrate: 9600,
mouse_absolute: true,
}
}
}
impl HidConfig {
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
self.otg_profile.resolve_functions(&self.otg_functions)
}
pub fn effective_otg_keyboard_leds(&self) -> bool {
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
}
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
self.effective_otg_functions()
}
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
let functions = self.effective_otg_functions();
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
if msd_enabled {
endpoints += 2;
}
endpoints
}
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
if self.backend != HidBackend::Otg {
return Ok(());
}
let functions = self.effective_otg_functions();
if functions.is_empty() {
return Err(crate::error::AppError::BadRequest(
"OTG HID functions cannot be empty".to_string(),
));
}
let resolved_limit = self.resolved_otg_endpoint_limit();
let required = self.effective_otg_required_endpoints(msd_enabled);
if let Some(limit) = resolved_limit {
if required > limit {
return Err(crate::error::AppError::BadRequest(format!(
"OTG selection requires {} endpoints, but the configured limit is {}",
required, limit
)));
}
}
Ok(())
}
#[inline]
pub fn resolved_otg_udc(&self) -> Option<String> {
if self.backend != HidBackend::Otg {
return None;
}
self.otg_udc
.as_ref()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.or_else(|| {
#[cfg(unix)]
{
crate::otg::OtgGadgetManager::find_udc()
}
#[cfg(not(unix))]
{
None
}
})
}
#[inline]
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
if self.backend != HidBackend::Otg {
return None;
}
match self.otg_endpoint_budget {
OtgEndpointBudget::Five => Some(5),
OtgEndpointBudget::Six => Some(6),
OtgEndpointBudget::Unlimited => None,
OtgEndpointBudget::Auto => {
#[cfg(unix)]
let udc = self.resolved_otg_udc().unwrap_or_default();
#[cfg(unix)]
if crate::otg::configfs::is_low_endpoint_udc(&udc) {
Some(5)
} else {
Some(6)
}
#[cfg(not(unix))]
{
Some(6)
}
}
}
}
}

44
src/config/schema/mod.rs Normal file
View File

@@ -0,0 +1,44 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
pub use crate::extensions::ExtensionsConfig;
pub use crate::rustdesk::config::RustDeskConfig;
mod atx;
mod common;
mod hid;
mod stream;
mod web;
pub use atx::*;
pub use common::*;
pub use hid::*;
pub use stream::*;
pub use web::*;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct AppConfig {
pub initialized: bool,
pub auth: AuthConfig,
pub video: VideoConfig,
pub hid: HidConfig,
pub msd: MsdConfig,
pub atx: AtxConfig,
pub audio: AudioConfig,
pub stream: StreamConfig,
pub web: WebConfig,
pub extensions: ExtensionsConfig,
pub rustdesk: RustDeskConfig,
pub rtsp: RtspConfig,
pub redfish: RedfishConfig,
}
impl AppConfig {
pub fn apply_platform_defaults(&mut self) {
crate::platform::defaults::apply(self);
}
}

149
src/config/schema/stream.rs Normal file
View File

@@ -0,0 +1,149 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
use super::BitratePreset;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum StreamMode {
WebRTC,
#[default]
Mjpeg,
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum RtspCodec {
#[default]
H264,
H265,
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RtspConfig {
pub enabled: bool,
pub bind: String,
pub port: u16,
pub path: String,
pub allow_one_client: bool,
pub codec: RtspCodec,
pub username: Option<String>,
#[typeshare(skip)]
pub password: Option<String>,
}
impl Default for RtspConfig {
fn default() -> Self {
Self {
enabled: false,
bind: "0.0.0.0".to_string(),
port: 8554,
path: "live".to_string(),
allow_one_client: true,
codec: RtspCodec::H264,
username: None,
password: None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum EncoderType {
#[default]
Auto,
Software,
Vaapi,
Nvenc,
Qsv,
Amf,
Rkmpp,
V4l2m2m,
}
impl EncoderType {
pub fn display_name(&self) -> &'static str {
match self {
EncoderType::Auto => "Auto (Recommended)",
EncoderType::Software => "Software (CPU)",
EncoderType::Vaapi => "VAAPI",
EncoderType::Nvenc => "NVIDIA NVENC",
EncoderType::Qsv => "Intel Quick Sync",
EncoderType::Amf => "AMD AMF",
EncoderType::Rkmpp => "Rockchip MPP",
EncoderType::V4l2m2m => "V4L2 M2M",
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StreamConfig {
pub mode: StreamMode,
pub encoder: EncoderType,
pub bitrate_preset: BitratePreset,
pub stun_server: Option<String>,
pub turn_server: Option<String>,
pub turn_username: Option<String>,
pub turn_password: Option<String>,
#[typeshare(skip)]
pub auto_pause_enabled: bool,
#[typeshare(skip)]
pub auto_pause_delay_secs: u64,
#[typeshare(skip)]
pub client_timeout_secs: u64,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
mode: StreamMode::Mjpeg,
encoder: EncoderType::Auto,
bitrate_preset: BitratePreset::Balanced,
stun_server: None,
turn_server: None,
turn_username: None,
turn_password: None,
auto_pause_enabled: false,
auto_pause_delay_secs: 10,
client_timeout_secs: 30,
}
}
}
impl StreamConfig {
pub fn is_using_public_ice_servers(&self) -> bool {
let no_custom_stun = self
.stun_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
let no_custom_turn = self
.turn_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
no_custom_stun && no_custom_turn
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RedfishConfig {
pub enabled: bool,
}
impl Default for RedfishConfig {
fn default() -> Self {
Self { enabled: false }
}
}

129
src/config/schema/web.rs Normal file
View File

@@ -0,0 +1,129 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AuthConfig {
pub session_timeout_secs: u32,
pub single_user_allow_multiple_sessions: bool,
pub totp_enabled: bool,
pub totp_secret: Option<String>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
session_timeout_secs: 3600 * 24,
single_user_allow_multiple_sessions: false,
totp_enabled: false,
totp_secret: None,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(default)]
pub struct VideoConfig {
pub device: Option<String>,
pub format: Option<String>,
pub width: u32,
pub height: u32,
pub fps: u32,
pub quality: u32,
}
impl Default for VideoConfig {
fn default() -> Self {
Self {
device: None,
format: None,
width: 1920,
height: 1080,
fps: 30,
quality: 80,
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct MsdConfig {
pub enabled: bool,
pub msd_dir: String,
}
impl Default for MsdConfig {
fn default() -> Self {
Self {
enabled: true,
msd_dir: String::new(),
}
}
}
impl MsdConfig {
pub fn msd_dir_path(&self) -> std::path::PathBuf {
std::path::PathBuf::from(&self.msd_dir)
}
pub fn images_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("images")
}
pub fn ventoy_dir(&self) -> std::path::PathBuf {
self.msd_dir_path().join("ventoy")
}
pub fn drive_path(&self) -> std::path::PathBuf {
self.ventoy_dir().join("ventoy.img")
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AudioConfig {
pub enabled: bool,
pub device: String,
pub quality: String,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: String::new(),
quality: "balanced".to_string(),
}
}
}
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WebConfig {
pub http_port: u16,
pub https_port: u16,
pub bind_addresses: Vec<String>,
pub bind_address: String,
pub https_enabled: bool,
pub ssl_cert_path: Option<String>,
pub ssl_key_path: Option<String>,
}
impl Default for WebConfig {
fn default() -> Self {
Self {
http_port: 8080,
https_port: 8443,
bind_addresses: Vec::new(),
bind_address: "0.0.0.0".to_string(),
https_enabled: false,
ssl_cert_path: None,
ssl_key_path: None,
}
}
}

View File

@@ -1,153 +1,37 @@
use arc_swap::ArcSwap;
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use sqlx::{Pool, Sqlite};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use super::AppConfig;
use super::ConfigChange;
use crate::error::{AppError, Result};
/// Configuration store backed by SQLite
///
/// Uses `ArcSwap` for lock-free reads, providing high performance
/// for frequent configuration access in hot paths.
#[derive(Clone)]
pub struct ConfigStore {
pool: Pool<Sqlite>,
/// Lock-free cache using ArcSwap for zero-cost reads
cache: Arc<ArcSwap<AppConfig>>,
change_tx: broadcast::Sender<ConfigChange>,
/// Serializes `set` / `update` so concurrent PATCH handlers cannot clobber each other
write_lock: Arc<Mutex<()>>,
}
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}
impl ConfigStore {
/// Create a new configuration store
pub async fn new(db_path: &Path) -> Result<Self> {
// Ensure parent directory exists
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
// SQLite uses single-writer mode, 2 connections is sufficient for embedded devices
// One for reads, one for writes to avoid blocking
.max_connections(2)
// Set reasonable timeouts for embedded environments
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
// Initialize database schema
Self::init_schema(&pool).await?;
// Load or create default config
let config = Self::load_config(&pool).await?;
let cache = Arc::new(ArcSwap::from_pointee(config));
let (change_tx, _) = broadcast::channel(16);
pub fn new(pool: Pool<Sqlite>) -> Result<Self> {
Ok(Self {
pool,
cache,
change_tx,
cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())),
change_tx: broadcast::channel(16).0,
write_lock: Arc::new(Mutex::new(())),
})
}
/// Initialize database schema
async fn init_schema(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
expires_at TEXT NOT NULL,
data TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS wol_history (
mac_address TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
ON wol_history(updated_at DESC)
"#,
)
.execute(pool)
.await?;
pub async fn load(&self) -> Result<()> {
let config = Self::load_config(&self.pool).await?;
self.cache.store(Arc::new(config));
Ok(())
}
/// Load configuration from database
async fn load_config(pool: &Pool<Sqlite>) -> Result<AppConfig> {
let row: Option<(String,)> =
sqlx::query_as("SELECT value FROM config WHERE key = 'app_config'")
@@ -159,7 +43,6 @@ impl ConfigStore {
serde_json::from_str(&json).map_err(|e| AppError::Config(e.to_string()))
}
None => {
// Create default config
let config = AppConfig::default();
Self::save_config_to_db(pool, &config).await?;
Ok(config)
@@ -167,7 +50,6 @@ impl ConfigStore {
}
}
/// Save configuration to database
async fn save_config_to_db(pool: &Pool<Sqlite>, config: &AppConfig) -> Result<()> {
let json = serde_json::to_string(config)?;
@@ -185,21 +67,15 @@ impl ConfigStore {
Ok(())
}
/// Get current configuration (lock-free, zero-copy)
///
/// Returns an `Arc<AppConfig>` for efficient sharing without cloning.
/// This is a lock-free operation with minimal overhead.
pub fn get(&self) -> Arc<AppConfig> {
self.cache.load_full()
}
/// Set entire configuration
pub async fn set(&self, config: AppConfig) -> Result<()> {
let _guard = self.write_lock.lock().await;
Self::save_config_to_db(&self.pool, &config).await?;
self.cache.store(Arc::new(config));
// Notify subscribers
let _ = self.change_tx.send(ConfigChange {
key: "app_config".to_string(),
});
@@ -207,27 +83,19 @@ impl ConfigStore {
Ok(())
}
/// Update configuration with a closure
///
/// Uses read-modify-write under a mutex so concurrent `update` / `set` calls are serialized
/// and merged correctly (each closure sees the latest stored config).
pub async fn update<F>(&self, f: F) -> Result<()>
where
F: FnOnce(&mut AppConfig),
{
let _guard = self.write_lock.lock().await;
// Load current config, clone it for modification
let current = self.cache.load();
let mut config = (**current).clone();
f(&mut config);
// Persist to database first
Self::save_config_to_db(&self.pool, &config).await?;
// Then update cache atomically
self.cache.store(Arc::new(config));
// Notify subscribers
let _ = self.change_tx.send(ConfigChange {
key: "app_config".to_string(),
});
@@ -235,25 +103,19 @@ impl ConfigStore {
Ok(())
}
/// Subscribe to configuration changes
pub fn subscribe(&self) -> broadcast::Receiver<ConfigChange> {
self.change_tx.subscribe()
}
/// Check if system is initialized (lock-free)
pub fn is_initialized(&self) -> bool {
self.cache.load().initialized
}
/// Get database pool for session management
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::DatabasePool;
use tempfile::tempdir;
#[tokio::test]
@@ -261,13 +123,15 @@ mod tests {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let store = ConfigStore::new(&db_path).await.unwrap();
let db = DatabasePool::new(&db_path).await.unwrap();
db.init_schema().await.unwrap();
let store = ConfigStore::new(db.clone_pool()).unwrap();
store.load().await.unwrap();
// Check default config (now lock-free, no await needed)
let config = store.get();
assert!(!config.initialized);
// Update config
store
.update(|c| {
c.initialized = true;
@@ -276,13 +140,12 @@ mod tests {
.await
.unwrap();
// Verify update
let config = store.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);
// Create new store instance and verify persistence
let store2 = ConfigStore::new(&db_path).await.unwrap();
let store2 = ConfigStore::new(db.clone_pool()).unwrap();
store2.load().await.unwrap();
let config = store2.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);

3
src/db/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
mod pool;
pub use pool::DatabasePool;

119
src/db/pool.rs Normal file
View File

@@ -0,0 +1,119 @@
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use std::time::Duration;
use crate::error::Result;
#[derive(Clone)]
pub struct DatabasePool {
pool: Pool<Sqlite>,
}
impl DatabasePool {
pub async fn new(db_path: &Path) -> Result<Self> {
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
.max_connections(4)
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
Ok(Self { pool })
}
pub async fn init_schema(&self) -> Result<()> {
self.create_config_table().await?;
self.create_users_table().await?;
self.create_api_tokens_table().await?;
self.create_wol_history_table().await?;
Ok(())
}
async fn create_config_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_users_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_api_tokens_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_wol_history_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS wol_history (
mac_address TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
ON wol_history(updated_at DESC)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
pub fn clone_pool(&self) -> Pool<Sqlite> {
self.pool.clone()
}
}

280
src/diagnostics/linux.rs Normal file
View File

@@ -0,0 +1,280 @@
use super::{DeviceInfo, DiskSpaceInfo, NetworkAddress};
use crate::error::{AppError, Result};
use crate::utils::hostname_uname;
pub fn get_disk_space(path: &std::path::Path) -> Result<DiskSpaceInfo> {
let stat = nix::sys::statvfs::statvfs(path)
.map_err(|e| AppError::Internal(format!("Failed to get disk space: {}", e)))?;
let block_size = stat.block_size() as u64;
let total = stat.blocks() as u64 * block_size;
let available = stat.blocks_available() as u64 * block_size;
let used = total - available;
Ok(DiskSpaceInfo {
total,
available,
used,
})
}
pub fn get_device_info() -> DeviceInfo {
let mem_info = get_meminfo();
DeviceInfo {
hostname: hostname_uname(),
cpu_model: get_cpu_model(),
cpu_usage: get_cpu_usage(),
memory_total: mem_info.total,
memory_used: mem_info.total.saturating_sub(mem_info.available),
network_addresses: get_network_addresses(),
serial_ports: crate::utils::list_serial_ports(),
}
}
fn get_cpu_model() -> String {
let cpuinfo = std::fs::read_to_string("/proc/cpuinfo").ok();
if let Some(model) = parse_cpu_model_from_cpuinfo_content(cpuinfo.as_deref()) {
return model;
}
if let Some(model) = read_device_tree_model() {
return model;
}
if let Some(content) = cpuinfo.as_deref() {
let cores = content
.lines()
.filter(|line| line.starts_with("processor"))
.count();
if cores > 0 {
return format!("{} {}C", std::env::consts::ARCH, cores);
}
}
std::env::consts::ARCH.to_string()
}
fn parse_cpu_model_from_cpuinfo_content(content: Option<&str>) -> Option<String> {
let content = content?;
content
.lines()
.find(|line| line.starts_with("model name") || line.starts_with("Model"))
.and_then(|line| line.split(':').nth(1))
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn read_device_tree_model() -> Option<String> {
std::fs::read("/proc/device-tree/model")
.ok()
.and_then(|bytes| parse_device_tree_model_bytes(bytes.as_slice()))
}
fn parse_device_tree_model_bytes(bytes: &[u8]) -> Option<String> {
let model = String::from_utf8_lossy(bytes)
.trim_matches(|c: char| c == '\0' || c.is_whitespace())
.to_string();
if model.is_empty() {
None
} else {
Some(model)
}
}
static CPU_PREV_STATS: std::sync::OnceLock<std::sync::Mutex<(u64, u64)>> =
std::sync::OnceLock::new();
fn get_cpu_usage() -> f32 {
let content = match std::fs::read_to_string("/proc/stat") {
Ok(c) => c,
Err(_) => return 0.0,
};
let cpu_line = match content.lines().next() {
Some(line) if line.starts_with("cpu ") => line,
_ => return 0.0,
};
let parts: Vec<u64> = cpu_line
.split_whitespace()
.skip(1)
.take(8)
.filter_map(|s| s.parse().ok())
.collect();
if parts.len() < 4 {
return 0.0;
}
let idle = parts[3] + parts.get(4).unwrap_or(&0);
let total: u64 = parts.iter().sum();
let prev_mutex = CPU_PREV_STATS.get_or_init(|| std::sync::Mutex::new((0, 0)));
let mut prev = prev_mutex.lock().unwrap();
let (prev_idle, prev_total) = *prev;
let idle_delta = idle.saturating_sub(prev_idle);
let total_delta = total.saturating_sub(prev_total);
*prev = (idle, total);
if total_delta == 0 {
return 0.0;
}
let usage = 100.0 * (1.0 - (idle_delta as f64 / total_delta as f64));
usage as f32
}
struct MemInfo {
total: u64,
available: u64,
}
fn get_meminfo() -> MemInfo {
let content = match std::fs::read_to_string("/proc/meminfo") {
Ok(c) => c,
Err(_) => {
return MemInfo {
total: 0,
available: 0,
}
}
};
let mut total = 0u64;
let mut available = 0u64;
for line in content.lines() {
if line.starts_with("MemTotal:") {
if let Some(kb) = line
.split_whitespace()
.nth(1)
.and_then(|v| v.parse::<u64>().ok())
{
total = kb * 1024;
}
} else if line.starts_with("MemAvailable:") {
if let Some(kb) = line
.split_whitespace()
.nth(1)
.and_then(|v| v.parse::<u64>().ok())
{
available = kb * 1024;
}
}
if total > 0 && available > 0 {
break;
}
}
MemInfo { total, available }
}
fn get_network_addresses() -> Vec<NetworkAddress> {
let all_addrs = match nix::ifaddrs::getifaddrs() {
Ok(addrs) => addrs,
Err(_) => return Vec::new(),
};
let mut up_ifaces = std::collections::HashSet::new();
let net_dir = match std::fs::read_dir("/sys/class/net") {
Ok(dir) => dir,
Err(_) => return Vec::new(),
};
for entry in net_dir.flatten() {
let iface_name = match entry.file_name().into_string() {
Ok(name) => name,
Err(_) => continue,
};
if iface_name == "lo" {
continue;
}
let operstate_path = entry.path().join("operstate");
let is_up = std::fs::read_to_string(&operstate_path)
.map(|s| s.trim() == "up")
.unwrap_or(false);
if is_up {
up_ifaces.insert(iface_name);
}
}
let mut addresses = Vec::new();
let mut seen = std::collections::HashSet::new();
for ifaddr in all_addrs {
let iface_name = &ifaddr.interface_name;
if iface_name == "lo" || !up_ifaces.contains(iface_name) {
continue;
}
if let Some(addr) = ifaddr.address {
if let Some(sockaddr_in) = addr.as_sockaddr_in() {
let ip = sockaddr_in.ip();
if ip.is_loopback() {
continue;
}
let ip_str = ip.to_string();
if seen.insert((iface_name.clone(), ip_str.clone())) {
addresses.push(NetworkAddress {
interface: iface_name.clone(),
ip: ip_str,
});
}
} else if let Some(sockaddr_in6) = addr.as_sockaddr_in6() {
let ip = sockaddr_in6.ip();
if ip.is_loopback() || ip.is_unspecified() || ip.is_unicast_link_local() {
continue;
}
let ip_str = ip.to_string();
if seen.insert((iface_name.clone(), ip_str.clone())) {
addresses.push(NetworkAddress {
interface: iface_name.clone(),
ip: ip_str,
});
}
}
}
}
addresses
}
#[cfg(test)]
mod tests {
use super::{parse_cpu_model_from_cpuinfo_content, parse_device_tree_model_bytes};
#[test]
fn parse_cpu_model_from_model_name_field() {
let input = "processor\t: 0\nmodel name\t: Intel(R) Xeon(R)\n";
assert_eq!(
parse_cpu_model_from_cpuinfo_content(input),
Some("Intel(R) Xeon(R)".to_string())
);
}
#[test]
fn parse_cpu_model_from_model_field() {
let input = "processor\t: 0\nModel\t\t: Raspberry Pi 4 Model B Rev 1.4\n";
assert_eq!(
parse_cpu_model_from_cpuinfo_content(input),
Some("Raspberry Pi 4 Model B Rev 1.4".to_string())
);
}
#[test]
fn parse_device_tree_model_trimmed() {
let input = b"Onething OEC Box\0\n";
assert_eq!(
parse_device_tree_model_bytes(input),
Some("Onething OEC Box".to_string())
);
}
}

47
src/diagnostics/mod.rs Normal file
View File

@@ -0,0 +1,47 @@
//! Host diagnostics used by the web status API.
use serde::Serialize;
use crate::error::Result;
#[derive(Debug, Clone, Serialize)]
pub struct DeviceInfo {
pub hostname: String,
pub cpu_model: String,
pub cpu_usage: f32,
pub memory_total: u64,
pub memory_used: u64,
pub network_addresses: Vec<NetworkAddress>,
pub serial_ports: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct NetworkAddress {
pub interface: String,
pub ip: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct DiskSpaceInfo {
pub total: u64,
pub available: u64,
pub used: u64,
}
#[cfg(unix)]
mod linux;
#[cfg(windows)]
mod windows;
#[cfg(unix)]
use linux as platform;
#[cfg(windows)]
use windows as platform;
pub fn get_disk_space(path: &std::path::Path) -> Result<DiskSpaceInfo> {
platform::get_disk_space(path)
}
pub fn get_device_info() -> DeviceInfo {
platform::get_device_info()
}

249
src/diagnostics/windows.rs Normal file
View File

@@ -0,0 +1,249 @@
use super::{DeviceInfo, DiskSpaceInfo, NetworkAddress};
use crate::error::{AppError, Result};
use crate::utils::hostname_uname;
use std::ffi::CStr;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::{Mutex, OnceLock};
use windows_sys::Win32::Foundation::{ERROR_BUFFER_OVERFLOW, ERROR_SUCCESS, FILETIME};
use windows_sys::Win32::NetworkManagement::IpHelper::{
GetAdaptersAddresses, GAA_FLAG_SKIP_ANYCAST, GAA_FLAG_SKIP_DNS_SERVER, GAA_FLAG_SKIP_MULTICAST,
IP_ADAPTER_ADDRESSES_LH,
};
use windows_sys::Win32::NetworkManagement::Ndis::IfOperStatusUp;
use windows_sys::Win32::Networking::WinSock::{
AF_INET, AF_INET6, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6,
};
use windows_sys::Win32::System::SystemInformation::{
GetNativeSystemInfo, GlobalMemoryStatusEx, MEMORYSTATUSEX, PROCESSOR_ARCHITECTURE_AMD64,
PROCESSOR_ARCHITECTURE_ARM64, PROCESSOR_ARCHITECTURE_INTEL, SYSTEM_INFO,
};
use windows_sys::Win32::System::Threading::GetSystemTimes;
pub fn get_disk_space(_path: &std::path::Path) -> Result<DiskSpaceInfo> {
Err(AppError::Internal(
"Disk space reporting is unavailable on Windows".to_string(),
))
}
pub fn get_device_info() -> DeviceInfo {
let (memory_total, memory_used) = get_memory_usage();
DeviceInfo {
hostname: hostname_uname(),
cpu_model: get_cpu_model(),
cpu_usage: get_cpu_usage(),
memory_total,
memory_used,
network_addresses: get_network_addresses(),
serial_ports: crate::utils::list_serial_ports(),
}
}
fn get_cpu_model() -> String {
std::env::var("PROCESSOR_IDENTIFIER")
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(get_cpu_arch_label)
}
fn get_cpu_arch_label() -> String {
let mut info = std::mem::MaybeUninit::<SYSTEM_INFO>::zeroed();
unsafe {
GetNativeSystemInfo(info.as_mut_ptr());
let info = info.assume_init();
match info.Anonymous.Anonymous.wProcessorArchitecture {
PROCESSOR_ARCHITECTURE_AMD64 => "x86_64".to_string(),
PROCESSOR_ARCHITECTURE_ARM64 => "aarch64".to_string(),
PROCESSOR_ARCHITECTURE_INTEL => "x86".to_string(),
_ => std::env::consts::ARCH.to_string(),
}
}
}
fn get_memory_usage() -> (u64, u64) {
let mut status = MEMORYSTATUSEX {
dwLength: std::mem::size_of::<MEMORYSTATUSEX>() as u32,
..unsafe { std::mem::zeroed() }
};
let ok = unsafe { GlobalMemoryStatusEx(&mut status) };
if ok == 0 {
return (0, 0);
}
(
status.ullTotalPhys,
status.ullTotalPhys.saturating_sub(status.ullAvailPhys),
)
}
fn get_cpu_usage() -> f32 {
static LAST_SAMPLE: OnceLock<Mutex<Option<CpuTimes>>> = OnceLock::new();
let Some(current) = read_cpu_times() else {
return 0.0;
};
let sample = LAST_SAMPLE.get_or_init(|| Mutex::new(None));
let Ok(mut last) = sample.lock() else {
return 0.0;
};
let (previous, current) = if let Some(previous) = last.replace(current) {
(previous, current)
} else {
drop(last);
std::thread::sleep(std::time::Duration::from_millis(100));
let Some(next) = read_cpu_times() else {
return 0.0;
};
if let Ok(mut last) = sample.lock() {
*last = Some(next);
}
(current, next)
};
let idle = current.idle.saturating_sub(previous.idle);
let kernel = current.kernel.saturating_sub(previous.kernel);
let user = current.user.saturating_sub(previous.user);
let total = kernel.saturating_add(user);
if total == 0 {
return 0.0;
}
((total.saturating_sub(idle)) as f64 * 100.0 / total as f64).clamp(0.0, 100.0) as f32
}
#[derive(Clone, Copy)]
struct CpuTimes {
idle: u64,
kernel: u64,
user: u64,
}
fn read_cpu_times() -> Option<CpuTimes> {
let mut idle = FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let mut kernel = idle;
let mut user = idle;
let ok = unsafe { GetSystemTimes(&mut idle, &mut kernel, &mut user) };
if ok == 0 {
return None;
}
Some(CpuTimes {
idle: filetime_to_u64(idle),
kernel: filetime_to_u64(kernel),
user: filetime_to_u64(user),
})
}
fn filetime_to_u64(time: FILETIME) -> u64 {
((time.dwHighDateTime as u64) << 32) | time.dwLowDateTime as u64
}
fn get_network_addresses() -> Vec<NetworkAddress> {
let mut buffer_len = 15_000u32;
let flags = GAA_FLAG_SKIP_ANYCAST | GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_DNS_SERVER;
for _ in 0..2 {
let mut buffer = vec![0u8; buffer_len as usize];
let ret = unsafe {
GetAdaptersAddresses(
0,
flags,
std::ptr::null_mut(),
buffer.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH,
&mut buffer_len,
)
};
if ret == ERROR_BUFFER_OVERFLOW {
continue;
}
if ret != ERROR_SUCCESS {
return Vec::new();
}
let mut addresses = Vec::new();
let mut adapter = buffer.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH;
while !adapter.is_null() {
let adapter_ref = unsafe { &*adapter };
if adapter_ref.OperStatus != IfOperStatusUp {
adapter = adapter_ref.Next;
continue;
}
let interface = adapter_name(adapter_ref);
let mut unicast = adapter_ref.FirstUnicastAddress;
while !unicast.is_null() {
let unicast_ref = unsafe { &*unicast };
if let Some(ip) = sockaddr_to_ip(unicast_ref.Address.lpSockaddr) {
addresses.push(NetworkAddress {
interface: interface.clone(),
ip,
});
}
unicast = unicast_ref.Next;
}
adapter = adapter_ref.Next;
}
addresses.sort_by(|a, b| a.interface.cmp(&b.interface).then(a.ip.cmp(&b.ip)));
addresses.dedup_by(|a, b| a.interface == b.interface && a.ip == b.ip);
return addresses;
}
Vec::new()
}
fn adapter_name(adapter: &IP_ADAPTER_ADDRESSES_LH) -> String {
unsafe {
if !adapter.FriendlyName.is_null() {
let mut len = 0usize;
while *adapter.FriendlyName.add(len) != 0 {
len += 1;
}
let name =
String::from_utf16_lossy(std::slice::from_raw_parts(adapter.FriendlyName, len));
if !name.trim().is_empty() {
return name;
}
}
if !adapter.AdapterName.is_null() {
return CStr::from_ptr(adapter.AdapterName.cast())
.to_string_lossy()
.into_owned();
}
}
"unknown".to_string()
}
fn sockaddr_to_ip(sockaddr: *const SOCKADDR) -> Option<String> {
if sockaddr.is_null() {
return None;
}
let family = unsafe { (*sockaddr).sa_family };
match family {
AF_INET => {
let addr = unsafe { *(sockaddr as *const SOCKADDR_IN) };
let bytes = unsafe { addr.sin_addr.S_un.S_addr.to_ne_bytes() };
Some(Ipv4Addr::from(bytes).to_string())
}
AF_INET6 => {
let addr = unsafe { *(sockaddr as *const SOCKADDR_IN6) };
let bytes = unsafe { addr.sin6_addr.u.Byte };
Some(Ipv6Addr::from(bytes).to_string())
}
_ => None,
}
}

View File

@@ -1,12 +1,5 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use thiserror::Error;
/// Application-wide error type
#[derive(Error, Debug)]
pub enum AppError {
#[error("Authentication failed: {0}")]
@@ -15,17 +8,14 @@ pub enum AppError {
#[error("Not authenticated")]
Unauthorized,
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Persistence error: {0}")]
Persistence(String),
#[error("Internal error: {0}")]
Internal(String),
@@ -42,8 +32,9 @@ pub enum AppError {
#[error("Video error: {0}")]
VideoError(String),
#[error("Video device lost [{device}]: {reason}")]
VideoDeviceLost { device: String, reason: String },
/// No input signal while opening capture; `kind` is `SignalStatus` as string (`from_str`).
#[error("Capture has no valid signal: {kind}")]
CaptureNoSignal { kind: String },
#[error("Audio error: {0}")]
AudioError(String),
@@ -62,37 +53,10 @@ pub enum AppError {
ServiceUnavailable(String),
}
/// Error response body (unified success format)
#[derive(Serialize)]
pub struct ErrorResponse {
pub success: bool,
pub message: String,
}
impl AppError {
fn status_code(&self) -> StatusCode {
// Always return 200 OK - success/failure is indicated by the success field
StatusCode::OK
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let status = self.status_code();
let body = ErrorResponse {
success: false,
message: self.to_string(),
};
tracing::error!(
error_type = std::any::type_name_of_val(&self),
error_message = %body.message,
"Request failed"
);
(status, Json(body)).into_response()
}
}
/// Result type alias for handlers
pub type Result<T> = std::result::Result<T, AppError>;
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
AppError::Persistence(err.to_string())
}
}

View File

@@ -1,41 +1,28 @@
//! Event system for real-time state notifications
//!
//! This module provides a global event bus for broadcasting system events
//! to WebSocket clients and other subscribers.
//! Event bus: [`SystemEvent`] fan-out to WebSocket subscribers and internal tasks.
pub mod types;
use self::types::EXACT_EVENT_TOPICS;
pub use types::{
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent,
TtydDeviceInfo, VideoDeviceInfo,
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, LedState, MsdDeviceInfo,
StreamDeviceLostKind, SystemEvent, TtydDeviceInfo, VideoDeviceInfo,
};
use tokio::sync::broadcast;
/// Event channel capacity (ring buffer size)
const EVENT_CHANNEL_CAPACITY: usize = 256;
const EXACT_TOPICS: &[&str] = &[
"stream.mode_switching",
"stream.state_changed",
"stream.config_changing",
"stream.config_applied",
"stream.device_lost",
"stream.reconnecting",
"stream.recovered",
"stream.webrtc_ready",
"stream.stats_update",
"stream.mode_changed",
"stream.mode_ready",
"webrtc.ice_candidate",
"webrtc.ice_complete",
"msd.upload_progress",
"msd.download_progress",
"system.device_info",
"error",
];
const PREFIX_TOPICS: &[&str] = &["stream.*", "webrtc.*", "msd.*", "system.*"];
fn collect_prefix_wildcards(exact: &[&'static str]) -> Vec<String> {
use std::collections::BTreeSet;
let mut segments = BTreeSet::new();
for name in exact {
if let Some((seg, _)) = name.split_once('.') {
segments.insert(seg);
}
}
segments.into_iter().map(|s| format!("{}.*", s)).collect()
}
fn make_sender() -> broadcast::Sender<SystemEvent> {
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
@@ -48,50 +35,23 @@ fn topic_prefix(event_name: &str) -> Option<String> {
.map(|(prefix, _)| format!("{}.*", prefix))
}
/// Global event bus for broadcasting system events
///
/// The event bus uses tokio's broadcast channel to distribute events
/// to multiple subscribers. Events are delivered to all active subscribers.
///
/// # Example
///
/// ```no_run
/// use one_kvm::events::{EventBus, SystemEvent};
///
/// let bus = EventBus::new();
///
/// // Publish an event
/// bus.publish(SystemEvent::StreamStateChanged {
/// state: "streaming".to_string(),
/// device: Some("/dev/video0".to_string()),
/// });
///
/// // Subscribe to events
/// let mut rx = bus.subscribe();
/// tokio::spawn(async move {
/// while let Ok(event) = rx.recv().await {
/// println!("Received event: {:?}", event);
/// }
/// });
/// ```
pub struct EventBus {
tx: broadcast::Sender<SystemEvent>,
exact_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
prefix_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
prefix_topics: std::collections::HashMap<String, broadcast::Sender<SystemEvent>>,
device_info_dirty_tx: broadcast::Sender<()>,
}
impl EventBus {
/// Create a new event bus
pub fn new() -> Self {
let tx = make_sender();
let exact_topics = EXACT_TOPICS
let exact_topics = EXACT_EVENT_TOPICS
.iter()
.map(|topic| (*topic, make_sender()))
.collect();
let prefix_topics = PREFIX_TOPICS
.iter()
.map(|topic| (*topic, make_sender()))
let prefix_topics = collect_prefix_wildcards(EXACT_EVENT_TOPICS)
.into_iter()
.map(|topic| (topic, make_sender()))
.collect();
let (device_info_dirty_tx, _dirty_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
@@ -103,10 +63,6 @@ impl EventBus {
}
}
/// Publish an event to all subscribers
///
/// If there are no active subscribers, the event is silently dropped.
/// This is by design - events are fire-and-forget notifications.
pub fn publish(&self, event: SystemEvent) {
let event_name = event.event_name();
@@ -115,28 +71,18 @@ impl EventBus {
}
if let Some(prefix) = topic_prefix(event_name) {
if let Some(tx) = self.prefix_topics.get(prefix.as_str()) {
if let Some(tx) = self.prefix_topics.get(&prefix) {
let _ = tx.send(event.clone());
}
}
// If no subscribers, send returns Err which is normal
let _ = self.tx.send(event);
}
/// Subscribe to events
///
/// Returns a receiver that will receive all future events.
/// The receiver uses a ring buffer, so if a subscriber falls too far
/// behind, it will receive a `Lagged` error and miss some events.
pub fn subscribe(&self) -> broadcast::Receiver<SystemEvent> {
self.tx.subscribe()
}
/// Subscribe to a specific topic.
///
/// Supports exact event names, namespace wildcards like `stream.*`, and
/// `*` for the full event stream.
pub fn subscribe_topic(&self, topic: &str) -> Option<broadcast::Receiver<SystemEvent>> {
if topic == "*" {
return Some(self.tx.subscribe());
@@ -149,22 +95,14 @@ impl EventBus {
self.exact_topics.get(topic).map(|tx| tx.subscribe())
}
/// Mark the device-info snapshot as stale.
///
/// This is an internal trigger used to refresh the latest `system.device_info`
/// snapshot without exposing another public WebSocket event.
pub fn mark_device_info_dirty(&self) {
let _ = self.device_info_dirty_tx.send(());
}
/// Subscribe to internal device-info refresh triggers.
pub fn subscribe_device_info_dirty(&self) -> broadcast::Receiver<()> {
self.device_info_dirty_tx.subscribe()
}
/// Get the current number of active subscribers
///
/// Useful for monitoring and debugging.
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
@@ -188,6 +126,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -205,6 +145,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
});
let event1 = rx1.recv().await.unwrap();
@@ -222,6 +164,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -236,6 +180,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -253,10 +199,11 @@ mod tests {
let bus = EventBus::new();
assert_eq!(bus.subscriber_count(), 0);
// Should not panic when publishing with no subscribers
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
}
}

View File

@@ -1,359 +1,234 @@
//! System event types
//!
//! Defines all event types that can be broadcast through the event bus.
//! [`SystemEvent`] and device snapshot types (WebSocket / JSON).
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::hid::LedState;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct LedState {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
pub compose: bool,
pub kana: bool,
}
// ============================================================================
// Device Info Structures (for system.device_info event)
// ============================================================================
/// Video device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoDeviceInfo {
/// Whether video device is available
pub available: bool,
/// Device path (e.g., /dev/video0)
pub device: Option<String>,
/// Pixel format (e.g., "MJPEG", "YUYV")
pub format: Option<String>,
/// Resolution (width, height)
pub resolution: Option<(u32, u32)>,
/// Frames per second
pub fps: u32,
/// Whether stream is currently active
pub online: bool,
/// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
pub stream_mode: String,
/// Whether video config is currently being changed (frontend should skip mode sync)
pub config_changing: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// HID device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HidDeviceInfo {
/// Whether HID backend is available
pub available: bool,
/// Backend type: "otg", "ch9329", "none"
pub backend: String,
/// Whether backend is initialized and ready
pub initialized: bool,
/// Whether backend is currently online
pub online: bool,
/// Whether absolute mouse positioning is supported
pub supports_absolute_mouse: bool,
/// Whether keyboard LED/status feedback is enabled.
pub keyboard_leds_enabled: bool,
/// Last known keyboard LED state.
pub led_state: LedState,
/// Device path (e.g., serial port for CH9329)
pub device: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
/// Error code if any, None if OK
pub error_code: Option<String>,
}
/// MSD device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdDeviceInfo {
/// Whether MSD is available
pub available: bool,
/// Operating mode: "none", "image", "drive"
pub mode: String,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image ID
pub image_id: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// ATX device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDeviceInfo {
/// Whether ATX controller is available
pub available: bool,
/// Backend type: "gpio", "usb_relay", "none"
pub backend: String,
/// Whether backend is initialized
pub initialized: bool,
/// Whether power is currently on
pub power_on: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// Audio device information
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioDeviceInfo {
/// Whether audio is enabled/available
pub available: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Current audio device name
pub device: Option<String>,
/// Quality preset: "voice", "balanced", "high"
pub quality: String,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// ttyd status information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydDeviceInfo {
/// Whether ttyd binary is available
pub available: bool,
/// Whether ttyd is currently running
pub running: bool,
}
/// Per-client statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientStats {
/// Client ID
pub id: String,
/// Current FPS for this client (frames sent in last second)
pub fps: u32,
/// Connected duration (seconds)
pub connected_secs: u64,
}
/// System event enumeration
///
/// All events are tagged with their event name for serialization.
/// The `serde(tag = "event", content = "data")` attribute creates a
/// JSON structure like:
/// ```json
/// {
/// "event": "stream.state_changed",
/// "data": { "state": "streaming", "device": "/dev/video0" }
/// }
/// ```
/// Video vs audio source for [`SystemEvent::StreamDeviceLost`] (WebSocket `stream.device_lost`).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamDeviceLostKind {
Video,
Audio,
}
/// JSON: `{"event": "<name>", "data": { ... }}`.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event", content = "data")]
#[allow(clippy::large_enum_variant)]
pub enum SystemEvent {
// ============================================================================
// Video Stream Events
// ============================================================================
/// Stream mode switching started (transactional, correlates all following events)
///
/// Sent immediately after a mode switch request is accepted.
/// Clients can use `transition_id` to correlate subsequent `stream.*` events.
#[serde(rename = "stream.mode_switching")]
StreamModeSwitching {
/// Unique transition ID for this mode switch transaction
transition_id: String,
/// Target mode: "mjpeg", "h264", "h265", "vp8", "vp9"
to_mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", "vp9"
from_mode: String,
},
/// Stream state changed (e.g., started, stopped, error)
#[serde(rename = "stream.state_changed")]
StreamStateChanged {
/// Current state: "uninitialized", "ready", "streaming", "no_signal", "error"
state: String,
/// Device path if available
device: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
next_retry_ms: Option<u64>,
},
/// Stream configuration is being changed
///
/// Sent before applying new configuration to notify clients that
/// the stream will be interrupted temporarily.
#[serde(rename = "stream.config_changing")]
StreamConfigChanging {
/// Optional transition ID if this config change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Reason for change: "device_switch", "resolution_change", "format_change"
reason: String,
},
/// Stream configuration has been applied successfully
///
/// Sent after new configuration is active. Clients can reconnect now.
#[serde(rename = "stream.config_applied")]
StreamConfigApplied {
/// Optional transition ID if this config change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Device path
device: String,
/// Resolution (width, height)
resolution: (u32, u32),
/// Pixel format: "mjpeg", "yuyv", etc.
format: String,
/// Frames per second
fps: u32,
},
/// Stream device was lost (disconnected or error)
#[serde(rename = "stream.device_lost")]
StreamDeviceLost {
/// Device path that was lost
kind: StreamDeviceLostKind,
device: String,
/// Reason for loss
reason: String,
},
/// Stream device is reconnecting
#[serde(rename = "stream.reconnecting")]
StreamReconnecting {
/// Device path being reconnected
device: String,
/// Retry attempt number
attempt: u32,
},
StreamReconnecting { device: String, attempt: u32 },
/// Stream device has recovered
#[serde(rename = "stream.recovered")]
StreamRecovered {
/// Device path that was recovered
device: String,
},
StreamRecovered { device: String },
/// WebRTC is ready to accept connections
///
/// Sent after video frame source is connected to WebRTC pipeline.
/// Clients should wait for this event before attempting to create WebRTC sessions.
#[serde(rename = "stream.webrtc_ready")]
WebRTCReady {
/// Optional transition ID if this readiness is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Current video codec
codec: String,
/// Whether hardware encoding is being used
hardware: bool,
},
/// WebRTC ICE candidate (server -> client trickle)
#[serde(rename = "webrtc.ice_candidate")]
WebRTCIceCandidate {
/// WebRTC session ID
session_id: String,
/// ICE candidate data
candidate: crate::webrtc::signaling::IceCandidate,
candidate: serde_json::Value,
},
/// WebRTC ICE gathering complete (server -> client)
#[serde(rename = "webrtc.ice_complete")]
WebRTCIceComplete {
/// WebRTC session ID
session_id: String,
},
WebRTCIceComplete { session_id: String },
/// Stream statistics update (sent periodically for client stats)
#[serde(rename = "stream.stats_update")]
StreamStatsUpdate {
/// Number of connected clients
clients: u64,
/// Per-client statistics (client_id -> client stats)
/// Each client's FPS reflects the actual frames sent in the last second
clients_stat: HashMap<String, ClientStats>,
},
/// Stream mode changed (MJPEG <-> WebRTC)
///
/// Sent when the streaming mode is switched. Clients should disconnect
/// from the current stream and reconnect using the new mode.
#[serde(rename = "stream.mode_changed")]
StreamModeChanged {
/// Optional transition ID if this change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
previous_mode: String,
},
/// Stream mode switching completed (transactional end marker)
///
/// Sent when the backend considers the new mode ready for clients to connect.
#[serde(rename = "stream.mode_ready")]
StreamModeReady {
/// Unique transition ID for this mode switch transaction
transition_id: String,
/// Active mode after switch: "mjpeg", "h264", "h265", "vp8", "vp9"
mode: String,
},
StreamModeReady { transition_id: String, mode: String },
// ============================================================================
// MSD (Mass Storage Device) Events
// ============================================================================
/// File upload progress (for large file uploads)
#[serde(rename = "msd.upload_progress")]
MsdUploadProgress {
/// Upload operation ID
upload_id: String,
/// Filename being uploaded
filename: String,
/// Bytes uploaded so far
bytes_uploaded: u64,
/// Total file size
total_bytes: u64,
/// Progress percentage (0.0 - 100.0)
progress_pct: f32,
},
/// Image download progress (for URL downloads)
#[serde(rename = "msd.download_progress")]
MsdDownloadProgress {
/// Download operation ID
download_id: String,
/// Source URL
url: String,
/// Target filename
filename: String,
/// Bytes downloaded so far
bytes_downloaded: u64,
/// Total file size (None if unknown)
total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
progress_pct: Option<f32>,
/// Download status: "started", "in_progress", "completed", "failed"
status: String,
},
/// Complete device information (sent on WebSocket connect and state changes)
#[serde(rename = "system.device_info")]
DeviceInfo {
/// Video device information
video: VideoDeviceInfo,
/// HID device information
hid: HidDeviceInfo,
/// MSD device information (None if MSD not enabled)
msd: Option<MsdDeviceInfo>,
/// ATX device information (None if ATX not enabled)
atx: Option<AtxDeviceInfo>,
/// Audio device information (None if audio not enabled)
audio: Option<AudioDeviceInfo>,
/// ttyd status information
ttyd: TtydDeviceInfo,
},
/// WebSocket error notification (for connection-level errors like lag)
#[serde(rename = "error")]
Error {
/// Error message
message: String,
},
Error { message: String },
}
/// One entry per [`SystemEvent::event_name`]. `EventBus` builds `*.`-wildcard channels from the first segment; names without `.` (e.g. `error`) have no wildcard channel.
pub(crate) const EXACT_EVENT_TOPICS: &[&str] = &[
"stream.mode_switching",
"stream.state_changed",
"stream.config_changing",
"stream.config_applied",
"stream.device_lost",
"stream.reconnecting",
"stream.recovered",
"stream.webrtc_ready",
"stream.stats_update",
"stream.mode_changed",
"stream.mode_ready",
"webrtc.ice_candidate",
"webrtc.ice_complete",
"msd.upload_progress",
"msd.download_progress",
"system.device_info",
"error",
];
impl SystemEvent {
/// Get the event name (for filtering/routing)
pub fn event_name(&self) -> &'static str {
match self {
Self::StreamModeSwitching { .. } => "stream.mode_switching",
@@ -375,27 +250,6 @@ impl SystemEvent {
Self::Error { .. } => "error",
}
}
/// Check if event name matches a topic pattern
///
/// Supports wildcards:
/// - `*` matches all events
/// - `stream.*` matches all stream events
/// - `stream.state_changed` matches exact event
pub fn matches_topic(&self, topic: &str) -> bool {
if topic == "*" {
return true;
}
let event_name = self.event_name();
if topic.ends_with(".*") {
let prefix = topic.trim_end_matches(".*");
event_name.starts_with(prefix)
} else {
event_name == topic
}
}
}
#[cfg(test)]
@@ -407,22 +261,145 @@ mod tests {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
};
assert_eq!(event.event_name(), "stream.state_changed");
}
#[test]
fn test_matches_topic() {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: None,
fn stream_device_lost_json_snake_case_kind() {
let event = SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: "hw:0,0".to_string(),
reason: "test".to_string(),
};
let v = serde_json::to_value(&event).unwrap();
let data = v.get("data").unwrap();
assert_eq!(data.get("kind").and_then(|x| x.as_str()), Some("audio"));
assert_eq!(data.get("device").and_then(|x| x.as_str()), Some("hw:0,0"));
}
assert!(event.matches_topic("*"));
assert!(event.matches_topic("stream.*"));
assert!(event.matches_topic("stream.state_changed"));
assert!(!event.matches_topic("msd.*"));
assert!(!event.matches_topic("stream.config_changed"));
#[test]
fn exact_topics_covers_all_variants() {
use std::collections::HashSet;
let samples = vec![
SystemEvent::StreamModeSwitching {
transition_id: String::new(),
to_mode: String::new(),
from_mode: String::new(),
},
SystemEvent::StreamStateChanged {
state: String::new(),
device: None,
reason: None,
next_retry_ms: None,
},
SystemEvent::StreamConfigChanging {
transition_id: None,
reason: String::new(),
},
SystemEvent::StreamConfigApplied {
transition_id: None,
device: String::new(),
resolution: (0, 0),
format: String::new(),
fps: 0,
},
SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Video,
device: String::new(),
reason: String::new(),
},
SystemEvent::StreamReconnecting {
device: String::new(),
attempt: 0,
},
SystemEvent::StreamRecovered {
device: String::new(),
},
SystemEvent::WebRTCReady {
transition_id: None,
codec: String::new(),
hardware: false,
},
SystemEvent::StreamStatsUpdate {
clients: 0,
clients_stat: HashMap::new(),
},
SystemEvent::StreamModeChanged {
transition_id: None,
mode: String::new(),
previous_mode: String::new(),
},
SystemEvent::StreamModeReady {
transition_id: String::new(),
mode: String::new(),
},
SystemEvent::WebRTCIceCandidate {
session_id: String::new(),
candidate: serde_json::Value::Null,
},
SystemEvent::WebRTCIceComplete {
session_id: String::new(),
},
SystemEvent::MsdUploadProgress {
upload_id: String::new(),
filename: String::new(),
bytes_uploaded: 0,
total_bytes: 0,
progress_pct: 0.0,
},
SystemEvent::MsdDownloadProgress {
download_id: String::new(),
url: String::new(),
filename: String::new(),
bytes_downloaded: 0,
total_bytes: None,
progress_pct: None,
status: String::new(),
},
SystemEvent::DeviceInfo {
video: VideoDeviceInfo {
available: false,
device: None,
format: None,
resolution: None,
fps: 0,
online: false,
stream_mode: String::new(),
config_changing: false,
error: None,
},
hid: HidDeviceInfo {
available: false,
backend: String::new(),
initialized: false,
online: false,
supports_absolute_mouse: false,
keyboard_leds_enabled: false,
led_state: LedState::default(),
device: None,
error: None,
error_code: None,
},
msd: None,
atx: None,
audio: None,
ttyd: TtydDeviceInfo {
available: false,
running: false,
},
},
SystemEvent::Error {
message: String::new(),
},
];
let from_enum: HashSet<_> = samples.iter().map(|e| e.event_name()).collect();
let from_const: HashSet<_> = super::EXACT_EVENT_TOPICS.iter().copied().collect();
assert_eq!(from_enum, from_const);
}
#[test]

View File

@@ -1,7 +1,4 @@
//! Extension process manager
use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
@@ -12,25 +9,26 @@ use tokio::sync::RwLock;
use super::types::*;
use crate::events::EventBus;
/// Maximum number of log lines to keep per extension
const LOG_BUFFER_SIZE: usize = 200;
/// Number of log lines to buffer before flushing to shared storage
const LOG_BATCH_SIZE: usize = 16;
/// Unix socket path for ttyd
#[cfg(unix)]
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
/// Extension process with log buffer
#[cfg(windows)]
pub const TTYD_TCP_ADDR: &str = "127.0.0.1:7681";
#[cfg(windows)]
const TTYD_TCP_HOST: &str = "127.0.0.1";
#[cfg(windows)]
const TTYD_TCP_PORT: &str = "7681";
struct ExtensionProcess {
child: Child,
logs: Arc<RwLock<VecDeque<String>>>,
}
/// Extension manager handles lifecycle of external processes
pub struct ExtensionManager {
processes: RwLock<HashMap<ExtensionId, ExtensionProcess>>,
/// Cached availability status (checked once at startup)
availability: HashMap<ExtensionId, bool>,
event_bus: RwLock<Option<Arc<EventBus>>>,
}
@@ -42,12 +40,10 @@ impl Default for ExtensionManager {
}
impl ExtensionManager {
/// Create a new extension manager with cached availability
pub fn new() -> Self {
// Check availability once at startup
let availability = ExtensionId::all()
.iter()
.map(|id| (*id, Path::new(id.binary_path()).exists()))
.map(|id| (*id, id.binary_path().exists()))
.collect();
Self {
@@ -57,7 +53,6 @@ impl ExtensionManager {
}
}
/// Set event bus for ttyd status notifications.
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus);
}
@@ -72,12 +67,24 @@ impl ExtensionManager {
}
}
/// Check if the binary for an extension is available (cached)
pub fn check_available(&self, id: ExtensionId) -> bool {
*self.availability.get(&id).unwrap_or(&false)
}
/// Get the current status of an extension
fn is_enabled_for_config(id: ExtensionId, config: &ExtensionsConfig) -> bool {
match id {
ExtensionId::Ttyd => config.ttyd.enabled,
ExtensionId::Gostc => {
config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
}
ExtensionId::Easytier => {
config.easytier.enabled && !config.easytier.network_name.is_empty()
}
}
}
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
if !self.check_available(id) {
return ExtensionStatus::Unavailable;
@@ -117,27 +124,24 @@ impl ExtensionManager {
ExtensionStatus::Stopped
}
/// Start an extension with the given configuration
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
if !self.check_available(id) {
return Err(format!(
"{} not found at {}",
id.display_name(),
id.binary_path()
id,
id.binary_path().display()
));
}
// Stop existing process first
self.stop(id).await.ok();
// Build command arguments
let args = self.build_args(id, config).await?;
tracing::info!(
"Starting extension {}: {} {}",
id,
id.binary_path(),
args.join(" ")
id.binary_path().display(),
Self::redact_args_for_log(&args).join(" ")
);
let mut child = Command::new(id.binary_path())
@@ -146,11 +150,10 @@ impl ExtensionManager {
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?;
.map_err(|e| format!("Failed to start {}: {}", id, e))?;
let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE)));
// Spawn log collector for stdout
if let Some(stdout) = child.stdout.take() {
let logs_clone = logs.clone();
let id_clone = id;
@@ -159,7 +162,6 @@ impl ExtensionManager {
});
}
// Spawn log collector for stderr
if let Some(stderr) = child.stderr.take() {
let logs_clone = logs.clone();
let id_clone = id;
@@ -179,7 +181,6 @@ impl ExtensionManager {
Ok(())
}
/// Stop an extension
pub async fn stop(&self, id: ExtensionId) -> Result<(), String> {
let mut processes = self.processes.write().await;
if let Some(mut proc) = processes.remove(&id) {
@@ -193,7 +194,6 @@ impl ExtensionManager {
Ok(())
}
/// Get recent logs for an extension
pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec<String> {
let processes = self.processes.read().await;
if let Some(proc) = processes.get(&id) {
@@ -205,7 +205,6 @@ impl ExtensionManager {
}
}
/// Collect logs from a stream with batched writes to reduce lock contention
async fn collect_logs<R: tokio::io::AsyncRead + Unpin>(
id: ExtensionId,
reader: R,
@@ -218,16 +217,14 @@ impl ExtensionManager {
loop {
match lines.next_line().await {
Ok(Some(line)) => {
tracing::debug!("[{}] {}", id, line);
tracing::info!("[{}] {}", id, line);
local_buffer.push(line);
// Flush when batch is full
if local_buffer.len() >= LOG_BATCH_SIZE {
Self::flush_logs(&logs, &mut local_buffer).await;
}
}
Ok(None) => {
// Stream ended, flush remaining logs
if !local_buffer.is_empty() {
Self::flush_logs(&logs, &mut local_buffer).await;
}
@@ -241,7 +238,6 @@ impl ExtensionManager {
}
}
/// Flush buffered logs to shared storage
async fn flush_logs(logs: &RwLock<VecDeque<String>>, buffer: &mut Vec<String>) {
let mut logs = logs.write().await;
for line in buffer.drain(..) {
@@ -252,7 +248,6 @@ impl ExtensionManager {
}
}
/// Build command arguments for an extension
async fn build_args(
&self,
id: ExtensionId,
@@ -262,18 +257,8 @@ impl ExtensionManager {
ExtensionId::Ttyd => {
let c = &config.ttyd;
// Prepare socket directory and clean up old socket (async)
Self::prepare_ttyd_socket().await?;
let mut args = Self::build_ttyd_listen_args().await?;
let mut args = vec![
"-i".to_string(),
TTYD_SOCKET_PATH.to_string(), // Unix socket
"-b".to_string(),
"/api/terminal".to_string(), // Base path for reverse proxy
"-W".to_string(), // Writable (allow input)
];
// Add shell as last argument
args.push(c.shell.clone());
Ok(args)
}
@@ -289,15 +274,12 @@ impl ExtensionManager {
let mut args = Vec::new();
// Add TLS flag
if c.tls {
args.push("--tls=true".to_string());
}
// Server address (validated non-empty above)
args.extend(["-addr".to_string(), c.addr.trim().to_string()]);
// Add client key
args.extend(["-key".to_string(), c.key.clone()]);
Ok(args)
@@ -316,24 +298,19 @@ impl ExtensionManager {
c.network_secret.clone(),
];
// Add peer URLs
for peer in &c.peer_urls {
if !peer.is_empty() {
args.extend(["--peers".to_string(), peer.clone()]);
}
}
// Add virtual IP: use -d for DHCP if empty, or -i for specific IP
if let Some(ref ip) = c.virtual_ip {
if !ip.is_empty() {
// Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24)
args.extend(["-i".to_string(), ip.clone()]);
} else {
// Empty string means use DHCP
args.push("-d".to_string());
}
} else {
// None means use DHCP
args.push("-d".to_string());
}
@@ -342,11 +319,75 @@ impl ExtensionManager {
}
}
/// Prepare ttyd socket directory and clean up old socket file
async fn prepare_ttyd_socket() -> Result<(), String> {
let socket_path = Path::new(TTYD_SOCKET_PATH);
#[cfg(unix)]
async fn build_ttyd_listen_args() -> Result<Vec<String>, String> {
Self::prepare_ttyd_socket().await?;
Ok(vec![
"-i".to_string(),
TTYD_SOCKET_PATH.to_string(),
"-b".to_string(),
"/api/terminal".to_string(),
"-W".to_string(),
])
}
#[cfg(windows)]
async fn build_ttyd_listen_args() -> Result<Vec<String>, String> {
let cwd = std::env::var("USERPROFILE")
.ok()
.filter(|path| !path.trim().is_empty())
.unwrap_or_else(|| {
std::env::current_dir()
.map(|path| path.to_string_lossy().to_string())
.unwrap_or_else(|_| ".".to_string())
});
Ok(vec![
"-i".to_string(),
TTYD_TCP_HOST.to_string(),
"-p".to_string(),
TTYD_TCP_PORT.to_string(),
"-b".to_string(),
"/api/terminal".to_string(),
"-w".to_string(),
cwd,
"-W".to_string(),
])
}
fn redact_args_for_log(args: &[String]) -> Vec<String> {
let mut redacted = Vec::with_capacity(args.len());
let mut redact_next = false;
for arg in args {
if redact_next {
redacted.push("****".to_string());
redact_next = false;
continue;
}
if arg == "-key" || arg == "--key" {
redacted.push(arg.clone());
redact_next = true;
} else if let Some((flag, _)) = arg.split_once('=') {
if flag == "-key" || flag == "--key" {
redacted.push(format!("{}=****", flag));
} else {
redacted.push(arg.clone());
}
} else {
redacted.push(arg.clone());
}
}
redacted
}
#[cfg(unix)]
async fn prepare_ttyd_socket() -> Result<(), String> {
let socket_path = std::path::Path::new(TTYD_SOCKET_PATH);
// Ensure socket directory exists
if let Some(socket_dir) = socket_path.parent() {
if !socket_dir.exists() {
tokio::fs::create_dir_all(socket_dir)
@@ -355,7 +396,6 @@ impl ExtensionManager {
}
}
// Remove old socket file if exists
if tokio::fs::try_exists(TTYD_SOCKET_PATH)
.await
.unwrap_or(false)
@@ -368,24 +408,11 @@ impl ExtensionManager {
Ok(())
}
/// Health check - restart crashed processes that should be running
pub async fn health_check(&self, config: &ExtensionsConfig) {
// Collect extensions that need restart check
let checks: Vec<_> = ExtensionId::all()
.iter()
.filter_map(|id| {
let should_run = match id {
ExtensionId::Ttyd => config.ttyd.enabled,
ExtensionId::Gostc => {
config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
}
ExtensionId::Easytier => {
config.easytier.enabled && !config.easytier.network_name.is_empty()
}
};
if should_run && self.check_available(*id) {
if Self::is_enabled_for_config(*id, config) && self.check_available(*id) {
Some(*id)
} else {
None
@@ -393,7 +420,6 @@ impl ExtensionManager {
})
.collect();
// Check which ones need restart (single read lock)
let needs_restart: Vec<_> = {
let processes = self.processes.read().await;
checks
@@ -408,7 +434,6 @@ impl ExtensionManager {
.collect()
};
// Restart all crashed extensions in parallel
let restart_futures: Vec<_> = needs_restart
.into_iter()
.map(|id| async move {
@@ -422,50 +447,20 @@ impl ExtensionManager {
futures::future::join_all(restart_futures).await;
}
/// Start all enabled extensions in parallel
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
use futures::Future;
use std::pin::Pin;
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
// Collect enabled extensions
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
tracing::error!("Failed to start ttyd: {}", e);
let start_futures: Vec<_> = ExtensionId::all()
.iter()
.filter(|id| Self::is_enabled_for_config(**id, config) && self.check_available(**id))
.map(|id| async move {
if let Err(e) = self.start(*id, config).await {
tracing::error!("Failed to start {}: {}", id, e);
}
}));
}
})
.collect();
if config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
&& self.check_available(ExtensionId::Gostc)
{
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Gostc, config).await {
tracing::error!("Failed to start gostc: {}", e);
}
}));
}
if config.easytier.enabled
&& !config.easytier.network_name.is_empty()
&& self.check_available(ExtensionId::Easytier)
{
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Easytier, config).await {
tracing::error!("Failed to start easytier: {}", e);
}
}));
}
// Start all in parallel
futures::future::join_all(start_futures).await;
}
/// Stop all running extensions in parallel
pub async fn stop_all(&self) {
let stop_futures: Vec<_> = ExtensionId::all().iter().map(|id| self.stop(*id)).collect();
futures::future::join_all(stop_futures).await;

View File

@@ -1,7 +1,10 @@
//! Extensions module - manage external processes like ttyd, gostc, easytier
mod manager;
mod software;
mod types;
pub use manager::{ExtensionManager, TTYD_SOCKET_PATH};
pub use manager::ExtensionManager;
#[cfg(unix)]
pub use manager::TTYD_SOCKET_PATH;
#[cfg(windows)]
pub use manager::TTYD_TCP_ADDR;
pub use types::*;

View File

@@ -0,0 +1,15 @@
use std::path::PathBuf;
use super::ExtensionId;
#[cfg_attr(windows, path = "software_windows.rs")]
#[cfg_attr(not(windows), path = "software_linux.rs")]
mod platform;
pub fn binary_path(id: ExtensionId) -> PathBuf {
platform::binary_path(id)
}
pub fn default_ttyd_shell() -> &'static str {
platform::default_ttyd_shell()
}

View File

@@ -0,0 +1,19 @@
use std::path::PathBuf;
use super::ExtensionId;
pub fn default_binary_path(id: ExtensionId) -> &'static str {
match id {
ExtensionId::Ttyd => "/usr/bin/ttyd",
ExtensionId::Gostc => "/usr/bin/gostc",
ExtensionId::Easytier => "/usr/bin/easytier-core",
}
}
pub fn binary_path(id: ExtensionId) -> PathBuf {
PathBuf::from(default_binary_path(id))
}
pub fn default_ttyd_shell() -> &'static str {
"/bin/bash"
}

View File

@@ -0,0 +1,47 @@
use std::path::PathBuf;
use super::ExtensionId;
pub fn default_binary_path(id: ExtensionId) -> &'static str {
match id {
ExtensionId::Ttyd => "ttyd.win32.exe",
ExtensionId::Gostc => "gostc.exe",
ExtensionId::Easytier => "easytier-core.exe",
}
}
pub fn binary_path(id: ExtensionId) -> PathBuf {
if id == ExtensionId::Ttyd {
if let Some(path) = env_path("ONE_KVM_TTYD_PATH") {
return path;
}
}
find_in_app_dir(default_binary_path(id))
.unwrap_or_else(|| PathBuf::from(default_binary_path(id)))
}
pub fn default_ttyd_shell() -> &'static str {
"cmd"
}
fn env_path(name: &str) -> Option<PathBuf> {
std::env::var(name)
.ok()
.map(|path| path.trim().to_string())
.filter(|path| !path.is_empty())
.map(PathBuf::from)
}
fn find_in_app_dir(binary_name: &str) -> Option<PathBuf> {
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
let bundled = exe_dir.join(binary_name);
if bundled.exists() {
return Some(bundled);
}
}
}
None
}

View File

@@ -1,41 +1,22 @@
//! Extension types and configurations
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Extension identifier (fixed set of supported extensions)
use super::software;
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ExtensionId {
/// Web terminal (ttyd)
Ttyd,
/// NAT traversal client (gostc)
Gostc,
/// P2P VPN (easytier)
Easytier,
}
impl ExtensionId {
/// Get the binary path for this extension
pub fn binary_path(&self) -> &'static str {
match self {
Self::Ttyd => "/usr/bin/ttyd",
Self::Gostc => "/usr/bin/gostc",
Self::Easytier => "/usr/bin/easytier-core",
}
pub fn binary_path(&self) -> std::path::PathBuf {
software::binary_path(*self)
}
/// Get the display name for this extension
pub fn display_name(&self) -> &'static str {
match self {
Self::Ttyd => "Web Terminal",
Self::Gostc => "GOSTC Tunnel",
Self::Easytier => "EasyTier VPN",
}
}
/// Get all extension IDs
pub fn all() -> &'static [ExtensionId] {
&[Self::Ttyd, Self::Gostc, Self::Easytier]
}
@@ -64,25 +45,13 @@ impl std::str::FromStr for ExtensionId {
}
}
/// Extension running status
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "state", content = "data", rename_all = "lowercase")]
pub enum ExtensionStatus {
/// Binary not found at expected path
Unavailable,
/// Extension is stopped
Stopped,
/// Extension is running
Running {
/// Process ID
pid: u32,
},
/// Extension failed to start
Failed {
/// Error message
error: String,
},
Running { pid: u32 },
}
impl ExtensionStatus {
@@ -91,16 +60,11 @@ impl ExtensionStatus {
}
}
/// ttyd configuration (Web Terminal)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TtydConfig {
/// Enable auto-start
pub enabled: bool,
/// Port to listen on
pub port: u16,
/// Shell to execute
pub shell: String,
}
@@ -108,25 +72,19 @@ impl Default for TtydConfig {
fn default() -> Self {
Self {
enabled: false,
port: 7681,
shell: "/bin/bash".to_string(),
shell: software::default_ttyd_shell().to_string(),
}
}
}
/// gostc configuration (NAT traversal based on FRP)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct GostcConfig {
/// Enable auto-start
pub enabled: bool,
/// Server address (hostname or IP)
pub addr: String,
/// Client key from GOSTC management panel
#[serde(skip_serializing_if = "String::is_empty")]
pub key: String,
/// Enable TLS
pub tls: bool,
}
@@ -141,28 +99,21 @@ impl Default for GostcConfig {
}
}
/// EasyTier configuration (P2P VPN)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct EasytierConfig {
/// Enable auto-start
pub enabled: bool,
/// Network name
pub network_name: String,
/// Network secret/password
#[serde(skip_serializing_if = "String::is_empty")]
pub network_secret: String,
/// Peer node URLs
#[serde(skip_serializing_if = "Vec::is_empty")]
pub peer_urls: Vec<String>,
/// Virtual IP address (optional, auto-assigned if not set)
#[serde(skip_serializing_if = "Option::is_none")]
pub virtual_ip: Option<String>,
}
/// Combined extensions configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
@@ -172,53 +123,37 @@ pub struct ExtensionsConfig {
pub easytier: EasytierConfig,
}
/// Extension info with status and config
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
}
/// ttyd extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: TtydConfig,
}
/// gostc extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GostcInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: GostcConfig,
}
/// easytier extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EasytierInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: EasytierConfig,
}
/// All extensions status response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionsStatus {
@@ -227,7 +162,6 @@ pub struct ExtensionsStatus {
pub easytier: EasytierInfo,
}
/// Extension logs response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionLogs {

View File

@@ -1,73 +1,32 @@
//! HID backend trait definition
//! `HidBackend` trait plus serde `HidBackendType` (OTG | CH9329 | disabled).
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use super::otg::LedState;
use super::types::{ConsumerEvent, KeyboardEvent, MouseEvent};
use crate::error::Result;
use crate::events::LedState;
/// Default CH9329 baud rate
fn default_ch9329_baud_rate() -> u32 {
9600
}
/// HID backend type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackendType {
/// USB OTG gadget mode
Otg,
/// CH9329 serial HID controller
Ch9329 {
/// Serial port path
port: String,
/// Baud rate (default: 9600)
#[serde(default = "default_ch9329_baud_rate")]
baud_rate: u32,
},
/// No HID backend (disabled)
#[default]
None,
}
impl HidBackendType {
/// Check if OTG backend is available on this system
pub fn otg_available() -> bool {
// Check for USB gadget support
std::path::Path::new("/sys/class/udc").exists()
}
/// Detect the best available backend
pub fn detect() -> Self {
// Check for OTG gadget support
if Self::otg_available() {
return Self::Otg;
}
// Check for common CH9329 serial ports
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
];
for port in &common_ports {
if std::path::Path::new(port).exists() {
return Self::Ch9329 {
port: port.to_string(),
baud_rate: 9600, // Use default baud rate for auto-detection
};
}
}
Self::None
}
/// Get backend name as string
pub fn name_str(&self) -> &str {
match self {
Self::Otg => "otg",
@@ -77,76 +36,40 @@ impl HidBackendType {
}
}
/// Current runtime status reported by a HID backend.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct HidBackendRuntimeSnapshot {
/// Whether the backend has been initialized and can accept requests.
pub initialized: bool,
/// Whether the backend is currently online and communicating successfully.
pub online: bool,
/// Whether absolute mouse positioning is supported.
pub supports_absolute_mouse: bool,
/// Whether keyboard LED/status feedback is currently enabled.
pub keyboard_leds_enabled: bool,
/// Last known keyboard LED state.
pub led_state: LedState,
/// Screen resolution for absolute mouse mode.
pub screen_resolution: Option<(u32, u32)>,
/// Device identifier associated with the backend, if any.
pub device: Option<String>,
/// Current user-facing error, if any.
pub error: Option<String>,
/// Current programmatic error code, if any.
pub error_code: Option<String>,
}
/// HID backend trait
#[async_trait]
pub trait HidBackend: Send + Sync {
/// Initialize the backend
async fn init(&self) -> Result<()>;
/// Send a keyboard event
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>;
/// Send a mouse event
async fn send_mouse(&self, event: MouseEvent) -> Result<()>;
/// Send a consumer control event (multimedia keys)
/// Default implementation returns an error (not supported)
async fn send_consumer(&self, _event: ConsumerEvent) -> Result<()> {
Err(crate::error::AppError::BadRequest(
"Consumer control not supported by this backend".to_string(),
))
}
/// Reset all inputs (release all keys/buttons)
async fn reset(&self) -> Result<()>;
/// Shutdown the backend
async fn shutdown(&self) -> Result<()>;
/// Get the current backend runtime snapshot.
fn runtime_snapshot(&self) -> HidBackendRuntimeSnapshot;
/// Subscribe to backend runtime changes.
fn subscribe_runtime(&self) -> watch::Receiver<()>;
/// Set screen resolution (for absolute mouse)
fn set_screen_resolution(&mut self, _width: u32, _height: u32) {}
}
/// HID backend information
#[derive(Debug, Clone, Serialize)]
pub struct HidBackendInfo {
/// Backend name
pub name: String,
/// Backend type
pub backend_type: String,
/// Is initialized
pub initialized: bool,
/// Supports absolute mouse
pub absolute_mouse: bool,
/// Screen resolution (if absolute mouse)
pub resolution: Option<(u32, u32)>,
fn set_screen_resolution(&self, _width: u32, _height: u32) {}
}

View File

@@ -1,9 +1,4 @@
//! CH9329 Serial HID Controller backend
//!
//! CH9329 is a USB HID chip controlled via UART from WCH (沁恒).
//! It supports keyboard, mouse (absolute + relative), and custom HID device emulation.
//!
//! ## Protocol Format
//! CH9329 over UART — WCH *Serial Communication Protocol V1.0*.
//! ```text
//! ┌──────┬──────┬──────┬────────┬──────────────┬──────────┐
//! │Header│ ADDR │ CMD │ LEN │ DATA │ SUM │
@@ -11,342 +6,36 @@
//! │57 AB │ 00 │ xx │ N │ N bytes │Checksum │
//! └──────┴──────┴──────┴────────┴──────────────┴──────────┘
//! ```
//!
//! Checksum: Sum of ALL bytes including header (modulo 256)
//!
//! ## Reference
//! Based on WCH CH9329 Serial Communication Protocol V1.0
//! Sum of all octets modulo 256 (including header).
use async_trait::async_trait;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering};
use std::sync::{mpsc, Arc};
use std::thread;
use std::time::{Duration, Instant};
use tokio::sync::watch;
use tracing::{info, trace, warn};
use tracing::{info, trace};
use super::backend::{HidBackend, HidBackendRuntimeSnapshot};
use super::otg::LedState;
use super::ch9329_proto::{
build_packet, cmd, expected_response_cmd, try_extract_response, ChipInfo, LedStatus, Response,
DEFAULT_ADDR, DEFAULT_BAUD_RATE, MAX_PACKET_SIZE,
};
use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType};
use crate::error::{AppError, Result};
use crate::events::LedState;
// ============================================================================
// Constants and Command Codes
// ============================================================================
/// CH9329 packet header
const PACKET_HEADER: [u8; 2] = [0x57, 0xAB];
/// Default address (accepts any address)
const DEFAULT_ADDR: u8 = 0x00;
/// Default baud rate for CH9329
pub const DEFAULT_BAUD_RATE: u32 = 9600;
/// Response timeout in milliseconds
const RESPONSE_TIMEOUT_MS: u64 = 500;
/// Maximum data length in a packet
const MAX_DATA_LEN: usize = 64;
/// CH9329 absolute mouse resolution
const CH9329_MOUSE_RESOLUTION: u32 = 4096;
/// How often the worker probes the chip when idle.
const PROBE_INTERVAL_MS: u64 = 100;
/// How long the worker waits before reopening the serial port after a failure.
const RECONNECT_DELAY_MS: u64 = 2000;
/// Initial startup wait for the worker to confirm CH9329 is reachable.
const INIT_WAIT_MS: u64 = 3000;
/// CH9329 command codes
pub mod cmd {
/// Get chip version, USB status, and LED status
pub const GET_INFO: u8 = 0x01;
/// Send standard keyboard data (8 bytes)
pub const SEND_KB_GENERAL_DATA: u8 = 0x02;
/// Send multimedia keyboard data
pub const SEND_KB_MEDIA_DATA: u8 = 0x03;
/// Send absolute mouse data
pub const SEND_MS_ABS_DATA: u8 = 0x04;
/// Send relative mouse data
pub const SEND_MS_REL_DATA: u8 = 0x05;
/// Send custom HID data
pub const SEND_MY_HID_DATA: u8 = 0x06;
/// Restore factory default configuration
pub const SET_DEFAULT_CFG: u8 = 0x0C;
/// Software reset
pub const RESET: u8 = 0x0F;
}
/// Response command mask (success = cmd | 0x80, error = cmd | 0xC0)
const RESPONSE_SUCCESS_MASK: u8 = 0x80;
const RESPONSE_ERROR_MASK: u8 = 0xC0;
/// CH9329 error codes
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Ch9329Error {
/// Command executed successfully
Success = 0x00,
/// Serial receive timeout
Timeout = 0xE1,
/// Invalid packet header
InvalidHeader = 0xE2,
/// Invalid command code
InvalidCommand = 0xE3,
/// Checksum mismatch
ChecksumError = 0xE4,
/// Parameter error
ParameterError = 0xE5,
/// Execution failed
OperationFailed = 0xE6,
}
impl From<u8> for Ch9329Error {
fn from(code: u8) -> Self {
match code {
0x00 => Ch9329Error::Success,
0xE1 => Ch9329Error::Timeout,
0xE2 => Ch9329Error::InvalidHeader,
0xE3 => Ch9329Error::InvalidCommand,
0xE4 => Ch9329Error::ChecksumError,
0xE5 => Ch9329Error::ParameterError,
0xE6 => Ch9329Error::OperationFailed,
_ => Ch9329Error::OperationFailed,
}
}
}
impl std::fmt::Display for Ch9329Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ch9329Error::Success => write!(f, "Success"),
Ch9329Error::Timeout => write!(f, "Serial receive timeout"),
Ch9329Error::InvalidHeader => write!(f, "Invalid packet header"),
Ch9329Error::InvalidCommand => write!(f, "Invalid command code"),
Ch9329Error::ChecksumError => write!(f, "Checksum mismatch"),
Ch9329Error::ParameterError => write!(f, "Parameter error"),
Ch9329Error::OperationFailed => write!(f, "Operation failed"),
}
}
}
// ============================================================================
// Chip Information
// ============================================================================
/// CH9329 chip information
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChipInfo {
/// Chip version (e.g., "V1.0", "V1.1")
pub version: String,
/// Raw version byte
pub version_raw: u8,
/// USB connection status
pub usb_connected: bool,
/// Num Lock LED state
pub num_lock: bool,
/// Caps Lock LED state
pub caps_lock: bool,
/// Scroll Lock LED state
pub scroll_lock: bool,
}
impl ChipInfo {
/// Parse chip info from response data (8 bytes)
pub fn from_response(data: &[u8]) -> Option<Self> {
if data.len() < 8 {
return None;
}
let version_raw = data[0];
let version = format!("V{}.{}", version_raw >> 4, version_raw & 0x0F);
let usb_connected = data[1] == 0x01;
let led_status = data[2];
Some(Self {
version,
version_raw,
usb_connected,
num_lock: (led_status & 0x01) != 0,
caps_lock: (led_status & 0x02) != 0,
scroll_lock: (led_status & 0x04) != 0,
})
}
}
/// Keyboard LED status
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LedStatus {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl From<u8> for LedStatus {
fn from(byte: u8) -> Self {
Self {
num_lock: (byte & 0x01) != 0,
caps_lock: (byte & 0x02) != 0,
scroll_lock: (byte & 0x04) != 0,
}
}
}
// ============================================================================
// Configuration
// ============================================================================
/// CH9329 work mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
#[derive(Default)]
pub enum WorkMode {
/// Mode 0: Standard USB Keyboard + Mouse (default)
#[default]
KeyboardMouse = 0x00,
/// Mode 1: Standard USB Keyboard only
KeyboardOnly = 0x01,
/// Mode 2: Standard USB Mouse only
MouseOnly = 0x02,
/// Mode 3: Custom HID device
CustomHid = 0x03,
}
/// CH9329 serial communication mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
#[derive(Default)]
pub enum SerialMode {
/// Mode 0: Protocol transmission mode (default)
#[default]
Protocol = 0x00,
/// Mode 1: ASCII mode
Ascii = 0x01,
/// Mode 2: Transparent mode
Transparent = 0x02,
}
/// CH9329 configuration parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ch9329Config {
/// Work mode
pub work_mode: WorkMode,
/// Serial communication mode
pub serial_mode: SerialMode,
/// Device address (0x00-0xFE, 0xFF = broadcast)
pub address: u8,
/// Baud rate
pub baud_rate: u32,
/// USB VID
pub vid: u16,
/// USB PID
pub pid: u16,
}
impl Default for Ch9329Config {
fn default() -> Self {
Self {
work_mode: WorkMode::KeyboardMouse,
serial_mode: SerialMode::Protocol,
address: 0x00,
baud_rate: 9600,
vid: 0x1A86,
pid: 0xE129,
}
}
}
// ============================================================================
// Response Parsing
// ============================================================================
/// Parsed response from CH9329
#[derive(Debug)]
pub struct Response {
/// Address byte
pub address: u8,
/// Command code (with response bits)
pub cmd: u8,
/// Data payload
pub data: Vec<u8>,
/// Whether this is an error response
pub is_error: bool,
/// Error code (if is_error)
pub error_code: Option<Ch9329Error>,
}
impl Response {
/// Parse a response from raw bytes
pub fn parse(bytes: &[u8]) -> Option<Self> {
// Minimum: Header(2) + Addr(1) + Cmd(1) + Len(1) + Sum(1) = 6
if bytes.len() < 6 {
return None;
}
// Check header
if bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] {
return None;
}
let address = bytes[2];
let cmd = bytes[3];
let len = bytes[4] as usize;
// Check if we have enough bytes
if bytes.len() < 5 + len + 1 {
return None;
}
// Verify checksum
let expected_checksum = bytes[5 + len];
let calculated_checksum = bytes[..5 + len]
.iter()
.fold(0u8, |acc, &x| acc.wrapping_add(x));
if expected_checksum != calculated_checksum {
warn!(
"CH9329 checksum mismatch: expected {:02X}, got {:02X}",
expected_checksum, calculated_checksum
);
return None;
}
let data = bytes[5..5 + len].to_vec();
let is_error = (cmd & RESPONSE_ERROR_MASK) == RESPONSE_ERROR_MASK;
let error_code = if is_error && !data.is_empty() {
Some(Ch9329Error::from(data[0]))
} else {
None
};
Some(Self {
address,
cmd,
data,
is_error,
error_code,
})
}
/// Check if the response indicates success
pub fn is_success(&self) -> bool {
!self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8)
}
}
/// Maximum packet size (header 2 + addr 1 + cmd 1 + len 1 + data 64 + checksum 1 = 70)
const MAX_PACKET_SIZE: usize = 70;
// ============================================================================
// CH9329 Backend Implementation
// ============================================================================
struct Ch9329RuntimeState {
initialized: AtomicBool,
@@ -424,47 +113,28 @@ enum WorkerCommand {
Shutdown,
}
/// CH9329 HID backend
pub struct Ch9329Backend {
/// Serial port path
port_path: String,
/// Baud rate
baud_rate: u32,
/// Worker command sender
worker_tx: Mutex<Option<mpsc::Sender<WorkerCommand>>>,
/// Background worker thread
worker_handle: Mutex<Option<thread::JoinHandle<()>>>,
/// Current keyboard state
keyboard_state: Mutex<KeyboardReport>,
/// Current mouse button state
mouse_buttons: AtomicU8,
/// Screen width for absolute mouse coordinate conversion
screen_width: u32,
/// Screen height for absolute mouse coordinate conversion
screen_height: u32,
/// Cached chip information
screen_resolution: RwLock<(u32, u32)>,
chip_info: Arc<RwLock<Option<ChipInfo>>>,
/// LED status cache
led_status: Arc<RwLock<LedStatus>>,
/// Device address (default 0x00)
address: u8,
/// Last absolute mouse X position (CH9329 coordinate: 0-4095)
last_abs_x: AtomicU16,
/// Last absolute mouse Y position (CH9329 coordinate: 0-4095)
last_abs_y: AtomicU16,
/// Whether relative mouse mode is active (set by incoming events)
relative_mouse_active: AtomicBool,
/// Shared runtime status updated only by the worker.
runtime: Arc<Ch9329RuntimeState>,
}
impl Ch9329Backend {
/// Create a new CH9329 backend with default baud rate (9600)
pub fn new(port_path: &str) -> Result<Self> {
Self::with_baud_rate(port_path, DEFAULT_BAUD_RATE)
}
/// Create a new CH9329 backend with custom baud rate
pub fn with_baud_rate(port_path: &str, baud_rate: u32) -> Result<Self> {
Ok(Self {
port_path: port_path.to_string(),
@@ -473,8 +143,7 @@ impl Ch9329Backend {
worker_handle: Mutex::new(None),
keyboard_state: Mutex::new(KeyboardReport::default()),
mouse_buttons: AtomicU8::new(0),
screen_width: 1920,
screen_height: 1080,
screen_resolution: RwLock::new((1920, 1080)),
chip_info: Arc::new(RwLock::new(None)),
led_status: Arc::new(RwLock::new(LedStatus::default())),
address: DEFAULT_ADDR,
@@ -489,12 +158,17 @@ impl Ch9329Backend {
self.runtime.set_error(reason, error_code);
}
/// Check if the serial port device file exists
pub fn check_port_exists(&self) -> bool {
#[cfg(windows)]
{
return crate::utils::list_serial_ports()
.iter()
.any(|port| port.eq_ignore_ascii_case(&self.port_path));
}
#[cfg(not(windows))]
std::path::Path::new(&self.port_path).exists()
}
/// Convert serialport error to HidError
fn serial_error_to_hid_error(e: serialport::Error, operation: &str) -> AppError {
let error_code = match e.kind() {
serialport::ErrorKind::NoDevice => "port_not_found",
@@ -518,52 +192,9 @@ impl Ch9329Backend {
}
}
/// Calculate checksum for CH9329 packet (sum of ALL bytes including header)
#[inline]
fn calculate_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
}
/// Build a CH9329 packet into a stack-allocated buffer
///
/// Packet format: `[Header 0x57 0xAB] [Address] [Command] [Length] [Data] [Checksum]`
/// Returns the packet buffer and the actual length
#[inline]
fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) {
debug_assert!(
data.len() <= MAX_DATA_LEN,
"Data too long for CH9329 packet"
);
let len = data.len() as u8;
let packet_len = 6 + data.len();
let mut packet = [0u8; MAX_PACKET_SIZE];
// Header (2 bytes)
packet[0] = PACKET_HEADER[0];
packet[1] = PACKET_HEADER[1];
// Address (1 byte)
packet[2] = address;
// Command (1 byte)
packet[3] = cmd;
// Length (1 byte) - data length only
packet[4] = len;
// Data (N bytes)
packet[5..5 + data.len()].copy_from_slice(data);
// Checksum (1 byte) - sum of ALL bytes including header
let checksum = Self::calculate_checksum(&packet[..5 + data.len()]);
packet[5 + data.len()] = checksum;
(packet, packet_len)
}
/// Build a CH9329 packet (legacy Vec version for compatibility)
fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec<u8> {
let (buf, len) = Self::build_packet_buf(address, cmd, data);
buf[..len].to_vec()
}
fn open_port(port_path: &str, baud_rate: u32) -> Result<Box<dyn serialport::SerialPort>> {
#[cfg(not(windows))]
if !std::path::Path::new(port_path).exists() {
return Err(Self::backend_error(
format!("Serial port {} not found", port_path),
@@ -583,46 +214,13 @@ impl Ch9329Backend {
cmd: u8,
data: &[u8],
) -> Result<()> {
let packet = Self::build_packet(address, cmd, data);
let packet = build_packet(address, cmd, data);
port.write_all(&packet).map_err(|e| {
Self::backend_error(format!("Failed to write to CH9329: {}", e), "write_failed")
})?;
Ok(())
}
fn try_extract_response(buffer: &[u8]) -> Option<(Response, usize)> {
let mut offset = 0;
while offset + 6 <= buffer.len() {
if buffer[offset] != PACKET_HEADER[0] || buffer[offset + 1] != PACKET_HEADER[1] {
offset += 1;
continue;
}
let len = buffer[offset + 4] as usize;
let frame_len = 6 + len;
if offset + frame_len > buffer.len() {
return None;
}
let frame = &buffer[offset..offset + frame_len];
if let Some(response) = Response::parse(frame) {
return Some((response, offset + frame_len));
}
offset += 1;
}
None
}
fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 {
cmd | if is_error {
RESPONSE_ERROR_MASK
} else {
RESPONSE_SUCCESS_MASK
}
}
fn xfer_packet(
port: &mut dyn serialport::SerialPort,
address: u8,
@@ -633,8 +231,8 @@ impl Ch9329Backend {
let mut pending = Vec::with_capacity(128);
let deadline = Instant::now() + Duration::from_millis(RESPONSE_TIMEOUT_MS);
let expected_ok = Self::expected_response_cmd(cmd, false);
let expected_err = Self::expected_response_cmd(cmd, true);
let expected_ok = expected_response_cmd(cmd, false);
let expected_err = expected_response_cmd(cmd, true);
loop {
let mut chunk = [0u8; 128];
@@ -642,7 +240,7 @@ impl Ch9329Backend {
Ok(n) if n > 0 => {
pending.extend_from_slice(&chunk[..n]);
while let Some((response, consumed)) = Self::try_extract_response(&pending) {
while let Some((response, consumed)) = try_extract_response(&pending) {
pending.drain(..consumed);
if response.cmd == expected_ok || response.cmd == expected_err {
return Ok(response);
@@ -984,10 +582,6 @@ impl Ch9329Backend {
}
}
// ============================================================================
// HidBackend Trait Implementation
// ============================================================================
#[async_trait]
impl HidBackend for Ch9329Backend {
async fn init(&self) -> Result<()> {
@@ -1060,7 +654,6 @@ impl HidBackend for Ch9329Backend {
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
let usb_key = event.key.to_hid_usage();
// Handle modifier keys separately
if event.key.is_modifier() {
let mut state = self.keyboard_state.lock();
@@ -1078,7 +671,6 @@ impl HidBackend for Ch9329Backend {
} else {
let mut state = self.keyboard_state.lock();
// Update modifiers from event
state.modifiers = event.modifiers.to_hid_byte();
match event.event_type {
@@ -1104,19 +696,15 @@ impl HidBackend for Ch9329Backend {
match event.event_type {
MouseEventType::Move => {
// Relative movement - send delta directly without inversion
self.relative_mouse_active.store(true, Ordering::Relaxed);
let dx = event.x.clamp(-127, 127) as i8;
let dy = event.y.clamp(-127, 127) as i8;
self.send_mouse_relative(buttons, dx, dy, 0)?;
}
MouseEventType::MoveAbs => {
// Absolute movement
self.relative_mouse_active.store(false, Ordering::Relaxed);
// Frontend sends 0-32767 (HID standard), CH9329 expects 0-4095
let x = ((event.x.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16;
let y = ((event.y.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16;
// Store last absolute position for click events
self.last_abs_x.store(x, Ordering::Relaxed);
self.last_abs_y.store(y, Ordering::Relaxed);
self.send_mouse_absolute(buttons, x, y, 0)?;
@@ -1153,7 +741,6 @@ impl HidBackend for Ch9329Backend {
if self.relative_mouse_active.load(Ordering::Relaxed) {
self.send_mouse_relative(buttons, 0, 0, event.scroll)?;
} else {
// Use absolute mouse for scroll with last position
let x = self.last_abs_x.load(Ordering::Relaxed);
let y = self.last_abs_y.load(Ordering::Relaxed);
self.send_mouse_absolute(buttons, x, y, event.scroll)?;
@@ -1165,7 +752,6 @@ impl HidBackend for Ch9329Backend {
}
async fn reset(&self) -> Result<()> {
// Reset keyboard
{
let mut state = self.keyboard_state.lock();
state.clear();
@@ -1174,14 +760,12 @@ impl HidBackend for Ch9329Backend {
self.send_keyboard_report(&report)?;
}
// Reset mouse
self.mouse_buttons.store(0, Ordering::Relaxed);
self.last_abs_x.store(0, Ordering::Relaxed);
self.last_abs_y.store(0, Ordering::Relaxed);
self.relative_mouse_active.store(false, Ordering::Relaxed);
self.send_mouse_absolute(0, 0, 0, 0)?;
// Reset media keys
let _ = self.release_media_keys();
info!("CH9329 HID state reset");
@@ -1210,7 +794,14 @@ impl HidBackend for Ch9329Backend {
let mut online = initialized && self.runtime.online.load(Ordering::Relaxed);
let mut error = self.runtime.last_error.read().clone();
if initialized && !self.check_port_exists() {
#[cfg(windows)]
let port_still_present = crate::utils::list_serial_ports()
.iter()
.any(|port| port.eq_ignore_ascii_case(&self.port_path));
#[cfg(not(windows))]
let port_still_present = self.check_port_exists();
if initialized && !port_still_present {
online = false;
error = Some((
format!("Serial port {} not found", self.port_path),
@@ -1233,7 +824,7 @@ impl HidBackend for Ch9329Backend {
kana: false,
}
},
screen_resolution: Some((self.screen_width, self.screen_height)),
screen_resolution: Some(*self.screen_resolution.read()),
device: Some(self.port_path.clone()),
error: error.as_ref().map(|(reason, _)| reason.clone()),
error_code: error.as_ref().map(|(_, code)| code.clone()),
@@ -1244,127 +835,24 @@ impl HidBackend for Ch9329Backend {
self.runtime.subscribe()
}
fn set_screen_resolution(&mut self, width: u32, height: u32) {
self.screen_width = width;
self.screen_height = height;
fn set_screen_resolution(&self, width: u32, height: u32) {
*self.screen_resolution.write() = (width, height);
self.runtime.notify();
}
}
// ============================================================================
// Detection and Helpers
// ============================================================================
/// Detect CH9329 on common serial ports
pub fn detect_ch9329() -> Option<String> {
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
"/dev/ttyS0",
];
// Try multiple baud rates
let baud_rates = [9600, 115200];
for port_path in &common_ports {
if !std::path::Path::new(port_path).exists() {
continue;
}
for &baud_rate in &baud_rates {
if let Ok(mut port) = serialport::new(*port_path, baud_rate)
.timeout(Duration::from_millis(200))
.open()
{
// Build GET_INFO packet manually (address = 0x00)
let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03];
if port.write_all(&packet).is_ok() {
std::thread::sleep(Duration::from_millis(50));
let mut response = [0u8; 16];
if let Ok(n) = port.read(&mut response) {
// Check for valid CH9329 response header
if n >= 6
&& response[0] == PACKET_HEADER[0]
&& response[1] == PACKET_HEADER[1]
{
info!("CH9329 detected on {} @ {} baud", port_path, baud_rate);
return Some(port_path.to_string());
}
}
}
}
}
}
None
}
/// Detect CH9329 and return both path and working baud rate
pub fn detect_ch9329_with_baud() -> Option<(String, u32)> {
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
"/dev/ttyS0",
];
let baud_rates = [9600, 115200, 57600, 38400, 19200];
for port_path in &common_ports {
if !std::path::Path::new(port_path).exists() {
continue;
}
for &baud_rate in &baud_rates {
if let Ok(mut port) = serialport::new(*port_path, baud_rate)
.timeout(Duration::from_millis(200))
.open()
{
let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03];
if port.write_all(&packet).is_ok() {
std::thread::sleep(Duration::from_millis(50));
let mut response = [0u8; 16];
if let Ok(n) = port.read(&mut response) {
if n >= 6
&& response[0] == PACKET_HEADER[0]
&& response[1] == PACKET_HEADER[1]
{
info!("CH9329 detected on {} @ {} baud", port_path, baud_rate);
return Some((port_path.to_string(), baud_rate));
}
}
}
}
}
}
None
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use super::ch9329_proto::{build_packet, calculate_checksum};
#[test]
fn test_packet_building() {
// Test GET_INFO packet (no data)
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]);
let packet = build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]);
assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]);
// Test keyboard packet (8 bytes data)
let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data);
let packet = build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data);
assert_eq!(packet[0], 0x57); // Header
assert_eq!(packet[1], 0xAB); // Header
@@ -1372,7 +860,6 @@ mod tests {
assert_eq!(packet[3], cmd::SEND_KB_GENERAL_DATA); // Command
assert_eq!(packet[4], 8); // Length (8 data bytes)
assert_eq!(&packet[5..13], &data); // Data
// Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10
let expected_checksum: u8 = packet[..13]
.iter()
.fold(0u8, |acc: u8, &x| acc.wrapping_add(x));
@@ -1381,9 +868,8 @@ mod tests {
#[test]
fn test_relative_mouse_packet() {
// Test relative mouse: move right 50 pixels
let data = [0x01, 0x00, 50u8, 0x00, 0x00];
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data);
let packet = build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data);
assert_eq!(packet[0], 0x57);
assert_eq!(packet[1], 0xAB);
@@ -1397,22 +883,19 @@ mod tests {
#[test]
fn test_checksum_calculation() {
// Known packet: GET_INFO
let packet = [0x57u8, 0xAB, 0x00, 0x01, 0x00];
let checksum = Ch9329Backend::calculate_checksum(&packet);
let checksum = calculate_checksum(&packet);
assert_eq!(checksum, 0x03);
// Known packet: Keyboard 'A' press
let packet = [
0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
];
let checksum = Ch9329Backend::calculate_checksum(&packet);
let checksum = calculate_checksum(&packet);
assert_eq!(checksum, 0x10);
}
#[test]
fn test_response_parsing() {
// Valid GET_INFO response
let response_bytes = [
0x57, 0xAB, // Header
0x00, // Address
@@ -1422,9 +905,7 @@ mod tests {
0xE0, // Checksum (calculated)
];
// Note: checksum in test is just placeholder, parse will validate
let _result = Response::parse(&response_bytes);
// This will fail because checksum doesn't match, but structure is tested
}
#[test]

225
src/hid/ch9329_proto.rs Normal file
View File

@@ -0,0 +1,225 @@
//! Shared CH9329 protocol types and packet helpers.
use serde::{Deserialize, Serialize};
const PACKET_HEADER: [u8; 2] = [0x57, 0xAB];
pub const RESPONSE_SUCCESS_MASK: u8 = 0x80;
pub const RESPONSE_ERROR_MASK: u8 = 0xC0;
pub const DEFAULT_ADDR: u8 = 0x00;
pub const DEFAULT_BAUD_RATE: u32 = 9600;
pub const MAX_DATA_LEN: usize = 64;
pub const MAX_PACKET_SIZE: usize = 70;
pub mod cmd {
pub const GET_INFO: u8 = 0x01;
pub const SEND_KB_GENERAL_DATA: u8 = 0x02;
pub const SEND_KB_MEDIA_DATA: u8 = 0x03;
pub const SEND_MS_ABS_DATA: u8 = 0x04;
pub const SEND_MS_REL_DATA: u8 = 0x05;
pub const RESET: u8 = 0x0F;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Ch9329Error {
Success = 0x00,
Timeout = 0xE1,
InvalidHeader = 0xE2,
InvalidCommand = 0xE3,
ChecksumError = 0xE4,
ParameterError = 0xE5,
OperationFailed = 0xE6,
}
impl From<u8> for Ch9329Error {
fn from(code: u8) -> Self {
match code {
0x00 => Ch9329Error::Success,
0xE1 => Ch9329Error::Timeout,
0xE2 => Ch9329Error::InvalidHeader,
0xE3 => Ch9329Error::InvalidCommand,
0xE4 => Ch9329Error::ChecksumError,
0xE5 => Ch9329Error::ParameterError,
0xE6 => Ch9329Error::OperationFailed,
_ => Ch9329Error::OperationFailed,
}
}
}
impl std::fmt::Display for Ch9329Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ch9329Error::Success => write!(f, "Success"),
Ch9329Error::Timeout => write!(f, "Serial receive timeout"),
Ch9329Error::InvalidHeader => write!(f, "Invalid packet header"),
Ch9329Error::InvalidCommand => write!(f, "Invalid command code"),
Ch9329Error::ChecksumError => write!(f, "Checksum mismatch"),
Ch9329Error::ParameterError => write!(f, "Parameter error"),
Ch9329Error::OperationFailed => write!(f, "Operation failed"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChipInfo {
pub version: String,
pub version_raw: u8,
pub usb_connected: bool,
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl ChipInfo {
pub fn from_response(data: &[u8]) -> Option<Self> {
if data.len() < 8 {
return None;
}
let version_raw = data[0];
let version = format!("V{}.{}", version_raw >> 4, version_raw & 0x0F);
let usb_connected = data[1] == 0x01;
let led_status = data[2];
Some(Self {
version,
version_raw,
usb_connected,
num_lock: (led_status & 0x01) != 0,
caps_lock: (led_status & 0x02) != 0,
scroll_lock: (led_status & 0x04) != 0,
})
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LedStatus {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
}
impl From<u8> for LedStatus {
fn from(byte: u8) -> Self {
Self {
num_lock: (byte & 0x01) != 0,
caps_lock: (byte & 0x02) != 0,
scroll_lock: (byte & 0x04) != 0,
}
}
}
#[derive(Debug)]
pub struct Response {
pub cmd: u8,
pub data: Vec<u8>,
pub is_error: bool,
pub error_code: Option<Ch9329Error>,
}
impl Response {
pub fn parse(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 6 || bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] {
return None;
}
let cmd = bytes[3];
let len = bytes[4] as usize;
if bytes.len() < 5 + len + 1 {
return None;
}
let expected_checksum = bytes[5 + len];
let calculated_checksum = bytes[..5 + len]
.iter()
.fold(0u8, |acc, &x| acc.wrapping_add(x));
if expected_checksum != calculated_checksum {
tracing::warn!(
"CH9329 checksum mismatch: expected {:02X}, got {:02X}",
expected_checksum,
calculated_checksum
);
return None;
}
let data = bytes[5..5 + len].to_vec();
let is_error = (cmd & RESPONSE_ERROR_MASK) == RESPONSE_ERROR_MASK;
let error_code = if is_error && !data.is_empty() {
Some(Ch9329Error::from(data[0]))
} else {
None
};
Some(Self {
cmd,
data,
is_error,
error_code,
})
}
}
#[inline]
pub fn calculate_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
}
#[inline]
pub fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) {
debug_assert!(data.len() <= MAX_DATA_LEN, "Data too long for CH9329 packet");
let len = data.len() as u8;
let packet_len = 6 + data.len();
let mut packet = [0u8; MAX_PACKET_SIZE];
packet[0] = PACKET_HEADER[0];
packet[1] = PACKET_HEADER[1];
packet[2] = address;
packet[3] = cmd;
packet[4] = len;
packet[5..5 + data.len()].copy_from_slice(data);
packet[5 + data.len()] = calculate_checksum(&packet[..5 + data.len()]);
(packet, packet_len)
}
#[inline]
pub fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec<u8> {
let (buf, len) = build_packet_buf(address, cmd, data);
buf[..len].to_vec()
}
#[inline]
pub fn expected_response_cmd(cmd: u8, is_error: bool) -> u8 {
cmd | if is_error {
RESPONSE_ERROR_MASK
} else {
RESPONSE_SUCCESS_MASK
}
}
pub fn try_extract_response(buffer: &[u8]) -> Option<(Response, usize)> {
let mut offset = 0;
while offset + 6 <= buffer.len() {
if buffer[offset] != PACKET_HEADER[0] || buffer[offset + 1] != PACKET_HEADER[1] {
offset += 1;
continue;
}
let len = buffer[offset + 4] as usize;
let frame_len = 6 + len;
if offset + frame_len > buffer.len() {
return None;
}
let frame = &buffer[offset..offset + frame_len];
if let Some(response) = Response::parse(frame) {
return Some((response, offset + frame_len));
}
offset += 1;
}
None
}

View File

@@ -2,21 +2,17 @@
//!
//! Reference: USB HID Usage Tables 1.12, Section 15 (Consumer Page 0x0C)
/// Consumer Control Usage codes for multimedia keys
pub mod usage {
// Transport Controls
pub const PLAY_PAUSE: u16 = 0x00CD;
pub const STOP: u16 = 0x00B7;
pub const NEXT_TRACK: u16 = 0x00B5;
pub const PREV_TRACK: u16 = 0x00B6;
// Volume Controls
pub const MUTE: u16 = 0x00E2;
pub const VOLUME_UP: u16 = 0x00E9;
pub const VOLUME_DOWN: u16 = 0x00EA;
}
/// Check if a usage code is valid
pub fn is_valid_usage(usage: u16) -> bool {
matches!(
usage,

View File

@@ -42,23 +42,19 @@ use super::{
MouseEventType,
};
/// Message types
pub const MSG_KEYBOARD: u8 = 0x01;
pub const MSG_MOUSE: u8 = 0x02;
pub const MSG_CONSUMER: u8 = 0x03;
/// Keyboard event types
pub const KB_EVENT_DOWN: u8 = 0x00;
pub const KB_EVENT_UP: u8 = 0x01;
/// Mouse event types
pub const MS_EVENT_MOVE: u8 = 0x00;
pub const MS_EVENT_MOVE_ABS: u8 = 0x01;
pub const MS_EVENT_DOWN: u8 = 0x02;
pub const MS_EVENT_UP: u8 = 0x03;
pub const MS_EVENT_SCROLL: u8 = 0x04;
/// Parsed HID event from DataChannel
#[derive(Debug, Clone)]
pub enum HidChannelEvent {
Keyboard(KeyboardEvent),
@@ -66,7 +62,6 @@ pub enum HidChannelEvent {
Consumer(ConsumerEvent),
}
/// Parse a binary HID message from DataChannel
pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.is_empty() {
warn!("Empty HID message");
@@ -86,7 +81,6 @@ pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
}
}
/// Parse keyboard message payload
fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 3 {
warn!("Keyboard message too short: {} bytes", data.len());
@@ -129,7 +123,6 @@ fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
}))
}
/// Parse mouse message payload
fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 6 {
warn!("Mouse message too short: {} bytes", data.len());
@@ -148,11 +141,9 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
}
};
// Parse coordinates as i16 LE (works for both relative and absolute)
let x = i16::from_le_bytes([data[1], data[2]]) as i32;
let y = i16::from_le_bytes([data[3], data[4]]) as i32;
// Button or scroll delta
let (button, scroll) = match event_type {
MouseEventType::Down | MouseEventType::Up => {
let btn = match data[5] {
@@ -178,7 +169,6 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
}))
}
/// Parse consumer control message payload
fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 2 {
warn!("Consumer message too short: {} bytes", data.len());
@@ -190,7 +180,6 @@ fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
Some(HidChannelEvent::Consumer(ConsumerEvent { usage }))
}
/// Encode a keyboard event to binary format (for sending to client if needed)
pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
let event_type = match event.event_type {
KeyEventType::Down => KB_EVENT_DOWN,
@@ -207,40 +196,6 @@ pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
]
}
/// Encode a mouse event to binary format (for sending to client if needed)
pub fn encode_mouse_event(event: &MouseEvent) -> Vec<u8> {
let event_type = match event.event_type {
MouseEventType::Move => MS_EVENT_MOVE,
MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS,
MouseEventType::Down => MS_EVENT_DOWN,
MouseEventType::Up => MS_EVENT_UP,
MouseEventType::Scroll => MS_EVENT_SCROLL,
};
let x_bytes = (event.x as i16).to_le_bytes();
let y_bytes = (event.y as i16).to_le_bytes();
let extra = match event.event_type {
MouseEventType::Down | MouseEventType::Up => event
.button
.as_ref()
.map(|b| match b {
MouseButton::Left => 0u8,
MouseButton::Middle => 1u8,
MouseButton::Right => 2u8,
MouseButton::Back => 3u8,
MouseButton::Forward => 4u8,
})
.unwrap_or(0),
MouseEventType::Scroll => event.scroll as u8,
_ => 0,
};
vec![
MSG_MOUSE, event_type, x_bytes[0], x_bytes[1], y_bytes[0], y_bytes[1], extra,
]
}
#[cfg(test)]
mod tests {
use super::*;

80
src/hid/factory.rs Normal file
View File

@@ -0,0 +1,80 @@
use std::sync::Arc;
use tracing::{info, warn};
use super::{ch9329, HidBackend, HidBackendType};
use crate::error::{AppError, Result};
#[cfg(unix)]
use crate::otg::OtgService;
pub struct HidBackendFactory {
#[cfg(unix)]
otg_service: Option<Arc<OtgService>>,
}
impl HidBackendFactory {
#[cfg(unix)]
pub fn new(otg_service: Option<Arc<OtgService>>) -> Self {
Self { otg_service }
}
#[cfg(not(unix))]
pub fn new() -> Self {
Self {}
}
pub async fn create_initialized(
&self,
backend_type: &HidBackendType,
) -> Result<Option<Arc<dyn HidBackend>>> {
let backend = match self.create(backend_type).await? {
Some(backend) => backend,
None => return Ok(None),
};
backend.init().await?;
Ok(Some(backend))
}
async fn create(&self, backend_type: &HidBackendType) -> Result<Option<Arc<dyn HidBackend>>> {
match backend_type {
HidBackendType::Otg => self.create_otg_backend().await.map(Some),
HidBackendType::Ch9329 { port, baud_rate } => {
info!(
"Initializing CH9329 HID backend on {} @ {} baud",
port, baud_rate
);
Ok(Some(Arc::new(ch9329::Ch9329Backend::with_baud_rate(
port, *baud_rate,
)?)))
}
HidBackendType::None => {
warn!("HID backend disabled");
Ok(None)
}
}
}
#[cfg(unix)]
async fn create_otg_backend(&self) -> Result<Arc<dyn HidBackend>> {
let otg_service = self
.otg_service
.as_ref()
.ok_or_else(|| AppError::Config("OTG backend not available".to_string()))?;
let handles = otg_service
.hid_device_paths()
.await
.ok_or_else(|| AppError::Config("OTG HID paths are not available".to_string()))?;
info!("Creating OTG HID backend from device paths");
Ok(Arc::new(super::otg::OtgBackend::from_handles(handles)?))
}
#[cfg(not(unix))]
async fn create_otg_backend(&self) -> Result<Arc<dyn HidBackend>> {
Err(AppError::Config(
"OTG HID is only available on Linux".to_string(),
))
}
}

View File

@@ -1,10 +1,6 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Shared canonical keyboard key identifiers used across frontend and backend.
///
/// The enum names intentionally mirror `KeyboardEvent.code` style values so the
/// browser, virtual keyboard, and HID backend can all speak the same language.
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CanonicalKey {
@@ -128,11 +124,6 @@ pub enum CanonicalKey {
}
impl CanonicalKey {
/// Convert the canonical key to a stable wire code.
///
/// The wire code intentionally matches the USB HID usage for keyboard page
/// keys so existing low-level behavior stays intact while the semantic type
/// becomes explicit.
pub const fn to_hid_usage(self) -> u8 {
match self {
Self::KeyA => 0x04,
@@ -255,7 +246,6 @@ impl CanonicalKey {
}
}
/// Convert a wire code / USB HID usage to its canonical key.
pub const fn from_hid_usage(usage: u8) -> Option<Self> {
match usage {
0x04 => Some(Self::KeyA),

View File

@@ -1,70 +1,39 @@
//! HID (Human Interface Device) control module
//!
//! This module provides keyboard and mouse control for remote KVM:
//! - USB OTG gadget mode (native Linux USB gadget)
//! - CH9329 serial HID controller
//!
//! Architecture:
//! ```text
//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC
//! |
//! [OTG | CH9329]
//! ```
//! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329.
pub mod backend;
mod ch9329_proto;
pub mod ch9329;
pub mod consumer;
pub mod datachannel;
mod factory;
pub mod keyboard;
#[cfg(unix)]
pub mod otg;
#[cfg(unix)]
mod otg_device;
pub mod types;
pub mod websocket;
pub use crate::events::LedState;
pub use backend::{HidBackend, HidBackendRuntimeSnapshot, HidBackendType};
pub use keyboard::CanonicalKey;
pub use otg::LedState;
pub use types::{
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent,
MouseEventType,
};
/// HID backend information
#[derive(Debug, Clone)]
pub struct HidInfo {
/// Backend name
pub name: String,
/// Whether backend is initialized
pub initialized: bool,
/// Whether absolute mouse positioning is supported
pub supports_absolute_mouse: bool,
/// Screen resolution for absolute mouse
pub screen_resolution: Option<(u32, u32)>,
}
/// Unified HID runtime state used by snapshots and events.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HidRuntimeState {
/// Whether a backend is configured and expected to exist.
pub available: bool,
/// Stable backend key: "otg", "ch9329", "none".
pub backend: String,
/// Whether the backend is currently initialized and operational.
pub initialized: bool,
/// Whether the backend is currently online.
pub online: bool,
/// Whether absolute mouse positioning is supported.
pub supports_absolute_mouse: bool,
/// Whether keyboard LED/status feedback is enabled.
pub keyboard_leds_enabled: bool,
/// Last known keyboard LED state.
pub led_state: LedState,
/// Screen resolution for absolute mouse mode.
pub screen_resolution: Option<(u32, u32)>,
/// Device path associated with the backend, if any.
pub device: Option<String>,
/// Current user-facing error, if any.
pub error: Option<String>,
/// Current programmatic error code, if any.
pub error_code: Option<String>,
}
@@ -131,7 +100,9 @@ use tracing::{info, warn};
use crate::error::{AppError, Result};
use crate::events::EventBus;
#[cfg(unix)]
use crate::otg::OtgService;
use factory::HidBackendFactory;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
@@ -140,50 +111,56 @@ const HID_EVENT_QUEUE_CAPACITY: usize = 64;
const HID_EVENT_SEND_TIMEOUT_MS: u64 = 30;
#[derive(Debug)]
enum HidEvent {
enum QueuedHidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
Consumer(ConsumerEvent),
Reset,
}
/// HID controller managing keyboard and mouse input
pub struct HidController {
/// OTG Service reference (only used when backend is OTG)
otg_service: Option<Arc<OtgService>>,
/// Active backend
backend_factory: HidBackendFactory,
backend: Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
/// Backend type (mutable for reload)
backend_type: Arc<RwLock<HidBackendType>>,
/// Event bus for broadcasting state changes (optional)
events: Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
/// Unified HID runtime state.
runtime_state: Arc<RwLock<HidRuntimeState>>,
/// HID event queue sender (non-blocking)
hid_tx: mpsc::Sender<HidEvent>,
/// HID event queue receiver (moved into worker on first start)
hid_rx: Mutex<Option<mpsc::Receiver<HidEvent>>>,
/// Coalesced mouse move (latest)
hid_tx: mpsc::Sender<QueuedHidEvent>,
hid_rx: Mutex<Option<mpsc::Receiver<QueuedHidEvent>>>,
pending_move: Arc<parking_lot::Mutex<Option<MouseEvent>>>,
/// Pending move flag (fast path)
pending_move_flag: Arc<AtomicBool>,
/// Worker task handle
hid_worker: Mutex<Option<JoinHandle<()>>>,
/// Backend runtime subscription task handle
runtime_worker: Mutex<Option<JoinHandle<()>>>,
/// Backend initialization fast flag
backend_available: Arc<AtomicBool>,
}
impl HidController {
/// Create a new HID controller with specified backend
///
/// For OTG backend, otg_service should be provided to support hot-reload
#[cfg(unix)]
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
Self {
otg_service,
backend: Arc::new(RwLock::new(None)),
backend_factory: HidBackendFactory::new(otg_service),
backend_type: Arc::new(RwLock::new(backend_type.clone())),
events: Arc::new(tokio::sync::RwLock::new(None)),
runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type(
&backend_type,
))),
hid_tx,
hid_rx: Mutex::new(Some(hid_rx)),
pending_move: Arc::new(parking_lot::Mutex::new(None)),
pending_move_flag: Arc::new(AtomicBool::new(false)),
hid_worker: Mutex::new(None),
runtime_worker: Mutex::new(None),
backend_available: Arc::new(AtomicBool::new(false)),
}
}
#[cfg(not(unix))]
pub fn new(backend_type: HidBackendType) -> Self {
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
Self {
backend: Arc::new(RwLock::new(None)),
backend_factory: HidBackendFactory::new(),
backend_type: Arc::new(RwLock::new(backend_type.clone())),
events: Arc::new(tokio::sync::RwLock::new(None)),
runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type(
@@ -199,64 +176,32 @@ impl HidController {
}
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Initialize the HID backend
pub async fn init(&self) -> Result<()> {
let backend_type = self.backend_type.read().await.clone();
let backend: Arc<dyn HidBackend> = match backend_type {
HidBackendType::Otg => {
let otg_service = self
.otg_service
.as_ref()
.ok_or_else(|| AppError::Internal("OtgService not available".into()))?;
let handles = otg_service.hid_device_paths().await.ok_or_else(|| {
AppError::Config("OTG HID paths are not available".to_string())
})?;
info!("Creating OTG HID backend from device paths");
Arc::new(otg::OtgBackend::from_handles(handles)?)
}
HidBackendType::Ch9329 {
ref port,
baud_rate,
} => {
info!(
"Initializing CH9329 HID backend on {} @ {} baud",
port, baud_rate
);
Arc::new(ch9329::Ch9329Backend::with_baud_rate(port, baud_rate)?)
}
HidBackendType::None => {
warn!("HID backend disabled");
return Ok(());
}
};
if let Err(e) = backend.init().await {
self.backend_available.store(false, Ordering::Release);
let error_state = {
let backend_type = self.backend_type.read().await.clone();
let backend = match self.backend_factory.create_initialized(&backend_type).await {
Ok(Some(backend)) => backend,
Ok(None) => return Ok(()),
Err(error) => {
self.backend_available.store(false, Ordering::Release);
let current = self.runtime_state.read().await.clone();
HidRuntimeState::with_error(
let error_state = HidRuntimeState::with_error(
&backend_type,
&current,
format!("Failed to initialize HID backend: {}", e),
format!("Failed to initialize HID backend: {}", error),
"init_failed",
)
};
self.apply_runtime_state(error_state).await;
return Err(e);
}
);
self.apply_runtime_state(error_state).await;
return Err(error);
}
};
*self.backend.write().await = Some(backend);
self.sync_runtime_state_from_backend().await;
// Start HID event worker (once)
self.start_event_worker().await;
self.restart_runtime_worker().await;
@@ -264,12 +209,10 @@ impl HidController {
Ok(())
}
/// Shutdown the HID backend and release resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down HID controller");
self.stop_runtime_worker().await;
// Close the backend
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down HID backend: {}", e);
@@ -290,17 +233,15 @@ impl HidController {
Ok(())
}
/// Send keyboard event
pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
"HID backend not available".to_string(),
));
}
self.enqueue_event(HidEvent::Keyboard(event)).await
self.enqueue_event(QueuedHidEvent::Keyboard(event)).await
}
/// Send mouse event
pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
@@ -312,147 +253,60 @@ impl HidController {
event.event_type,
MouseEventType::Move | MouseEventType::MoveAbs
) {
// Best-effort: drop/merge move events if queue is full
self.enqueue_mouse_move(event)
} else {
self.enqueue_event(HidEvent::Mouse(event)).await
self.enqueue_event(QueuedHidEvent::Mouse(event)).await
}
}
/// Send consumer control event (multimedia keys)
pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
"HID backend not available".to_string(),
));
}
self.enqueue_event(HidEvent::Consumer(event)).await
self.enqueue_event(QueuedHidEvent::Consumer(event)).await
}
/// Reset all keys (release all pressed keys)
pub async fn reset(&self) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Ok(());
}
// Reset is important but best-effort; enqueue to avoid blocking
self.enqueue_event(HidEvent::Reset).await
self.enqueue_event(QueuedHidEvent::Reset).await
}
/// Check if backend is available
pub async fn is_available(&self) -> bool {
self.backend_available.load(Ordering::Acquire)
}
/// Get backend type
pub async fn backend_type(&self) -> HidBackendType {
self.backend_type.read().await.clone()
}
/// Get backend info
pub async fn info(&self) -> Option<HidInfo> {
let state = self.runtime_state.read().await.clone();
if !state.available {
return None;
}
Some(HidInfo {
name: state.backend,
initialized: state.initialized,
supports_absolute_mouse: state.supports_absolute_mouse,
screen_resolution: state.screen_resolution,
})
}
/// Get current HID runtime state snapshot.
pub async fn snapshot(&self) -> HidRuntimeState {
self.runtime_state.read().await.clone()
}
/// Reload the HID backend with new type
pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> {
info!("Reloading HID backend: {:?}", new_backend_type);
self.backend_available.store(false, Ordering::Release);
self.stop_runtime_worker().await;
// Shutdown existing backend first
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down old HID backend: {}", e);
}
}
// Create and initialize new backend
let new_backend: Option<Arc<dyn HidBackend>> = match new_backend_type {
HidBackendType::Otg => {
info!("Initializing OTG HID backend");
// Get OtgService reference
let otg_service = match self.otg_service.as_ref() {
Some(svc) => svc,
None => {
warn!("OTG backend requires OtgService, but it's not available");
return Err(AppError::Config(
"OTG backend not available (OtgService missing)".to_string(),
));
}
};
match otg_service.hid_device_paths().await {
Some(handles) => {
// Create OtgBackend from handles
match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let backend = Arc::new(backend);
match backend.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(backend)
}
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
None
}
}
}
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
None
}
}
}
None => {
warn!("OTG HID paths are not available");
None
}
}
}
HidBackendType::Ch9329 {
ref port,
baud_rate,
} => {
info!(
"Initializing CH9329 HID backend on {} @ {} baud",
port, baud_rate
);
match ch9329::Ch9329Backend::with_baud_rate(port, baud_rate) {
Ok(b) => {
let backend = Arc::new(b);
match backend.init().await {
Ok(_) => Some(backend),
Err(e) => {
warn!("Failed to initialize CH9329 backend: {}", e);
None
}
}
}
Err(e) => {
warn!("Failed to create CH9329 backend: {}", e);
None
}
}
}
HidBackendType::None => {
warn!("HID backend disabled");
let new_backend = match self
.backend_factory
.create_initialized(&new_backend_type)
.await
{
Ok(backend) => backend,
Err(error) if matches!(&new_backend_type, HidBackendType::None) => return Err(error),
Err(error) => {
warn!("Failed to initialize HID backend: {}", error);
None
}
};
@@ -470,7 +324,6 @@ impl HidController {
info!("HID backend reloaded successfully: {:?}", new_backend_type);
self.start_event_worker().await;
// Update backend_type on success
*self.backend_type.write().await = new_backend_type.clone();
self.sync_runtime_state_from_backend().await;
@@ -481,7 +334,6 @@ impl HidController {
warn!("HID backend reload resulted in no active backend");
self.backend_available.store(false, Ordering::Release);
// Update backend_type even on failure (to reflect the attempted change)
*self.backend_type.write().await = new_backend_type.clone();
let current = self.runtime_state.read().await.clone();
@@ -541,11 +393,10 @@ impl HidController {
process_hid_event(event, &backend).await;
// After each event, flush latest move if pending
if pending_move_flag.swap(false, Ordering::AcqRel) {
let move_event = { pending_move.lock().take() };
if let Some(move_event) = move_event {
process_hid_event(HidEvent::Mouse(move_event), &backend).await;
process_hid_event(QueuedHidEvent::Mouse(move_event), &backend).await;
}
}
}
@@ -595,7 +446,7 @@ impl HidController {
}
fn enqueue_mouse_move(&self, event: MouseEvent) -> Result<()> {
match self.hid_tx.try_send(HidEvent::Mouse(event.clone())) {
match self.hid_tx.try_send(QueuedHidEvent::Mouse(event.clone())) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(_)) => {
*self.pending_move.lock() = Some(event);
@@ -608,11 +459,10 @@ impl HidController {
}
}
async fn enqueue_event(&self, event: HidEvent) -> Result<()> {
async fn enqueue_event(&self, event: QueuedHidEvent) -> Result<()> {
match self.hid_tx.try_send(event) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(ev)) => {
// For non-move events, wait briefly to avoid dropping critical input
let tx = self.hid_tx.clone();
let send_result = tokio::time::timeout(
Duration::from_millis(HID_EVENT_SEND_TIMEOUT_MS),
@@ -649,7 +499,10 @@ async fn apply_backend_runtime_state(
apply_runtime_state(runtime_state, events, next).await;
}
async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>) {
async fn process_hid_event(
event: QueuedHidEvent,
backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
) {
let backend_opt = backend.read().await.clone();
let backend = match backend_opt {
Some(b) => b,
@@ -660,10 +513,10 @@ async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn
let result = tokio::task::spawn_blocking(move || {
futures::executor::block_on(async move {
match event {
HidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
HidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
HidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
HidEvent::Reset => backend_for_send.reset().await,
QueuedHidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
QueuedHidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
QueuedHidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
QueuedHidEvent::Reset => backend_for_send.reset().await,
}
})
})
@@ -682,12 +535,6 @@ async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn
}
}
impl Default for HidController {
fn default() -> Self {
Self::new(HidBackendType::None, None)
}
}
fn device_for_backend_type(backend_type: &HidBackendType) -> Option<String> {
match backend_type {
HidBackendType::Ch9329 { port, .. } => Some(port.clone()),

View File

@@ -1,48 +1,32 @@
//! OTG USB Gadget HID backend
//! Linux gadget HID: `/dev/hidg*` opened from [`crate::otg::OtgService`].
//! Typical nodes: hidg0 keyboard, hidg1 relative mouse, hidg2 absolute, hidg3 consumer control.
//!
//! This backend uses Linux USB Gadget API to emulate USB HID devices.
//! It opens the HID gadget device nodes created by `OtgService`.
//! Depending on the configured OTG profile, this may include:
//! - hidg0: Keyboard
//! - hidg1: Relative Mouse
//! - hidg2: Absolute Mouse
//! - hidg3: Consumer Control Keyboard
//!
//! Requirements:
//! - USB OTG/Device controller (UDC)
//! - ConfigFS with USB gadget support
//! - Root privileges for gadget setup
//!
//! Error Recovery:
//! This module implements automatic device reconnection based on PiKVM's approach.
//! When ESHUTDOWN or EAGAIN errors occur (common during MSD operations), the device
//! file handles are closed and reopened on the next operation.
//! See: https://github.com/raspberrypi/linux/issues/4373
//! Polled timed writes (JetKVM-style). Treat `ESHUTDOWN` (108) by closing handles and reopening; keep fd on `EAGAIN` (11). Host/gadget teardown during MSD resembles PiKVM. <https://github.com/raspberrypi/linux/issues/4373>
use async_trait::async_trait;
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::io::Read;
use std::os::fd::AsFd;
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::AsFd;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use tokio::sync::watch;
use tracing::{debug, info, trace, warn};
use super::backend::{HidBackend, HidBackendRuntimeSnapshot};
use super::otg_device::OtgDeviceIo;
use super::types::{
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType,
};
use crate::error::{AppError, Result};
use crate::events::LedState;
use crate::otg::{wait_for_hid_devices, HidDevicePaths};
/// Device type for ensure_device operations
#[derive(Debug, Clone, Copy)]
enum DeviceType {
Keyboard,
@@ -51,23 +35,7 @@ enum DeviceType {
ConsumerControl,
}
/// Keyboard LED state
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct LedState {
/// Num Lock LED
pub num_lock: bool,
/// Caps Lock LED
pub caps_lock: bool,
/// Scroll Lock LED
pub scroll_lock: bool,
/// Compose LED
pub compose: bool,
/// Kana LED
pub kana: bool,
}
impl LedState {
/// Create from raw byte
pub fn from_byte(b: u8) -> Self {
Self {
num_lock: b & 0x01 != 0,
@@ -78,7 +46,6 @@ impl LedState {
}
}
/// Convert to raw byte
pub fn to_byte(&self) -> u8 {
let mut b = 0u8;
if self.num_lock {
@@ -100,76 +67,36 @@ impl LedState {
}
}
/// OTG HID backend with 4 devices
///
/// This backend opens HID device files created by OtgService.
/// It does NOT manage the USB gadget itself - that's handled by OtgService.
///
/// ## Error Recovery
///
/// Based on PiKVM's implementation, this backend automatically handles:
/// - EAGAIN (errno 11): Resource temporarily unavailable - just retry later, don't close device
/// - ESHUTDOWN (errno 108): Transport endpoint shutdown - close and reopen device
///
/// When ESHUTDOWN occurs, the device file handle is closed and will be
/// reopened on the next operation attempt.
/// Opens `/dev/hidg*` nodes provisioned by `OtgService`; gadget lifecycle is not handled here.
pub struct OtgBackend {
/// Keyboard device path (/dev/hidg0)
keyboard_path: Option<PathBuf>,
/// Relative mouse device path (/dev/hidg1)
mouse_rel_path: Option<PathBuf>,
/// Absolute mouse device path (/dev/hidg2)
mouse_abs_path: Option<PathBuf>,
/// Consumer control device path (/dev/hidg3)
consumer_path: Option<PathBuf>,
/// Keyboard device file
keyboard_dev: Mutex<Option<File>>,
/// Relative mouse device file
mouse_rel_dev: Mutex<Option<File>>,
/// Absolute mouse device file
mouse_abs_dev: Mutex<Option<File>>,
/// Consumer control device file
consumer_dev: Mutex<Option<File>>,
/// Whether keyboard LED/status feedback is enabled.
keyboard_leds_enabled: bool,
/// Current keyboard state
keyboard_state: Mutex<KeyboardReport>,
/// Current mouse button state
mouse_buttons: AtomicU8,
/// Last known LED state (using parking_lot::RwLock for sync access)
led_state: Arc<parking_lot::RwLock<LedState>>,
/// Screen resolution for absolute mouse (using parking_lot::RwLock for sync access)
screen_resolution: parking_lot::RwLock<Option<(u32, u32)>>,
/// UDC name for state checking (e.g., "fcc00000.usb")
udc_name: Arc<parking_lot::RwLock<Option<String>>>,
/// Whether the backend has been initialized.
initialized: AtomicBool,
/// Whether the device is currently online (UDC configured and devices accessible)
online: AtomicBool,
/// Last backend error state.
last_error: parking_lot::RwLock<Option<(String, String)>>,
/// Last error log time for throttling (using parking_lot for sync)
last_error_log: parking_lot::Mutex<std::time::Instant>,
/// Error count since last successful operation (for log throttling)
error_count: AtomicU8,
/// Consecutive EAGAIN count (for offline threshold detection)
eagain_count: AtomicU8,
/// Runtime change notifier.
runtime_notify_tx: watch::Sender<()>,
/// Runtime monitor stop flag.
runtime_worker_stop: Arc<AtomicBool>,
/// Runtime monitor thread.
runtime_worker: Mutex<Option<thread::JoinHandle<()>>>,
}
/// Write timeout in milliseconds (same as JetKVM's hidWriteTimeout)
const HID_WRITE_TIMEOUT_MS: i32 = 20;
impl OtgBackend {
/// Create OTG backend from device paths provided by OtgService
///
/// This is the ONLY way to create an OtgBackend - it no longer manages
/// the USB gadget itself. The gadget must already be set up by OtgService.
/// Gadget must already exist; paths come from `OtgService`.
pub fn from_handles(paths: HidDevicePaths) -> Result<Self> {
let (runtime_notify_tx, _runtime_notify_rx) = watch::channel(());
Ok(Self {
@@ -192,7 +119,6 @@ impl OtgBackend {
last_error: parking_lot::RwLock::new(None),
last_error_log: parking_lot::Mutex::new(std::time::Instant::now()),
error_count: AtomicU8::new(0),
eagain_count: AtomicU8::new(0),
runtime_notify_tx,
runtime_worker_stop: Arc::new(AtomicBool::new(false)),
runtime_worker: Mutex::new(None),
@@ -234,7 +160,6 @@ impl OtgBackend {
}
}
/// Log throttled error message (max once per second)
fn log_throttled_error(&self, msg: &str) {
let mut last_log = self.last_error_log.lock();
let now = std::time::Instant::now();
@@ -251,48 +176,15 @@ impl OtgBackend {
}
}
/// Reset error count on successful operation
fn reset_error_count(&self) {
self.error_count.store(0, Ordering::Relaxed);
// Also reset EAGAIN count - successful operation means device is working
self.eagain_count.store(0, Ordering::Relaxed);
}
/// Write data to HID device with timeout (JetKVM style)
///
/// Uses poll() to wait for device to be ready for writing.
/// If timeout expires, silently drops the data (acceptable for mouse movement).
/// Returns Ok(true) if write succeeded, Ok(false) if timed out (silently dropped).
/// Poll-based write with `HID_WRITE_TIMEOUT_MS`; timeout → drop (JetKVM-style).
fn write_with_timeout(&self, file: &mut File, data: &[u8]) -> std::io::Result<bool> {
let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)];
match poll(&mut pollfd, PollTimeout::from(HID_WRITE_TIMEOUT_MS as u16)) {
Ok(1) => {
// Device ready, check for errors
if let Some(revents) = pollfd[0].revents() {
if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP)
{
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Device error or hangup",
));
}
}
// Write the data
file.write_all(data)?;
Ok(true)
}
Ok(0) => {
// Timeout - silently drop (JetKVM behavior)
trace!("HID write timeout, dropping data");
Ok(false)
}
Ok(_) => Ok(false),
Err(e) => Err(std::io::Error::other(e)),
}
OtgDeviceIo::write_with_timeout(file, data, HID_WRITE_TIMEOUT_MS)
}
/// Set the UDC name for state checking
pub fn set_udc_name(&self, udc: &str) {
*self.udc_name.write() = Some(udc.to_string());
}
@@ -324,15 +216,11 @@ impl OtgBackend {
}
}
/// Check if the UDC is in "configured" state
///
/// This is based on PiKVM's `__is_udc_configured()` method.
/// The UDC state file indicates whether the USB host has enumerated and configured the gadget.
/// `true` when `/sys/class/udc/<name>/state` reads `configured` (PiKVM-style).
pub fn is_udc_configured(&self) -> bool {
Self::read_udc_configured(&self.udc_name)
}
/// Find the first available UDC
fn find_udc() -> Option<String> {
let udc_path = PathBuf::from("/sys/class/udc");
if let Ok(entries) = fs::read_dir(&udc_path) {
@@ -345,12 +233,7 @@ impl OtgBackend {
None
}
/// Ensure a device is open and ready for I/O
///
/// This method is based on PiKVM's `__ensure_device()` pattern:
/// 1. Check if device path exists, close handle if not
/// 2. If handle is None but path exists, reopen the device
/// 3. Return whether the device is ready for I/O
/// PiKVM-style: drop handle if node missing; reopen when path reappears.
fn ensure_device(&self, device_type: DeviceType) -> Result<()> {
let (path_opt, dev_mutex) = match device_type {
DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev),
@@ -372,9 +255,7 @@ impl OtgBackend {
}
};
// Check if device path exists
if !path.exists() {
// Close the device if open (device was removed)
let mut dev = dev_mutex.lock();
if dev.is_some() {
debug!(
@@ -392,7 +273,6 @@ impl OtgBackend {
});
}
// If device is not open, try to open it
let mut dev = dev_mutex.lock();
if dev.is_none() {
match Self::open_device(path) {
@@ -415,7 +295,6 @@ impl OtgBackend {
Ok(())
}
/// Open a HID device file with read/write access
fn open_device(path: &PathBuf) -> Result<File> {
OpenOptions::new()
.read(true)
@@ -431,16 +310,15 @@ impl OtgBackend {
})
}
/// Convert I/O error to HidError with appropriate error code
fn io_error_code(e: &std::io::Error) -> &'static str {
match e.raw_os_error() {
Some(32) => "epipe", // EPIPE - broken pipe
Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown
Some(11) => "eagain", // EAGAIN - resource temporarily unavailable
Some(6) => "enxio", // ENXIO - no such device or address
Some(19) => "enodev", // ENODEV - no such device
Some(5) => "eio", // EIO - I/O error
Some(2) => "enoent", // ENOENT - no such file or directory
Some(32) => "epipe",
Some(108) => "eshutdown",
Some(11) => "eagain",
Some(6) => "enxio",
Some(19) => "enodev",
Some(5) => "eio",
Some(2) => "enoent",
_ => "io_error",
}
}
@@ -455,7 +333,32 @@ impl OtgBackend {
}
}
/// Check if all HID device files exist
fn handle_write_error(
&self,
dev: &mut Option<File>,
err: std::io::Error,
operation: &str,
device_label: &str,
) -> Result<()> {
match err.raw_os_error() {
Some(108) => {
debug!("{} ESHUTDOWN, closing for recovery", device_label);
*dev = None;
self.record_error(format!("{}: {}", operation, err), "eshutdown");
Err(Self::io_error_to_hid_error(err, operation))
}
Some(11) => {
trace!("{} EAGAIN after poll, dropping", device_label);
Ok(())
}
_ => {
warn!("{} write error: {}", device_label, err);
self.record_error(format!("{}: {}", operation, err), Self::io_error_code(&err));
Err(Self::io_error_to_hid_error(err, operation))
}
}
}
pub fn check_devices_exist(&self) -> bool {
self.keyboard_path.as_ref().is_none_or(|p| p.exists())
&& self.mouse_rel_path.as_ref().is_none_or(|p| p.exists())
@@ -463,7 +366,6 @@ impl OtgBackend {
&& self.consumer_path.as_ref().is_none_or(|p| p.exists())
}
/// Get list of missing device paths
pub fn get_missing_devices(&self) -> Vec<String> {
let mut missing = Vec::new();
if let Some(ref path) = self.keyboard_path {
@@ -484,17 +386,11 @@ impl OtgBackend {
missing
}
/// Send keyboard report (8 bytes)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// Uses write_with_timeout to avoid blocking on busy devices.
fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> {
if self.keyboard_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::Keyboard)?;
let mut dev = self.keyboard_dev.lock();
@@ -508,47 +404,15 @@ impl OtgBackend {
Ok(())
}
Ok(false) => {
// Timeout - silently dropped (JetKVM behavior)
self.log_throttled_error("HID keyboard write timeout, dropped");
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
// ESHUTDOWN - endpoint closed, need to reopen device
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Keyboard ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write keyboard report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write keyboard report",
))
}
Some(11) => {
// EAGAIN after poll - should be rare, silently drop
trace!("Keyboard EAGAIN after poll, dropping");
Ok(())
}
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Keyboard write error: {}", e);
self.record_error(
format!("Failed to write keyboard report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write keyboard report",
))
}
}
}
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write keyboard report",
"Keyboard",
),
}
} else {
Err(AppError::HidError {
@@ -559,17 +423,11 @@ impl OtgBackend {
}
}
/// Send relative mouse report (4 bytes: buttons, dx, dy, wheel)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// Uses write_with_timeout to avoid blocking on busy devices.
fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> {
if self.mouse_rel_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::MouseRelative)?;
let mut dev = self.mouse_rel_dev.lock();
@@ -582,45 +440,13 @@ impl OtgBackend {
trace!("Sent relative mouse report: {:02X?}", data);
Ok(())
}
Ok(false) => {
// Timeout - silently dropped (JetKVM behavior)
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Relative mouse ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write mouse report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
Some(11) => {
// EAGAIN after poll - should be rare, silently drop
Ok(())
}
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Relative mouse write error: {}", e);
self.record_error(
format!("Failed to write mouse report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
}
}
Ok(false) => Ok(()),
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write mouse report",
"Relative mouse",
),
}
} else {
Err(AppError::HidError {
@@ -631,17 +457,11 @@ impl OtgBackend {
}
}
/// Send absolute mouse report (6 bytes: buttons, x_lo, x_hi, y_lo, y_hi, wheel)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// Uses write_with_timeout to avoid blocking on busy devices.
fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> {
if self.mouse_abs_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::MouseAbsolute)?;
let mut dev = self.mouse_abs_dev.lock();
@@ -660,45 +480,13 @@ impl OtgBackend {
self.reset_error_count();
Ok(())
}
Ok(false) => {
// Timeout - silently dropped (JetKVM behavior)
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Absolute mouse ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write mouse report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
Some(11) => {
// EAGAIN after poll - should be rare, silently drop
Ok(())
}
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Absolute mouse write error: {}", e);
self.record_error(
format!("Failed to write mouse report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write mouse report",
))
}
}
}
Ok(false) => Ok(()),
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write mouse report",
"Absolute mouse",
),
}
} else {
Err(AppError::HidError {
@@ -709,67 +497,33 @@ impl OtgBackend {
}
}
/// Send consumer control report (2 bytes: usage_lo, usage_hi)
///
/// Sends a consumer control usage code and then releases it (sends 0x0000).
/// Press (`usage`) then release (`0x0000`).
fn send_consumer_report(&self, usage: u16) -> Result<()> {
if self.consumer_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::ConsumerControl)?;
let mut dev = self.consumer_dev.lock();
if let Some(ref mut file) = *dev {
// Send the usage code
let data = [(usage & 0xFF) as u8, (usage >> 8) as u8];
match self.write_with_timeout(file, &data) {
Ok(true) => {
trace!("Sent consumer report: {:02X?}", data);
// Send release (0x0000)
let release = [0u8, 0u8];
let _ = self.write_with_timeout(file, &release);
self.mark_online();
self.reset_error_count();
Ok(())
}
Ok(false) => {
// Timeout - silently dropped
Ok(())
}
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
Some(108) => {
debug!("Consumer control ESHUTDOWN, closing for recovery");
*dev = None;
self.record_error(
format!("Failed to write consumer report: {}", e),
"eshutdown",
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write consumer report",
))
}
Some(11) => {
// EAGAIN after poll - silently drop
Ok(())
}
_ => {
warn!("Consumer control write error: {}", e);
self.record_error(
format!("Failed to write consumer report: {}", e),
Self::io_error_code(&e),
);
Err(Self::io_error_to_hid_error(
e,
"Failed to write consumer report",
))
}
}
}
Ok(false) => Ok(()),
Err(e) => self.handle_write_error(
&mut dev,
e,
"Failed to write consumer report",
"Consumer control",
),
}
} else {
Err(AppError::HidError {
@@ -780,12 +534,10 @@ impl OtgBackend {
}
}
/// Send consumer control event
pub fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
self.send_consumer_report(event.usage)
}
/// Get last known LED state
pub fn led_state(&self) -> LedState {
*self.led_state.read()
}
@@ -973,19 +725,17 @@ impl OtgBackend {
#[async_trait]
impl HidBackend for OtgBackend {
async fn init(&self) -> Result<()> {
info!("Initializing OTG HID backend");
debug!("Initializing OTG HID backend");
// Auto-detect UDC name for state checking only if OtgService did not provide one
if self.udc_name.read().is_none() {
if let Some(udc) = Self::find_udc() {
info!("Auto-detected UDC: {}", udc);
debug!("Auto-detected UDC: {}", udc);
self.set_udc_name(&udc);
}
} else if let Some(udc) = self.udc_name.read().clone() {
info!("Using configured UDC: {}", udc);
debug!("Using configured UDC: {}", udc);
}
// Wait for devices to appear (they should already exist from OtgService)
let mut device_paths = Vec::new();
if let Some(ref path) = self.keyboard_path {
device_paths.push(path.clone());
@@ -1010,51 +760,46 @@ impl HidBackend for OtgBackend {
return Err(AppError::Internal("HID devices did not appear".into()));
}
// Open keyboard device
if let Some(ref path) = self.keyboard_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.keyboard_dev.lock() = Some(file);
info!("Keyboard device opened: {}", path.display());
debug!("Keyboard device opened: {}", path.display());
} else {
warn!("Keyboard device not found: {}", path.display());
}
}
// Open relative mouse device
if let Some(ref path) = self.mouse_rel_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.mouse_rel_dev.lock() = Some(file);
info!("Relative mouse device opened: {}", path.display());
debug!("Relative mouse device opened: {}", path.display());
} else {
warn!("Relative mouse device not found: {}", path.display());
}
}
// Open absolute mouse device
if let Some(ref path) = self.mouse_abs_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.mouse_abs_dev.lock() = Some(file);
info!("Absolute mouse device opened: {}", path.display());
debug!("Absolute mouse device opened: {}", path.display());
} else {
warn!("Absolute mouse device not found: {}", path.display());
}
}
// Open consumer control device (optional, may not exist on older setups)
if let Some(ref path) = self.consumer_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.consumer_dev.lock() = Some(file);
info!("Consumer control device opened: {}", path.display());
debug!("Consumer control device opened: {}", path.display());
} else {
debug!("Consumer control device not found: {}", path.display());
}
}
// Mark as online if all devices opened successfully
self.initialized.store(true, Ordering::Relaxed);
self.notify_runtime_changed();
self.start_runtime_worker();
@@ -1066,7 +811,6 @@ impl HidBackend for OtgBackend {
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
let usb_key = event.key.to_hid_usage();
// Handle modifier keys separately
if event.key.is_modifier() {
let mut state = self.keyboard_state.lock();
@@ -1084,7 +828,6 @@ impl HidBackend for OtgBackend {
} else {
let mut state = self.keyboard_state.lock();
// Update modifiers from event
state.modifiers = event.modifiers.to_hid_byte();
match event.event_type {
@@ -1110,15 +853,12 @@ impl HidBackend for OtgBackend {
match event.event_type {
MouseEventType::Move => {
// Relative movement - use hidg1
let dx = event.x.clamp(-127, 127) as i8;
let dy = event.y.clamp(-127, 127) as i8;
self.send_mouse_report_relative(buttons, dx, dy, 0)?;
}
MouseEventType::MoveAbs => {
// Absolute movement - use hidg2
// Frontend sends 0-32767 range directly (standard HID absolute mouse range)
// Don't send button state with move - buttons are handled separately on relative device
// Coordinates 032767; buttons are sent only on the relative endpoint.
let x = event.x.clamp(0, 32767) as u16;
let y = event.y.clamp(0, 32767) as u16;
self.send_mouse_report_absolute(0, x, y, 0)?;
@@ -1127,7 +867,6 @@ impl HidBackend for OtgBackend {
if let Some(button) = event.button {
let bit = button.to_hid_bit();
let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit;
// Send on relative device for button clicks
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
}
}
@@ -1147,7 +886,6 @@ impl HidBackend for OtgBackend {
}
async fn reset(&self) -> Result<()> {
// Reset keyboard
{
let mut state = self.keyboard_state.lock();
state.clear();
@@ -1156,7 +894,6 @@ impl HidBackend for OtgBackend {
self.send_keyboard_report(&report)?;
}
// Reset mouse
self.mouse_buttons.store(0, Ordering::Relaxed);
self.send_mouse_report_relative(0, 0, 0, 0)?;
self.send_mouse_report_absolute(0, 0, 0, 0)?;
@@ -1168,16 +905,13 @@ impl HidBackend for OtgBackend {
async fn shutdown(&self) -> Result<()> {
self.stop_runtime_worker();
// Reset before closing
self.reset().await?;
// Close devices
*self.keyboard_dev.lock() = None;
*self.mouse_rel_dev.lock() = None;
*self.mouse_abs_dev.lock() = None;
*self.consumer_dev.lock() = None;
// Gadget cleanup is handled by OtgService, not here
self.initialized.store(false, Ordering::Relaxed);
self.online.store(false, Ordering::Relaxed);
self.clear_error();
@@ -1199,31 +933,18 @@ impl HidBackend for OtgBackend {
self.send_consumer_report(event.usage)
}
fn set_screen_resolution(&mut self, width: u32, height: u32) {
fn set_screen_resolution(&self, width: u32, height: u32) {
*self.screen_resolution.write() = Some((width, height));
self.notify_runtime_changed();
}
}
/// Check if OTG HID gadget is available
pub fn is_otg_available() -> bool {
// Check for existing HID devices (they should be created by OtgService)
let kb = PathBuf::from("/dev/hidg0");
let mouse_rel = PathBuf::from("/dev/hidg1");
let mouse_abs = PathBuf::from("/dev/hidg2");
kb.exists() || mouse_rel.exists() || mouse_abs.exists()
}
/// Implement Drop for OtgBackend to close device files
impl Drop for OtgBackend {
fn drop(&mut self) {
self.runtime_worker_stop.store(true, Ordering::Relaxed);
if let Some(handle) = self.runtime_worker.get_mut().take() {
let _ = handle.join();
}
// Close device files
// Note: Gadget cleanup is handled by OtgService, not here
*self.keyboard_dev.lock() = None;
*self.mouse_rel_dev.lock() = None;
*self.mouse_abs_dev.lock() = None;
@@ -1236,12 +957,6 @@ impl Drop for OtgBackend {
mod tests {
use super::*;
#[test]
fn test_otg_availability_check() {
// This just tests the function runs without panicking
let _available = is_otg_available();
}
#[test]
fn test_led_state() {
let state = LedState::from_byte(0b00000011);
@@ -1254,7 +969,6 @@ mod tests {
#[test]
fn test_report_sizes() {
// Keyboard report is 8 bytes
let kb_report = KeyboardReport::default();
assert_eq!(kb_report.to_bytes().len(), 8);
}

46
src/hid/otg_device.rs Normal file
View File

@@ -0,0 +1,46 @@
#[cfg(unix)]
use std::fs::File;
#[cfg(unix)]
use std::io::Write;
#[cfg(unix)]
use std::os::unix::io::AsFd;
#[cfg(unix)]
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
#[cfg(unix)]
use tracing::trace;
#[cfg(unix)]
pub struct OtgDeviceIo;
#[cfg(unix)]
impl OtgDeviceIo {
pub fn write_with_timeout(
file: &mut File,
data: &[u8],
timeout_ms: i32,
) -> std::io::Result<bool> {
let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)];
match poll(&mut pollfd, PollTimeout::from(timeout_ms as u16)) {
Ok(1) => {
if let Some(revents) = pollfd[0].revents() {
if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP)
{
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Device error or hangup",
));
}
}
file.write_all(data)?;
Ok(true)
}
Ok(0) => {
trace!("HID write timeout, dropping data");
Ok(false)
}
Ok(_) => Ok(false),
Err(e) => Err(std::io::Error::other(e)),
}
}
}

View File

@@ -1,50 +1,37 @@
//! HID event types for keyboard and mouse
//! Keyboard/mouse/consumer structs (`KeyboardEvent`, `MouseEvent`, …).
use serde::{Deserialize, Serialize};
use super::keyboard::CanonicalKey;
/// Keyboard event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KeyEventType {
/// Key pressed down
Down,
/// Key released
Up,
}
/// Keyboard modifier flags
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeyboardModifiers {
/// Left Control
#[serde(default)]
pub left_ctrl: bool,
/// Left Shift
#[serde(default)]
pub left_shift: bool,
/// Left Alt
#[serde(default)]
pub left_alt: bool,
/// Left Meta (Windows/Super key)
#[serde(default)]
pub left_meta: bool,
/// Right Control
#[serde(default)]
pub right_ctrl: bool,
/// Right Shift
#[serde(default)]
pub right_shift: bool,
/// Right Alt (AltGr)
#[serde(default)]
pub right_alt: bool,
/// Right Meta
#[serde(default)]
pub right_meta: bool,
}
impl KeyboardModifiers {
/// Convert to USB HID modifier byte
pub fn to_hid_byte(&self) -> u8 {
let mut byte = 0u8;
if self.left_ctrl {
@@ -74,7 +61,6 @@ impl KeyboardModifiers {
byte
}
/// Create from USB HID modifier byte
pub fn from_hid_byte(byte: u8) -> Self {
Self {
left_ctrl: byte & 0x01 != 0,
@@ -88,7 +74,6 @@ impl KeyboardModifiers {
}
}
/// Check if any modifier is active
pub fn any(&self) -> bool {
self.left_ctrl
|| self.left_shift
@@ -101,21 +86,16 @@ impl KeyboardModifiers {
}
}
/// Keyboard event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyboardEvent {
/// Event type (down/up)
#[serde(rename = "type")]
pub event_type: KeyEventType,
/// Canonical keyboard key identifier shared across frontend and backend
pub key: CanonicalKey,
/// Modifier keys state
#[serde(default)]
pub modifiers: KeyboardModifiers,
}
impl KeyboardEvent {
/// Create a key down event
pub fn key_down(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Down,
@@ -124,7 +104,6 @@ impl KeyboardEvent {
}
}
/// Create a key up event
pub fn key_up(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Up,
@@ -134,7 +113,6 @@ impl KeyboardEvent {
}
}
/// Mouse button
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseButton {
@@ -146,7 +124,6 @@ pub enum MouseButton {
}
impl MouseButton {
/// Convert to USB HID button bit
pub fn to_hid_bit(&self) -> u8 {
match self {
MouseButton::Left => 0x01,
@@ -158,44 +135,31 @@ impl MouseButton {
}
}
/// Mouse event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseEventType {
/// Mouse moved (relative movement)
Move,
/// Mouse moved (absolute position)
MoveAbs,
/// Button pressed
Down,
/// Button released
Up,
/// Mouse wheel scroll
Scroll,
}
/// Mouse event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MouseEvent {
/// Event type
#[serde(rename = "type")]
pub event_type: MouseEventType,
/// X coordinate or delta
#[serde(default)]
pub x: i32,
/// Y coordinate or delta
#[serde(default)]
pub y: i32,
/// Button (for down/up events)
#[serde(default)]
pub button: Option<MouseButton>,
/// Scroll delta (for scroll events)
#[serde(default)]
pub scroll: i8,
}
impl MouseEvent {
/// Create a relative move event
pub fn move_rel(dx: i32, dy: i32) -> Self {
Self {
event_type: MouseEventType::Move,
@@ -206,7 +170,6 @@ impl MouseEvent {
}
}
/// Create an absolute move event
pub fn move_abs(x: i32, y: i32) -> Self {
Self {
event_type: MouseEventType::MoveAbs,
@@ -217,7 +180,6 @@ impl MouseEvent {
}
}
/// Create a button down event
pub fn button_down(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Down,
@@ -228,7 +190,6 @@ impl MouseEvent {
}
}
/// Create a button up event
pub fn button_up(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Up,
@@ -239,7 +200,6 @@ impl MouseEvent {
}
}
/// Create a scroll event
pub fn scroll(delta: i8) -> Self {
Self {
event_type: MouseEventType::Scroll,
@@ -251,35 +211,19 @@ impl MouseEvent {
}
}
/// Combined HID event (keyboard or mouse)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "device", rename_all = "lowercase")]
pub enum HidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
Consumer(ConsumerEvent),
}
/// Consumer control event (multimedia keys)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsumerEvent {
/// Consumer control usage code (e.g., 0x00CD for Play/Pause)
pub usage: u16,
}
/// USB HID keyboard report (8 bytes)
#[derive(Debug, Clone, Default)]
pub struct KeyboardReport {
/// Modifier byte
pub modifiers: u8,
/// Reserved byte
pub reserved: u8,
/// Key codes (up to 6 simultaneous keys)
pub keys: [u8; 6],
}
impl KeyboardReport {
/// Convert to bytes for USB HID
pub fn to_bytes(&self) -> [u8; 8] {
[
self.modifiers,
@@ -293,7 +237,6 @@ impl KeyboardReport {
]
}
/// Add a key to the report
pub fn add_key(&mut self, key: u8) -> bool {
for slot in &mut self.keys {
if *slot == 0 {
@@ -304,56 +247,21 @@ impl KeyboardReport {
false // All slots full
}
/// Remove a key from the report
pub fn remove_key(&mut self, key: u8) {
for slot in &mut self.keys {
if *slot == key {
*slot = 0;
}
}
// Compact the array
self.keys.sort_by(|a, b| b.cmp(a));
}
/// Clear all keys
pub fn clear(&mut self) {
self.modifiers = 0;
self.keys = [0; 6];
}
}
/// USB HID mouse report
#[derive(Debug, Clone, Default)]
pub struct MouseReport {
/// Button state
pub buttons: u8,
/// X movement (-127 to 127)
pub x: i8,
/// Y movement (-127 to 127)
pub y: i8,
/// Wheel movement (-127 to 127)
pub wheel: i8,
}
impl MouseReport {
/// Convert to bytes for USB HID (relative mouse)
pub fn to_bytes_relative(&self) -> [u8; 4] {
[self.buttons, self.x as u8, self.y as u8, self.wheel as u8]
}
/// Convert to bytes for USB HID (absolute mouse)
pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] {
[
self.buttons,
(x & 0xFF) as u8,
(x >> 8) as u8,
(y & 0xFF) as u8,
(y >> 8) as u8,
self.wheel as u8,
]
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,13 +1,4 @@
//! WebSocket HID channel for HTTP/MJPEG mode
//!
//! This provides an alternative to WebRTC DataChannel for HID input
//! when using MJPEG streaming mode.
//!
//! Uses binary protocol only (same format as DataChannel):
//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes)
//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes)
//!
//! See datachannel.rs for detailed protocol specification.
//! MJPEG mode: HID over WebSocket — same binary framing as [`super::datachannel`] (`0x01`/`0x02`/`0x03`; layout detailed there).
use axum::{
extract::{
@@ -24,25 +15,20 @@ use super::datachannel::{parse_hid_message, HidChannelEvent};
use crate::state::AppState;
use crate::utils::LogThrottler;
/// Binary response codes
const RESP_OK: u8 = 0x00;
const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01;
const RESP_ERR_INVALID_MESSAGE: u8 = 0x02;
/// WebSocket HID upgrade handler
pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
ws.on_upgrade(move |socket| handle_hid_socket(socket, state))
}
/// Handle HID WebSocket connection
async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// Log throttler for error messages (5 second interval)
let log_throttler = LogThrottler::with_secs(5);
info!("WebSocket HID connection established (binary protocol)");
// Check if HID controller is available and send initial status
let hid_available = state.hid.is_available().await;
let initial_response = if hid_available {
vec![RESP_OK]
@@ -59,17 +45,14 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
return;
}
// Process incoming messages (binary only)
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Check HID availability before processing each message
let hid_available = state.hid.is_available().await;
if !hid_available {
if log_throttler.should_log("hid_unavailable") {
warn!("HID controller not available, ignoring message");
}
// Send error response (optional, for client awareness)
let _ = sender
.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into()))
.await;
@@ -77,15 +60,12 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
}
if let Err(e) = handle_binary_message(&data, &state).await {
// Log with throttling to avoid spam
if log_throttler.should_log("binary_hid_error") {
warn!("Binary HID message error: {}", e);
}
// Don't send error response for every failed message to reduce overhead
}
}
Ok(Message::Text(text)) => {
// Text messages are no longer supported
if log_throttler.should_log("text_message_rejected") {
debug!(
"Received text message (not supported): {} bytes",
@@ -111,7 +91,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
}
}
// Reset HID state to release any held keys/buttons
if let Err(e) = state.hid.reset().await {
warn!("Failed to reset HID on WebSocket disconnect: {}", e);
}
@@ -119,7 +98,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
info!("WebSocket HID connection ended");
}
/// Handle binary HID message (same format as DataChannel)
async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> {
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
@@ -160,12 +138,10 @@ mod tests {
assert_eq!(RESP_OK, 0x00);
assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01);
assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02);
// assert_eq!(RESP_ERR_SEND_FAILED, 0x03); // TODO: fix test
}
#[test]
fn test_keyboard_message_format() {
// Keyboard message: [0x01, event_type, key, modifiers]
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl
let event = parse_hid_message(&data);
assert!(event.is_some());
@@ -173,7 +149,6 @@ mod tests {
#[test]
fn test_mouse_message_format() {
// Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra]
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
let event = parse_hid_message(&data);
assert!(event.is_some());

View File

@@ -1,30 +1,35 @@
//! One-KVM - Lightweight IP-KVM solution
//!
//! This crate provides the core functionality for One-KVM,
//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust.
//! Core library for One-KVM (IPKVM: capture, HID, OTG, streaming, Web UI glue).
#[cfg(not(any(unix, windows)))]
compile_error!("One-KVM supports Linux and Windows targets only.");
pub mod atx;
pub mod audio;
pub mod auth;
pub mod config;
pub mod db;
pub mod diagnostics;
pub mod error;
pub mod events;
pub mod extensions;
pub mod hid;
pub mod modules;
#[cfg(unix)]
pub mod msd;
#[cfg(unix)]
pub mod otg;
pub mod platform;
pub mod redfish;
pub mod rtsp;
pub mod rustdesk;
pub mod state;
pub mod stream;
pub mod stream_encoder;
pub mod update;
pub mod utils;
pub mod video;
pub mod web;
pub mod webrtc;
/// Auto-generated secrets module (from secrets.toml at compile time)
pub mod secrets {
include!(concat!(env!("OUT_DIR"), "/secrets_generated.rs"));
}

View File

@@ -1,7 +1,8 @@
use std::collections::HashSet;
use std::future::Future;
use std::io::Write;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
@@ -15,11 +16,15 @@ use one_kvm::atx::AtxController;
use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality};
use one_kvm::auth::{SessionStore, UserStore};
use one_kvm::config::{self, AppConfig, ConfigStore};
use one_kvm::db::DatabasePool;
use one_kvm::events::EventBus;
use one_kvm::extensions::ExtensionManager;
use one_kvm::hid::{HidBackendType, HidController};
#[cfg(unix)]
use one_kvm::msd::MsdController;
#[cfg(unix)]
use one_kvm::otg::OtgService;
use one_kvm::platform::PlatformCapabilities;
use one_kvm::rtsp::RtspService;
use one_kvm::rustdesk::RustDeskService;
use one_kvm::state::AppState;
@@ -33,7 +38,6 @@ use one_kvm::video::{Streamer, VideoStreamManager};
use one_kvm::web;
use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig};
/// Log level for the application
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
enum LogLevel {
Error,
@@ -45,7 +49,6 @@ enum LogLevel {
Trace,
}
/// One-KVM command line arguments
#[derive(Parser, Debug)]
#[command(name = "one-kvm")]
#[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)]
@@ -78,7 +81,7 @@ struct CliArgs {
#[arg(long, value_name = "FILE", requires = "ssl_cert")]
ssl_key: Option<PathBuf>,
/// Data directory path (default: /etc/one-kvm)
/// Data directory path (default: /etc/one-kvm, or the executable directory on Windows)
#[arg(short = 'd', long, value_name = "DIR")]
data_dir: Option<PathBuf>,
@@ -111,65 +114,31 @@ enum UserAction {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Parse command line arguments
let args = CliArgs::parse();
// Initialize logging with CLI arguments
init_logging(args.log_level, args.verbose);
// Install default crypto provider (required by rustls 0.23+)
CryptoProvider::install_default(ring::default_provider())
.expect("Failed to install rustls crypto provider");
tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION"));
let platform = PlatformCapabilities::current();
tracing::info!(
"Platform mode: {:?} ({})",
platform.mode,
platform.mode_label
);
// Determine data directory (CLI arg takes precedence)
let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir);
tracing::info!("Data directory: {}", data_dir.display());
// Run one-off CLI command and exit.
if let Some(command) = args.command {
run_cli_command(command, data_dir).await?;
return Ok(());
}
// Ensure data directory exists
tokio::fs::create_dir_all(&data_dir).await?;
let (db, config_store, mut config) = load_runtime_config(&data_dir).await?;
// Initialize configuration store
let db_path = data_dir.join("one-kvm.db");
let config_store = ConfigStore::new(&db_path).await?;
let mut config = (*config_store.get()).clone();
// Normalize MSD directory (absolute path under data dir if empty/relative)
let mut msd_dir_updated = false;
if config.msd.msd_dir.trim().is_empty() {
let msd_dir = data_dir.join("msd");
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
} else if !PathBuf::from(&config.msd.msd_dir).is_absolute() {
let msd_dir = data_dir.join(&config.msd.msd_dir);
tracing::warn!(
"MSD directory is relative, rebasing to {}",
msd_dir.display()
);
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
}
if msd_dir_updated {
config_store.set(config.clone()).await?;
}
// Ensure MSD directories exist (msd/images, msd/ventoy)
let msd_dir = PathBuf::from(&config.msd.msd_dir);
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
tracing::warn!("Failed to create MSD images directory: {}", e);
}
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("ventoy")).await {
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
}
// Apply CLI argument overrides to config (only if explicitly specified)
if let Some(addr) = args.address {
config.web.bind_address = addr.clone();
config.web.bind_addresses = vec![addr];
@@ -203,29 +172,20 @@ async fn main() -> anyhow::Result<()> {
config.web.http_port
};
// Log final configuration
for ip in &bind_ips {
let addr = SocketAddr::new(*ip, bind_port);
tracing::info!("Server will listen on: {}://{}", scheme, addr);
}
// Initialize session store
let session_store = SessionStore::new(
config_store.pool().clone(),
config.auth.session_timeout_secs as i64,
);
let session_store = SessionStore::new(config.auth.session_timeout_secs as i64);
// Initialize user store
let user_store = UserStore::new(config_store.pool().clone());
let user_store = UserStore::new(db.clone_pool());
// Create shutdown channel
let (shutdown_tx, _) = broadcast::channel::<()>(1);
// Create event bus for real-time notifications
let events = Arc::new(EventBus::new());
tracing::info!("Event bus initialized");
// Parse video configuration once (avoid duplication)
let (video_format, video_resolution) = parse_video_config(&config);
tracing::debug!(
"Parsed video config: {} @ {}x{}",
@@ -234,7 +194,6 @@ async fn main() -> anyhow::Result<()> {
video_resolution.height
);
// Create video streamer and initialize with config if device is set
let streamer = Streamer::new();
streamer.set_event_bus(events.clone()).await;
if let Some(ref device_path) = config.video.device {
@@ -262,19 +221,19 @@ async fn main() -> anyhow::Result<()> {
}
}
// Create WebRTC streamer
let webrtc_streamer = {
let webrtc_config = WebRtcStreamerConfig {
resolution: video_resolution,
input_format: video_format,
fps: config.video.fps,
bitrate_preset: config.stream.bitrate_preset,
encoder_backend: config.stream.encoder.to_backend(),
encoder_backend: one_kvm::stream_encoder::encoder_type_to_backend(
config.stream.encoder.clone(),
),
webrtc: {
let mut stun_servers = vec![];
let mut turn_servers = vec![];
// Check if user configured custom servers
let has_custom_stun = config
.stream
.stun_server
@@ -288,25 +247,12 @@ async fn main() -> anyhow::Result<()> {
.map(|s| !s.is_empty())
.unwrap_or(false);
// If no custom servers, use public ICE servers (like RustDesk)
if !has_custom_stun && !has_custom_turn {
use one_kvm::webrtc::config::public_ice;
if public_ice::is_configured() {
if let Some(stun) = public_ice::stun_server() {
stun_servers.push(stun.clone());
tracing::info!("Using public STUN server: {}", stun);
}
for turn in public_ice::turn_servers() {
tracing::info!("Using public TURN server: {:?}", turn.urls);
turn_servers.push(turn);
}
} else {
tracing::info!(
"No public ICE servers configured, using host candidates only"
);
}
let stun = public_ice::stun_server().to_string();
tracing::info!("Using public STUN server: {}", stun);
stun_servers.push(stun);
} else {
// Use custom servers
if let Some(ref stun) = config.stream.stun_server {
if !stun.is_empty() {
stun_servers.push(stun.clone());
@@ -342,18 +288,18 @@ async fn main() -> anyhow::Result<()> {
};
WebRtcStreamer::with_config(webrtc_config)
};
tracing::info!("WebRTC streamer created (supports H264, extensible to VP8/VP9/H265)");
tracing::info!("WebRTC streamer created");
// Create OTG Service (single instance for centralized USB gadget management)
#[cfg(unix)]
let otg_service = Arc::new(OtgService::new());
#[cfg(unix)]
tracing::info!("OTG Service created");
// Reconcile OTG once from the persisted config so controllers only consume its result.
#[cfg(unix)]
if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await {
tracing::warn!("Failed to apply OTG config: {}", e);
}
// Create HID controller based on config
let hid_backend = match config.hid.backend {
config::HidBackend::Otg => HidBackendType::Otg,
config::HidBackend::Ch9329 => HidBackendType::Ch9329 {
@@ -362,27 +308,21 @@ async fn main() -> anyhow::Result<()> {
},
config::HidBackend::None => HidBackendType::None,
};
let hid = Arc::new(HidController::new(
hid_backend,
Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG
));
#[cfg(unix)]
let hid = Arc::new(HidController::new(hid_backend, Some(otg_service.clone())));
#[cfg(not(unix))]
let hid = Arc::new(HidController::new(hid_backend));
hid.set_event_bus(events.clone()).await;
if let Err(e) = hid.init().await {
tracing::warn!("Failed to initialize HID backend: {}", e);
}
// Create MSD controller (optional, based on config)
#[cfg(unix)]
let msd = if config.msd.enabled {
// Initialize Ventoy resources from data directory
let ventoy_resource_dir = ventoy_img::get_resource_dir(&data_dir);
let ventoy_resource_dir = data_dir.join("ventoy");
if ventoy_resource_dir.exists() {
if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) {
tracing::warn!("Failed to initialize Ventoy resources: {}", e);
tracing::info!(
"Ventoy resource files should be placed in: {}",
ventoy_resource_dir.display()
);
tracing::info!("Required files: {:?}", ventoy_img::required_files());
} else {
tracing::info!(
"Ventoy resources initialized from {}",
@@ -394,10 +334,6 @@ async fn main() -> anyhow::Result<()> {
"Ventoy resource directory not found: {}",
ventoy_resource_dir.display()
);
tracing::info!(
"Create the directory and place the following files: {:?}",
ventoy_img::required_files()
);
}
let controller = MsdController::new(otg_service.clone(), config.msd.msd_dir_path());
@@ -413,7 +349,6 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create ATX controller (optional, based on config)
let atx = if config.atx.enabled {
let controller_config = config.atx.to_controller_config();
let controller = AtxController::new(controller_config);
@@ -429,12 +364,21 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create Audio controller
let audio = {
let audio_config = AudioControllerConfig {
enabled: config.audio.enabled,
device: config.audio.device.clone(),
quality: AudioQuality::from_str(&config.audio.quality),
quality: match config.audio.quality.parse::<AudioQuality>() {
Ok(q) => q,
Err(e) => {
tracing::warn!(
"Invalid audio quality in config (value={:?}): {}, using balanced",
config.audio.quality,
e
);
AudioQuality::Balanced
}
},
};
let controller = AudioController::new(audio_config);
@@ -446,7 +390,6 @@ async fn main() -> anyhow::Result<()> {
config.audio.device,
config.audio.quality
);
// Start audio streaming so WebRTC can subscribe to Opus frames
if let Err(e) = controller.start_streaming().await {
tracing::warn!("Failed to start audio streaming: {}", e);
}
@@ -457,29 +400,23 @@ async fn main() -> anyhow::Result<()> {
Arc::new(controller)
};
// Create Extension manager (ttyd, gostc, easytier)
let extensions = Arc::new(ExtensionManager::new());
tracing::info!("Extension manager initialized");
// Wire up WebRTC streamer with HID controller
// This enables WebRTC DataChannel to process HID events
webrtc_streamer.set_hid_controller(hid.clone()).await;
// Wire up WebRTC streamer with Audio controller
// This enables WebRTC audio track to receive Opus frames
webrtc_streamer.set_audio_controller(audio.clone()).await;
if config.audio.enabled {
if let Err(e) = webrtc_streamer.set_audio_enabled(true).await {
tracing::warn!("Failed to enable WebRTC audio: {}", e);
} else {
tracing::info!("WebRTC audio enabled");
tracing::debug!("WebRTC audio enabled");
}
}
// Configure direct capture for WebRTC encoder pipeline
let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) =
streamer.current_capture_config().await;
tracing::info!(
tracing::debug!(
"Initial video config: {}x{} {:?} @ {}fps",
actual_resolution.width,
actual_resolution.height,
@@ -490,22 +427,50 @@ async fn main() -> anyhow::Result<()> {
.update_video_config(actual_resolution, actual_format, actual_fps)
.await;
if let Some(device_path) = device_path {
let (subdev_path, bridge_kind, v4l2_driver) = streamer
.current_device()
.await
.map(|d| {
(
d.subdev_path.clone(),
d.bridge_kind.clone(),
Some(d.driver.clone()),
)
})
.unwrap_or((None, None, None));
webrtc_streamer
.set_capture_device(device_path, jpeg_quality)
.set_capture_device(
device_path,
jpeg_quality,
subdev_path,
bridge_kind,
v4l2_driver,
)
.await;
tracing::info!("WebRTC streamer configured for direct capture");
tracing::debug!("WebRTC streamer configured for direct capture");
} else {
tracing::warn!("No capture device configured for WebRTC");
}
// Create video stream manager (unified MJPEG/WebRTC management)
// Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance
let stream_manager =
VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone());
let stream_manager = VideoStreamManager::with_webrtc_streamer(
streamer.clone(),
webrtc_streamer.clone() as std::sync::Arc<dyn one_kvm::video::traits::VideoOutput>,
);
stream_manager.set_event_bus(events.clone()).await;
stream_manager.set_config_store(config_store.clone()).await;
{
let stream_manager_weak = Arc::downgrade(&stream_manager);
audio
.set_recovered_callback(Arc::new(move || {
if let Some(stream_manager) = stream_manager_weak.upgrade() {
tokio::spawn(async move {
stream_manager.reconnect_webrtc_audio_sources().await;
});
}
}))
.await;
}
// Initialize stream manager with configured mode
let initial_mode = config.stream.mode.clone();
if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await {
tracing::warn!(
@@ -520,7 +485,6 @@ async fn main() -> anyhow::Result<()> {
);
}
// Create RustDesk service (optional, based on config)
let rustdesk = if config.rustdesk.is_valid() {
tracing::info!(
"Initializing RustDesk service: ID={} -> {}",
@@ -545,7 +509,6 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create RTSP service (optional, based on config)
let rtsp = if config.rtsp.enabled {
tracing::info!(
"Initializing RTSP service: rtsp://{}:{}/{}",
@@ -560,16 +523,19 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create application state
let update_service = Arc::new(UpdateService::new(data_dir.join("updates")));
let state = AppState::new(
db.clone(),
config_store.clone(),
session_store,
user_store,
#[cfg(unix)]
otg_service,
stream_manager,
webrtc_streamer.clone(),
hid,
#[cfg(unix)]
msd,
atx,
audio,
@@ -584,12 +550,10 @@ async fn main() -> anyhow::Result<()> {
extensions.set_event_bus(events.clone()).await;
// Start RustDesk service if enabled
if let Some(ref service) = rustdesk {
if let Err(e) = service.start().await {
tracing::error!("Failed to start RustDesk service: {}", e);
} else {
// Save generated keypair and UUID to config
if let Some(updated_config) = service.save_credentials() {
if let Err(e) = config_store
.update(|cfg| {
@@ -609,7 +573,6 @@ async fn main() -> anyhow::Result<()> {
}
}
// Start RTSP service if enabled
if let Some(ref service) = rtsp {
if let Err(e) = service.start().await {
tracing::error!("Failed to start RTSP service: {}", e);
@@ -618,7 +581,6 @@ async fn main() -> anyhow::Result<()> {
}
}
// Enforce startup codec constraints (e.g. RTSP/RustDesk locks)
{
let runtime_config = state.config.get();
let constraints = StreamCodecConstraints::from_config(&runtime_config);
@@ -633,13 +595,11 @@ async fn main() -> anyhow::Result<()> {
}
}
// Start enabled extensions
{
let ext_config = config_store.get();
extensions.start_enabled(&ext_config.extensions).await;
}
// Start extension health check task (every 30 seconds)
{
let extensions_clone = extensions.clone();
let config_store_clone = config_store.clone();
@@ -656,17 +616,12 @@ async fn main() -> anyhow::Result<()> {
state.publish_device_info().await;
// Start device info broadcast task
// This monitors state change events and broadcasts DeviceInfo to all clients
spawn_device_info_broadcaster(state.clone(), events);
// Create router
let app = web::create_router(state.clone());
// Bind sockets for configured addresses
let listeners = bind_tcp_listeners(&bind_ips, bind_port)?;
// Setup graceful shutdown
let shutdown_signal = async move {
tokio::signal::ctrl_c()
.await
@@ -675,9 +630,7 @@ async fn main() -> anyhow::Result<()> {
let _ = shutdown_tx.send(());
};
// Start server
if config.web.https_enabled {
// Generate self-signed certificate if no custom cert provided
let tls_config = if let (Some(cert_path), Some(key_path)) =
(&config.web.ssl_cert_path, &config.web.ssl_key_path)
{
@@ -687,7 +640,6 @@ async fn main() -> anyhow::Result<()> {
let cert_path = cert_dir.join("server.crt");
let key_path = cert_dir.join("server.key");
// Check if certificate already exists, only generate if missing
if !cert_path.exists() || !key_path.exists() {
tracing::info!("Generating new self-signed TLS certificate");
let cert = generate_self_signed_cert()?;
@@ -701,7 +653,7 @@ async fn main() -> anyhow::Result<()> {
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
};
let mut servers = FuturesUnordered::new();
let servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTPS server on {}", local_addr);
@@ -711,19 +663,9 @@ async fn main() -> anyhow::Result<()> {
servers.push(server);
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTPS server error: {}", e);
}
cleanup(&state).await;
}
}
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await;
} else {
let mut servers = FuturesUnordered::new();
let servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTP server on {}", local_addr);
@@ -733,26 +675,14 @@ async fn main() -> anyhow::Result<()> {
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTP server error: {}", e);
}
cleanup(&state).await;
}
}
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await;
}
tracing::info!("Server shutdown complete");
Ok(())
}
/// Initialize logging with tracing
fn init_logging(level: LogLevel, verbose_count: u8) {
// Verbose count overrides log level
let effective_level = match verbose_count {
0 => level,
1 => LogLevel::Verbose,
@@ -760,17 +690,15 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
_ => LogLevel::Trace,
};
// Build filter string based on effective level
let filter = match effective_level {
LogLevel::Error => "one_kvm=error,tower_http=error",
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
LogLevel::Info => "one_kvm=info,tower_http=info",
LogLevel::Verbose => "one_kvm=debug,tower_http=info",
LogLevel::Debug => "one_kvm=debug,tower_http=debug",
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
LogLevel::Error => "one_kvm=error,tower_http=error,webrtc_sctp=warn",
LogLevel::Warn => "one_kvm=warn,tower_http=warn,webrtc_sctp=warn",
LogLevel::Info => "one_kvm=info,tower_http=info,webrtc_sctp=warn",
LogLevel::Verbose => "one_kvm=debug,tower_http=info,webrtc_sctp=warn",
LogLevel::Debug => "one_kvm=debug,tower_http=debug,webrtc_sctp=warn",
LogLevel::Trace => "one_kvm=trace,tower_http=debug,webrtc_sctp=warn",
};
// Environment variable takes highest priority
let env_filter =
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| filter.into());
@@ -783,29 +711,125 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
}
}
/// Get the application data directory
fn get_data_dir() -> PathBuf {
// Check environment variable first
if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") {
return PathBuf::from(path);
}
// Default to system configuration directory
#[cfg(windows)]
{
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
return exe_dir.join("one-kvm");
}
}
return std::env::current_dir()
.map(|dir| dir.join("one-kvm"))
.unwrap_or_else(|_| PathBuf::from("one-kvm"));
}
#[cfg(not(windows))]
PathBuf::from("/etc/one-kvm")
}
async fn open_database_pool(data_dir: &Path) -> anyhow::Result<DatabasePool> {
let db_path = data_dir.join("one-kvm.db");
let db = DatabasePool::new(&db_path).await?;
db.init_schema().await?;
Ok(db)
}
async fn run_servers_until_shutdown<F, E>(
mut servers: FuturesUnordered<F>,
shutdown_signal: impl Future<Output = ()>,
state: &Arc<AppState>,
protocol: &'static str,
) where
F: Future<Output = Result<(), E>> + Send,
E: std::fmt::Display,
{
tokio::select! {
_ = shutdown_signal => {
cleanup(state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("{} server error: {}", protocol, e);
}
cleanup(state).await;
}
}
}
async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Result<()> {
tokio::fs::create_dir_all(&data_dir).await?;
let db_path = data_dir.join("one-kvm.db");
let config_store = ConfigStore::new(&db_path).await?;
let users = UserStore::new(config_store.pool().clone());
let sessions = SessionStore::new(config_store.pool().clone(), 0);
let db = open_database_pool(&data_dir).await?;
let users = UserStore::new(db.clone_pool());
let sessions = SessionStore::new(0);
match command {
CliCommand::User(user) => run_user_action(user.action, &users, &sessions).await,
}
}
async fn load_runtime_config(
data_dir: &Path,
) -> anyhow::Result<(DatabasePool, ConfigStore, AppConfig)> {
tokio::fs::create_dir_all(data_dir).await?;
let db = open_database_pool(data_dir).await?;
let config_store = ConfigStore::new(db.clone_pool())?;
config_store.load().await?;
let mut config = (*config_store.get()).clone();
config.apply_platform_defaults();
prepare_linux_runtime_dirs(data_dir, &config_store, &mut config).await?;
Ok((db, config_store, config))
}
#[cfg(unix)]
async fn prepare_linux_runtime_dirs(
data_dir: &Path,
config_store: &ConfigStore,
config: &mut AppConfig,
) -> anyhow::Result<()> {
let mut msd_dir_updated = false;
if config.msd.msd_dir.trim().is_empty() {
let msd_dir = data_dir.join("msd");
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
} else if !PathBuf::from(&config.msd.msd_dir).is_absolute() {
let msd_dir = data_dir.join(&config.msd.msd_dir);
tracing::warn!(
"MSD directory is relative, rebasing to {}",
msd_dir.display()
);
config.msd.msd_dir = msd_dir.to_string_lossy().to_string();
msd_dir_updated = true;
}
if msd_dir_updated {
config_store.set(config.clone()).await?;
}
let msd_dir = PathBuf::from(&config.msd.msd_dir);
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
tracing::warn!("Failed to create MSD images directory: {}", e);
}
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("ventoy")).await {
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
}
Ok(())
}
#[cfg(not(unix))]
async fn prepare_linux_runtime_dirs(
_data_dir: &Path,
_config_store: &ConfigStore,
_config: &mut AppConfig,
) -> anyhow::Result<()> {
Ok(())
}
async fn run_user_action(
action: UserAction,
users: &UserStore,
@@ -817,15 +841,9 @@ async fn run_user_action(
}
async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow::Result<()> {
let all = users.list().await?;
let user = match all.len() {
0 => anyhow::bail!("No local user exists yet; complete setup in the web UI first."),
1 => &all[0],
_ => anyhow::bail!(
"Expected exactly one local user (single-user design), found {}. Remove extra users from the database or contact support.",
all.len()
),
};
let user = users.single_user().await?.ok_or_else(|| {
anyhow::anyhow!("No local user exists yet; complete setup in the web UI first.")
})?;
let new_password = read_new_password_interactive()?;
if new_password.len() < 4 {
@@ -833,7 +851,7 @@ async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow
}
users.update_password(&user.id, &new_password).await?;
let revoked = sessions.delete_by_user_id(&user.id).await?;
let revoked = sessions.delete_all().await?;
tracing::info!(
"Password updated for user '{}' and {} sessions revoked",
@@ -869,7 +887,6 @@ fn read_new_password_interactive() -> anyhow::Result<String> {
Ok(a)
}
/// Resolve bind IPs from config, preferring bind_addresses when set.
fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result<Vec<IpAddr>> {
let raw_addrs = if !web.bind_addresses.is_empty() {
web.bind_addresses.as_slice()
@@ -910,7 +927,6 @@ fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result<Vec<std::ne
Ok(listeners)
}
/// Parse video format and resolution from config (avoids code duplication)
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
let format = config
.video
@@ -922,7 +938,6 @@ fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
(format, resolution)
}
/// Generate a self-signed TLS certificate
fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyPair>> {
use rcgen::generate_simple_self_signed;
@@ -936,8 +951,6 @@ fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyP
Ok(certified_key)
}
/// Spawn a background task that monitors state change events
/// and broadcasts DeviceInfo to all WebSocket clients with debouncing
fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
use std::time::{Duration, Instant};
@@ -1024,7 +1037,6 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
let mut pending_broadcast = false;
loop {
// Use timeout to handle pending broadcasts
let recv_result = if pending_broadcast {
let remaining =
DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64);
@@ -1049,12 +1061,9 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
tracing::info!("Event bus closed, stopping DeviceInfo broadcaster");
break;
}
Err(_timeout) => {
// Debounce timeout reached, broadcast now
}
Err(_timeout) => {}
}
// Broadcast if pending and debounce time has passed
if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) {
state.publish_device_info().await;
tracing::trace!("Broadcasted DeviceInfo (debounced)");
@@ -1070,13 +1079,10 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
);
}
/// Clean up subsystems on shutdown
async fn cleanup(state: &Arc<AppState>) {
// Stop all extensions
state.extensions.stop_all().await;
tracing::info!("Extensions stopped");
// Stop RustDesk service
if let Some(ref service) = *state.rustdesk.read().await {
if let Err(e) = service.stop().await {
tracing::warn!("Failed to stop RustDesk service: {}", e);
@@ -1085,7 +1091,6 @@ async fn cleanup(state: &Arc<AppState>) {
}
}
// Stop RTSP service
if let Some(ref service) = *state.rtsp.read().await {
if let Err(e) = service.stop().await {
tracing::warn!("Failed to stop RTSP service: {}", e);
@@ -1094,31 +1099,27 @@ async fn cleanup(state: &Arc<AppState>) {
}
}
// Stop video
if let Err(e) = state.stream_manager.stop().await {
tracing::warn!("Failed to stop streamer: {}", e);
}
// Shutdown HID
if let Err(e) = state.hid.shutdown().await {
tracing::warn!("Failed to shutdown HID: {}", e);
}
// Shutdown MSD
#[cfg(unix)]
if let Some(msd) = state.msd.write().await.as_mut() {
if let Err(e) = msd.shutdown().await {
tracing::warn!("Failed to shutdown MSD: {}", e);
}
}
// Shutdown ATX
if let Some(atx) = state.atx.write().await.as_mut() {
if let Err(e) = atx.shutdown().await {
tracing::warn!("Failed to shutdown ATX: {}", e);
}
}
// Shutdown Audio
if let Err(e) = state.audio.shutdown().await {
tracing::warn!("Failed to shutdown audio: {}", e);
}

View File

@@ -1,49 +0,0 @@
//! Module management for One-KVM
//!
//! This module provides infrastructure for managing feature modules
//! (video streaming, HID control, MSD, ATX) as independent async tasks.
use std::future::Future;
use std::pin::Pin;
use tokio::sync::broadcast;
/// Module status
#[derive(Debug, Clone, PartialEq)]
pub enum ModuleStatus {
Stopped,
Starting,
Running,
Stopping,
Error(String),
}
/// Trait for feature modules
pub trait Module: Send + Sync {
/// Module name
fn name(&self) -> &'static str;
/// Current status
fn status(&self) -> ModuleStatus;
/// Start the module
fn start(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
/// Stop the module
fn stop(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
}
/// Module manager for coordinating feature modules
pub struct ModuleManager {
shutdown_rx: broadcast::Receiver<()>,
}
impl ModuleManager {
pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self {
Self { shutdown_rx }
}
/// Wait for shutdown signal
pub async fn wait_for_shutdown(&mut self) {
let _ = self.shutdown_rx.recv().await;
}
}

View File

@@ -1,11 +1,3 @@
//! MSD Controller
//!
//! Manages the mass storage device lifecycle including:
//! - Image mounting and unmounting
//! - Virtual drive management
//! - State tracking
//! - Image downloads from URL
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
@@ -14,41 +6,25 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::image::ImageManager;
use super::monitor::{MsdHealthMonitor, MsdHealthStatus};
use super::monitor::MsdHealthMonitor;
use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState};
use crate::error::{AppError, Result};
use crate::otg::{MsdFunction, MsdLunConfig, OtgService};
/// MSD Controller
pub struct MsdController {
/// OTG Service reference
otg_service: Arc<OtgService>,
/// MSD function manager (provided by OtgService)
msd_function: RwLock<Option<MsdFunction>>,
/// Current state
state: RwLock<MsdState>,
/// Images storage path
images_path: PathBuf,
/// Ventoy directory path
ventoy_dir: PathBuf,
/// Virtual drive path
drive_path: PathBuf,
/// Event bus for broadcasting state changes (optional)
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
/// Active downloads (download_id -> CancellationToken)
downloads: Arc<RwLock<HashMap<String, CancellationToken>>>,
/// Operation mutex lock (prevents concurrent operations)
operation_lock: Arc<RwLock<()>>,
/// Health monitor for error tracking and recovery
monitor: Arc<MsdHealthMonitor>,
}
impl MsdController {
/// Create new MSD controller
///
/// # Parameters
/// * `otg_service` - OTG service for gadget management
/// * `msd_dir` - Base directory for MSD storage
pub fn new(otg_service: Arc<OtgService>, msd_dir: impl Into<PathBuf>) -> Self {
let msd_dir = msd_dir.into();
let images_path = msd_dir.join("images");
@@ -68,11 +44,9 @@ impl MsdController {
}
}
/// Initialize the MSD controller
pub async fn init(&self) -> Result<()> {
info!("Initializing MSD controller");
// 1. Ensure images directory exists
if let Err(e) = std::fs::create_dir_all(&self.images_path) {
warn!("Failed to create images directory: {}", e);
}
@@ -80,20 +54,16 @@ impl MsdController {
warn!("Failed to create ventoy directory: {}", e);
}
// 2. Get active MSD function from OtgService
info!("Fetching MSD function from OtgService");
let msd_func = self.otg_service.msd_function().await.ok_or_else(|| {
AppError::Internal("MSD function is not active in OtgService".to_string())
})?;
// 3. Store function handle
*self.msd_function.write().await = Some(msd_func);
// 4. Update state
let mut state = self.state.write().await;
state.available = true;
// 5. Check for existing virtual drive
if self.drive_path.exists() {
if let Ok(metadata) = std::fs::metadata(&self.drive_path) {
state.drive_info = Some(DriveInfo {
@@ -114,17 +84,14 @@ impl MsdController {
Ok(())
}
/// Get current MSD state
pub async fn state(&self) -> MsdState {
self.state.read().await.clone()
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: std::sync::Arc<crate::events::EventBus>) {
*self.events.write().await = Some(events);
}
/// Publish an event to the event bus
async fn publish_event(&self, event: crate::events::SystemEvent) {
if let Some(ref bus) = *self.events.read().await {
bus.publish(event);
@@ -137,43 +104,21 @@ impl MsdController {
}
}
/// Check if MSD is available
pub async fn is_available(&self) -> bool {
self.state.read().await.available
}
/// Connect an image file
///
/// # Parameters
/// * `image` - Image info to mount
/// * `cdrom` - Mount as CD-ROM (read-only, removable)
/// * `read_only` - Mount as read-only
pub async fn connect_image(
&self,
image: &ImageInfo,
cdrom: bool,
read_only: bool,
) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(err);
}
self.assert_can_connect(&state).await?;
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Verify image exists
if !image.path.exists() {
let error_msg = format!("Image file not found: {}", image.path.display());
self.monitor
@@ -182,29 +127,12 @@ impl MsdController {
return Err(AppError::Internal(error_msg));
}
// Configure LUN
let config = if cdrom {
MsdLunConfig::cdrom(image.path.clone())
} else {
MsdLunConfig::disk(image.path.clone(), read_only)
};
let gadget_path = self.active_gadget_path().await?;
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(err);
}
self.configure_lun_now(&config).await?;
state.connected = true;
state.mode = MsdMode::Image;
@@ -215,42 +143,19 @@ impl MsdController {
image.name, cdrom, read_only
);
// Release the lock before publishing events
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
self.finish_connect_success().await;
Ok(())
}
/// Connect the virtual drive
pub async fn connect_drive(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(err);
}
self.assert_can_connect(&state).await?;
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Check drive exists
if !self.drive_path.exists() {
let err =
AppError::Internal("Virtual drive not initialized. Call init first.".to_string());
@@ -260,25 +165,8 @@ impl MsdController {
return Err(err);
}
// Configure LUN as read-write disk
let config = MsdLunConfig::disk(self.drive_path.clone(), false);
let gadget_path = self.active_gadget_path().await?;
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(err);
}
self.configure_lun_now(&config).await?;
state.connected = true;
state.mode = MsdMode::Drive;
@@ -286,23 +174,57 @@ impl MsdController {
info!("Connected virtual drive: {}", self.drive_path.display());
// Release the lock before publishing event
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
self.finish_connect_success().await;
Ok(())
}
/// Disconnect current storage
async fn assert_can_connect(&self, state: &MsdState) -> Result<()> {
if !state.available {
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(AppError::Internal("MSD not available".to_string()));
}
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
Ok(())
}
async fn configure_lun_now(&self, config: &MsdLunConfig) -> Result<()> {
let gadget_path = self.active_gadget_path().await?;
let msd_hold = self.msd_function.read().await;
let Some(ref msd) = *msd_hold else {
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(AppError::Internal(
"MSD function not initialized".to_string(),
));
};
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
Ok(())
}
async fn finish_connect_success(&self) {
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
}
pub async fn disconnect(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
@@ -323,7 +245,6 @@ impl MsdController {
info!("Disconnected storage");
// Release the lock before publishing events
drop(state);
drop(_op_guard);
@@ -332,41 +253,31 @@ impl MsdController {
Ok(())
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
/// Get ventoy directory path
pub fn ventoy_dir(&self) -> &PathBuf {
&self.ventoy_dir
}
/// Get virtual drive path
pub fn drive_path(&self) -> &PathBuf {
&self.drive_path
}
/// Check if currently connected
pub async fn is_connected(&self) -> bool {
self.state.read().await.connected
}
/// Get current mode
pub async fn mode(&self) -> MsdMode {
self.state.read().await.mode.clone()
}
/// Update drive info
pub async fn update_drive_info(&self, info: DriveInfo) {
let mut state = self.state.write().await;
state.drive_info = Some(info);
}
/// Start downloading an image from URL
///
/// Returns the download_id that can be used to track or cancel the download.
/// Progress is reported via MsdDownloadProgress events.
pub async fn download_image(
&self,
url: String,
@@ -375,18 +286,15 @@ impl MsdController {
let download_id = uuid::Uuid::new_v4().to_string();
let cancel_token = CancellationToken::new();
// Register download
{
let mut downloads = self.downloads.write().await;
downloads.insert(download_id.clone(), cancel_token.clone());
}
// Extract filename for initial response
let display_filename = filename
.clone()
.unwrap_or_else(|| url.rsplit('/').next().unwrap_or("download").to_string());
// Create initial progress
let initial_progress = DownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
@@ -398,7 +306,6 @@ impl MsdController {
error: None,
};
// Publish started event
self.publish_event(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
@@ -410,18 +317,15 @@ impl MsdController {
})
.await;
// Clone what we need for the spawned task
let images_path = self.images_path.clone();
let events = self.events.read().await.clone();
let downloads = self.downloads.clone();
let download_id_clone = download_id.clone();
let url_clone = url.clone();
// Spawn download task
tokio::spawn(async move {
let manager = ImageManager::new(images_path);
// Create progress callback
let events_for_callback = events.clone();
let download_id_for_callback = download_id_clone.clone();
let url_for_callback = url_clone.clone();
@@ -443,18 +347,15 @@ impl MsdController {
}
};
// Run download
let result = manager
.download_from_url(&url_clone, filename, progress_callback)
.await;
// Remove from active downloads
{
let mut downloads_guard = downloads.write().await;
downloads_guard.remove(&download_id_clone);
}
// Publish completion event
match result {
Ok(image_info) => {
if let Some(ref bus) = events {
@@ -489,7 +390,6 @@ impl MsdController {
Ok(initial_progress)
}
/// Cancel an active download
pub async fn cancel_download(&self, download_id: &str) -> Result<()> {
let mut downloads = self.downloads.write().await;
@@ -505,12 +405,6 @@ impl MsdController {
}
}
/// Get list of active download IDs
pub async fn active_downloads(&self) -> Vec<String> {
let downloads = self.downloads.read().await;
downloads.keys().cloned().collect()
}
async fn active_gadget_path(&self) -> Result<PathBuf> {
self.otg_service
.gadget_path()
@@ -518,16 +412,13 @@ impl MsdController {
.ok_or_else(|| AppError::Internal("OTG gadget path is not available".to_string()))
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down MSD controller");
// 1. Disconnect if connected
if let Err(e) = self.disconnect().await {
warn!("Error disconnecting during shutdown: {}", e);
}
// 2. Clear local state
*self.msd_function.write().await = None;
let mut state = self.state.write().await;
@@ -537,27 +428,9 @@ impl MsdController {
Ok(())
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<MsdHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> MsdHealthStatus {
self.monitor.status().await
}
/// Check if the MSD is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Drop for MsdController {
fn drop(&mut self) {
// Cleanup is handled by OtgGadgetManager when the gadget is torn down
// Individual controllers don't need to cleanup the ConfigFS
}
}
#[cfg(test)]
@@ -573,7 +446,6 @@ mod tests {
let controller = MsdController::new(otg_service, &msd_dir);
// Check that MSD is not initialized (msd_function is None)
let state = controller.state().await;
assert!(!state.available);
assert!(controller.images_path.ends_with("images"));

View File

@@ -1,53 +1,37 @@
//! Image file manager
//!
//! Handles ISO/IMG image file operations:
//! - List available images
//! - Upload new images
//! - Delete images
//! - Metadata management
//! - Download from URL
use chrono::Utc;
use futures::StreamExt;
use std::fs::{self, File};
use std::io::{self, Read, Write};
use std::fs;
#[cfg(test)]
use std::io::Write;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use time::OffsetDateTime;
use tokio::io::AsyncWriteExt;
use tracing::info;
use super::types::ImageInfo;
use crate::error::{AppError, Result};
/// Maximum image size (32 GB)
const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024;
/// Progress report throttle interval (milliseconds)
const PROGRESS_THROTTLE_MS: u64 = 200;
/// Progress report throttle bytes threshold (512 KB)
const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024;
/// Image Manager
pub struct ImageManager {
/// Images storage directory
images_path: PathBuf,
}
impl ImageManager {
/// Create a new image manager
pub fn new(images_path: PathBuf) -> Self {
Self { images_path }
}
/// Ensure images directory exists
pub fn ensure_dir(&self) -> Result<()> {
fs::create_dir_all(&self.images_path)
.map_err(|e| AppError::Internal(format!("Failed to create images directory: {}", e)))?;
Ok(())
}
/// List all available images
pub fn list(&self) -> Result<Vec<ImageInfo>> {
self.ensure_dir()?;
@@ -68,28 +52,26 @@ impl ImageManager {
}
}
// Sort by creation time (newest first)
images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(images)
}
/// Get image info from path
fn get_image_info(&self, path: &Path) -> Option<ImageInfo> {
let metadata = fs::metadata(path).ok()?;
let name = path.file_name()?.to_string_lossy().to_string();
// Use filename hash as ID (stable across restarts)
let id = format!("{:x}", md5_hash(&name));
let id = stable_image_id_from_filename(&name);
let created_at = metadata
.created()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| {
chrono::DateTime::from_timestamp(d.as_secs() as i64, 0).unwrap_or_else(Utc::now)
OffsetDateTime::from_unix_timestamp(d.as_secs() as i64)
.unwrap_or_else(|_| OffsetDateTime::now_utc())
})
.unwrap_or_else(Utc::now);
.unwrap_or_else(OffsetDateTime::now_utc);
Some(ImageInfo {
id,
@@ -100,7 +82,6 @@ impl ImageManager {
})
}
/// Get image by ID
pub fn get(&self, id: &str) -> Result<ImageInfo> {
for image in self.list()? {
if image.id == id {
@@ -110,24 +91,21 @@ impl ImageManager {
Err(AppError::NotFound(format!("Image not found: {}", id)))
}
/// Get image by name
pub fn get_by_name(&self, name: &str) -> Result<ImageInfo> {
let path = self.images_path.join(name);
self.get_image_info(&path)
.ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name)))
}
/// Create a new image from bytes
pub fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
#[cfg(test)]
fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
self.ensure_dir()?;
// Validate name
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Check size
if data.len() as u64 > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
@@ -135,7 +113,6 @@ impl ImageManager {
)));
}
// Write file
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
@@ -144,11 +121,10 @@ impl ImageManager {
)));
}
let mut file = File::create(&path)
let mut file = fs::File::create(&path)
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
file.write_all(data).map_err(|e| {
// Try to clean up on error
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
@@ -158,55 +134,6 @@ impl ImageManager {
self.get_by_name(&name)
}
/// Create a new image from a file stream (for chunked uploads)
pub fn create_from_stream<R: Read>(
&self,
name: &str,
reader: &mut R,
expected_size: Option<u64>,
) -> Result<ImageInfo> {
self.ensure_dir()?;
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
if let Some(size) = expected_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
}
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
name
)));
}
// Create file and copy data
let mut file = File::create(&path)
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
let bytes_written = io::copy(reader, &mut file).map_err(|e| {
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
info!("Created image: {} ({} bytes)", name, bytes_written);
self.get_by_name(&name)
}
/// Create a new image from an async multipart field (streaming, memory-efficient)
///
/// This method streams data directly to disk without buffering the entire file in memory,
/// making it suitable for large files (multi-GB ISOs).
pub async fn create_from_multipart_field(
&self,
name: &str,
@@ -219,12 +146,10 @@ impl ImageManager {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Use a temporary file during upload
let temp_name = format!(".upload_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_name);
let final_path = self.images_path.join(&name);
// Check if final file already exists
if final_path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
@@ -232,23 +157,19 @@ impl ImageManager {
)));
}
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
let mut bytes_written: u64 = 0;
// Stream chunks directly to disk
while let Some(chunk) = field
.chunk()
.await
.map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))?
{
// Check size limit
bytes_written += chunk.len() as u64;
if bytes_written > MAX_IMAGE_SIZE {
// Cleanup and return error
drop(file);
let _ = tokio::fs::remove_file(&temp_path).await;
return Err(AppError::Internal(format!(
@@ -257,19 +178,16 @@ impl ImageManager {
)));
}
// Write chunk to file
file.write_all(&chunk)
.await
.map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?;
}
// Flush and close file
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
@@ -285,7 +203,6 @@ impl ImageManager {
self.get_by_name(&name)
}
/// Delete an image by ID
pub fn delete(&self, id: &str) -> Result<()> {
let image = self.get(id)?;
@@ -296,45 +213,6 @@ impl ImageManager {
Ok(())
}
/// Delete an image by name
pub fn delete_by_name(&self, name: &str) -> Result<()> {
let path = self.images_path.join(name);
if !path.exists() {
return Err(AppError::NotFound(format!("Image not found: {}", name)));
}
fs::remove_file(&path)
.map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?;
info!("Deleted image: {}", name);
Ok(())
}
/// Get total storage used
pub fn used_space(&self) -> u64 {
self.list()
.map(|images| images.iter().map(|i| i.size).sum())
.unwrap_or(0)
}
/// Check if storage has space for new image
pub fn has_space(&self, size: u64) -> bool {
// For now, just check against max size
// In the future, could check disk space
size <= MAX_IMAGE_SIZE
}
/// Download image from URL with progress callback
///
/// # Arguments
/// * `url` - The URL to download from
/// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided)
/// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes)
///
/// # Returns
/// * `Ok(ImageInfo)` - The downloaded image info
/// * `Err(AppError)` - If download fails
pub async fn download_from_url<F>(
&self,
url: &str,
@@ -346,20 +224,17 @@ impl ImageManager {
{
self.ensure_dir()?;
// Validate URL
let parsed_url = reqwest::Url::parse(url)
.map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?;
info!("Starting download from: {}", url);
// Create HTTP client with timeout
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files
.timeout(std::time::Duration::from_secs(3600))
.connect_timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?;
// Send HEAD request first to get content info
let head_response = client
.head(url)
.send()
@@ -379,7 +254,6 @@ impl ImageManager {
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
// Check file size
if let Some(size) = total_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::BadRequest(format!(
@@ -390,11 +264,9 @@ impl ImageManager {
}
}
// Determine filename
let final_filename = if let Some(name) = filename {
sanitize_filename(&name)
} else {
// Try Content-Disposition header first
let from_header = head_response
.headers()
.get(reqwest::header::CONTENT_DISPOSITION)
@@ -404,7 +276,6 @@ impl ImageManager {
if let Some(name) = from_header {
sanitize_filename(&name)
} else {
// Fall back to URL path
let path = parsed_url.path();
let name = path.rsplit('/').next().unwrap_or("download");
let name = urlencoding::decode(name).unwrap_or_else(|_| name.into());
@@ -418,7 +289,6 @@ impl ImageManager {
));
}
// Check if file already exists
let final_path = self.images_path.join(&final_filename);
if final_path.exists() {
return Err(AppError::BadRequest(format!(
@@ -427,11 +297,9 @@ impl ImageManager {
)));
}
// Create temporary file for download
let temp_filename = format!(".download_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_filename);
// Start actual download
let response = client
.get(url)
.send()
@@ -445,7 +313,6 @@ impl ImageManager {
)));
}
// Get actual content length from response (may differ from HEAD)
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
@@ -453,19 +320,16 @@ impl ImageManager {
.and_then(|s| s.parse::<u64>().ok())
.or(total_size);
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
// Stream download with progress (throttled)
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
let mut last_report_time = Instant::now();
let mut last_reported_bytes: u64 = 0;
let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS);
// Report initial progress
progress_callback(0, content_length);
while let Some(chunk_result) = stream.next().await {
@@ -473,14 +337,12 @@ impl ImageManager {
chunk_result.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?;
file.write_all(&chunk).await.map_err(|e| {
// Cleanup on error
let _ = std::fs::remove_file(&temp_path);
AppError::Internal(format!("Failed to write data: {}", e))
})?;
downloaded += chunk.len() as u64;
// Throttled progress reporting: report if enough time or bytes have passed
let now = Instant::now();
let time_elapsed = now.duration_since(last_report_time) >= throttle_interval;
let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES;
@@ -492,18 +354,15 @@ impl ImageManager {
}
}
// Always report final progress
if downloaded != last_reported_bytes {
progress_callback(downloaded, content_length);
}
// Ensure all data is flushed
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Verify downloaded size
let metadata = tokio::fs::metadata(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?;
@@ -519,7 +378,6 @@ impl ImageManager {
}
}
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
@@ -533,35 +391,29 @@ impl ImageManager {
metadata.len()
);
// Return image info
self.get_by_name(&final_filename)
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
}
/// Simple hash function for generating stable IDs
fn md5_hash(s: &str) -> u64 {
fn stable_image_id_from_filename(name: &str) -> String {
let mut hash: u64 = 0;
for (i, byte) in s.bytes().enumerate() {
for (i, byte) in name.bytes().enumerate() {
hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1)));
hash = hash.wrapping_mul(31);
}
hash
format!("{:x}", hash)
}
/// Sanitize filename to prevent path traversal
fn sanitize_filename(name: &str) -> String {
let name = name.trim();
let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_");
// Remove leading dots (hidden files)
let name = name.trim_start_matches('.');
// Limit length
if name.len() > 255 {
name[..255].to_string()
} else {
@@ -569,17 +421,10 @@ fn sanitize_filename(name: &str) -> String {
}
}
/// Extract filename from Content-Disposition header
fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
// Handle both:
// Content-Disposition: attachment; filename="example.iso"
// Content-Disposition: attachment; filename*=UTF-8''example.iso
// Try filename* first (RFC 5987)
if let Some(pos) = header.find("filename*=") {
let start = pos + 10;
let value = &header[start..];
// Format: charset'language'value
if let Some(quote_start) = value.find("''") {
let encoded = value[quote_start + 2..].split(';').next()?;
let decoded = urlencoding::decode(encoded.trim()).ok()?;
@@ -590,7 +435,6 @@ fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
}
}
// Try filename next
if let Some(pos) = header.find("filename=") {
let start = pos + 9;
let value = &header[start..];
@@ -612,7 +456,7 @@ mod tests {
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("test.iso"), "test.iso");
assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.')
assert_eq!(sanitize_filename("../test.iso"), "_test.iso");
assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso");
assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso");
}

View File

@@ -1,19 +1,3 @@
//! MSD (Mass Storage Device) module
//!
//! Provides virtual USB storage functionality with two modes:
//! - Image mounting: Mount ISO/IMG files for system installation
//! - Ventoy drive: Bootable exFAT drive for multiple ISO files
//!
//! Architecture:
//! ```text
//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC
//! |
//! ┌──────┴──────┐
//! │ │
//! Image Manager Ventoy Drive
//! (ISO/IMG) (Bootable exFAT)
//! ```
pub mod controller;
pub mod image;
pub mod monitor;
@@ -22,12 +6,11 @@ pub mod ventoy_drive;
pub use controller::MsdController;
pub use image::ImageManager;
pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig};
pub use monitor::MsdHealthMonitor;
pub use types::{
DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest,
ImageInfo, MsdConnectRequest, MsdMode, MsdState,
};
pub use ventoy_drive::VentoyDrive;
// Re-export from otg module for backward compatibility
pub use crate::otg::{MsdFunction, MsdLunConfig};

View File

@@ -1,99 +1,46 @@
//! MSD (Mass Storage Device) health monitoring
//!
//! This module provides health monitoring for MSD operations, including:
//! - ConfigFS operation error tracking
//! - Image mount/unmount error tracking
//! - Error state tracking
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::utils::LogThrottler;
/// MSD health status
const LOG_THROTTLE_SECS: u64 = 5;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum MsdHealthStatus {
/// Device is healthy and operational
pub(crate) enum MsdHealthStatus {
#[default]
Healthy,
/// Device has an error
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
},
}
/// MSD health monitor configuration
#[derive(Debug, Clone)]
pub struct MsdMonitorConfig {
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for MsdMonitorConfig {
fn default() -> Self {
Self {
log_throttle_secs: 5,
}
}
}
/// MSD health monitor
///
/// Monitors MSD operation health and manages error state.
pub struct MsdHealthMonitor {
/// Current health status
status: RwLock<MsdHealthStatus>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Error count (for tracking)
error_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
}
impl MsdHealthMonitor {
/// Create a new MSD health monitor with the specified configuration
pub fn new(config: MsdMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
pub fn with_defaults() -> Self {
Self {
status: RwLock::new(MsdHealthStatus::Healthy),
throttler: LogThrottler::with_secs(throttle_secs),
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
error_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
}
}
/// Create a new MSD health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(MsdMonitorConfig::default())
}
/// Report an error from MSD operations
///
/// This method is called when an MSD operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Updates in-memory error state
///
/// # Arguments
///
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, reason: &str, error_code: &str) {
let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("msd_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
@@ -102,29 +49,21 @@ impl MsdHealthMonitor {
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = MsdHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
};
}
/// Report that the MSD has recovered from error
///
/// This method is called when an MSD operation succeeds after errors.
/// It resets the error state.
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != MsdHealthStatus::Healthy {
let error_count = self.error_count.load(Ordering::Relaxed);
info!("MSD recovered after {} errors", error_count);
// Reset state
self.error_count.store(0, Ordering::Relaxed);
self.throttler.clear_all();
*self.last_error_code.write().await = None;
@@ -132,29 +71,25 @@ impl MsdHealthMonitor {
}
}
/// Get the current health status
pub async fn status(&self) -> MsdHealthStatus {
#[cfg(test)]
pub(crate) async fn status(&self) -> MsdHealthStatus {
self.status.read().await.clone()
}
/// Get the current error count
pub fn error_count(&self) -> u32 {
#[cfg(test)]
pub(crate) fn error_count(&self) -> u32 {
self.error_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
#[cfg(test)]
pub(crate) async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.error_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
@@ -162,7 +97,6 @@ impl MsdHealthMonitor {
self.throttler.clear_all();
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
match &*self.status.read().await {
MsdHealthStatus::Error { reason, .. } => Some(reason.clone()),
@@ -212,13 +146,11 @@ mod tests {
async fn test_report_recovered() {
let monitor = MsdHealthMonitor::with_defaults();
// First report an error
monitor
.report_error("Image not found", "image_not_found")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.error_count(), 0);

View File

@@ -1,52 +1,38 @@
//! MSD data types and structures
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use time::OffsetDateTime;
/// MSD operating mode
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum MsdMode {
/// No storage connected
#[default]
None,
/// Image file mounted (ISO/IMG)
Image,
/// Virtual drive (FAT32) connected
Drive,
}
/// Image file metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageInfo {
/// Unique image ID
pub id: String,
/// Display name
pub name: String,
/// File path on disk
#[serde(skip_serializing)]
pub path: PathBuf,
/// File size in bytes
pub size: u64,
/// Creation timestamp
pub created_at: DateTime<Utc>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
}
impl ImageInfo {
/// Create new image info
pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self {
Self {
id,
name,
path,
size,
created_at: Utc::now(),
created_at: OffsetDateTime::now_utc(),
}
}
/// Format size for display
pub fn size_display(&self) -> String {
const KB: u64 = 1024;
const MB: u64 = KB * 1024;
@@ -64,18 +50,12 @@ impl ImageInfo {
}
}
/// MSD state information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdState {
/// Whether MSD feature is available
pub available: bool,
/// Current mode
pub mode: MsdMode,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image (if mode is Image)
pub current_image: Option<ImageInfo>,
/// Virtual drive info (if mode is Drive)
pub drive_info: Option<DriveInfo>,
}
@@ -91,24 +71,17 @@ impl Default for MsdState {
}
}
/// Virtual drive information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveInfo {
/// Drive size in bytes
pub size: u64,
/// Used space in bytes
pub used: u64,
/// Free space in bytes
pub free: u64,
/// Whether drive is initialized
pub initialized: bool,
/// Drive file path
#[serde(skip_serializing)]
pub path: PathBuf,
}
impl DriveInfo {
/// Create new drive info
pub fn new(path: PathBuf, size: u64) -> Self {
Self {
size,
@@ -120,91 +93,60 @@ impl DriveInfo {
}
}
/// File entry in virtual drive
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveFile {
/// File name
pub name: String,
/// Relative path from drive root
pub path: String,
/// File size in bytes (0 for directories)
pub size: u64,
/// Whether this is a directory
pub is_dir: bool,
/// Last modified timestamp
pub modified: Option<DateTime<Utc>>,
#[serde(with = "time::serde::rfc3339::option")]
pub modified: Option<OffsetDateTime>,
}
/// MSD connect request
#[derive(Debug, Clone, Deserialize)]
pub struct MsdConnectRequest {
/// Connection mode: "image" or "drive"
pub mode: MsdMode,
/// Image ID to mount (required for image mode)
pub image_id: Option<String>,
/// Mount as CD-ROM (optional, defaults based on image type)
#[serde(default)]
pub cdrom: Option<bool>,
/// Mount as read-only
#[serde(default)]
pub read_only: Option<bool>,
}
/// Virtual drive init request
#[derive(Debug, Clone, Deserialize)]
pub struct DriveInitRequest {
/// Drive size in megabytes (defaults to 16GB)
#[serde(default = "default_drive_size")]
pub size_mb: u32,
/// Optional custom path for Ventoy installation
pub ventoy_path: Option<String>,
}
fn default_drive_size() -> u32 {
16 * 1024 // 16GB
16 * 1024
}
/// Image download request
#[derive(Debug, Clone, Deserialize)]
pub struct ImageDownloadRequest {
/// URL to download from
pub url: String,
/// Optional custom filename
pub filename: Option<String>,
}
/// Download status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DownloadStatus {
/// Download has started
Started,
/// Download is in progress
InProgress,
/// Download completed successfully
Completed,
/// Download failed
Failed,
}
/// Download progress information
#[derive(Debug, Clone, Serialize)]
pub struct DownloadProgress {
/// Unique download ID
pub download_id: String,
/// Source URL
pub url: String,
/// Target filename
pub filename: String,
/// Bytes downloaded so far
pub bytes_downloaded: u64,
/// Total file size (None if unknown)
pub total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
pub progress_pct: Option<f32>,
/// Download status
pub status: DownloadStatus,
/// Error message if failed
pub error: Option<String>,
}
@@ -218,7 +160,7 @@ mod tests {
"test".into(),
"test.iso".into(),
PathBuf::from("/tmp/test.iso"),
1024 * 1024 * 1024 * 2, // 2 GB
1024 * 1024 * 1024 * 2,
);
assert!(info.size_display().contains("GB"));
}

View File

@@ -1,8 +1,3 @@
//! Ventoy Virtual Drive
//!
//! Replaces FAT32 VirtualDrive with a Ventoy bootable image.
//! Provides a bootable USB with exFAT data partition for ISO files.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
@@ -13,33 +8,20 @@ use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage};
use super::types::{DriveFile, DriveInfo};
use crate::error::{AppError, Result};
/// Chunk size for streaming reads (64 KB)
const STREAM_CHUNK_SIZE: usize = 64 * 1024;
/// Minimum drive size (1 GB) - Ventoy requires space for boot partition
const MIN_DRIVE_SIZE_MB: u32 = 1024;
/// Maximum drive size (128 GB)
const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024;
/// Default drive label
const DEFAULT_LABEL: &str = "ONE-KVM";
/// Ventoy Drive Manager
///
/// Thread-safe wrapper around VentoyImage providing async file operations.
/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous.
/// Uses RwLock to allow concurrent read operations while serializing writes.
pub struct VentoyDrive {
/// Drive image path
path: PathBuf,
/// RwLock for concurrent reads, exclusive writes
/// (ventoy-img-rs operations are synchronous and not thread-safe)
lock: Arc<RwLock<()>>,
}
impl VentoyDrive {
/// Create new Ventoy drive manager
pub fn new(path: PathBuf) -> Self {
Self {
path,
@@ -47,40 +29,32 @@ impl VentoyDrive {
}
}
/// Check if drive image exists
pub fn exists(&self) -> bool {
self.path.exists()
}
/// Get drive path
pub fn path(&self) -> &PathBuf {
&self.path
}
/// Initialize a new Ventoy drive image
///
/// Creates a bootable Ventoy image with the specified size.
/// The image includes boot partitions and an exFAT data partition.
pub async fn init(&self, size_mb: u32) -> Result<DriveInfo> {
let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB);
let size_str = format!("{}M", size_mb);
let path = self.path.clone();
let _lock = self.lock.write().await; // Write lock for initialization
let _lock = self.lock.write().await;
info!("Creating {} MB Ventoy drive at {}", size_mb, path.display());
// Run Ventoy creation in blocking task
let info = tokio::task::spawn_blocking(move || {
VentoyImage::create(&path, &size_str, DEFAULT_LABEL).map_err(ventoy_to_app_error)?;
// Get file metadata for DriveInfo
let metadata = std::fs::metadata(&path)
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
Ok::<DriveInfo, AppError>(DriveInfo {
size: metadata.len(),
used: 0,
free: metadata.len(), // Approximate - exFAT overhead not calculated
free: metadata.len(),
initialized: true,
path,
})
@@ -92,20 +66,18 @@ impl VentoyDrive {
Ok(info)
}
/// Get drive information
pub async fn info(&self) -> Result<DriveInfo> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let _lock = self.lock.read().await; // Read lock for info query
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let metadata = std::fs::metadata(&path)
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
// Open image to get file list and calculate used space
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
let files = image.list_files_recursive().map_err(ventoy_to_app_error)?;
@@ -116,7 +88,6 @@ impl VentoyDrive {
.map(|f| f.size)
.sum();
// Note: This is approximate since we don't have exact exFAT overhead
let size = metadata.len();
let free = size.saturating_sub(used);
@@ -132,7 +103,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// List files at a given path (or root if empty/"/")
pub async fn list_files(&self, dir_path: &str) -> Result<Vec<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -140,7 +110,7 @@ impl VentoyDrive {
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.read().await; // Read lock for listing
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -161,9 +131,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Write a file to the drive from multipart upload (streaming)
///
/// Streams the file directly into the Ventoy image's exFAT partition.
pub async fn write_file_from_multipart_field(
&self,
file_path: &str,
@@ -173,12 +140,10 @@ impl VentoyDrive {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, stream to a temporary file (to get the size)
let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp"));
let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4());
let temp_path = temp_dir.join(&temp_name);
// Stream upload to temp file
let mut temp_file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
@@ -201,23 +166,16 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?;
drop(temp_file);
// Now copy from temp file to Ventoy image
let path = self.path.clone();
let file_path = file_path.to_string();
let temp_path_clone = temp_path.clone();
let _lock = self.lock.write().await; // Write lock for file write
let _lock = self.lock.write().await;
let result = tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use add_file_to_path which handles streaming internally
image
.add_file_to_path(
&temp_path_clone,
&file_path,
true, // create_parents
true, // overwrite
)
.add_file_to_path(&temp_path_clone, &file_path, true, true)
.map_err(ventoy_to_app_error)?;
Ok::<(), AppError>(())
@@ -225,14 +183,13 @@ impl VentoyDrive {
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?;
// Cleanup temp file
let _ = tokio::fs::remove_file(&temp_path).await;
result?;
Ok(bytes_written)
}
/// Read a file from the drive (for download)
#[cfg(test)]
pub async fn read_file(&self, file_path: &str) -> Result<Vec<u8>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -240,7 +197,7 @@ impl VentoyDrive {
let path = self.path.clone();
let file_path = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file read
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -251,10 +208,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Get file information without reading content
///
/// Returns file size, name, and other metadata.
/// Returns None if the file doesn't exist.
pub async fn get_file_info(&self, file_path: &str) -> Result<Option<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -262,7 +215,7 @@ impl VentoyDrive {
let path = self.path.clone();
let file_path_owned = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file info
let _lock = self.lock.read().await;
let info = tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -282,10 +235,6 @@ impl VentoyDrive {
}))
}
/// Read a file from the drive as a stream (for large file downloads)
///
/// Returns an async channel receiver that yields chunks of file data.
/// This avoids loading the entire file into memory.
pub async fn read_file_stream(
&self,
file_path: &str,
@@ -297,7 +246,6 @@ impl VentoyDrive {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, get the file size
let file_info = self
.get_file_info(file_path)
.await?
@@ -315,15 +263,12 @@ impl VentoyDrive {
let file_path_owned = file_path.to_string();
let lock = self.lock.clone();
// Create a channel for streaming data
let (tx, rx) =
tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, std::io::Error>>(8);
// Spawn blocking task to read and send chunks
tokio::task::spawn_blocking(move || {
// Hold read lock for the entire read operation
let rt = tokio::runtime::Handle::current();
let _lock = rt.block_on(lock.read()); // Read lock for streaming
let _lock = rt.block_on(lock.read());
let image = match VentoyImage::open(&path) {
Ok(img) => img,
@@ -333,10 +278,8 @@ impl VentoyDrive {
}
};
// Create a channel writer that sends chunks
let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone());
// Stream the file through the writer
if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) {
let _ = rt.block_on(tx.send(Err(std::io::Error::other(e.to_string()))));
}
@@ -345,7 +288,6 @@ impl VentoyDrive {
Ok((file_size, rx))
}
/// Create a directory
pub async fn mkdir(&self, dir_path: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -353,7 +295,7 @@ impl VentoyDrive {
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.write().await; // Write lock for mkdir
let _lock = self.lock.write().await;
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -366,7 +308,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Delete a file or directory
pub async fn delete(&self, path_to_delete: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -374,12 +315,11 @@ impl VentoyDrive {
let path = self.path.clone();
let path_to_delete = path_to_delete.to_string();
let _lock = self.lock.write().await; // Write lock for delete
let _lock = self.lock.write().await;
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use recursive delete to handle directories
image
.remove_recursive(&path_to_delete)
.map_err(ventoy_to_app_error)
@@ -389,7 +329,6 @@ impl VentoyDrive {
}
}
/// Convert VentoyError to AppError
fn ventoy_to_app_error(err: VentoyError) -> AppError {
match err {
VentoyError::Io(e) => AppError::Io(e),
@@ -405,7 +344,6 @@ fn ventoy_to_app_error(err: VentoyError) -> AppError {
}
}
/// Convert VentoyFileInfo to DriveFile
fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile {
let full_path = if parent_path.is_empty() || parent_path == "/" {
format!("/{}", info.name)
@@ -418,13 +356,10 @@ fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFi
path: full_path,
size: info.size,
is_dir: info.is_directory,
modified: None, // Ventoy FileInfo doesn't include timestamps
modified: None,
}
}
/// A writer that sends chunks to an async channel
///
/// This bridges the sync Write trait with async channels for streaming.
struct ChannelWriter {
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
rt: tokio::runtime::Handle,
@@ -484,7 +419,6 @@ impl std::io::Write for ChannelWriter {
impl Drop for ChannelWriter {
fn drop(&mut self) {
// Flush any remaining data when the writer is dropped
let _ = self.flush_buffer();
}
}
@@ -496,16 +430,13 @@ mod tests {
use std::sync::OnceLock;
use tempfile::TempDir;
/// Path to ventoy resources directory
static RESOURCE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../ventoy-img-rs/resources");
/// Initialize ventoy resources once
fn init_ventoy_resources() -> bool {
static INIT: OnceLock<bool> = OnceLock::new();
*INIT.get_or_init(|| {
let resource_path = std::path::Path::new(RESOURCE_DIR);
// Decompress xz files if needed
let core_xz = resource_path.join("core.img.xz");
let core_img = resource_path.join("core.img");
if core_xz.exists() && !core_img.exists() {
@@ -524,7 +455,6 @@ mod tests {
}
}
// Initialize resources
if let Err(e) = ventoy_img::resources::init_resources(resource_path) {
eprintln!("Failed to init ventoy resources: {}", e);
return false;
@@ -534,7 +464,6 @@ mod tests {
})
}
/// Decompress xz file using system command
fn decompress_xz(src: &std::path::Path, dst: &std::path::Path) -> std::io::Result<()> {
let output = Command::new("xz")
.args(["-d", "-k", "-c", src.to_str().unwrap()])
@@ -551,7 +480,6 @@ mod tests {
Ok(())
}
/// Ensure resources are initialized, skip test if failed
fn ensure_resources() -> bool {
if !init_ventoy_resources() {
eprintln!("Skipping test: ventoy resources not available");
@@ -602,15 +530,12 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Write a test file
let test_content = b"Hello, Ventoy!";
let test_file_path = temp_dir.path().join("test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive using ventoy-img directly
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -619,7 +544,6 @@ mod tests {
.await
.unwrap();
// Read file from drive
let read_data = drive.read_file("/test.txt").await.unwrap();
assert_eq!(read_data, test_content);
}
@@ -633,18 +557,14 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a directory
drive.mkdir("/mydir").await.unwrap();
// Write a test file
let test_content = b"Test file content for info check";
let test_file_path = temp_dir.path().join("info_test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -653,7 +573,6 @@ mod tests {
.await
.unwrap();
// Test get_file_info for file
let file_info = drive.get_file_info("/info_test.txt").await.unwrap();
assert!(file_info.is_some());
let file_info = file_info.unwrap();
@@ -661,14 +580,12 @@ mod tests {
assert_eq!(file_info.size, test_content.len() as u64);
assert!(!file_info.is_dir);
// Test get_file_info for directory
let dir_info = drive.get_file_info("/mydir").await.unwrap();
assert!(dir_info.is_some());
let dir_info = dir_info.unwrap();
assert_eq!(dir_info.name, "mydir");
assert!(dir_info.is_dir);
// Test get_file_info for non-existent file
let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap();
assert!(not_found.is_none());
}
@@ -682,16 +599,13 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create test data that spans multiple chunks (>64KB)
let test_size = 200 * 1024; // 200 KB
let test_size = 200 * 1024;
let test_content: Vec<u8> = (0..test_size).map(|i| (i % 256) as u8).collect();
let test_file_path = temp_dir.path().join("large_file.bin");
std::fs::write(&test_file_path, &test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
let file_path_clone = test_file_path.clone();
tokio::task::spawn_blocking(move || {
@@ -701,18 +615,15 @@ mod tests {
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap();
assert_eq!(file_size, test_size as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.len(), test_content.len());
assert_eq!(received_data, test_content);
}
@@ -726,15 +637,12 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a small test file
let test_content = b"Small file for streaming test";
let test_file_path = temp_dir.path().join("small.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -743,18 +651,15 @@ mod tests {
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap();
assert_eq!(file_size, test_content.len() as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.as_slice(), test_content);
}
}

View File

@@ -1,5 +1,3 @@
//! ConfigFS file operations for USB Gadget
use std::fs::{self, File, OpenOptions};
use std::io::Write;
use std::path::Path;
@@ -7,34 +5,18 @@ use std::process::Command;
use crate::error::{AppError, Result};
/// ConfigFS base path for USB gadgets
pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget";
/// Default gadget name
pub const DEFAULT_GADGET_NAME: &str = "one-kvm";
/// USB Vendor ID (Linux Foundation) - default value
pub const DEFAULT_USB_VENDOR_ID: u16 = 0x1d6b;
/// USB Product ID (Multifunction Composite Gadget) - default value
pub const DEFAULT_USB_PRODUCT_ID: u16 = 0x0104;
/// USB device version - default value
pub const DEFAULT_USB_BCD_DEVICE: u16 = 0x0100;
/// USB spec version (USB 2.0)
pub const USB_BCD_USB: u16 = 0x0200;
/// Check if ConfigFS is available
pub fn is_configfs_available() -> bool {
Path::new(CONFIGFS_PATH).exists()
}
/// Ensure libcomposite support is available for USB gadget operations.
///
/// This is a best-effort runtime fallback for systems where `libcomposite`
/// is built as a module and not loaded yet. It does not try to mount configfs;
/// mounting remains an explicit system responsibility.
/// Loads `libcomposite` if needed; does not mount configfs.
pub fn ensure_libcomposite_loaded() -> Result<()> {
if is_configfs_available() {
return Ok(());
@@ -66,7 +48,6 @@ pub fn ensure_libcomposite_loaded() -> Result<()> {
}
}
/// Find available UDC (USB Device Controller)
pub fn find_udc() -> Option<String> {
let udc_path = Path::new("/sys/class/udc");
if !udc_path.exists() {
@@ -80,40 +61,17 @@ pub fn find_udc() -> Option<String> {
.next()
}
/// Check if UDC is known to have low endpoint resources
pub fn is_low_endpoint_udc(name: &str) -> bool {
let name = name.to_ascii_lowercase();
name.contains("musb") || name.contains("musb-hdrc")
}
/// Resolve preferred UDC name if available, otherwise auto-detect
pub fn resolve_udc_name(preferred: Option<&str>) -> Option<String> {
if let Some(name) = preferred {
let path = Path::new("/sys/class/udc").join(name);
if path.exists() {
return Some(name.to_string());
}
}
find_udc()
}
/// Write string content to a file
///
/// For sysfs files, this function appends a newline and flushes
/// to ensure the kernel processes the write immediately.
///
/// IMPORTANT: sysfs attributes require a single atomic write() syscall.
/// The kernel processes the value on the first write(), so we must
/// build the complete buffer (including newline) before writing.
/// Sysfs/configfs: one write syscall with final buffer (incl. newline when needed).
pub fn write_file(path: &Path, content: &str) -> Result<()> {
// For sysfs files (especially write-only ones like forced_eject),
// we need to use simple O_WRONLY without O_TRUNC
// O_TRUNC may fail on special files or require read permission
let mut file = OpenOptions::new()
.write(true)
.open(path)
.or_else(|e| {
// If open fails, try create (for regular files)
if path.exists() {
Err(e)
} else {
@@ -122,9 +80,6 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
})
.map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?;
// Build complete buffer with newline, then write in single syscall.
// This is critical for sysfs - multiple write() calls may cause
// the kernel to only process partial data or return EINVAL.
let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') {
content.as_bytes().into()
} else {
@@ -136,14 +91,12 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
file.write_all(&data)
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
// Explicitly flush to ensure sysfs processes the write
file.flush()
.map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?;
Ok(())
}
/// Write binary content to a file
pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
let mut file = File::create(path)
.map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?;
@@ -154,14 +107,6 @@ pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
Ok(())
}
/// Read string content from a file
pub fn read_file(path: &Path) -> Result<String> {
fs::read_to_string(path)
.map(|s| s.trim().to_string())
.map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e)))
}
/// Create directory if not exists
pub fn create_dir(path: &Path) -> Result<()> {
fs::create_dir_all(path).map_err(|e| {
AppError::Internal(format!(
@@ -172,7 +117,6 @@ pub fn create_dir(path: &Path) -> Result<()> {
})
}
/// Remove directory
pub fn remove_dir(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_dir(path).map_err(|e| {
@@ -186,7 +130,6 @@ pub fn remove_dir(path: &Path) -> Result<()> {
Ok(())
}
/// Remove file
pub fn remove_file(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_file(path).map_err(|e| {
@@ -196,7 +139,6 @@ pub fn remove_file(path: &Path) -> Result<()> {
Ok(())
}
/// Create symlink
pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> {
std::os::unix::fs::symlink(src, dest).map_err(|e| {
AppError::Internal(format!(

View File

@@ -1,11 +1,7 @@
//! USB Endpoint allocation management
use crate::error::{AppError, Result};
/// Default maximum endpoints for typical UDC
pub const DEFAULT_MAX_ENDPOINTS: u8 = 16;
/// Endpoint allocator - manages UDC endpoint resources
#[derive(Debug, Clone)]
pub struct EndpointAllocator {
max_endpoints: u8,
@@ -13,7 +9,6 @@ pub struct EndpointAllocator {
}
impl EndpointAllocator {
/// Create a new endpoint allocator
pub fn new(max_endpoints: u8) -> Self {
Self {
max_endpoints,
@@ -21,7 +16,6 @@ impl EndpointAllocator {
}
}
/// Allocate endpoints for a function
pub fn allocate(&mut self, count: u8) -> Result<()> {
if self.used_endpoints + count > self.max_endpoints {
return Err(AppError::Internal(format!(
@@ -34,27 +28,22 @@ impl EndpointAllocator {
Ok(())
}
/// Release endpoints
pub fn release(&mut self, count: u8) {
self.used_endpoints = self.used_endpoints.saturating_sub(count);
}
/// Get available endpoint count
pub fn available(&self) -> u8 {
self.max_endpoints.saturating_sub(self.used_endpoints)
}
/// Get used endpoint count
pub fn used(&self) -> u8 {
self.used_endpoints
}
/// Get maximum endpoint count
pub fn max(&self) -> u8 {
self.max_endpoints
}
/// Check if can allocate
pub fn can_allocate(&self, count: u8) -> bool {
self.available() >= count
}
@@ -82,7 +71,6 @@ mod tests {
alloc.allocate(4).unwrap();
assert_eq!(alloc.available(), 2);
// Should fail - not enough endpoints
assert!(alloc.allocate(3).is_err());
alloc.release(2);

Some files were not shown because too many files have changed in this diff Show More