Compare commits

..

64 Commits

Author SHA1 Message Date
mofeng-git
3ee3df77b8 chroe: 不再配置 iceCandidatePoolSize,沿用浏览器默认 2026-05-05 01:19:17 +08:00
mofeng-git
8ec2f25e82 chore: bump version to v0.2.0 2026-05-05 00:59:16 +08:00
mofeng-git
c27d3a6703 fix:改进atx usb 继电器适配;修复 webrtc 无法建立连接问题;网页样式优化 2026-05-05 00:52:16 +08:00
mofeng-git
6723f432a3 feat: 允许通过环境变量手动指定前端资源路径,删除 debug 分支默认资源路径 2026-05-04 17:53:27 +08:00
mofeng-git
12a3f1c947 feat: 增加设备丢失自恢复机制
增加音频设备丢失自恢复机制,完善视频设备丢失自恢复机制

降级部分日志级别,GOSTC key打印脱敏

代码格式化
2026-05-02 10:55:05 +08:00
mofeng-git
52754c862b feat: 优化网页消息提醒样式 2026-05-01 21:46:32 +08:00
mofeng-git
e51d243324 feat: 增加 MSD 虚拟盘文件路径编码 2026-05-01 21:27:03 +08:00
mofeng-git
a1ebd34083 feat: 外部扩展程序输出日志级别修改为 info 级别 2026-05-01 21:21:54 +08:00
mofeng-git
89b19ea7dd refactor: 修改为同步请求 2026-05-01 20:06:22 +08:00
mofeng-git
0d47d8395d refactor: 重构视频采集状态与错误分类公共逻辑 2026-05-01 17:56:56 +08:00
mofeng-git
d82c863f40 refactor: 精简依赖 2026-05-01 17:41:11 +08:00
mofeng-git
d8e7de74a6 refactor: 删除部分多余的代码和注释 2026-05-01 17:31:04 +08:00
SilentWind
74035f8e12 Merge pull request #247 from tedaimengtech/main
Update: Add CQU Mirror Information
2026-04-29 14:13:25 +08:00
tedaimeng
8d45186eba Add Mirror Download Services to README
Added a section for Mirror Download Services and updated sponsors.
2026-04-29 13:12:35 +08:00
tedaimeng
c484580b8f Update README with CQUMirror details
Added CQUMirror information and links to the README.
2026-04-29 13:09:23 +08:00
SilentWind
56bce7937c Add GNU General Public License v3 2026-04-29 10:38:30 +08:00
mofeng-git
07b982d1d2 feat: 完善 USB UVC 设备异常处理,添加 USB 设备复位功能 2026-04-27 16:37:04 +08:00
mofeng-git
9065e01225 feat: 优化控制台页面状态工具栏在不同宽度网页下的自适应能力 2026-04-25 20:32:44 +08:00
mofeng-git
cc3cc15774 refactor: 删除部分多余的 Ventoy 逻辑 2026-04-20 14:07:28 +08:00
mofeng-git
fcb39c73fc refactor: 删除未使用的公共 STUN/TURN 逻辑 2026-04-20 10:15:53 +08:00
mofeng-git
7c703b8b4b feat: 深入适配 RK628D CSI 采集卡的设备识别、参数读取、自恢复和音频采集 2026-04-19 11:26:21 +08:00
mofeng-git
8eac31f69f chore: bump version to v0.1.9
Made-with: Cursor
2026-04-12 20:43:44 +08:00
mofeng-git
9653e16a68 feat: CLI 改密、自定义 TLS、移动端适配与扩展校验
- 新增 one-kvm user set-password(交互式),改密后吊销该用户全部会话
- /api/config/web 支持 PEM 证书/密钥上传与清除,响应含 has_custom_cert
- 移动端:ActionBar 溢出菜单、ATX/粘贴底部 Sheet、BrandMark 与控制台等响应式优化
- GOSTC:校验服务器地址非空,管理器启动条件与 HTTP 热更新一致
- RustDesk:中继密钥 relay_key 校验为标准 Base64 且解码后恰好 32 字节
- StatusCard、InfoBar:合并精简冗余状态信息
2026-04-12 19:28:17 +08:00
mofeng-git
d0c0852fbb fix: 修复设置页 IPv6 URL 拼接问题 #241
统一处理 Settings 页面中的主机地址格式,避免 IPv6 场景下 RTSP 地址展示和重启后跳转 URL 因缺少方括号而失效。

Made-with: Cursor
2026-04-12 09:49:42 +08:00
mofeng-git
c0a0c90cbd fix: 修复 ATX 串口继电器共享设备失效 #233
当 power/reset 配置为同一串口设备时,改为复用同一个串口句柄,避免重复 open 导致 reset 不生效。

Made-with: Cursor
2026-04-11 22:31:17 +08:00
mofeng-git
9e3483b836 style: 统一代码格式
应用 fmt 对已暂存 Rust 文件进行格式化,保持代码风格一致并减少无意义差异。

Made-with: Cursor
2026-04-11 22:25:46 +08:00
mofeng-git
132f445c29 完善音频采集 2026-04-11 22:22:05 +08:00
mofeng-git
4952cbaf19 fix: 修复 RTSP 功能对 VLC 视频播放器的兼容性 #237 2026-04-11 21:55:34 +08:00
mofeng-git
099f0b1ca2 fix: 优化视频切换流畅性;修复 OTG HID 功能无法一次保存成功和页面未即刻生效问题 2026-04-11 21:20:54 +08:00
mofeng-git
eecbc0fc13 CSI 采集适配优化 2026-04-11 20:26:33 +08:00
mofeng-git
2d81a071e5 fix: 修复 MJPEG模式可用但状态显示离线的问题 2026-04-11 15:11:47 +08:00
mofeng-git
3e35181583 fix: 修复镜像列表滚动条问题 #238 2026-04-11 13:13:24 +08:00
mofeng-git
c3a3f41a2c 修改 .gitignore 2026-04-11 13:12:11 +08:00
mofeng-git
2b2b471cfb chore(release): bump version to v0.1.8 2026-04-01 21:30:41 +08:00
mofeng-git
abb319068b feat: 适配 RK 原生 HDMI IN 适配采集 2026-04-01 21:28:15 +08:00
mofeng-git
51d7d8b8be chore: bump version to v0.1.7 2026-03-28 22:45:39 +08:00
mofeng-git
f95714d9f0 删除无用测试 2026-03-28 22:41:35 +08:00
mofeng-git
7d52b2e2ea fix(otg): 优化运行时状态监测与未枚举提示 2026-03-28 22:06:53 +08:00
mofeng-git
a2a8b3802d feat(web): 优化视频格式与虚拟键盘显示 2026-03-28 21:34:22 +08:00
mofeng-git
f4283f45a4 refactor(otg): 简化运行时与设置逻辑 2026-03-28 21:09:10 +08:00
mofeng-git
4784cb75e4 调整登录页文案并改为点击切换语言 2026-03-28 20:47:29 +08:00
mofeng-git
abc6bd1677 feat(otg): 增加 libcomposite 自动加载兜底 2026-03-28 17:06:41 +08:00
mofeng-git
1c5288d783 优化控制台与设置页切换时的 WebRTC 会话保活与恢复逻辑 2026-03-27 11:29:27 +08:00
mofeng-git
6bcb54bd22 feat(web): 改为通过 WebSocket 推送 ttyd 状态并清理轮询与冗余接口 2026-03-27 10:49:04 +08:00
mofeng-git
e20136a5ab fix(web): 修复 WebRTC 首帧状态与视频状态判定 2026-03-27 10:44:59 +08:00
mofeng-git
c8fd3648ad refactor(video): 删除废弃视频流水线并收敛 MJPEG/WebRTC 编排与死代码 2026-03-27 08:21:14 +08:00
mofeng-git
6ef2d394d9 fix(video): 启动时丢弃前三帧无效 MJPEG 2026-03-26 23:27:42 +08:00
mofeng-git
762a3b037d chore(hid): 删除 CH9329 收发 trace 日志 2026-03-26 23:05:49 +08:00
mofeng-git
e09a906f93 refactor(hid): 统一 HID 键盘 CanonicalKey 语义并清理前端布局与输入链路冗余代码 2026-03-26 22:51:29 +08:00
mofeng-git
95bf1a852e feat(web): 登录页改为引导卡片样式并增加语言切换按钮 2026-03-26 22:45:28 +08:00
mofeng-git
200f947b5d fix(video): 修复 FFmpeg 硬编码 EAGAIN 刷屏并为编码错误增加日志节流 2026-03-26 22:30:53 +08:00
mofeng-git
46ae0c81e2 refactor(events): 将设备状态广播降级为快照同步并按需订阅 WebSocket 事件,顺带修复相关测试 2026-03-26 22:01:50 +08:00
mofeng-git
779aa180ad refactor: 重构部分事件检查逻辑,修复 ch9329 hid 状态显示异常 2026-03-26 12:33:24 +08:00
mofeng-git
ae26e3c863 feat: 支持自动检测因特尔 GPU 驱动类型 2026-03-25 20:27:26 +08:00
mofeng-git
eeb41159b7 build: 增加硬件编码所需驱动依赖 2026-03-22 20:19:30 +08:00
mofeng-git
24a10aa222 feat: 支持硬件编码能力测试,otg 自检修改为需要手动执行 2026-03-22 20:19:30 +08:00
mofeng-git
c119db4908 perf: 编码器探测测试分辨率由 1080p 调整为 720p 2026-03-22 20:19:30 +08:00
mofeng-git
0db287bf55 refactor: 重构 ffmpeg 编码器探测模块 2026-03-22 20:19:30 +08:00
mofeng-git
e229f35777 fix(web): 修复控制台全屏视频时鼠标定位偏移问题 2026-03-22 20:19:30 +08:00
SilentWind
df647b45cd Merge pull request #230 from a15355447898a/main
修复树莓派4B上 V4L2 编码时 WebRTC 无画面的问题
2026-03-02 19:05:32 +08:00
a15355447898a
b74659dcd4 refactor(video): restore v4l2r and remove temporary debug logs 2026-03-01 01:40:28 +08:00
a15355447898a
4f2fb534a4 fix(video): v4l path + webrtc h264 startup diagnostics 2026-03-01 01:24:26 +08:00
mofeng-git
bd17f8d0f8 chore: 更新版本号到 v0.1.6 2026-02-22 23:03:24 +08:00
mofeng-git
cee43795f8 fix: 添加前端电源状态显示 #226 2026-02-22 22:55:56 +08:00
229 changed files with 20752 additions and 21679 deletions

2
.gitignore vendored
View File

@@ -30,6 +30,7 @@ Thumbs.db
# Build artifacts
/dist/
/build-staging
# Frontend (built files)
/web/node_modules/
@@ -41,3 +42,4 @@ CLAUDE.md
secrets.toml
.env
/docs/
web/package-lock.json

View File

@@ -1,6 +1,6 @@
[package]
name = "one-kvm"
version = "0.1.5"
version = "0.2.0"
edition = "2021"
authors = ["SilentWind"]
description = "A open and lightweight IP-KVM solution written in Rust"
@@ -16,8 +16,8 @@ tokio-util = { version = "0.7", features = ["rt"] }
# Web framework
axum = { version = "0.8", features = ["ws", "multipart", "tokio"] }
axum-extra = { version = "0.12", features = ["typed-header", "cookie"] }
tower-http = { version = "0.6", features = ["fs", "cors", "trace", "compression-gzip"] }
axum-extra = { version = "0.12", features = ["cookie"] }
tower-http = { version = "0.6", features = ["cors", "trace"] }
# Database - Use bundled SQLite for static linking
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
@@ -29,7 +29,6 @@ serde_json = "1"
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "tracing-log"] }
tracing-log = "0.2"
# Error handling
thiserror = "2"
@@ -41,7 +40,6 @@ rand = "0.9"
# Utilities
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
base64 = "0.22"
nix = { version = "0.30", features = ["fs", "net", "hostname", "poll"] }
@@ -51,7 +49,7 @@ reqwest = { version = "0.13", features = ["stream", "rustls", "json"], default-f
urlencoding = "2"
# Static file embedding
rust-embed = { version = "8", features = ["compression"] }
rust-embed = { version = "8", features = ["compression", "debug-embed"] }
mime_guess = "2"
# TLS/HTTPS
@@ -62,8 +60,8 @@ axum-server = { version = "0.8", features = ["tls-rustls"] }
# CLI argument parsing
clap = { version = "4", features = ["derive"] }
# Time
time = "0.3"
# Time (cookie max_age + RFC3339 timestamps)
time = { version = "0.3", features = ["serde", "formatting", "parsing"] }
# Video capture (V4L2)
v4l2r = "0.0.7"
@@ -125,13 +123,10 @@ libyuv = { path = "res/vcpkg/libyuv" }
typeshare = "1.0"
[dev-dependencies]
tokio-test = "0.4"
tempfile = "3"
[build-dependencies]
protobuf-codegen = "3.7"
toml = "0.9"
cc = "1"
[profile.release]
opt-level = 3

674
LICENSE Normal file
View File

@@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

238
README.en.md Normal file
View File

@@ -0,0 +1,238 @@
<div align="center">
<h1>One-KVM</h1>
<p><strong>An open, lightweight IP-KVM stack in Rust — remote management down to BIOS level</strong></p>
<p><a href="README.md">简体中文</a> · <a href="README.en.md">English</a></p>
[![GitHub Release](https://img.shields.io/github/v/release/mofeng-git/One-KVM)](https://github.com/mofeng-git/One-KVM/releases)
[![GitHub stars](https://img.shields.io/github/stars/mofeng-git/One-KVM?style=social)](https://github.com/mofeng-git/One-KVM/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/mofeng-git/One-KVM?style=social)](https://github.com/mofeng-git/One-KVM/network/members)
[![GitHub issues](https://img.shields.io/github/issues/mofeng-git/One-KVM)](https://github.com/mofeng-git/One-KVM/issues)
</div>
---
## Overview
**One-KVM (Rust)** is a lightweight IP-KVM solution written in Rust. It lets you manage servers and workstations over the network, including at BIOS level.
Goals: an open, lightweight, easy-to-use IP-KVM stack.
- **Open**: not tied to one hardware recipe; runs across many setups.
- **Lightweight**: shipped as a binary with minimal moving parts for deployment.
- **Easy to use**: no hand-edited config files required; settings are done in the web UI.
> **One-KVM (Python)** is no longer maintained. If you still need it, see <https://github.com/mofeng-git/One-KVM/tree/python>.
<div align="center">
![One-KVM web console](https://one-kvm.cn/hero-app-effect.png)
</div>
## Features
### Core
| Area | Capabilities |
|------|----------------|
| Video capture | HDMI USB / MIPI CSI / RK3588 HDMI IN; MJPEG and WebRTC (H.264 / H.265 / VP8 / VP9) |
| Video encoding | VAAPI / QSV / RKMPP / V4L2 M2M hardware paths, with software fallback |
| Keyboard & mouse | USB OTG HID or CH340 + CH9329 HID; absolute / relative mouse |
| Virtual media | USB mass storage; ISO/IMG mount and Ventoy-style virtual USB |
| ATX power | GPIO or USB relay; power and reset control |
| Audio | ALSA capture + Opus (HTTP / WebRTC) |
The web UI supports visual configuration and Chinese/English locales. Built-ins include a web terminal (ttyd), intranet tunnel (gostc), P2P (EasyTier), RustDesk protocol (optional cross-platform remote access), and RTSP streaming.
## Installation
Release artifacts are on [GitHub Releases](https://github.com/mofeng-git/One-KVM/releases). Below are short paths for common setups. For **system requirements, hardware, Docker env vars, USB OTG**, and full troubleshooting, see the [One-KVM documentation](https://docs.one-kvm.cn/) (Chinese; use a translator if needed).
### Debian / Ubuntu (.deb)
Download a `one-kvm_*.deb` matching your CPU architecture from Releases, then from the directory containing the package:
```bash
sudo apt update
sudo apt install ./one-kvm_0.x.x_<arch>.deb
```
Replace the version and architecture in the filename with your actual file name.
### Docker
Images:
- **one-kvm** — main app + ttyd
- **one-kvm-full** — same plus optional extras (e.g. gostc, easytier-core)
Example:
```bash
docker run --name one-kvm -itd \
--privileged=true --restart unless-stopped \
-v /dev:/dev -v /sys:/sys \
--net=host \
silentwind0/one-kvm-full
```
If pulls are slow, use the Aliyun mirror, e.g. `registry.cn-hangzhou.aliyuncs.com/silentwind/one-kvm-full` (and `registry.cn-hangzhou.aliyuncs.com/silentwind/one-kvm` for the slim image).
### fnOS NAS (Feiniu / 飞牛)
One-KVM is listed in the fnOS **app store**; search and install on your NAS.
### Web UI and first run
Open `http://<device-ip>:8080` in a browser (**8420** after fnOS install). The first visit runs initial setup.
## Reporting issues
If something breaks:
1. Open [GitHub Issues](https://github.com/mofeng-git/One-KVM/issues) or report in the project QQ group.
2. Include **useful** error messages and steps to reproduce.
3. Mention software version, hardware, and OS details.
## Sponsorship
One-KVM builds on many great open-source projects; a lot of time goes into testing and maintenance. If you find it useful, you can support development on **[Afdian (为爱发电)](https://afdian.com/a/silentwind)**.
### Thanks
<details>
<summary><strong>Supporter list</strong></summary>
- 浩龙的电子嵌入式之路
- Tsuki
- H_xiaoming
- 0蓝蓝0
- fairybl
- Will
- 自.知
- 观棋不语٩ ི۶
- 爱发电用户_a57a4
- 爱发电用户_2c769
- 霜序
- 远方(闲鱼用户名:小远技术店铺)
- 爱发电用户_399fc
- 斐斐の
- 爱发电用户_09451
- 超高校级的錆鱼
- 爱发电用户_08cff
- guoke
- mgt
- 姜沢掵
- ui_beam
- 爱发电用户_c0dd7
- 爱发电用户_dnjK
- 忍者胖猪
- 永遠の願い
- 爱发电用户_GBrF
- 爱发电用户_fd65c
- 爱发电用户_vhNa
- 爱发电用户_Xu6S
- moss
- woshididi
- 爱发电用户_a0fd1
- 爱发电用户_f6bH
- 码农
- 爱发电用户_6639f
- jeron
- 爱发电用户_CN7y
- 爱发电用户_Up6w
- 爱发电用户_e3202
- 一语念白
- 云边
- 爱发电用户_5a711
- 爱发电用户_9a706
- T0m9ir1SUKI
- 爱发电用户_56d52
- 爱发电用户_3N6F
- DUSK
- 飘零
- .
- 饭太稀
-
- MaxZ
- 爱发电用户_c5f33
- 爱发电用户_09386
- 爱发电用户_JT6c
- 爱发电用户_d3d9c
- ......
</details>
### Sponsors
**Mirror Download Services:**
- **[Chongqing University Open Source Software Mirror](https://mirrors.cqu.edu.cn/)** — provides mirror download services
**File hosting**
- **[Huang1111 public-interest program](https://pan.huang1111.cn/s/mxkx3T1)** — login-free downloads
**Cloud**
- **[林枫云](https://www.dkdun.cn)** — project server sponsorship
![林枫云](https://docs.one-kvm.cn/img/36076FEFF0898A80EBD5756D28F4076C.png)
林枫云 offers premium network routes, high-frequency game servers, and high-bandwidth servers in China and abroad.

View File

@@ -2,8 +2,9 @@
<h1>One-KVM</h1>
<p><strong>Rust 编写的开放轻量 IP-KVM 解决方案,实现 BIOS 级远程管理</strong></p>
<p><a href="README.md">简体中文</a></p>
<p><a href="README.md">简体中文</a> · <a href="README.en.md">English</a></p>
[![GitHub Release](https://img.shields.io/github/v/release/mofeng-git/One-KVM)](https://github.com/mofeng-git/One-KVM/releases)
[![GitHub stars](https://img.shields.io/github/stars/mofeng-git/One-KVM?style=social)](https://github.com/mofeng-git/One-KVM/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/mofeng-git/One-KVM?style=social)](https://github.com/mofeng-git/One-KVM/network/members)
[![GitHub issues](https://img.shields.io/github/issues/mofeng-git/One-KVM)](https://github.com/mofeng-git/One-KVM/issues)
@@ -15,57 +16,78 @@
**One-KVM Rust** 是一个用 Rust 编写的轻量级 IP-KVM 解决方案,可通过网络远程管理服务器和工作站,实现 BIOS 级远程控制。
项目目标:
项目目标:提供一个开放、轻量、易用的 IPKVM 解决方案。
- **开放**:不绑定特定硬件配置,尽量适配常见 Linux 设备
- **轻量**二进制分发,部署过程简单
- **易用**网页界面完成设备与参数配置,无需手动配置文件
- **开放**:不绑定特定硬件配置,可在各类硬件环境中稳定运行。
- **轻量**二进制文件形式分发,无繁杂的依赖项,部署过程简单
- **易用**:无需手动编辑配置文件,参数设置均可通过网页界面完成。
> **注意:** One-KVM Rust 目前仍处于开发早期阶段,功能与细节会快速迭代,欢迎体验与反馈
> **One-KVM Python** 已停止开发,如有需要可访问 <https://github.com/mofeng-git/One-KVM/tree/python>
## 🔁 迁移说明
<div align="center">
开发重心正在从 **One-KVM Python** 逐步转向 **One-KVM Rust**
![One-KVM Web 控制台界面](https://one-kvm.cn/hero-app-effect.png)
- 如果你在使用 **One-KVM Python基于 PiKVM**,请查看 [One-KVM Python 文档](https://docs.one-kvm.cn/python/)
- One-KVM Rust 相较于 One-KVM Python**尚未完全适配 CSI HDMI 采集卡**、**不支持 VNC 访问**,仍处于开发早期阶段
</div>
## 📊 功能介绍
### 核心功能
| 功能 | 说明 |
| 功能 | 能力说明 |
|------|------|
| 视频采集 | HDMI USB 采集支持,提供 MJPEG / WebRTCH.264/H.265/VP8/VP9 |
| 视频采集 | HDMI USB /MIPI CSI/RK3588 HDMI IN 采集支持,提供 MJPEG / WebRTCH.264/H.265/VP8/VP9 视频流|
| 视频编码 | VAAPI/QSV/RKMPP/V4L2M2M 硬件编码支持,以及软件编码兜底 |
| 键鼠控制 | USB OTG HID 或 CH340 + CH9329 HID支持绝对/相对鼠标模式 |
| 虚拟媒体 | USB Mass Storage支持 ISO/IMG 镜像挂载和 Ventoy 虚拟U盘模式 |
| ATX 电源控制 | GPIO 控制电源/重启按钮 |
| ATX 电源控制 | GPIO /USB 继电器,支持控制电源重启按钮 |
| 音频传输 | ALSA 采集 + Opus 编码HTTP/WebRTC |
### 硬件编码
支持自动检测和选择硬件加速:
- **VAAPI**Intel/AMD GPU
- **RKMPP**Rockchip SoC
- **V4L2 M2M**:通用硬件编码器
- **软件编码**CPU 编码
### 扩展能力
- Web UI 配置,多语言支持(中文/英文)
- 内置 Web 终端ttyd、内网穿透支持gostc、P2P 组网支持EasyTier、RustDesk 协议集成(用于跨平台远程访问能力扩展)和 RTSP 视频流(用于视频推流)
此外提供基于 Web UI 的可视化配置与中英文界面;并集成 Web 终端ttyd、内网穿透gostc、P2P 组网EasyTier、RustDesk 协议(扩展跨平台远程访问)以及 RTSP 推流等能力。
## ⚡ 安装使用
可以访问 [One-KVM Rust 文档站点](https://docs.one-kvm.cn/) 获取详细信息
构建产物见 [GitHub Releases](https://github.com/mofeng-git/One-KVM/releases)。以下为常见安装方式的简要步骤;**系统要求、硬件准备、Docker 环境变量与 USB OTG 等完整说明**请查阅 [One-KVM Rust 文档站点](https://docs.one-kvm.cn/)。
### 使用 deb 安装Debian / Ubuntu
从 Releases 下载与本机架构匹配的 `one-kvm_*.deb`,在包所在目录执行:
```bash
sudo apt update
sudo apt install ./one-kvm_0.x.x_<arch>.deb
```
将文件名中的版本号与架构替换为实际下载的包名。
### 使用 Docker
镜像分为 **one-kvm**One-KVM 主程序 + ttyd**one-kvm-full**(另含 gostc、easytier-core 等可选扩展),按需选用。
```bash
docker run --name one-kvm -itd \
--privileged=true --restart unless-stopped \
-v /dev:/dev -v /sys:/sys \
--net=host \
silentwind0/one-kvm-full
```
拉取较慢时,可将镜像名替换为阿里云加速,例如 `registry.cn-hangzhou.aliyuncs.com/silentwind/one-kvm-full``one-kvm` 镜像同理,将 `silentwind0/one-kvm` 换为 `registry.cn-hangzhou.aliyuncs.com/silentwind/one-kvm`)。
### 飞牛 NAS
One-KVM 已上架飞牛 **应用市场**,在 NAS 上直接搜索安装即可。
### 访问 Web 与首次配置
浏览器访问 `http://<设备 IP>:8080`(飞牛 NAS 安装后为 8420 端口)。首次访问将引导完成初始配置。
## 报告问题
如果您发现了问题,请:
1. 使用 [GitHub Issues](https://github.com/mofeng-git/One-KVM/issues) 报告,或加入 QQ 群聊反馈。
2. 提供详细的错误信息和复现步骤
3. 包含您硬件配置和系统信息
2. 提供有帮助的错误信息和复现步骤
3. 包含您使用的软件版本、硬件配置和系统信息
## 赞助支持
@@ -88,8 +110,6 @@
- Will
- 浩龙的电子嵌入式之路
- 自.知
- 观棋不语٩ ི۶
@@ -188,8 +208,6 @@
- 爱发电用户_JT6c
- MaxZ
- 爱发电用户_d3d9c
- ......
@@ -200,13 +218,16 @@
本项目得到以下赞助商的支持:
**镜像下载服务:**
- **[重庆大学开源软件镜像站](https://mirrors.cqu.edu.cn/)** - 提供镜像站下载服务
**文件存储服务:**
- **[Huang1111公益计划](https://pan.huang1111.cn/s/mxkx3T1)** - 提供免登录下载服务
**云服务商**
- **[林枫云](https://www.dkdun.cn)** - 赞助了本项目宁波大带宽服务器
- **[林枫云](https://www.dkdun.cn)** - 赞助了本项目服务器
![林枫云](https://docs.one-kvm.cn/img/36076FEFF0898A80EBD5756D28F4076C.png)
林枫云主营国内外地域的精品线路业务服务器、高主频游戏服务器和大带宽服务器。
林枫云主营国内外地域的精品线路业务服务器、高主频游戏服务器和大带宽服务器。

View File

@@ -64,25 +64,6 @@ fn generate_secrets() {
pub mod ice {
/// Google public STUN server URL (hardcoded)
pub const STUN_SERVER: &str = "stun:stun.l.google.com:19302";
/// TURN server URLs - not provided, users must configure their own
pub const TURN_URLS: &str = "";
/// TURN authentication username
pub const TURN_USERNAME: &str = "";
/// TURN authentication password
pub const TURN_PASSWORD: &str = "";
/// Always returns true since we have STUN
pub const fn is_configured() -> bool {
true
}
/// Always returns false since TURN is not provided
pub const fn has_turn() -> bool {
false
}
}
/// RustDesk public server configuration - NOT PROVIDED

View File

@@ -12,7 +12,8 @@ ARG TARGETPLATFORM
# Install runtime dependencies in a single layer
# All codec libraries (libx264, libx265, libopus) are now statically linked
# Only hardware acceleration drivers and core system libraries remain dynamic
RUN apt-get update && \
RUN sed -i 's/ main$/ main contrib non-free/' /etc/apt/sources.list && \
apt-get update && \
apt-get install -y --no-install-recommends \
# Core runtime (all platforms) - no codec libs needed
ca-certificates \
@@ -24,7 +25,8 @@ RUN apt-get update && \
# Platform-specific hardware acceleration
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
apt-get install -y --no-install-recommends \
libva2 libva-drm2 libva-x11-2 libx11-6 libxcb1 libxau6 libxdmcp6 libmfx1; \
libva2 libva-drm2 libva-x11-2 libx11-6 libxcb1 libxau6 libxdmcp6 libmfx1 \
i965-va-driver-shaders intel-media-va-driver-non-free vainfo; \
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
apt-get install -y --no-install-recommends \
libdrm2 libva2; \

View File

@@ -12,7 +12,8 @@ ARG TARGETPLATFORM
# Install runtime dependencies in a single layer
# All codec libraries (libx264, libx265, libopus) are now statically linked
# Only hardware acceleration drivers and core system libraries remain dynamic
RUN apt-get update && \
RUN sed -i 's/ main$/ main contrib non-free/' /etc/apt/sources.list && \
apt-get update && \
apt-get install -y --no-install-recommends \
# Core runtime (all platforms) - no codec libs needed
ca-certificates \
@@ -24,7 +25,8 @@ RUN apt-get update && \
# Platform-specific hardware acceleration
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
apt-get install -y --no-install-recommends \
libva2 libva-drm2 libva-x11-2 libx11-6 libxcb1 libxau6 libxdmcp6 libmfx1; \
libva2 libva-drm2 libva-x11-2 libx11-6 libxcb1 libxau6 libxdmcp6 libmfx1 \
i965-va-driver-shaders intel-media-va-driver-non-free vainfo; \
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
apt-get install -y --no-install-recommends \
libdrm2 libva2; \

View File

@@ -4,6 +4,68 @@
set -e
detect_intel_libva_driver() {
if [ -n "${LIBVA_DRIVER_NAME:-}" ]; then
echo "[INFO] Using preconfigured LIBVA_DRIVER_NAME=$LIBVA_DRIVER_NAME"
return
fi
if [ "$(uname -m)" != "x86_64" ]; then
return
fi
local devices=()
if [ -n "${LIBVA_DEVICE:-}" ]; then
devices=("$LIBVA_DEVICE")
else
shopt -s nullglob
devices=(/dev/dri/renderD*)
shopt -u nullglob
fi
if [ ${#devices[@]} -eq 0 ]; then
return
fi
local device=""
local node=""
local vendor=""
local driver=""
for device in "${devices[@]}"; do
if [ ! -e "$device" ]; then
continue
fi
node="$(basename "$device")"
vendor=""
if [ -r "/sys/class/drm/$node/device/vendor" ]; then
vendor="$(cat "/sys/class/drm/$node/device/vendor")"
fi
if [ -n "$vendor" ] && [ "$vendor" != "0x8086" ]; then
echo "[INFO] Skipping VA-API probe for $device (vendor=$vendor)"
continue
fi
for driver in iHD i965; do
if LIBVA_DRIVER_NAME="$driver" vainfo --display drm --device "$device" >/dev/null 2>&1; then
export LIBVA_DRIVER_NAME="$driver"
if [ -n "$vendor" ]; then
echo "[INFO] Detected Intel VA-API driver '$driver' on $device (vendor=$vendor)"
else
echo "[INFO] Detected Intel VA-API driver '$driver' on $device"
fi
return
fi
done
done
echo "[WARN] Unable to auto-detect an Intel VA-API driver; leaving LIBVA_DRIVER_NAME unset"
}
detect_intel_libva_driver
# Start one-kvm with default options.
# Additional options can be passed via environment variables.

View File

@@ -7,6 +7,8 @@ Wants=network-online.target
[Service]
Type=simple
User=root
# Example for older Intel GPUs:
# Environment=LIBVA_DRIVER_NAME=i965
ExecStart=/usr/bin/one-kvm
Restart=on-failure
RestartSec=5

View File

@@ -126,7 +126,7 @@ EOF
# Create control file
BASE_DEPS="libc6 (>= 2.31), libgcc-s1, libstdc++6, libasound2 (>= 1.1), libdrm2 (>= 2.4)"
AMD64_DEPS="libva2 (>= 2.0), libva-drm2 (>= 2.10), libva-x11-2 (>= 2.10), libmfx1 (>= 21.1), libx11-6 (>= 1.6), libxcb1 (>= 1.14)"
AMD64_DEPS="libva2 (>= 2.0), libva-drm2 (>= 2.10), libva-x11-2 (>= 2.10), libmfx1 (>= 21.1), libx11-6 (>= 1.6), libxcb1 (>= 1.14), i965-va-driver-shaders (>= 2.4), intel-media-va-driver-non-free (>= 21.1)"
DEPS="$BASE_DEPS"
if [ "$DEB_ARCH" = "amd64" ]; then
DEPS="$DEPS, $AMD64_DEPS"

View File

@@ -4,6 +4,9 @@ version = "0.8.0"
edition = "2021"
description = "Hardware video codec for IP-KVM (Windows/Linux)"
[package.metadata.cargo-machete]
ignored = ["serde"]
[features]
default = []
rkmpp = []
@@ -17,6 +20,3 @@ serde_json = "1.0"
[build-dependencies]
cc = "1.0"
bindgen = "0.59"
[dev-dependencies]
env_logger = "0.10"

View File

@@ -25,6 +25,7 @@ enum AVPixelFormat {
AV_PIX_FMT_NV24 = 188,
};
int av_get_pix_fmt(const char *name);
int av_log_get_level(void);
void av_log_set_level(int level);
void hwcodec_set_av_log_callback();

View File

@@ -271,6 +271,7 @@ extern "C" int ffmpeg_hw_mjpeg_h26x_encode(FfmpegHwMjpegH26x* handle,
*out_data = nullptr;
*out_len = 0;
*out_keyframe = 0;
bool encoded = false;
av_packet_unref(ctx->dec_pkt);
int ret = av_new_packet(ctx->dec_pkt, len);
@@ -290,7 +291,7 @@ extern "C" int ffmpeg_hw_mjpeg_h26x_encode(FfmpegHwMjpegH26x* handle,
while (true) {
ret = avcodec_receive_frame(ctx->dec_ctx, ctx->dec_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
return 0;
return encoded ? 1 : 0;
}
if (ret < 0) {
set_last_error(make_err("avcodec_receive_frame failed", ret));
@@ -370,33 +371,40 @@ extern "C" int ffmpeg_hw_mjpeg_h26x_encode(FfmpegHwMjpegH26x* handle,
return -1;
}
av_packet_unref(ctx->enc_pkt);
ret = avcodec_receive_packet(ctx->enc_ctx, ctx->enc_pkt);
if (ret == AVERROR(EAGAIN)) {
av_frame_unref(ctx->dec_frame);
return 0;
}
if (ret < 0) {
set_last_error(make_err("avcodec_receive_packet failed", ret));
av_frame_unref(ctx->dec_frame);
return -1;
}
if (ctx->enc_pkt->size > 0) {
uint8_t *buf = (uint8_t*)malloc(ctx->enc_pkt->size);
if (!buf) {
set_last_error("malloc for output packet failed");
while (true) {
av_packet_unref(ctx->enc_pkt);
ret = avcodec_receive_packet(ctx->enc_ctx, ctx->enc_pkt);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
break;
}
if (ret < 0) {
set_last_error(make_err("avcodec_receive_packet failed", ret));
av_packet_unref(ctx->enc_pkt);
av_frame_unref(ctx->dec_frame);
return -1;
}
memcpy(buf, ctx->enc_pkt->data, ctx->enc_pkt->size);
*out_data = buf;
*out_len = ctx->enc_pkt->size;
*out_keyframe = (ctx->enc_pkt->flags & AV_PKT_FLAG_KEY) ? 1 : 0;
av_packet_unref(ctx->enc_pkt);
av_frame_unref(ctx->dec_frame);
return 1;
if (ctx->enc_pkt->size <= 0) {
set_last_error("avcodec_receive_packet failed, pkt size is 0");
av_packet_unref(ctx->enc_pkt);
av_frame_unref(ctx->dec_frame);
return -1;
}
if (!encoded) {
uint8_t *buf = (uint8_t*)malloc(ctx->enc_pkt->size);
if (!buf) {
set_last_error("malloc for output packet failed");
av_packet_unref(ctx->enc_pkt);
av_frame_unref(ctx->dec_frame);
return -1;
}
memcpy(buf, ctx->enc_pkt->data, ctx->enc_pkt->size);
*out_data = buf;
*out_len = ctx->enc_pkt->size;
*out_keyframe = (ctx->enc_pkt->flags & AV_PKT_FLAG_KEY) ? 1 : 0;
encoded = true;
}
}
av_frame_unref(ctx->dec_frame);

View File

@@ -30,9 +30,15 @@ static int calculate_offset_length(int pix_fmt, int height, const int *linesize,
*length = offset[1] + linesize[2] * height / 2;
break;
case AV_PIX_FMT_NV12:
case AV_PIX_FMT_NV21:
offset[0] = linesize[0] * height;
*length = offset[0] + linesize[1] * height / 2;
break;
case AV_PIX_FMT_NV16:
case AV_PIX_FMT_NV24:
offset[0] = linesize[0] * height;
*length = offset[0] + linesize[1] * height;
break;
case AV_PIX_FMT_YUYV422:
case AV_PIX_FMT_YVYU422:
case AV_PIX_FMT_UYVY422:
@@ -41,6 +47,11 @@ static int calculate_offset_length(int pix_fmt, int height, const int *linesize,
offset[0] = 0; // Only one plane
*length = linesize[0] * height;
break;
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
offset[0] = 0; // Only one plane
*length = linesize[0] * height;
break;
default:
LOG_ERROR(std::string("unsupported pixfmt") + std::to_string(pix_fmt));
return -1;
@@ -397,9 +408,23 @@ private:
const int *const offset) {
switch (frame->format) {
case AV_PIX_FMT_NV12:
case AV_PIX_FMT_NV21:
if (data_length <
frame->height * (frame->linesize[0] + frame->linesize[1] / 2)) {
LOG_ERROR(std::string("fill_frame: NV12 data length error. data_length:") +
LOG_ERROR(std::string("fill_frame: NV12/NV21 data length error. data_length:") +
std::to_string(data_length) +
", linesize[0]:" + std::to_string(frame->linesize[0]) +
", linesize[1]:" + std::to_string(frame->linesize[1]));
return -1;
}
frame->data[0] = data;
frame->data[1] = data + offset[0];
break;
case AV_PIX_FMT_NV16:
case AV_PIX_FMT_NV24:
if (data_length <
frame->height * (frame->linesize[0] + frame->linesize[1])) {
LOG_ERROR(std::string("fill_frame: NV16/NV24 data length error. data_length:") +
std::to_string(data_length) +
", linesize[0]:" + std::to_string(frame->linesize[0]) +
", linesize[1]:" + std::to_string(frame->linesize[1]));
@@ -436,6 +461,17 @@ private:
}
frame->data[0] = data;
break;
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
if (data_length < frame->height * frame->linesize[0]) {
LOG_ERROR(std::string("fill_frame: RGB24/BGR24 data length error. data_length:") +
std::to_string(data_length) +
", linesize[0]:" + std::to_string(frame->linesize[0]) +
", height:" + std::to_string(frame->height));
return -1;
}
frame->data[0] = data;
break;
default:
LOG_ERROR(std::string("fill_frame: unsupported format, ") +
std::to_string(frame->format));

View File

@@ -6,7 +6,7 @@
include!(concat!(env!("OUT_DIR"), "/ffmpeg_ffi.rs"));
use serde_derive::{Deserialize, Serialize};
use std::env;
use std::{env, ffi::CString};
#[derive(Debug, Eq, PartialEq, Clone, Copy, Serialize, Deserialize)]
pub enum AVHWDeviceType {
@@ -59,6 +59,22 @@ pub(crate) fn init_av_log() {
});
}
pub fn resolve_pixel_format(name: &str, fallback: AVPixelFormat) -> i32 {
let c_name = match CString::new(name) {
Ok(name) => name,
Err(_) => return fallback as i32,
};
unsafe {
let resolved = av_get_pix_fmt(c_name.as_ptr());
if resolved >= 0 {
resolved
} else {
fallback as i32
}
}
}
fn parse_ffmpeg_log_level() -> i32 {
let raw = match env::var("ONE_KVM_FFMPEG_LOG") {
Ok(value) => value,

View File

@@ -15,12 +15,412 @@ use std::{
slice,
};
use super::Priority;
#[cfg(any(windows, target_os = "linux"))]
use crate::common::Driver;
/// Timeout for encoder test in milliseconds
const TEST_TIMEOUT_MS: u64 = 3000;
const PRIORITY_NVENC: i32 = 0;
const PRIORITY_QSV: i32 = 1;
const PRIORITY_AMF: i32 = 2;
const PRIORITY_RKMPP: i32 = 3;
const PRIORITY_VAAPI: i32 = 4;
const PRIORITY_V4L2M2M: i32 = 5;
#[derive(Clone, Copy)]
struct CandidateCodecSpec {
name: &'static str,
format: DataFormat,
priority: i32,
}
fn push_candidate(codecs: &mut Vec<CodecInfo>, candidate: CandidateCodecSpec) {
codecs.push(CodecInfo {
name: candidate.name.to_owned(),
format: candidate.format,
priority: candidate.priority,
..Default::default()
});
}
#[cfg(target_os = "linux")]
fn linux_support_vaapi() -> bool {
let entries = match std::fs::read_dir("/dev/dri") {
Ok(entries) => entries,
Err(_) => return false,
};
entries.flatten().any(|entry| {
entry
.file_name()
.to_str()
.map(|name| name.starts_with("renderD"))
.unwrap_or(false)
})
}
#[cfg(not(target_os = "linux"))]
fn linux_support_vaapi() -> bool {
false
}
#[cfg(target_os = "linux")]
fn linux_support_rkmpp() -> bool {
extern "C" {
fn linux_support_rkmpp() -> c_int;
}
unsafe { linux_support_rkmpp() == 0 }
}
#[cfg(not(target_os = "linux"))]
fn linux_support_rkmpp() -> bool {
false
}
#[cfg(target_os = "linux")]
fn linux_support_v4l2m2m() -> bool {
extern "C" {
fn linux_support_v4l2m2m() -> c_int;
}
unsafe { linux_support_v4l2m2m() == 0 }
}
#[cfg(not(target_os = "linux"))]
fn linux_support_v4l2m2m() -> bool {
false
}
#[cfg(any(windows, target_os = "linux"))]
fn enumerate_candidate_codecs(ctx: &EncodeContext) -> Vec<CodecInfo> {
use log::debug;
let mut codecs = Vec::new();
let contains = |_vendor: Driver, _format: DataFormat| {
// Without VRAM feature, we can't check SDK availability.
// Keep the prefilter coarse and let FFmpeg validation do the real check.
true
};
let (nv, amf, intel) = crate::common::supported_gpu(true);
debug!(
"GPU support detected - NV: {}, AMF: {}, Intel: {}",
nv, amf, intel
);
if nv && contains(Driver::NV, H264) {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "h264_nvenc",
format: H264,
priority: PRIORITY_NVENC,
},
);
}
if nv && contains(Driver::NV, H265) {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "hevc_nvenc",
format: H265,
priority: PRIORITY_NVENC,
},
);
}
if intel && contains(Driver::MFX, H264) {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "h264_qsv",
format: H264,
priority: PRIORITY_QSV,
},
);
}
if intel && contains(Driver::MFX, H265) {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "hevc_qsv",
format: H265,
priority: PRIORITY_QSV,
},
);
}
if amf && contains(Driver::AMF, H264) {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "h264_amf",
format: H264,
priority: PRIORITY_AMF,
},
);
}
if amf && contains(Driver::AMF, H265) {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "hevc_amf",
format: H265,
priority: PRIORITY_AMF,
},
);
}
if linux_support_rkmpp() {
debug!("RKMPP hardware detected, adding Rockchip encoders");
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "h264_rkmpp",
format: H264,
priority: PRIORITY_RKMPP,
},
);
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "hevc_rkmpp",
format: H265,
priority: PRIORITY_RKMPP,
},
);
}
if cfg!(target_os = "linux") && linux_support_vaapi() {
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "h264_vaapi",
format: H264,
priority: PRIORITY_VAAPI,
},
);
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "hevc_vaapi",
format: H265,
priority: PRIORITY_VAAPI,
},
);
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "vp8_vaapi",
format: VP8,
priority: PRIORITY_VAAPI,
},
);
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "vp9_vaapi",
format: VP9,
priority: PRIORITY_VAAPI,
},
);
}
if linux_support_v4l2m2m() {
debug!("V4L2 M2M hardware detected, adding V4L2 encoders");
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "h264_v4l2m2m",
format: H264,
priority: PRIORITY_V4L2M2M,
},
);
push_candidate(
&mut codecs,
CandidateCodecSpec {
name: "hevc_v4l2m2m",
format: H265,
priority: PRIORITY_V4L2M2M,
},
);
}
codecs.retain(|codec| {
!(ctx.pixfmt == AVPixelFormat::AV_PIX_FMT_YUV420P as i32 && codec.name.contains("qsv"))
});
codecs
}
#[derive(Clone, Copy)]
struct ProbePolicy {
max_attempts: usize,
request_keyframe: bool,
accept_any_output: bool,
}
impl ProbePolicy {
fn for_codec(codec_name: &str) -> Self {
if codec_name.contains("v4l2m2m") {
Self {
max_attempts: 5,
request_keyframe: true,
accept_any_output: true,
}
} else {
Self {
max_attempts: 1,
request_keyframe: false,
accept_any_output: false,
}
}
}
fn prepare_attempt(&self, encoder: &mut Encoder) {
if self.request_keyframe {
encoder.request_keyframe();
}
}
fn passed(&self, frames: &[EncodeFrame], elapsed_ms: u128) -> bool {
if elapsed_ms >= TEST_TIMEOUT_MS as u128 {
return false;
}
if self.accept_any_output {
!frames.is_empty()
} else {
frames.len() == 1 && frames[0].key == 1
}
}
}
fn log_failed_probe_attempt(
codec_name: &str,
policy: ProbePolicy,
attempt: usize,
frames: &[EncodeFrame],
elapsed_ms: u128,
) {
use log::debug;
if policy.accept_any_output {
if frames.is_empty() {
debug!(
"Encoder {} test produced no output on attempt {}",
codec_name, attempt
);
} else {
debug!(
"Encoder {} test failed on attempt {} - frames: {}, timeout: {}ms",
codec_name,
attempt,
frames.len(),
elapsed_ms
);
}
} else if frames.len() == 1 {
debug!(
"Encoder {} test failed on attempt {} - key: {}, timeout: {}ms",
codec_name, attempt, frames[0].key, elapsed_ms
);
} else {
debug!(
"Encoder {} test failed on attempt {} - wrong frame count: {}",
codec_name,
attempt,
frames.len()
);
}
}
fn validate_candidate(codec: &CodecInfo, ctx: &EncodeContext, yuv: &[u8]) -> bool {
use log::debug;
debug!("Testing encoder: {}", codec.name);
let test_ctx = EncodeContext {
name: codec.name.clone(),
mc_name: codec.mc_name.clone(),
..ctx.clone()
};
match Encoder::new(test_ctx) {
Ok(mut encoder) => {
debug!("Encoder {} created successfully", codec.name);
let policy = ProbePolicy::for_codec(&codec.name);
let mut last_err: Option<i32> = None;
for attempt in 0..policy.max_attempts {
let attempt_no = attempt + 1;
policy.prepare_attempt(&mut encoder);
let pts = (attempt as i64) * 33;
let start = std::time::Instant::now();
match encoder.encode(yuv, pts) {
Ok(frames) => {
let elapsed = start.elapsed().as_millis();
if policy.passed(frames, elapsed) {
if policy.accept_any_output {
debug!(
"Encoder {} test passed on attempt {} (frames: {})",
codec.name,
attempt_no,
frames.len()
);
} else {
debug!(
"Encoder {} test passed on attempt {}",
codec.name, attempt_no
);
}
return true;
} else {
log_failed_probe_attempt(
&codec.name,
policy,
attempt_no,
frames,
elapsed,
);
}
}
Err(err) => {
last_err = Some(err);
debug!(
"Encoder {} test attempt {} returned error: {}",
codec.name, attempt_no, err
);
}
}
}
debug!(
"Encoder {} test failed after retries{}",
codec.name,
last_err
.map(|e| format!(" (last err: {})", e))
.unwrap_or_default()
);
false
}
Err(_) => {
debug!("Failed to create encoder {}", codec.name);
false
}
}
}
fn add_software_fallback(codecs: &mut Vec<CodecInfo>) {
use log::debug;
for fallback in CodecInfo::soft().into_vec() {
if !codecs.iter().any(|codec| codec.format == fallback.format) {
debug!(
"Adding software {:?} encoder: {}",
fallback.format, fallback.name
);
codecs.push(fallback);
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct EncodeContext {
@@ -28,7 +428,7 @@ pub struct EncodeContext {
pub mc_name: Option<String>,
pub width: i32,
pub height: i32,
pub pixfmt: AVPixelFormat,
pub pixfmt: i32,
pub align: i32,
pub fps: i32,
pub gop: i32,
@@ -83,7 +483,7 @@ impl Encoder {
CString::new(mc_name.as_str()).map_err(|_| ())?.as_ptr(),
ctx.width,
ctx.height,
ctx.pixfmt as c_int,
ctx.pixfmt,
ctx.align,
ctx.fps,
ctx.gop,
@@ -185,305 +585,21 @@ impl Encoder {
if !(cfg!(windows) || cfg!(target_os = "linux")) {
return vec![];
}
let mut codecs: Vec<CodecInfo> = vec![];
#[cfg(any(windows, target_os = "linux"))]
{
let contains = |_vendor: Driver, _format: DataFormat| {
// Without VRAM feature, we can't check SDK availability
// Just return true and let FFmpeg handle the actual detection
true
};
let (_nv, amf, _intel) = crate::common::supported_gpu(true);
debug!(
"GPU support detected - NV: {}, AMF: {}, Intel: {}",
_nv, amf, _intel
);
#[cfg(windows)]
if _intel && contains(Driver::MFX, H264) {
codecs.push(CodecInfo {
name: "h264_qsv".to_owned(),
format: H264,
priority: Priority::Best as _,
..Default::default()
});
}
#[cfg(windows)]
if _intel && contains(Driver::MFX, H265) {
codecs.push(CodecInfo {
name: "hevc_qsv".to_owned(),
format: H265,
priority: Priority::Best as _,
..Default::default()
});
}
if _nv && contains(Driver::NV, H264) {
codecs.push(CodecInfo {
name: "h264_nvenc".to_owned(),
format: H264,
priority: Priority::Best as _,
..Default::default()
});
}
if _nv && contains(Driver::NV, H265) {
codecs.push(CodecInfo {
name: "hevc_nvenc".to_owned(),
format: H265,
priority: Priority::Best as _,
..Default::default()
});
}
if amf && contains(Driver::AMF, H264) {
codecs.push(CodecInfo {
name: "h264_amf".to_owned(),
format: H264,
priority: Priority::Best as _,
..Default::default()
});
}
if amf {
codecs.push(CodecInfo {
name: "hevc_amf".to_owned(),
format: H265,
priority: Priority::Best as _,
..Default::default()
});
}
#[cfg(target_os = "linux")]
{
codecs.push(CodecInfo {
name: "h264_vaapi".to_owned(),
format: H264,
priority: Priority::Good as _,
..Default::default()
});
codecs.push(CodecInfo {
name: "hevc_vaapi".to_owned(),
format: H265,
priority: Priority::Good as _,
..Default::default()
});
codecs.push(CodecInfo {
name: "vp8_vaapi".to_owned(),
format: VP8,
priority: Priority::Good as _,
..Default::default()
});
codecs.push(CodecInfo {
name: "vp9_vaapi".to_owned(),
format: VP9,
priority: Priority::Good as _,
..Default::default()
});
// Rockchip MPP hardware encoder support
use std::ffi::c_int;
extern "C" {
fn linux_support_rkmpp() -> c_int;
fn linux_support_v4l2m2m() -> c_int;
}
if unsafe { linux_support_rkmpp() } == 0 {
debug!("RKMPP hardware detected, adding Rockchip encoders");
codecs.push(CodecInfo {
name: "h264_rkmpp".to_owned(),
format: H264,
priority: Priority::Best as _,
..Default::default()
});
codecs.push(CodecInfo {
name: "hevc_rkmpp".to_owned(),
format: H265,
priority: Priority::Best as _,
..Default::default()
});
}
// V4L2 Memory-to-Memory hardware encoder support (generic ARM)
if unsafe { linux_support_v4l2m2m() } == 0 {
debug!("V4L2 M2M hardware detected, adding V4L2 encoders");
codecs.push(CodecInfo {
name: "h264_v4l2m2m".to_owned(),
format: H264,
priority: Priority::Good as _,
..Default::default()
});
codecs.push(CodecInfo {
name: "hevc_v4l2m2m".to_owned(),
format: H265,
priority: Priority::Good as _,
..Default::default()
});
}
}
}
// qsv doesn't support yuv420p
codecs.retain(|c| {
let ctx = ctx.clone();
if ctx.pixfmt == AVPixelFormat::AV_PIX_FMT_YUV420P && c.name.contains("qsv") {
return false;
}
return true;
});
let mut res = vec![];
#[cfg(any(windows, target_os = "linux"))]
let codecs = enumerate_candidate_codecs(&ctx);
if let Ok(yuv) = Encoder::dummy_yuv(ctx.clone()) {
for codec in codecs {
// Skip if this format already exists in results
if res
.iter()
.any(|existing: &CodecInfo| existing.format == codec.format)
{
continue;
}
debug!("Testing encoder: {}", codec.name);
let c = EncodeContext {
name: codec.name.clone(),
mc_name: codec.mc_name.clone(),
..ctx
};
match Encoder::new(c) {
Ok(mut encoder) => {
debug!("Encoder {} created successfully", codec.name);
let mut passed = false;
let mut last_err: Option<i32> = None;
let is_v4l2m2m = codec.name.contains("v4l2m2m");
let max_attempts = if is_v4l2m2m { 5 } else { 1 };
for attempt in 0..max_attempts {
if is_v4l2m2m {
encoder.request_keyframe();
}
let pts = (attempt as i64) * 33; // 33ms is an approximation for 30 FPS (1000 / 30)
let start = std::time::Instant::now();
match encoder.encode(&yuv, pts) {
Ok(frames) => {
let elapsed = start.elapsed().as_millis();
if is_v4l2m2m {
if !frames.is_empty() && elapsed < TEST_TIMEOUT_MS as _ {
debug!(
"Encoder {} test passed on attempt {} (frames: {})",
codec.name,
attempt + 1,
frames.len()
);
res.push(codec.clone());
passed = true;
break;
} else if frames.is_empty() {
debug!(
"Encoder {} test produced no output on attempt {}",
codec.name,
attempt + 1
);
} else {
debug!(
"Encoder {} test failed on attempt {} - frames: {}, timeout: {}ms",
codec.name,
attempt + 1,
frames.len(),
elapsed
);
}
} else if frames.len() == 1 {
if frames[0].key == 1 && elapsed < TEST_TIMEOUT_MS as _ {
debug!(
"Encoder {} test passed on attempt {}",
codec.name,
attempt + 1
);
res.push(codec.clone());
passed = true;
break;
} else {
debug!(
"Encoder {} test failed on attempt {} - key: {}, timeout: {}ms",
codec.name,
attempt + 1,
frames[0].key,
elapsed
);
}
} else {
debug!(
"Encoder {} test failed on attempt {} - wrong frame count: {}",
codec.name,
attempt + 1,
frames.len()
);
}
}
Err(err) => {
last_err = Some(err);
debug!(
"Encoder {} test attempt {} returned error: {}",
codec.name,
attempt + 1,
err
);
}
}
}
if !passed {
debug!(
"Encoder {} test failed after retries{}",
codec.name,
last_err
.map(|e| format!(" (last err: {})", e))
.unwrap_or_default()
);
}
}
Err(_) => {
debug!("Failed to create encoder {}", codec.name);
}
if validate_candidate(&codec, &ctx, &yuv) {
res.push(codec);
}
}
} else {
debug!("Failed to generate dummy YUV data");
}
// Add software encoders as fallback
let soft_codecs = CodecInfo::soft();
// Add H264 software encoder if not already present
if !res.iter().any(|c| c.format == H264) {
if let Some(h264_soft) = soft_codecs.h264 {
debug!("Adding software H264 encoder: {}", h264_soft.name);
res.push(h264_soft);
}
}
// Add H265 software encoder if not already present
if !res.iter().any(|c| c.format == H265) {
if let Some(h265_soft) = soft_codecs.h265 {
debug!("Adding software H265 encoder: {}", h265_soft.name);
res.push(h265_soft);
}
}
// Add VP8 software encoder if not already present
if !res.iter().any(|c| c.format == VP8) {
if let Some(vp8_soft) = soft_codecs.vp8 {
debug!("Adding software VP8 encoder: {}", vp8_soft.name);
res.push(vp8_soft);
}
}
// Add VP9 software encoder if not already present
if !res.iter().any(|c| c.format == VP9) {
if let Some(vp9_soft) = soft_codecs.vp9 {
debug!("Adding software VP9 encoder: {}", vp9_soft.name);
res.push(vp9_soft);
}
}
add_software_fallback(&mut res);
res
}

View File

@@ -3,10 +3,7 @@
#![allow(non_snake_case)]
use crate::common::DataFormat::{self, *};
use crate::ffmpeg::{
AVHWDeviceType::{self, *},
AVPixelFormat,
};
use crate::ffmpeg::AVHWDeviceType::{self, *};
use serde_derive::{Deserialize, Serialize};
use std::ffi::c_int;
@@ -86,6 +83,40 @@ impl Default for CodecInfo {
}
impl CodecInfo {
pub fn software(format: DataFormat) -> Option<Self> {
match format {
H264 => Some(CodecInfo {
name: "libx264".to_owned(),
mc_name: Default::default(),
format: H264,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
H265 => Some(CodecInfo {
name: "libx265".to_owned(),
mc_name: Default::default(),
format: H265,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
VP8 => Some(CodecInfo {
name: "libvpx".to_owned(),
mc_name: Default::default(),
format: VP8,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
VP9 => Some(CodecInfo {
name: "libvpx-vp9".to_owned(),
mc_name: Default::default(),
format: VP9,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
AV1 => None,
}
}
pub fn prioritized(coders: Vec<CodecInfo>) -> CodecInfos {
let mut h264: Option<CodecInfo> = None;
let mut h265: Option<CodecInfo> = None;
@@ -148,34 +179,10 @@ impl CodecInfo {
pub fn soft() -> CodecInfos {
CodecInfos {
h264: Some(CodecInfo {
name: "libx264".to_owned(),
mc_name: Default::default(),
format: H264,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
h265: Some(CodecInfo {
name: "libx265".to_owned(),
mc_name: Default::default(),
format: H265,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
vp8: Some(CodecInfo {
name: "libvpx".to_owned(),
mc_name: Default::default(),
format: VP8,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
vp9: Some(CodecInfo {
name: "libvpx-vp9".to_owned(),
mc_name: Default::default(),
format: VP9,
hwdevice: AV_HWDEVICE_TYPE_NONE,
priority: Priority::Soft as _,
}),
h264: CodecInfo::software(H264),
h265: CodecInfo::software(H265),
vp8: CodecInfo::software(VP8),
vp9: CodecInfo::software(VP9),
av1: None,
}
}
@@ -191,6 +198,23 @@ pub struct CodecInfos {
}
impl CodecInfos {
pub fn into_vec(self) -> Vec<CodecInfo> {
let mut codecs = Vec::new();
if let Some(codec) = self.h264 {
codecs.push(codec);
}
if let Some(codec) = self.h265 {
codecs.push(codec);
}
if let Some(codec) = self.vp8 {
codecs.push(codec);
}
if let Some(codec) = self.vp9 {
codecs.push(codec);
}
codecs
}
pub fn serialize(&self) -> Result<String, ()> {
match serde_json::to_string_pretty(self) {
Ok(s) => Ok(s),
@@ -207,7 +231,7 @@ impl CodecInfos {
}
pub fn ffmpeg_linesize_offset_length(
pixfmt: AVPixelFormat,
pixfmt: i32,
width: usize,
height: usize,
align: usize,
@@ -220,7 +244,7 @@ pub fn ffmpeg_linesize_offset_length(
length.resize(1, 0);
unsafe {
if ffmpeg_ram_get_linesize_offset_length(
pixfmt as _,
pixfmt,
width as _,
height as _,
align as _,

View File

@@ -45,4 +45,4 @@ pub use error::{Result, VentoyError};
pub use exfat::FileInfo;
pub use image::VentoyImage;
pub use partition::{parse_size, PartitionLayout};
pub use resources::{get_resource_dir, init_resources, is_initialized, required_files};
pub use resources::{init_resources, is_initialized, required_files};

View File

@@ -5,7 +5,7 @@
use crate::error::{Result, VentoyError};
use std::fs;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::OnceLock;
/// Resource file names
@@ -151,13 +151,6 @@ pub fn get_ventoy_disk_img() -> Result<&'static [u8]> {
})
}
/// Get the resource directory path for a given data directory
///
/// Returns `{data_dir}/ventoy`
pub fn get_resource_dir(data_dir: &Path) -> PathBuf {
data_dir.join("ventoy")
}
/// List required resource files
pub fn required_files() -> &'static [&'static str] {
&[BOOT_IMG_NAME, CORE_IMG_NAME, VENTOY_DISK_IMG_NAME]
@@ -166,22 +159,6 @@ pub fn required_files() -> &'static [&'static str] {
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn create_test_resources(dir: &Path) {
// Create boot.img (512 bytes)
let mut boot = std::fs::File::create(dir.join(BOOT_IMG_NAME)).unwrap();
boot.write_all(&[0u8; 512]).unwrap();
// Create core.img (fake, 1KB)
let mut core = std::fs::File::create(dir.join(CORE_IMG_NAME)).unwrap();
core.write_all(&[0u8; 1024]).unwrap();
// Create ventoy.disk.img (fake, 1KB)
let mut ventoy = std::fs::File::create(dir.join(VENTOY_DISK_IMG_NAME)).unwrap();
ventoy.write_all(&[0u8; 1024]).unwrap();
}
#[test]
fn test_required_files() {
@@ -191,11 +168,4 @@ mod tests {
assert!(files.contains(&"core.img"));
assert!(files.contains(&"ventoy.disk.img"));
}
#[test]
fn test_get_resource_dir() {
let data_dir = Path::new("/var/lib/one-kvm");
let resource_dir = get_resource_dir(data_dir);
assert_eq!(resource_dir, PathBuf::from("/var/lib/one-kvm/ventoy"));
}
}

View File

@@ -8,5 +8,4 @@ license = "BSD-3-Clause"
[dependencies]
[build-dependencies]
cc = "1.0"
bindgen = "0.59"

View File

@@ -34,10 +34,12 @@ fn generate_bindings(cpp_dir: &Path) {
.allowlist_function("I420Copy")
// I422 conversions
.allowlist_function("I422ToI420")
.allowlist_function("I444ToI420")
// NV12/NV21 conversions
.allowlist_function("NV12ToI420")
.allowlist_function("NV21ToI420")
.allowlist_function("NV12Copy")
.allowlist_function("SplitUVPlane")
// ARGB/BGRA conversions
.allowlist_function("ARGBToI420")
.allowlist_function("ARGBToNV12")
@@ -53,6 +55,7 @@ fn generate_bindings(cpp_dir: &Path) {
// YUV to RGB conversions
.allowlist_function("I420ToRGB24")
.allowlist_function("I420ToARGB")
.allowlist_function("H444ToARGB")
.allowlist_function("NV12ToRGB24")
.allowlist_function("NV12ToARGB")
.allowlist_function("YUY2ToARGB")

View File

@@ -58,6 +58,15 @@ int I422ToI420(const uint8_t* src_y, int src_stride_y,
uint8_t* dst_v, int dst_stride_v,
int width, int height);
// I444 (YUV444P) -> I420 (YUV420P) with horizontal and vertical chroma downsampling
int I444ToI420(const uint8_t* src_y, int src_stride_y,
const uint8_t* src_u, int src_stride_u,
const uint8_t* src_v, int src_stride_v,
uint8_t* dst_y, int dst_stride_y,
uint8_t* dst_u, int dst_stride_u,
uint8_t* dst_v, int dst_stride_v,
int width, int height);
// I420 -> NV12
int I420ToNV12(const uint8_t* src_y, int src_stride_y,
const uint8_t* src_u, int src_stride_u,
@@ -94,6 +103,12 @@ int NV21ToI420(const uint8_t* src_y, int src_stride_y,
uint8_t* dst_v, int dst_stride_v,
int width, int height);
// Split interleaved UV plane into separate U and V planes
void SplitUVPlane(const uint8_t* src_uv, int src_stride_uv,
uint8_t* dst_u, int dst_stride_u,
uint8_t* dst_v, int dst_stride_v,
int width, int height);
// ----------------------------------------------------------------------------
// ARGB/BGRA conversions (32-bit RGB)
// Note: libyuv uses ARGB to mean BGRA in memory (little-endian)
@@ -180,6 +195,13 @@ int I420ToARGB(const uint8_t* src_y, int src_stride_y,
uint8_t* dst_argb, int dst_stride_argb,
int width, int height);
// H444 (BT.709 limited-range YUV444P) -> ARGB (BGRA)
int H444ToARGB(const uint8_t* src_y, int src_stride_y,
const uint8_t* src_u, int src_stride_u,
const uint8_t* src_v, int src_stride_v,
uint8_t* dst_argb, int dst_stride_argb,
int width, int height);
// NV12 -> RGB24
int NV12ToRGB24(const uint8_t* src_y, int src_stride_y,
const uint8_t* src_uv, int src_stride_uv,

View File

@@ -297,6 +297,93 @@ pub fn i422_to_i420_planar(
))
}
/// Convert I444 (YUV444P) to I420 (YUV420P) with separate planes and explicit strides
/// This performs horizontal and vertical chroma downsampling using SIMD
pub fn i444_to_i420_planar(
src_y: &[u8],
src_y_stride: i32,
src_u: &[u8],
src_u_stride: i32,
src_v: &[u8],
src_v_stride: i32,
dst: &mut [u8],
width: i32,
height: i32,
) -> Result<()> {
if width % 2 != 0 || height % 2 != 0 {
return Err(YuvError::InvalidDimensions);
}
let w = width as usize;
let h = height as usize;
let y_size = w * h;
let uv_size = (w / 2) * (h / 2);
if dst.len() < i420_size(w, h) {
return Err(YuvError::BufferTooSmall);
}
call_yuv!(I444ToI420(
src_y.as_ptr(),
src_y_stride,
src_u.as_ptr(),
src_u_stride,
src_v.as_ptr(),
src_v_stride,
dst.as_mut_ptr(),
width,
dst[y_size..].as_mut_ptr(),
width / 2,
dst[y_size + uv_size..].as_mut_ptr(),
width / 2,
width,
height,
))
}
/// Split an interleaved UV plane into separate U and V planes using libyuv SIMD helpers.
///
/// `width` is the number of chroma samples per row, not the number of source bytes.
pub fn split_uv_plane(
src_uv: &[u8],
src_stride_uv: i32,
dst_u: &mut [u8],
dst_stride_u: i32,
dst_v: &mut [u8],
dst_stride_v: i32,
width: i32,
height: i32,
) -> Result<()> {
if width <= 0 || height <= 0 {
return Err(YuvError::InvalidDimensions);
}
let width = width as usize;
let height = height as usize;
let src_required = (src_stride_uv as usize).saturating_mul(height);
let dst_u_required = (dst_stride_u as usize).saturating_mul(height);
let dst_v_required = (dst_stride_v as usize).saturating_mul(height);
if src_uv.len() < src_required || dst_u.len() < dst_u_required || dst_v.len() < dst_v_required {
return Err(YuvError::BufferTooSmall);
}
unsafe {
SplitUVPlane(
src_uv.as_ptr(),
src_stride_uv,
dst_u.as_mut_ptr(),
dst_stride_u,
dst_v.as_mut_ptr(),
dst_stride_v,
width as i32,
height as i32,
);
}
Ok(())
}
// ============================================================================
// I420 <-> NV12 conversions
// ============================================================================
@@ -761,6 +848,41 @@ pub fn i420_to_bgra(src: &[u8], dst: &mut [u8], width: i32, height: i32) -> Resu
))
}
/// Convert H444 (BT.709 limited-range YUV444P) to BGRA.
pub fn h444_to_bgra(
src_y: &[u8],
src_u: &[u8],
src_v: &[u8],
dst: &mut [u8],
width: i32,
height: i32,
) -> Result<()> {
let w = width as usize;
let h = height as usize;
let plane_size = w * h;
if src_y.len() < plane_size || src_u.len() < plane_size || src_v.len() < plane_size {
return Err(YuvError::BufferTooSmall);
}
if dst.len() < argb_size(w, h) {
return Err(YuvError::BufferTooSmall);
}
call_yuv!(H444ToARGB(
src_y.as_ptr(),
width,
src_u.as_ptr(),
width,
src_v.as_ptr(),
width,
dst.as_mut_ptr(),
width * 4,
width,
height,
))
}
/// Convert NV12 to RGB24
pub fn nv12_to_rgb24(src: &[u8], dst: &mut [u8], width: i32, height: i32) -> Result<()> {
if width % 2 != 0 || height % 2 != 0 {

341
scripts/build-update-site.sh Executable file
View File

@@ -0,0 +1,341 @@
#!/usr/bin/env bash
#
# 生成 One-KVM 在线升级静态站点并打包为可部署 tar.gz。
# 输出目录结构:
# <site_name>/v1/channels.json
# <site_name>/v1/releases.json
# <site_name>/v1/bin/<version>/one-kvm-<triple>
#
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
VERSION=""
RELEASE_CHANNEL="stable"
STABLE_VERSION=""
BETA_VERSION=""
PUBLISHED_AT="$(date -u +"%Y-%m-%dT%H:%M:%SZ")"
ARTIFACTS_DIR=""
X86_64_BIN=""
AARCH64_BIN=""
ARMV7_BIN=""
X86_64_SET=0
AARCH64_SET=0
ARMV7_SET=0
SITE_NAME="one-kvm-update"
OUTPUT_FILE=""
OUTPUT_DIR="${PROJECT_ROOT}/dist"
declare -a NOTES=()
usage() {
cat <<'EOF'
Usage:
./scripts/build-update-site.sh --version <x.x.x> [options]
Required:
--version <x.x.x> Release 版本号(如 0.1.10
Artifact input (二选一,可混用):
--artifacts-dir <dir> 自动扫描目录中的标准文件名:
one-kvm-x86_64-unknown-linux-gnu
one-kvm-aarch64-unknown-linux-gnu
one-kvm-armv7-unknown-linux-gnueabihf
--x86_64 <file> 指定 x86_64 二进制路径
--aarch64 <file> 指定 aarch64 二进制路径
--armv7 <file> 指定 armv7 二进制路径
Manifest options:
--release-channel <stable|beta> releases.json 里该版本所属渠道,默认 stable
--stable <x.x.x> channels.json 的 stable 指针,默认等于 --version
--beta <x.x.x> channels.json 的 beta 指针,默认等于 --version
--published-at <RFC3339> 发布时间,默认当前 UTC 时间
--note <text> 发布说明,可重复传入多次
Output options:
--site-name <name> 打包根目录名,默认 one-kvm-update
--output-dir <dir> 输出目录(默认 <repo>/dist
--output <file.tar.gz> 输出包完整路径(优先级高于 --output-dir
Other:
-h, --help 显示帮助
Example:
./scripts/build-update-site.sh \
--version 0.1.10 \
--artifacts-dir ./target/release \
--release-channel stable \
--stable 0.1.10 \
--beta 0.1.11 \
--note "修复 WebRTC 断流问题" \
--note "优化 HID 输入延迟"
EOF
}
fail() {
echo "Error: $*" >&2
exit 1
}
require_cmd() {
local cmd="$1"
command -v "$cmd" >/dev/null 2>&1 || fail "Missing required command: ${cmd}"
}
json_escape() {
local s="$1"
s=${s//\\/\\\\}
s=${s//\"/\\\"}
s=${s//$'\n'/\\n}
s=${s//$'\r'/\\r}
s=${s//$'\t'/\\t}
printf '%s' "$s"
}
is_valid_version() {
local v="$1"
[[ "$v" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]
}
is_valid_channel() {
local c="$1"
[[ "$c" == "stable" || "$c" == "beta" ]]
}
while [[ $# -gt 0 ]]; do
case "$1" in
--version)
VERSION="${2:-}"
shift 2
;;
--release-channel)
RELEASE_CHANNEL="${2:-}"
shift 2
;;
--stable)
STABLE_VERSION="${2:-}"
shift 2
;;
--beta)
BETA_VERSION="${2:-}"
shift 2
;;
--published-at)
PUBLISHED_AT="${2:-}"
shift 2
;;
--note)
NOTES+=("${2:-}")
shift 2
;;
--artifacts-dir)
ARTIFACTS_DIR="${2:-}"
shift 2
;;
--x86_64)
X86_64_BIN="${2:-}"
X86_64_SET=1
shift 2
;;
--aarch64)
AARCH64_BIN="${2:-}"
AARCH64_SET=1
shift 2
;;
--armv7)
ARMV7_BIN="${2:-}"
ARMV7_SET=1
shift 2
;;
--site-name)
SITE_NAME="${2:-}"
shift 2
;;
--output-dir)
OUTPUT_DIR="${2:-}"
shift 2
;;
--output)
OUTPUT_FILE="${2:-}"
shift 2
;;
-h | --help)
usage
exit 0
;;
*)
fail "Unknown argument: $1 (use --help)"
;;
esac
done
require_cmd sha256sum
require_cmd stat
require_cmd tar
require_cmd mktemp
[[ -n "$VERSION" ]] || fail "--version is required"
is_valid_version "$VERSION" || fail "Invalid --version: ${VERSION} (expected x.x.x)"
is_valid_channel "$RELEASE_CHANNEL" || fail "Invalid --release-channel: ${RELEASE_CHANNEL}"
if [[ -z "$STABLE_VERSION" ]]; then
STABLE_VERSION="$VERSION"
fi
if [[ -z "$BETA_VERSION" ]]; then
BETA_VERSION="$VERSION"
fi
is_valid_version "$STABLE_VERSION" || fail "Invalid --stable: ${STABLE_VERSION}"
is_valid_version "$BETA_VERSION" || fail "Invalid --beta: ${BETA_VERSION}"
if [[ -n "$ARTIFACTS_DIR" ]]; then
[[ -d "$ARTIFACTS_DIR" ]] || fail "--artifacts-dir not found: ${ARTIFACTS_DIR}"
[[ -n "$X86_64_BIN" ]] || X86_64_BIN="${ARTIFACTS_DIR}/one-kvm-x86_64-unknown-linux-gnu"
[[ -n "$AARCH64_BIN" ]] || AARCH64_BIN="${ARTIFACTS_DIR}/one-kvm-aarch64-unknown-linux-gnu"
[[ -n "$ARMV7_BIN" ]] || ARMV7_BIN="${ARTIFACTS_DIR}/one-kvm-armv7-unknown-linux-gnueabihf"
fi
if [[ "$X86_64_SET" -eq 1 && ! -f "$X86_64_BIN" ]]; then
fail "--x86_64 file not found: ${X86_64_BIN}"
fi
if [[ "$AARCH64_SET" -eq 1 && ! -f "$AARCH64_BIN" ]]; then
fail "--aarch64 file not found: ${AARCH64_BIN}"
fi
if [[ "$ARMV7_SET" -eq 1 && ! -f "$ARMV7_BIN" ]]; then
fail "--armv7 file not found: ${ARMV7_BIN}"
fi
declare -A SRC_BY_TRIPLE=()
if [[ -n "$X86_64_BIN" && -f "$X86_64_BIN" ]]; then
SRC_BY_TRIPLE["x86_64-unknown-linux-gnu"]="$X86_64_BIN"
fi
if [[ -n "$AARCH64_BIN" && -f "$AARCH64_BIN" ]]; then
SRC_BY_TRIPLE["aarch64-unknown-linux-gnu"]="$AARCH64_BIN"
fi
if [[ -n "$ARMV7_BIN" && -f "$ARMV7_BIN" ]]; then
SRC_BY_TRIPLE["armv7-unknown-linux-gnueabihf"]="$ARMV7_BIN"
fi
if [[ ${#SRC_BY_TRIPLE[@]} -eq 0 ]]; then
fail "No artifact found. Provide --artifacts-dir or at least one of --x86_64/--aarch64/--armv7."
fi
BUILD_DIR="$(mktemp -d)"
trap 'rm -rf "$BUILD_DIR"' EXIT
SITE_DIR="${BUILD_DIR}/${SITE_NAME}"
V1_DIR="${SITE_DIR}/v1"
BIN_DIR="${V1_DIR}/bin/${VERSION}"
mkdir -p "$BIN_DIR"
declare -A SHA_BY_TRIPLE=()
declare -A SIZE_BY_TRIPLE=()
TRIPLES=(
"x86_64-unknown-linux-gnu"
"aarch64-unknown-linux-gnu"
"armv7-unknown-linux-gnueabihf"
)
for triple in "${TRIPLES[@]}"; do
src="${SRC_BY_TRIPLE[$triple]:-}"
if [[ -z "$src" ]]; then
continue
fi
[[ -f "$src" ]] || fail "Artifact not found for ${triple}: ${src}"
dest_name="one-kvm-${triple}"
dest_path="${BIN_DIR}/${dest_name}"
cp "$src" "$dest_path"
sha="$(sha256sum "$dest_path" | awk '{print $1}')"
size="$(stat -c%s "$dest_path")"
SHA_BY_TRIPLE["$triple"]="$sha"
SIZE_BY_TRIPLE["$triple"]="$size"
done
cat >"${V1_DIR}/channels.json" <<EOF
{
"stable": "${STABLE_VERSION}",
"beta": "${BETA_VERSION}"
}
EOF
RELEASES_FILE="${V1_DIR}/releases.json"
{
echo '{'
echo ' "releases": ['
echo ' {'
echo " \"version\": \"${VERSION}\","
echo " \"channel\": \"${RELEASE_CHANNEL}\","
echo " \"published_at\": \"${PUBLISHED_AT}\","
if [[ ${#NOTES[@]} -eq 0 ]]; then
echo ' "notes": [],'
else
echo ' "notes": ['
for i in "${!NOTES[@]}"; do
esc_note="$(json_escape "${NOTES[$i]}")"
if [[ "$i" -lt $((${#NOTES[@]} - 1)) ]]; then
echo " \"${esc_note}\","
else
echo " \"${esc_note}\""
fi
done
echo ' ],'
fi
echo ' "artifacts": {'
written=0
for triple in "${TRIPLES[@]}"; do
if [[ -z "${SHA_BY_TRIPLE[$triple]:-}" ]]; then
continue
fi
url="/v1/bin/${VERSION}/one-kvm-${triple}"
if [[ $written -eq 1 ]]; then
echo ','
fi
cat <<EOF
"${triple}": {
"url": "${url}",
"sha256": "${SHA_BY_TRIPLE[$triple]}",
"size": ${SIZE_BY_TRIPLE[$triple]}
}
EOF
written=1
done
echo
echo ' }'
echo ' }'
echo ' ]'
echo '}'
} >"$RELEASES_FILE"
if [[ -n "$OUTPUT_FILE" ]]; then
if [[ "$OUTPUT_FILE" != /* ]]; then
OUTPUT_FILE="${PROJECT_ROOT}/${OUTPUT_FILE}"
fi
else
mkdir -p "$OUTPUT_DIR"
OUTPUT_FILE="${OUTPUT_DIR}/${SITE_NAME}-${VERSION}.tar.gz"
fi
mkdir -p "$(dirname "$OUTPUT_FILE")"
tar -C "$BUILD_DIR" -czf "$OUTPUT_FILE" "$SITE_NAME"
echo "Build complete:"
echo " package: ${OUTPUT_FILE}"
echo " site root in tar: ${SITE_NAME}/"
echo " release version: ${VERSION}"
echo " release channel: ${RELEASE_CHANNEL}"
echo " channels: stable=${STABLE_VERSION}, beta=${BETA_VERSION}"
echo " artifacts:"
for triple in "${TRIPLES[@]}"; do
if [[ -n "${SHA_BY_TRIPLE[$triple]:-}" ]]; then
echo " - ${triple}: size=${SIZE_BY_TRIPLE[$triple]} sha256=${SHA_BY_TRIPLE[$triple]}"
fi
done
echo
echo "Deploy example:"
echo " tar -xzf \"${OUTPUT_FILE}\" -C /var/www/"
echo " # then ensure nginx root points to /var/www/${SITE_NAME}"

View File

@@ -43,32 +43,93 @@ pub struct AtxController {
}
impl AtxController {
async fn init_components(inner: &mut AtxInner) {
// Initialize power executor
if inner.config.power.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.power.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver, inner.config.power.device, inner.config.power.pin
);
inner.power_executor = Some(executor);
}
}
fn should_share_serial_device(power: &AtxKeyConfig, reset: &AtxKeyConfig) -> bool {
power.is_configured()
&& reset.is_configured()
&& power.driver == super::types::AtxDriverType::Serial
&& reset.driver == super::types::AtxDriverType::Serial
&& !power.device.is_empty()
&& power.device == reset.device
&& power.baud_rate == reset.baud_rate
}
// Initialize reset executor
if inner.config.reset.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.reset.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver, inner.config.reset.device, inner.config.reset.pin
);
inner.reset_executor = Some(executor);
async fn init_components(inner: &mut AtxInner) {
if Self::should_share_serial_device(&inner.config.power, &inner.config.reset) {
match AtxKeyExecutor::open_shared_serial(
&inner.config.power.device,
inner.config.power.baud_rate,
) {
Ok(shared_serial) => {
let mut power_executor = AtxKeyExecutor::new_with_shared_serial(
inner.config.power.clone(),
shared_serial.clone(),
);
if let Err(e) = power_executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver,
inner.config.power.device,
inner.config.power.pin
);
inner.power_executor = Some(power_executor);
}
let mut reset_executor = AtxKeyExecutor::new_with_shared_serial(
inner.config.reset.clone(),
shared_serial,
);
if let Err(e) = reset_executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver,
inner.config.reset.device,
inner.config.reset.pin
);
inner.reset_executor = Some(reset_executor);
}
}
Err(e) => {
warn!(
"Failed to open shared serial device {} for ATX power/reset: {}",
inner.config.power.device, e
);
}
}
} else {
// Initialize power executor
if inner.config.power.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.power.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize power executor: {}", e);
} else {
info!(
"Power executor initialized: {:?} on {} pin {}",
inner.config.power.driver,
inner.config.power.device,
inner.config.power.pin
);
inner.power_executor = Some(executor);
}
}
// Initialize reset executor
if inner.config.reset.is_configured() {
let mut executor = AtxKeyExecutor::new(inner.config.reset.clone());
if let Err(e) = executor.init().await {
warn!("Failed to initialize reset executor: {}", e);
} else {
info!(
"Reset executor initialized: {:?} on {} pin {}",
inner.config.reset.driver,
inner.config.reset.device,
inner.config.reset.pin
);
inner.reset_executor = Some(executor);
}
}
}
@@ -262,3 +323,49 @@ impl AtxController {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::atx::AtxDriverType;
#[test]
fn test_should_share_serial_device_true() {
let power = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 1,
active_level: super::super::types::ActiveLevel::High,
baud_rate: 9600,
};
let reset = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 2,
active_level: super::super::types::ActiveLevel::High,
baud_rate: 9600,
};
assert!(AtxController::should_share_serial_device(&power, &reset));
}
#[test]
fn test_should_share_serial_device_false_on_different_baud() {
let power = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 1,
active_level: super::super::types::ActiveLevel::High,
baud_rate: 9600,
};
let reset = AtxKeyConfig {
driver: AtxDriverType::Serial,
device: "/dev/ttyUSB0".to_string(),
pin: 2,
active_level: super::super::types::ActiveLevel::High,
baud_rate: 115200,
};
assert!(!AtxController::should_share_serial_device(&power, &reset));
}
}

View File

@@ -7,8 +7,9 @@ use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use serialport::SerialPort;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::os::fd::AsRawFd;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, info};
@@ -16,6 +17,12 @@ use tracing::{debug, info};
use super::types::{ActiveLevel, AtxDriverType, AtxKeyConfig};
use crate::error::{AppError, Result};
pub type SharedSerialHandle = Arc<Mutex<Box<dyn SerialPort>>>;
const USB_RELAY_MAX_CHANNEL: u8 = 8;
const USB_RELAY_REPORT_LEN: usize = 9;
const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806; // _IOC(_IOC_READ|_IOC_WRITE, 'H', 0x06, 9)
/// Timing constants for ATX operations
pub mod timing {
use std::time::Duration;
@@ -39,8 +46,8 @@ pub struct AtxKeyExecutor {
gpio_handle: Mutex<Option<LineHandle>>,
/// Cached USB relay file handle to avoid repeated open/close syscalls
usb_relay_handle: Mutex<Option<File>>,
/// Cached Serial port handle
serial_handle: Mutex<Option<Box<dyn SerialPort>>>,
/// Cached Serial port handle (can be shared across power/reset executors)
serial_handle: Mutex<Option<SharedSerialHandle>>,
initialized: AtomicBool,
}
@@ -56,6 +63,26 @@ impl AtxKeyExecutor {
}
}
/// Create a new executor with a pre-opened shared serial handle.
pub fn new_with_shared_serial(config: AtxKeyConfig, serial_handle: SharedSerialHandle) -> Self {
Self {
config,
gpio_handle: Mutex::new(None),
usb_relay_handle: Mutex::new(None),
serial_handle: Mutex::new(Some(serial_handle)),
initialized: AtomicBool::new(false),
}
}
/// Open a serial relay device and wrap it for shared use.
pub fn open_shared_serial(device: &str, baud_rate: u32) -> Result<SharedSerialHandle> {
let port = serialport::new(device, baud_rate)
.timeout(Duration::from_millis(100))
.open()
.map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?;
Ok(Arc::new(Mutex::new(port)))
}
/// Check if this executor is configured
pub fn is_configured(&self) -> bool {
self.config.is_configured()
@@ -107,12 +134,23 @@ impl AtxKeyExecutor {
}
}
AtxDriverType::UsbRelay => {
if self.config.pin == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > u8::MAX as u32 {
return Err(AppError::Config(format!(
"USB relay channel must be <= {}",
u8::MAX
)));
}
if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
}
AtxDriverType::Gpio | AtxDriverType::None => {}
}
@@ -181,14 +219,11 @@ impl AtxKeyExecutor {
self.config.device, self.config.pin
);
let baud_rate = self.config.baud_rate;
let port = serialport::new(&self.config.device, baud_rate)
.timeout(Duration::from_millis(100))
.open()
.map_err(|e| AppError::Internal(format!("Serial port open failed: {}", e)))?;
*self.serial_handle.lock().unwrap() = Some(port);
let existing_handle = self.serial_handle.lock().unwrap().as_ref().cloned();
if existing_handle.is_none() {
let shared = Self::open_shared_serial(&self.config.device, self.config.baud_rate)?;
*self.serial_handle.lock().unwrap() = Some(shared);
}
// Ensure relay is off initially
self.send_serial_relay_command(false)?;
@@ -273,26 +308,64 @@ impl AtxKeyExecutor {
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if channel > USB_RELAY_MAX_CHANNEL {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
// Standard HID relay command format
let cmd = if on {
[0x00, channel + 1, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00]
} else {
[0x00, channel + 1, 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00]
};
let cmd = Self::build_usb_relay_command(channel, on);
let mut guard = self.usb_relay_handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
device
.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("USB relay write failed: {}", e)))?;
if let Err(feature_err) = Self::send_usb_relay_feature_report(device, &cmd) {
debug!(
"USB relay feature report failed ({}), falling back to hidraw write",
feature_err
);
device.write_all(&cmd).map_err(|write_err| {
AppError::Internal(format!(
"USB relay feature report failed: {}; raw write failed: {}",
feature_err, write_err
))
})?;
device
.flush()
.map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?;
}
Ok(())
}
fn build_usb_relay_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] {
let mut cmd = [0x00; USB_RELAY_REPORT_LEN];
cmd[1] = if on { 0xFF } else { 0xFD };
cmd[2] = channel;
cmd
}
fn send_usb_relay_feature_report(
device: &File,
report: &[u8; USB_RELAY_REPORT_LEN],
) -> std::io::Result<()> {
// Linux hidraw feature reports include the report ID as the first byte.
let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) };
if rc < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
/// Pulse Serial relay
async fn pulse_serial(&self, duration: Duration) -> Result<()> {
info!(
@@ -337,13 +410,19 @@ impl AtxKeyExecutor {
// OFF: A0 01 00 A1
let cmd = [0xA0, channel, state, checksum];
let mut guard = self.serial_handle.lock().unwrap();
let port = guard
.as_mut()
let serial_handle = self
.serial_handle
.lock()
.unwrap()
.as_ref()
.cloned()
.ok_or_else(|| AppError::Internal("Serial relay not initialized".to_string()))?;
let mut port = serial_handle.lock().unwrap();
port.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?;
port.flush()
.map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?;
Ok(())
}
@@ -430,7 +509,7 @@ mod tests {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
pin: 1,
active_level: ActiveLevel::High, // Ignored for USB relay
baud_rate: 9600,
};
@@ -458,6 +537,18 @@ mod tests {
assert_eq!(timing::RESET_PRESS.as_millis(), 500);
}
#[test]
fn test_usb_relay_command_format() {
assert_eq!(
AtxKeyExecutor::build_usb_relay_command(1, true),
[0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
AtxKeyExecutor::build_usb_relay_command(1, false),
[0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_zero() {
let config = AtxKeyConfig {
@@ -472,6 +563,34 @@ mod tests {
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_usb_relay_channel_zero() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_usb_relay_channel_overflow() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: USB_RELAY_MAX_CHANNEL as u32 + 1,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_overflow() {
let config = AtxKeyConfig {

View File

@@ -15,6 +15,7 @@
//!
//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control
//! - **USB Relay**: Uses HID USB relay modules for isolated switching
//! - **Serial Relay**: Uses LCUS-style serial relay modules
//!
//! # Example
//!
@@ -59,9 +60,25 @@ pub use types::{
};
pub use wol::send_wol;
fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool {
let upper = uevent.to_ascii_uppercase();
upper.contains("000016C0:000005DF")
|| upper.contains("16C0:05DF")
|| upper.contains("PRODUCT=16C0/5DF")
|| upper.contains("USBRELAY")
|| upper.contains("USB RELAY")
}
fn is_usb_relay_hidraw(name: &str) -> bool {
let uevent_path = format!("/sys/class/hidraw/{}/device/uevent", name);
std::fs::read_to_string(uevent_path)
.map(|uevent| hidraw_uevent_is_usb_relay(&uevent))
.unwrap_or(false)
}
/// Discover available ATX devices on the system
///
/// Scans for GPIO chips and USB HID relay devices in a single pass.
/// Scans for GPIO chips, LCUS USB HID relay devices, and serial relay ports.
pub fn discover_devices() -> AtxDevices {
let mut devices = AtxDevices::default();
@@ -72,7 +89,7 @@ pub fn discover_devices() -> AtxDevices {
let name_str = name.to_string_lossy();
if name_str.starts_with("gpiochip") {
devices.gpio_chips.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("hidraw") {
} else if name_str.starts_with("hidraw") && is_usb_relay_hidraw(&name_str) {
devices.usb_relays.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("ttyUSB") || name_str.starts_with("ttyACM") {
devices.serial_ports.push(format!("/dev/{}", name_str));
@@ -93,11 +110,21 @@ mod tests {
#[test]
fn test_discover_devices() {
let devices = discover_devices();
// Just verify the function runs without error
assert!(devices.gpio_chips.len() >= 0);
assert!(devices.usb_relays.len() >= 0);
assert!(devices.serial_ports.len() >= 0);
let _devices = discover_devices();
}
#[test]
fn test_hidraw_uevent_detects_usb_relay_id() {
assert!(hidraw_uevent_is_usb_relay(
"HID_ID=0003:000016C0:000005DF\nHID_NAME=www.dcttech.com USBRelay2\n"
));
}
#[test]
fn test_hidraw_uevent_rejects_unrelated_hid() {
assert!(!hidraw_uevent_is_usb_relay(
"HID_ID=0003:0000046D:0000C534\nHID_NAME=Logitech USB Receiver\n"
));
}
#[test]

View File

@@ -61,7 +61,7 @@ pub struct AtxKeyConfig {
pub device: String,
/// Pin or channel number:
/// - For GPIO: GPIO pin number
/// - For USB Relay: relay channel (0-based)
/// - For USB Relay: relay channel (1-based)
/// - For Serial Relay (LCUS): relay channel (1-based)
pub pin: u32,
/// Active level (only applicable to GPIO, ignored for USB Relay)

View File

@@ -1,5 +1,3 @@
//! ALSA audio capture implementation
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
@@ -14,30 +12,23 @@ use crate::error::{AppError, Result};
use crate::utils::LogThrottler;
use crate::{error_throttled, warn_throttled};
/// Audio capture configuration
#[derive(Debug, Clone)]
pub struct AudioConfig {
/// ALSA device name (e.g., "hw:0,0" or "default")
pub device_name: String,
/// Sample rate in Hz
pub sample_rate: u32,
/// Number of channels (1 = mono, 2 = stereo)
pub channels: u32,
/// Samples per frame (for Opus, typically 480 for 10ms at 48kHz)
pub frame_size: u32,
/// Buffer size in frames
pub buffer_frames: u32,
/// Period size in frames
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: "default".to_string(),
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960, // 20ms at 48kHz (good for Opus)
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
@@ -45,71 +36,46 @@ impl Default for AudioConfig {
}
impl AudioConfig {
/// Create config for a specific device
pub fn for_device(device: &AudioDeviceInfo) -> Self {
let sample_rate = if device.sample_rates.contains(&48000) {
48000
} else {
*device.sample_rates.first().unwrap_or(&48000)
};
let channels = if device.channels.contains(&2) {
2
} else {
*device.channels.first().unwrap_or(&2)
};
Self {
device_name: device.name.clone(),
sample_rate,
channels,
frame_size: sample_rate / 50, // 20ms
..Default::default()
}
}
/// Bytes per sample (16-bit signed)
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
/// Bytes per frame
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
/// Audio frame data
#[derive(Debug, Clone)]
pub struct AudioFrame {
/// Raw PCM data (S16LE interleaved)
pub data: Bytes,
/// Sample rate
pub sample_rate: u32,
/// Number of channels
pub channels: u32,
/// Number of samples per channel
pub samples: u32,
/// Frame sequence number
pub sequence: u64,
/// Capture timestamp
pub timestamp: Instant,
}
impl AudioFrame {
pub fn new(data: Bytes, config: &AudioConfig, sequence: u64) -> Self {
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
samples: data.len() as u32 / config.bytes_per_sample(),
samples: data.len() as u32 / bps,
data,
sample_rate: config.sample_rate,
channels: config.channels,
sample_rate,
channels,
sequence,
timestamp: Instant::now(),
}
}
}
/// Audio capture state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
@@ -117,7 +83,6 @@ pub enum CaptureState {
Error,
}
/// ALSA audio capturer
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
@@ -126,15 +91,13 @@ pub struct AudioCapturer {
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
/// Log throttler to prevent log flooding
log_throttler: LogThrottler,
}
impl AudioCapturer {
/// Create a new audio capturer
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
@@ -148,28 +111,24 @@ impl AudioCapturer {
}
}
/// Get current state
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
/// Subscribe to audio frames
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
/// Start capturing
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
info!(
debug!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
@@ -184,14 +143,27 @@ impl AudioCapturer {
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
capture_loop(config, state, frame_tx, stop_flag, sequence, log_throttler);
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
/// Stop capturing
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
@@ -204,38 +176,11 @@ impl AudioCapturer {
Ok(())
}
/// Check if running
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
/// Main capture loop
fn capture_loop(
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
log_throttler: LogThrottler,
) {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
@@ -244,7 +189,6 @@ fn run_capture(
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
// Open ALSA device
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
@@ -252,7 +196,6 @@ fn run_capture(
))
})?;
// Configure hardware parameters
{
let hwp = HwParams::any(&pcm)
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
@@ -279,30 +222,46 @@ fn run_capture(
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
}
// Get actual configuration
let actual_rate = pcm
.hw_params_current()
.map(|h| h.get_rate().unwrap_or(config.sample_rate))
.unwrap_or(config.sample_rate);
let hw_now = pcm.hw_params_current().map_err(|e| {
AppError::AudioError(format!("Failed to read hw_params after apply: {}", e))
})?;
let actual_rate = hw_now
.get_rate()
.map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?;
let actual_ch = hw_now
.get_channels()
.map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?;
if actual_rate != 48_000 {
return Err(AppError::AudioError(format!(
"Audio capture requires 48000 Hz; device is {} Hz",
actual_rate
)));
}
if actual_ch != 2 {
return Err(AppError::AudioError(format!(
"Audio capture requires 2 channels (stereo); device has {}",
actual_ch
)));
}
debug!("Audio capture: 48000 Hz, 2 ch");
info!(
"Audio capture configured: {}Hz {}ch (requested {}Hz)",
actual_rate, config.channels, config.sample_rate
);
// Prepare for capture
pcm.prepare()
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
let _ = state.send(CaptureState::Running);
// Allocate buffer - use u8 directly for zero-copy
let frame_bytes = config.bytes_per_frame();
let mut buffer = vec![0u8; frame_bytes];
let period_frames = pcm
.hw_params_current()
.ok()
.and_then(|h| h.get_period_size().ok())
.map(|f| f as usize)
.unwrap_or(1024)
.max(256);
let buf_frames = period_frames.saturating_mul(4).max(2048);
let bytes_per_frame = (config.channels as usize) * 2;
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
// Capture loop
while !stop_flag.load(Ordering::Relaxed) {
// Check PCM state
match pcm.state() {
State::XRun => {
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
@@ -321,9 +280,7 @@ fn run_capture(
_ => {}
}
// Get IO handle and read audio data directly as bytes
// Note: Use io() instead of io_checked() because USB audio devices
// typically don't support mmap, which io_checked() requires
// io_bytes: USB capture often lacks mmap (io_checked requires it).
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
@@ -332,15 +289,16 @@ fn run_capture(
continue;
}
// Calculate actual byte count
let byte_count = frames_read * config.channels as usize * 2;
// Directly use the buffer slice (already in correct byte format)
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame =
AudioFrame::new(Bytes::copy_from_slice(&buffer[..byte_count]), config, seq);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(&buffer[..byte_count]),
config.channels,
48_000,
seq,
);
// Send to subscribers
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
@@ -348,15 +306,15 @@ fn run_capture(
}
}
Err(e) => {
// Check for buffer overrun (EPIPE = 32 on Linux)
let desc = e.to_string();
if desc.contains("EPIPE") || desc.contains("Broken pipe") {
// Buffer overrun
if is_device_lost_error(&desc) {
return Err(AppError::AudioError(format!(
"Audio device lost while reading {}: {}",
config.device_name, e
)));
} else if desc.contains("EPIPE") || desc.contains("Broken pipe") {
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
let _ = pcm.prepare();
} else if desc.contains("No such device") || desc.contains("ENODEV") {
// Device disconnected - use longer throttle for this
error_throttled!(log_throttler, "no_device", "Audio read error: {}", e);
} else {
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
}
@@ -367,3 +325,10 @@ fn run_capture(
info!("Audio capture stopped");
Ok(())
}
fn is_device_lost_error(desc: &str) -> bool {
desc.contains("No such device")
|| desc.contains("ENODEV")
|| desc.contains("ENXIO")
|| desc.contains("ESHUTDOWN")
}

View File

@@ -1,35 +1,38 @@
//! Audio controller for high-level audio management
//!
//! Provides device enumeration, selection, quality control, and streaming management.
//! Device selection, quality presets, streaming.
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::info;
use tracing::{debug, info, warn};
use super::capture::AudioConfig;
use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo};
use super::device::{
enumerate_audio_devices, enumerate_audio_devices_with_current, find_best_audio_device,
AudioDeviceInfo,
};
use super::encoder::{OpusConfig, OpusFrame};
use super::monitor::{AudioHealthMonitor, AudioHealthStatus};
use super::streamer::{AudioStreamer, AudioStreamerConfig};
use super::monitor::AudioHealthMonitor;
use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
use crate::error::{AppError, Result};
use crate::events::{EventBus, SystemEvent};
use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent};
const AUDIO_RECOVERY_RETRY_DELAY: Duration = Duration::from_secs(1);
type AudioRecoveredCallback = Arc<dyn Fn() + Send + Sync>;
/// Audio quality presets
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
/// Low bandwidth voice (32kbps)
Voice,
/// Balanced quality (64kbps) - default
#[default]
Balanced,
/// High quality audio (128kbps)
High,
}
impl AudioQuality {
/// Get the bitrate for this quality level
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
@@ -38,17 +41,6 @@ impl AudioQuality {
}
}
/// Parse from string
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"voice" | "low" => AudioQuality::Voice,
"high" | "music" => AudioQuality::High,
_ => AudioQuality::Balanced,
}
}
/// Convert to OpusConfig
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
@@ -58,6 +50,22 @@ impl AudioQuality {
}
}
impl FromStr for AudioQuality {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"voice" => Ok(Self::Voice),
"balanced" => Ok(Self::Balanced),
"high" => Ok(Self::High),
_ => Err(AppError::BadRequest(format!(
"invalid audio quality {:?} (expected voice, balanced, or high)",
s.trim()
))),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -68,17 +76,10 @@ impl std::fmt::Display for AudioQuality {
}
}
/// Audio controller configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
/// Whether audio is enabled
pub enabled: bool,
/// Selected device name
pub device: String,
/// Audio quality preset
pub quality: AudioQuality,
}
@@ -86,76 +87,347 @@ impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
device: String::new(),
quality: AudioQuality::Balanced,
}
}
}
/// Current audio status
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
/// Whether audio feature is enabled
pub enabled: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Currently selected device
pub device: Option<String>,
/// Current quality preset
pub quality: AudioQuality,
/// Number of connected subscribers
pub subscriber_count: usize,
/// Error message if any
pub error: Option<String>,
}
/// Audio controller
///
/// High-level interface for audio management, providing:
/// - Device enumeration and selection
/// - Quality control
/// - Stream start/stop
/// - Status reporting
pub struct AudioController {
config: RwLock<AudioControllerConfig>,
streamer: RwLock<Option<Arc<AudioStreamer>>>,
devices: RwLock<Vec<AudioDeviceInfo>>,
event_bus: RwLock<Option<Arc<EventBus>>>,
last_error: RwLock<Option<String>>,
/// Health monitor for error tracking and recovery
config: Arc<RwLock<AudioControllerConfig>>,
streamer: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
devices: Arc<RwLock<Vec<AudioDeviceInfo>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
}
impl AudioController {
/// Create a new audio controller with configuration
pub fn new(config: AudioControllerConfig) -> Self {
Self {
config: RwLock::new(config),
streamer: RwLock::new(None),
devices: RwLock::new(Vec::new()),
event_bus: RwLock::new(None),
last_error: RwLock::new(None),
monitor: Arc::new(AudioHealthMonitor::with_defaults()),
config: Arc::new(RwLock::new(config)),
streamer: Arc::new(RwLock::new(None)),
devices: Arc::new(RwLock::new(Vec::new())),
event_bus: Arc::new(RwLock::new(None)),
monitor: Arc::new(AudioHealthMonitor::new()),
recovery_in_progress: Arc::new(AtomicBool::new(false)),
recovered_callback: Arc::new(RwLock::new(None)),
}
}
/// Set event bus for publishing audio events
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus.clone());
// Also set event bus on the monitor for health notifications
self.monitor.set_event_bus(event_bus).await;
*self.event_bus.write().await = Some(event_bus);
}
/// Publish an event to the event bus
async fn publish_event(&self, event: SystemEvent) {
pub async fn set_recovered_callback(&self, callback: Arc<dyn Fn() + Send + Sync>) {
*self.recovered_callback.write().await = Some(callback);
}
async fn mark_device_info_dirty(&self) {
if let Some(ref bus) = *self.event_bus.read().await {
bus.publish(event);
bus.mark_device_info_dirty();
}
}
/// List available audio capture devices
async fn publish_state(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
state: &str,
device: Option<String>,
reason: Option<&str>,
next_retry_ms: Option<u64>,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamStateChanged {
state: state.to_string(),
device,
reason: reason.map(str::to_string),
next_retry_ms,
});
bus.mark_device_info_dirty();
}
}
async fn publish_device_lost(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
reason: &str,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: device.to_string(),
reason: reason.to_string(),
});
}
}
async fn publish_reconnecting(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
attempt: u32,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamReconnecting {
device: device.to_string(),
attempt,
});
}
}
async fn publish_recovered(event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>, device: &str) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamRecovered {
device: device.to_string(),
});
}
}
fn select_recovery_device(
devices: &[AudioDeviceInfo],
preferred: &str,
) -> Option<AudioDeviceInfo> {
if !preferred.trim().is_empty() {
if let Some(device) = devices.iter().find(|d| d.name == preferred) {
return Some(device.clone());
}
}
devices
.iter()
.find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2))
.or_else(|| {
devices
.iter()
.find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2))
})
.or_else(|| devices.first())
.cloned()
}
fn spawn_stream_monitor_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
let mut state_rx = streamer.state_watch();
tokio::spawn(async move {
loop {
if state_rx.changed().await.is_err() {
return;
}
if *state_rx.borrow() != AudioStreamState::Error {
continue;
}
{
let current = streamer_slot.read().await;
if !current
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &streamer))
{
return;
}
}
let reason = format!("Audio device lost: {}", device);
monitor.report_error(&reason, "device_lost").await;
Self::spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
device,
reason,
);
return;
}
});
}
fn spawn_recovery_task_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
if recovery_in_progress.swap(true, Ordering::SeqCst) {
debug!("Audio recovery already in progress");
return;
}
tokio::spawn(async move {
warn!("Audio recovery started for {}: {}", lost_device, reason);
Self::publish_device_lost(&event_bus, &lost_device, &reason).await;
Self::publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_device_lost"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
let mut attempt = 0u32;
loop {
if !recovery_in_progress.load(Ordering::SeqCst) {
debug!("Audio recovery canceled");
return;
}
if streamer_slot
.read()
.await
.as_ref()
.is_some_and(|s| s.is_running())
{
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
let cfg = config.read().await.clone();
if !cfg.enabled {
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
attempt = attempt.saturating_add(1);
Self::publish_reconnecting(&event_bus, &lost_device, attempt).await;
Self::publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_reconnecting"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await;
let devices = match enumerate_audio_devices() {
Ok(devices) => devices,
Err(e) => {
debug!(
"Audio recovery enumerate failed (attempt {}): {}",
attempt, e
);
continue;
}
};
let Some(device) = Self::select_recovery_device(&devices, &cfg.device) else {
debug!("No audio devices found during recovery attempt {}", attempt);
continue;
};
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device.name.clone(),
..Default::default()
},
opus: cfg.quality.to_opus_config(),
};
let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config));
match new_streamer.start().await {
Ok(()) => {
{
let mut cfg = config.write().await;
cfg.device = device.name.clone();
}
*streamer_slot.write().await = Some(new_streamer.clone());
monitor.report_recovered().await;
Self::publish_recovered(&event_bus, &device.name).await;
if let Some(callback) = recovered_callback.read().await.clone() {
callback();
}
Self::publish_state(
&event_bus,
"streaming",
Some(device.name.clone()),
None,
None,
)
.await;
recovery_in_progress.store(false, Ordering::SeqCst);
info!(
"Audio device recovered with {} after {} attempts",
device.name, attempt
);
Self::spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
new_streamer,
device.name,
);
return;
}
Err(e) => {
debug!(
"Audio recovery start failed with {} (attempt {}): {}",
device.name, attempt, e
);
}
}
}
});
}
fn spawn_recovery_task(&self, lost_device: String, reason: String) {
Self::spawn_recovery_task_from_parts(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
self.monitor.clone(),
self.recovery_in_progress.clone(),
self.recovered_callback.clone(),
lost_device,
reason,
);
}
fn spawn_stream_monitor(&self, streamer: Arc<AudioStreamer>, device: String) {
Self::spawn_stream_monitor_from_parts(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
self.monitor.clone(),
self.recovery_in_progress.clone(),
self.recovered_callback.clone(),
streamer,
device,
);
}
pub async fn list_devices(&self) -> Result<Vec<AudioDeviceInfo>> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
@@ -167,55 +439,30 @@ impl AudioController {
Ok(devices)
}
/// Refresh device list and cache it
pub async fn refresh_devices(&self) -> Result<()> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
None
};
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
*self.devices.write().await = devices;
Ok(())
}
/// Get cached device list
pub async fn get_cached_devices(&self) -> Vec<AudioDeviceInfo> {
self.devices.read().await.clone()
}
/// Select audio device
pub async fn select_device(&self, device: &str) -> Result<()> {
// Validate device exists
let devices = self.list_devices().await?;
let found = devices
.iter()
.any(|d| d.name == device || d.description.contains(device));
if !found && device != "default" {
if !found {
return Err(AppError::AudioError(format!(
"Audio device not found: {}",
device
)));
}
// Update config
{
let mut config = self.config.write().await;
config.device = device.to_string();
}
// Publish event
self.publish_event(SystemEvent::AudioDeviceSelected {
device: device.to_string(),
})
.await;
info!("Audio device selected: {}", device);
// If streaming, restart with new device
if self.is_streaming().await {
self.stop_streaming().await?;
self.start_streaming().await?;
@@ -224,25 +471,16 @@ impl AudioController {
Ok(())
}
/// Set audio quality
pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> {
// Update config
{
let mut config = self.config.write().await;
config.quality = quality;
}
// Update streamer if running
if let Some(ref streamer) = *self.streamer.read().await {
streamer.set_bitrate(quality.bitrate()).await?;
}
// Publish event
self.publish_event(SystemEvent::AudioQualityChanged {
quality: quality.to_string(),
})
.await;
info!(
"Audio quality set to: {:?} ({}bps)",
quality,
@@ -251,90 +489,94 @@ impl AudioController {
Ok(())
}
/// Start audio streaming
pub async fn start_streaming(&self) -> Result<()> {
let config = self.config.read().await.clone();
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
{
let config = self.config.read().await;
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
}
}
// Check if already streaming
if self.is_streaming().await {
return Ok(());
}
info!("Starting audio streaming with device: {}", config.device);
// Clear any previous error
*self.last_error.write().await = None;
// Create streamer config (fixed 48kHz stereo)
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: config.device.clone(),
..Default::default()
},
opus: config.quality.to_opus_config(),
let mut select_error = None;
let (device_name, quality) = {
let mut cfg = self.config.write().await;
if cfg.device.trim().is_empty() {
match find_best_audio_device() {
Ok(best) => cfg.device = best.name,
Err(e) => {
select_error = Some(format!("Failed to select audio device: {}", e));
}
}
}
(cfg.device.clone(), cfg.quality)
};
if let Some(error_msg) = select_error {
self.monitor.report_error(&error_msg, "start_failed").await;
self.spawn_recovery_task("auto".to_string(), error_msg.clone());
self.mark_device_info_dirty().await;
return Err(AppError::AudioError(error_msg));
}
debug!("Starting audio streaming with device: {}", device_name);
self.monitor.prepare_retry_attempt();
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device_name.clone(),
..Default::default()
},
opus: quality.to_opus_config(),
};
// Create and start streamer
let streamer = Arc::new(AudioStreamer::with_config(streamer_config));
if let Err(e) = streamer.start().await {
let error_msg = format!("Failed to start audio: {}", e);
*self.last_error.write().await = Some(error_msg.clone());
// Report error to health monitor
self.monitor
.report_error(Some(&config.device), &error_msg, "start_failed")
.await;
self.monitor.report_error(&error_msg, "start_failed").await;
self.spawn_recovery_task(device_name.clone(), error_msg.clone());
self.publish_event(SystemEvent::AudioStateChanged {
streaming: false,
device: None,
})
.await;
self.mark_device_info_dirty().await;
return Err(AppError::AudioError(error_msg));
}
let streamer_for_monitor = streamer.clone();
*self.streamer.write().await = Some(streamer);
self.spawn_stream_monitor(streamer_for_monitor, device_name.clone());
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered(Some(&config.device)).await;
self.monitor.report_recovered().await;
}
// Publish event
self.publish_event(SystemEvent::AudioStateChanged {
streaming: true,
device: Some(config.device),
})
.await;
self.recovery_in_progress.store(false, Ordering::SeqCst);
self.mark_device_info_dirty().await;
info!("Audio streaming started");
Ok(())
}
/// Stop audio streaming
pub async fn stop_streaming(&self) -> Result<()> {
self.recovery_in_progress.store(false, Ordering::SeqCst);
if let Some(streamer) = self.streamer.write().await.take() {
streamer.stop().await?;
}
// Publish event
self.publish_event(SystemEvent::AudioStateChanged {
streaming: false,
device: None,
})
.await;
self.monitor.reset().await;
self.mark_device_info_dirty().await;
info!("Audio streaming stopped");
Ok(())
}
/// Check if currently streaming
pub async fn is_streaming(&self) -> bool {
if let Some(ref streamer) = *self.streamer.read().await {
streamer.is_running()
@@ -343,46 +585,37 @@ impl AudioController {
}
}
/// Get current status
pub async fn status(&self) -> AudioStatus {
let config = self.config.read().await;
let streaming = self.is_streaming().await;
let error = self.last_error.read().await.clone();
let (enabled, device_str, quality) = {
let c = self.config.read().await;
(c.enabled, c.device.clone(), c.quality)
};
let error = self.monitor.error_message().await;
let subscriber_count = if let Some(ref streamer) = *self.streamer.read().await {
streamer.stats().await.subscriber_count
let (streaming, subscriber_count) = if let Some(ref streamer) = *self.streamer.read().await
{
let streaming = streamer.is_running();
let subscriber_count = streamer.stats().subscriber_count;
(streaming, subscriber_count)
} else {
0
(false, 0)
};
AudioStatus {
enabled: config.enabled,
enabled,
streaming,
device: if streaming || config.enabled {
Some(config.device.clone())
device: if streaming || enabled {
Some(device_str)
} else {
None
},
quality: config.quality,
quality,
subscriber_count,
error,
}
}
/// Subscribe to Opus frames (for WebSocket clients)
pub fn subscribe_opus(&self) -> Option<tokio::sync::watch::Receiver<Option<Arc<OpusFrame>>>> {
// Use try_read to avoid blocking - this is called from sync context sometimes
if let Ok(guard) = self.streamer.try_read() {
guard.as_ref().map(|s| s.subscribe_opus())
} else {
None
}
}
/// Subscribe to Opus frames (async version)
pub async fn subscribe_opus_async(
&self,
) -> Option<tokio::sync::watch::Receiver<Option<Arc<OpusFrame>>>> {
pub async fn subscribe_opus(&self) -> Option<tokio::sync::mpsc::Receiver<Arc<OpusFrame>>> {
self.streamer
.read()
.await
@@ -390,7 +623,6 @@ impl AudioController {
.map(|s| s.subscribe_opus())
}
/// Enable or disable audio
pub async fn set_enabled(&self, enabled: bool) -> Result<()> {
{
let mut config = self.config.write().await;
@@ -405,61 +637,25 @@ impl AudioController {
Ok(())
}
/// Update full configuration
pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> {
let was_streaming = self.is_streaming().await;
let old_config = self.config.read().await.clone();
// Stop streaming if running
if was_streaming {
self.stop_streaming().await?;
}
// Update config
*self.config.write().await = new_config.clone();
// Restart streaming if it was running and still enabled
if was_streaming && new_config.enabled {
if new_config.enabled {
self.start_streaming().await?;
}
// Publish events for changes
if old_config.device != new_config.device {
self.publish_event(SystemEvent::AudioDeviceSelected {
device: new_config.device.clone(),
})
.await;
}
if old_config.quality != new_config.quality {
self.publish_event(SystemEvent::AudioQualityChanged {
quality: new_config.quality.to_string(),
})
.await;
}
Ok(())
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
self.stop_streaming().await
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<AudioHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> AudioHealthStatus {
self.monitor.status().await
}
/// Check if the audio is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Default for AudioController {
@@ -481,12 +677,23 @@ mod tests {
#[test]
fn test_audio_quality_from_str() {
assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced);
assert_eq!(AudioQuality::from_str("high"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("music"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced);
assert_eq!(
"voice".parse::<AudioQuality>().unwrap(),
AudioQuality::Voice
);
assert_eq!(
"balanced".parse::<AudioQuality>().unwrap(),
AudioQuality::Balanced
);
assert_eq!("high".parse::<AudioQuality>().unwrap(), AudioQuality::High);
}
#[test]
fn test_audio_quality_from_str_rejects_aliases_and_unknown() {
assert!("low".parse::<AudioQuality>().is_err());
assert!("music".parse::<AudioQuality>().is_err());
assert!("unknown".parse::<AudioQuality>().is_err());
assert!("".parse::<AudioQuality>().is_err());
}
#[tokio::test]

View File

@@ -1,5 +1,3 @@
//! Audio device enumeration using ALSA
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
@@ -7,54 +5,30 @@ use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
/// Audio device information
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
/// Device name (e.g., "hw:0,0" or "default")
pub name: String,
/// Human-readable description
pub description: String,
/// Card index
pub card_index: i32,
/// Device index
pub device_index: i32,
/// Supported sample rates
pub sample_rates: Vec<u32>,
/// Supported channel counts
pub channels: Vec<u32>,
/// Is this a capture device
pub is_capture: bool,
/// Is this an HDMI audio device (likely from capture card)
pub is_hdmi: bool,
/// USB bus info for matching with video devices (e.g., "1-1" from USB path)
pub usb_bus: Option<String>,
}
impl AudioDeviceInfo {
/// Get ALSA device name
pub fn alsa_name(&self) -> String {
format!("hw:{},{}", self.card_index, self.device_index)
}
}
/// Get USB bus info for an audio card by reading sysfs
/// Returns the USB port path like "1-1" or "1-2.3"
fn get_usb_bus_info(card_index: i32) -> Option<String> {
if card_index < 0 {
return None;
}
// Read the device symlink: /sys/class/sound/cardX/device -> ../../usb1/1-1/1-1:1.0
let device_path = format!("/sys/class/sound/card{}/device", card_index);
let link_target = std::fs::read_link(&device_path).ok()?;
let link_str = link_target.to_string_lossy();
// Extract USB port from path like "../../usb1/1-1/1-1:1.0" or "../../1-1/1-1:1.0"
// We want the "1-1" part (USB bus-port)
for component in link_str.split('/') {
// Match patterns like "1-1", "1-2", "1-1.2", "2-1.3.1"
if component.contains('-') && !component.contains(':') {
// Verify it looks like a USB port (starts with digit)
if component
.chars()
.next()
@@ -69,22 +43,15 @@ fn get_usb_bus_info(card_index: i32) -> Option<String> {
None
}
/// Enumerate available audio capture devices
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
/// Enumerate available audio capture devices, with option to include a currently-in-use device
///
/// # Arguments
/// * `current_device` - Optional device name that is currently in use. This device will be
/// included in the list even if it cannot be opened (because it's already open by us).
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
// Try to enumerate cards
let cards = alsa::card::Iter::new();
for card_result in cards {
@@ -102,104 +69,71 @@ pub fn enumerate_audio_devices_with_current(
debug!("Found audio card {}: {}", card_index, card_longname);
// Check if this looks like an HDMI capture device
let is_hdmi = card_longname.to_lowercase().contains("hdmi")
|| card_longname.to_lowercase().contains("capture")
|| card_longname.to_lowercase().contains("usb");
let long_lower = card_longname.to_lowercase();
let is_hdmi = long_lower.contains("hdmi")
|| long_lower.contains("capture")
|| long_lower.contains("usb");
// Get USB bus info for this card
let usb_bus = get_usb_bus_info(card_index);
// Try to open each device on this card for capture
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
// Check if this is the currently-in-use device
let is_current_device = current_device == Some(device_name.as_str());
// Try to open for capture
let mut push_info =
|sample_rates: Vec<u32>, channels: Vec<u32>, description: String| {
devices.push(AudioDeviceInfo {
name: device_name.clone(),
description,
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
};
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
// Query capabilities
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
devices.push(AudioDeviceInfo {
name: device_name,
description: format!("{} - Device {}", card_longname, device_index),
card_index,
device_index,
push_info(
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
format!("{} - Device {}", card_longname, device_index),
);
}
}
Err(_) => {
// Device doesn't exist or can't be opened for capture
// But if it's the current device, include it anyway (it's busy because we're using it)
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
devices.push(AudioDeviceInfo {
name: device_name,
description: format!(
"{} - Device {} (in use)",
card_longname, device_index
),
card_index,
device_index,
// Use common default capabilities for HDMI capture devices
sample_rates: vec![44100, 48000],
channels: vec![2],
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
push_info(
vec![44100, 48000],
vec![2],
format!("{} - Device {} (in use)", card_longname, device_index),
);
}
continue;
}
}
}
}
// Also check for "default" device
if let Ok(pcm) = PCM::new("default", Direction::Capture, false) {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() {
devices.insert(
0,
AudioDeviceInfo {
name: "default".to_string(),
description: "Default Audio Device".to_string(),
card_index: -1,
device_index: -1,
sample_rates,
channels,
is_capture: true,
is_hdmi: false,
usb_bus: None,
},
);
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
/// Query device capabilities
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
// Common sample rates to check
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
@@ -209,7 +143,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
}
}
// Check channel counts
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
@@ -220,8 +153,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
(supported_rates, supported_channels)
}
/// Find the best audio device for capture
/// Prefers HDMI/capture devices over built-in microphones
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
@@ -231,23 +162,24 @@ pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
));
}
// First, look for HDMI/capture card devices that support 48kHz stereo
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if device.is_hdmi && device.sample_rates.contains(&48000) && device.channels.contains(&2) {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
}
// Then look for any device supporting 48kHz stereo
for device in &devices {
if device.sample_rates.contains(&48000) && device.channels.contains(&2) {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
// Fall back to first device
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
@@ -262,10 +194,8 @@ mod tests {
#[test]
fn test_enumerate_devices() {
// This test may not find devices in CI environment
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
// Just verify it doesn't panic
assert!(result.is_ok());
}
}

View File

@@ -1,26 +1,19 @@
//! Opus audio encoder for WebRTC
//! Opus encoder.
use audiopus::coder::GenericCtl;
use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate};
use bytes::Bytes;
use std::time::Instant;
use tracing::info;
use tracing::debug;
use super::capture::AudioFrame;
use crate::error::{AppError, Result};
/// Opus encoder configuration
#[derive(Debug, Clone)]
pub struct OpusConfig {
/// Sample rate (must be 8000, 12000, 16000, 24000, or 48000)
pub sample_rate: u32,
/// Channels (1 or 2)
pub channels: u32,
/// Target bitrate in bps
pub bitrate: u32,
/// Application mode
pub application: OpusApplication,
/// Enable forward error correction
pub fec: bool,
}
@@ -29,7 +22,7 @@ impl Default for OpusConfig {
Self {
sample_rate: 48000,
channels: 2,
bitrate: 64000, // 64 kbps
bitrate: 64000,
application: OpusApplication::Audio,
fec: true,
}
@@ -37,7 +30,6 @@ impl Default for OpusConfig {
}
impl OpusConfig {
/// Create config for voice (lower latency)
pub fn voice() -> Self {
Self {
application: OpusApplication::Voip,
@@ -46,7 +38,6 @@ impl OpusConfig {
}
}
/// Create config for music (higher quality)
pub fn music() -> Self {
Self {
application: OpusApplication::Audio,
@@ -82,30 +73,18 @@ impl OpusConfig {
}
}
/// Opus application mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpusApplication {
/// Voice over IP
Voip,
/// General audio
Audio,
/// Low delay mode
LowDelay,
}
/// Encoded Opus frame
#[derive(Debug, Clone)]
pub struct OpusFrame {
/// Encoded Opus data
pub data: Bytes,
/// Duration in milliseconds
pub duration_ms: u32,
/// Sequence number
pub sequence: u64,
/// Timestamp
pub timestamp: Instant,
/// RTP timestamp (samples)
pub rtp_timestamp: u32,
}
impl OpusFrame {
@@ -118,20 +97,14 @@ impl OpusFrame {
}
}
/// Opus encoder
pub struct OpusEncoder {
config: OpusConfig,
encoder: Encoder,
/// Output buffer
output_buffer: Vec<u8>,
/// Frame counter for RTP timestamp
frame_count: u64,
/// Samples per frame
samples_per_frame: u32,
}
impl OpusEncoder {
/// Create a new Opus encoder
pub fn new(config: OpusConfig) -> Result<Self> {
let sample_rate = config.to_audiopus_sample_rate();
let channels = config.to_audiopus_channels();
@@ -140,7 +113,6 @@ impl OpusEncoder {
let mut encoder = Encoder::new(sample_rate, channels, application)
.map_err(|e| AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)))?;
// Configure encoder
encoder
.set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32))
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
@@ -151,10 +123,7 @@ impl OpusEncoder {
.map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?;
}
// Calculate samples per frame (20ms at sample_rate)
let samples_per_frame = config.sample_rate / 50;
info!(
debug!(
"Opus encoder created: {}Hz {}ch {}bps",
config.sample_rate, config.channels, config.bitrate
);
@@ -162,18 +131,11 @@ impl OpusEncoder {
Ok(Self {
config,
encoder,
output_buffer: vec![0u8; 4000], // Max Opus frame size
output_buffer: vec![0u8; 4000],
frame_count: 0,
samples_per_frame,
})
}
/// Create with default configuration
pub fn default_config() -> Result<Self> {
Self::new(OpusConfig::default())
}
/// Encode PCM audio data (S16LE interleaved)
pub fn encode(&mut self, pcm_data: &[i16]) -> Result<OpusFrame> {
let encoded_len = self
.encoder
@@ -182,7 +144,6 @@ impl OpusEncoder {
let samples = pcm_data.len() as u32 / self.config.channels;
let duration_ms = (samples * 1000) / self.config.sample_rate;
let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32;
self.frame_count += 1;
@@ -190,27 +151,18 @@ impl OpusEncoder {
data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]),
duration_ms,
sequence: self.frame_count - 1,
timestamp: Instant::now(),
rtp_timestamp,
})
}
/// Encode from AudioFrame
///
/// Uses zero-copy conversion from bytes to i16 samples via bytemuck.
pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result<OpusFrame> {
// Zero-copy: directly cast bytes to i16 slice
// AudioFrame.data is S16LE format, which matches native little-endian i16
let samples: &[i16] = bytemuck::cast_slice(&frame.data);
self.encode(samples)
}
/// Get encoder configuration
pub fn config(&self) -> &OpusConfig {
&self.config
}
/// Reset encoder state
pub fn reset(&mut self) -> Result<()> {
self.encoder
.reset_state()
@@ -219,7 +171,6 @@ impl OpusEncoder {
Ok(())
}
/// Set bitrate dynamically
pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
self.encoder
.set_bitrate(Bitrate::BitsPerSecond(bitrate as i32))
@@ -228,15 +179,6 @@ impl OpusEncoder {
}
}
/// Audio encoder statistics
#[derive(Debug, Clone, Default)]
pub struct EncoderStats {
pub frames_encoded: u64,
pub bytes_output: u64,
pub avg_frame_size: usize,
pub current_bitrate: u32,
}
#[cfg(test)]
mod tests {
use super::*;
@@ -261,13 +203,12 @@ mod tests {
let config = OpusConfig::default();
let mut encoder = OpusEncoder::new(config).unwrap();
// 20ms of stereo silence at 48kHz
let silence = vec![0i16; 960 * 2];
let result = encoder.encode(&silence);
assert!(result.is_ok());
let frame = result.unwrap();
assert!(!frame.is_empty());
assert!(frame.len() < silence.len() * 2); // Should be compressed
assert!(frame.len() < silence.len() * 2);
}
}

View File

@@ -1,12 +1,4 @@
//! Audio capture and encoding module
//!
//! This module provides:
//! - ALSA audio capture
//! - Opus encoding for WebRTC
//! - Audio device enumeration
//! - Audio streaming pipeline
//! - High-level audio controller
//! - Device health monitoring
//! ALSA capture, Opus encode, device enumeration, streaming, controller, health monitor.
pub mod capture;
pub mod controller;
@@ -19,5 +11,5 @@ pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus};
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};

View File

@@ -1,129 +1,58 @@
//! Audio device health monitoring
//!
//! This module provides health monitoring for audio capture devices, including:
//! - Device connectivity checks
//! - Automatic reconnection on failure
//! - Error tracking and notification
//! - Log throttling to prevent log flooding
//! Audio device health and logging throttle for repeated failures.
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use tracing::{info, warn};
use crate::events::{EventBus, SystemEvent};
use crate::utils::LogThrottler;
/// Audio health status
const LOG_THROTTLE_SECS: u64 = 5;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum AudioHealthStatus {
/// Device is healthy and operational
#[default]
Healthy,
/// Device has an error, attempting recovery
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
/// Number of recovery attempts made
retry_count: u32,
},
/// Device is disconnected or not available
Disconnected,
}
/// Audio health monitor configuration
#[derive(Debug, Clone)]
pub struct AudioMonitorConfig {
/// Retry interval when device is lost (milliseconds)
pub retry_interval_ms: u64,
/// Maximum retry attempts before giving up (0 = infinite)
pub max_retries: u32,
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for AudioMonitorConfig {
fn default() -> Self {
Self {
retry_interval_ms: 1000,
max_retries: 0, // infinite retry
log_throttle_secs: 5,
}
}
}
/// Audio health monitor
///
/// Monitors audio device health and manages error recovery.
/// Publishes WebSocket events when device status changes.
pub struct AudioHealthMonitor {
/// Current health status
status: RwLock<AudioHealthStatus>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
config: AudioMonitorConfig,
/// Whether monitoring is active (reserved for future use)
#[allow(dead_code)]
running: AtomicBool,
/// Current retry count
retry_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
/// Hide `error_message` while a new capture attempt is in flight (internal error state unchanged).
suppress_display: AtomicBool,
}
impl AudioHealthMonitor {
/// Create a new audio health monitor with the specified configuration
pub fn new(config: AudioMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
pub fn new() -> Self {
Self {
status: RwLock::new(AudioHealthStatus::Healthy),
events: RwLock::new(None),
throttler: LogThrottler::with_secs(throttle_secs),
config,
running: AtomicBool::new(false),
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
retry_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
suppress_display: AtomicBool::new(false),
}
}
/// Create a new audio health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(AudioMonitorConfig::default())
/// Clears the error string exposed via [`Self::error_message`] until the next outcome (`report_error` or recovery).
pub fn prepare_retry_attempt(&self) {
self.suppress_display.store(true, Ordering::Relaxed);
}
/// Set the event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
pub async fn report_error(&self, reason: &str, error_code: &str) {
self.suppress_display.store(false, Ordering::Relaxed);
/// Report an error from audio operations
///
/// This method is called when an audio operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Publishes a WebSocket event if the error is new or changed
///
/// # Arguments
///
/// * `device` - The audio device name (if known)
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, device: Option<&str>, reason: &str, error_code: &str) {
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("audio_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
@@ -132,127 +61,53 @@ impl AudioHealthMonitor {
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = AudioHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
retry_count: count,
};
// Publish event (only if error changed or first occurrence)
if error_changed || count == 1 {
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::AudioDeviceLost {
device: device.map(|s| s.to_string()),
reason: reason.to_string(),
error_code: error_code.to_string(),
});
}
}
}
/// Report that a reconnection attempt is starting
///
/// Publishes a reconnecting event to notify clients.
pub async fn report_reconnecting(&self) {
let attempt = self.retry_count.load(Ordering::Relaxed);
// Only publish every 5 attempts to avoid event spam
if attempt == 1 || attempt.is_multiple_of(5) {
debug!("Audio reconnecting, attempt {}", attempt);
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::AudioReconnecting { attempt });
}
}
}
/// Report that the device has recovered
///
/// This method is called when the audio device successfully reconnects.
/// It resets the error state and publishes a recovery event.
///
/// # Arguments
///
/// * `device` - The audio device name
pub async fn report_recovered(&self, device: Option<&str>) {
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != AudioHealthStatus::Healthy {
let retry_count = self.retry_count.load(Ordering::Relaxed);
info!("Audio recovered after {} retries", retry_count);
// Reset state
self.suppress_display.store(false, Ordering::Relaxed);
self.retry_count.store(0, Ordering::Relaxed);
self.throttler.clear("audio_");
*self.last_error_code.write().await = None;
*self.status.write().await = AudioHealthStatus::Healthy;
// Publish recovery event
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::AudioRecovered {
device: device.map(|s| s.to_string()),
});
}
}
}
/// Get the current health status
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Get the current retry count
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.suppress_display.store(false, Ordering::Relaxed);
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = AudioHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the configuration
pub fn config(&self) -> &AudioMonitorConfig {
&self.config
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Check if we should continue retrying
///
/// Returns `false` if max_retries is set and we've exceeded it.
pub fn should_retry(&self) -> bool {
if self.config.max_retries == 0 {
return true; // Infinite retry
}
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Get the retry interval
pub fn retry_interval(&self) -> Duration {
Duration::from_millis(self.config.retry_interval_ms)
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
if self.suppress_display.load(Ordering::Relaxed) {
return None;
}
match &*self.status.read().await {
AudioHealthStatus::Error { reason, .. } => Some(reason.clone()),
_ => None,
@@ -262,7 +117,7 @@ impl AudioHealthMonitor {
impl Default for AudioHealthMonitor {
fn default() -> Self {
Self::with_defaults()
Self::new()
}
}
@@ -272,32 +127,25 @@ mod tests {
#[tokio::test]
async fn test_initial_status() {
let monitor = AudioHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
let monitor = AudioHealthMonitor::new();
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
monitor
.report_error(Some("hw:0,0"), "Device not found", "device_disconnected")
.report_error("Device not found", "device_disconnected")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.retry_count(), 1);
if let AudioHealthStatus::Error {
reason,
error_code,
retry_count,
} = monitor.status().await
{
if let AudioHealthStatus::Error { reason, error_code } = monitor.status().await {
assert_eq!(reason, "Device not found");
assert_eq!(error_code, "device_disconnected");
assert_eq!(retry_count, 1);
} else {
panic!("Expected Error status");
}
@@ -305,39 +153,52 @@ mod tests {
#[tokio::test]
async fn test_report_recovered() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
// First report an error
monitor
.report_error(Some("default"), "Capture failed", "capture_error")
.report_error("Capture failed", "capture_error")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered(Some("default")).await;
assert!(monitor.is_healthy().await);
monitor.report_recovered().await;
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_retry_count_increments() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
for i in 1..=5 {
monitor.report_error(None, "Error", "io_error").await;
monitor.report_error("Error", "io_error").await;
assert_eq!(monitor.retry_count(), i);
}
}
#[tokio::test]
async fn test_reset() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
monitor.report_error(None, "Error", "io_error").await;
monitor.report_error("Error", "io_error").await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_prepare_retry_hides_error_until_next_failure() {
let monitor = AudioHealthMonitor::new();
monitor.report_error("bad", "e").await;
assert_eq!(monitor.error_message().await.as_deref(), Some("bad"));
monitor.prepare_retry_attempt();
assert!(monitor.is_error().await);
assert!(monitor.error_message().await.is_none());
monitor.report_error("still bad", "e").await;
assert_eq!(monitor.error_message().await.as_deref(), Some("still bad"));
}
}

View File

@@ -1,43 +1,36 @@
//! Audio streaming pipeline
//!
//! Coordinates audio capture and Opus encoding, distributing encoded
//! frames to multiple subscribers via broadcast channel.
//! ALSA 48 kHz stereo → Opus 20 ms frames, fan-out per subscriber.
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex, RwLock};
use tracing::{error, info, warn};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::{broadcast, mpsc, watch, Mutex as AsyncMutex, RwLock};
use tracing::{debug, error, info, warn};
use super::capture::{AudioCapturer, AudioConfig, CaptureState};
use super::capture::{AudioCapturer, AudioConfig, AudioFrame, CaptureState};
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
use crate::error::{AppError, Result};
use bytemuck;
use bytes::Bytes;
use std::time::Duration;
/// 48 kHz stereo: 20 ms = 960 × 2 samples (S16LE).
const OPUS_STEREO_SAMPLES: usize = 960 * 2;
/// Audio stream state
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AudioStreamState {
/// Stream is stopped
#[default]
Stopped,
/// Stream is starting up
Starting,
/// Stream is running
Running,
/// Stream encountered an error
Error,
}
/// Audio streamer configuration
#[derive(Debug, Clone, Default)]
pub struct AudioStreamerConfig {
/// Audio capture configuration
pub capture: AudioConfig,
/// Opus encoder configuration
pub opus: OpusConfig,
}
impl AudioStreamerConfig {
/// Create config for a specific device with default quality
pub fn for_device(device_name: &str) -> Self {
Self {
capture: AudioConfig {
@@ -48,90 +41,75 @@ impl AudioStreamerConfig {
}
}
/// Create config with specified bitrate
pub fn with_bitrate(mut self, bitrate: u32) -> Self {
self.opus.bitrate = bitrate;
self
}
}
/// Audio stream statistics
#[derive(Debug, Clone, Default)]
pub struct AudioStreamStats {
/// Frames encoded to Opus
/// Number of active subscribers
pub subscriber_count: usize,
}
/// Audio streamer
///
/// Manages the audio capture -> encode -> broadcast pipeline.
pub struct AudioStreamer {
config: RwLock<AudioStreamerConfig>,
state: watch::Sender<AudioStreamState>,
state_rx: watch::Receiver<AudioStreamState>,
capturer: RwLock<Option<Arc<AudioCapturer>>>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: watch::Sender<Option<Arc<OpusFrame>>>,
stats: Arc<Mutex<AudioStreamStats>>,
sequence: AtomicU64,
stream_start_time: RwLock<Option<Instant>>,
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
stop_flag: Arc<AtomicBool>,
}
impl AudioStreamer {
/// Create a new audio streamer with default configuration
pub fn new() -> Self {
Self::with_config(AudioStreamerConfig::default())
}
/// Create a new audio streamer with specified configuration
pub fn with_config(config: AudioStreamerConfig) -> Self {
let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped);
let (opus_tx, _opus_rx) = watch::channel(None);
Self {
config: RwLock::new(config),
state: state_tx,
state_rx,
capturer: RwLock::new(None),
encoder: Arc::new(Mutex::new(None)),
opus_tx,
stats: Arc::new(Mutex::new(AudioStreamStats::default())),
sequence: AtomicU64::new(0),
stream_start_time: RwLock::new(None),
encoder: Arc::new(AsyncMutex::new(None)),
opus_subscribers: Arc::new(Mutex::new(Vec::new())),
stop_flag: Arc::new(AtomicBool::new(false)),
}
}
/// Get current state
pub fn state(&self) -> AudioStreamState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<AudioStreamState> {
self.state_rx.clone()
}
/// Subscribe to Opus frames
pub fn subscribe_opus(&self) -> watch::Receiver<Option<Arc<OpusFrame>>> {
self.opus_tx.subscribe()
pub fn subscribe_opus(&self) -> mpsc::Receiver<Arc<OpusFrame>> {
let (tx, rx) = mpsc::channel::<Arc<OpusFrame>>(128);
self.opus_subscribers.lock().unwrap().push(tx);
rx
}
/// Get number of active subscribers
pub fn subscriber_count(&self) -> usize {
self.opus_tx.receiver_count()
self.opus_subscribers
.lock()
.unwrap()
.iter()
.filter(|s| !s.is_closed())
.count()
}
/// Get current statistics
pub async fn stats(&self) -> AudioStreamStats {
let mut stats = self.stats.lock().await.clone();
stats.subscriber_count = self.subscriber_count();
stats
pub fn stats(&self) -> AudioStreamStats {
AudioStreamStats {
subscriber_count: self.subscriber_count(),
}
}
/// Update configuration (only when stopped)
pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> {
if self.state() != AudioStreamState::Stopped {
return Err(AppError::AudioError(
@@ -142,12 +120,9 @@ impl AudioStreamer {
Ok(())
}
/// Update bitrate dynamically (can be done while streaming)
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
// Update config
self.config.write().await.opus.bitrate = bitrate;
// Update encoder if running
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate)?;
}
@@ -156,7 +131,6 @@ impl AudioStreamer {
Ok(())
}
/// Start the audio stream
pub async fn start(&self) -> Result<()> {
if self.state() == AudioStreamState::Running {
return Ok(());
@@ -175,42 +149,77 @@ impl AudioStreamer {
config.opus.bitrate
);
// Create capturer
let capturer = Arc::new(AudioCapturer::new(config.capture.clone()));
*self.capturer.write().await = Some(capturer.clone());
// Create encoder
let encoder = OpusEncoder::new(config.opus.clone())?;
*self.encoder.lock().await = Some(encoder);
// Start capture
capturer.start().await?;
// Reset stats
{
let mut stats = self.stats.lock().await;
*stats = AudioStreamStats::default();
let mut capture_state = capturer.state_watch();
let startup_result = tokio::time::timeout(Duration::from_secs(2), async {
loop {
let current_state = *capture_state.borrow();
match current_state {
CaptureState::Running => return Ok(()),
CaptureState::Error => {
return Err(AppError::AudioError(
"Audio capture failed to start".to_string(),
))
}
CaptureState::Stopped => {
if capture_state.changed().await.is_err() {
return Err(AppError::AudioError(
"Audio capture stopped during startup".to_string(),
));
}
}
}
}
})
.await;
match startup_result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
let _ = capturer.stop().await;
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
let _ = self.state.send(AudioStreamState::Error);
return Err(e);
}
Err(_) => {
let _ = capturer.stop().await;
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
let _ = self.state.send(AudioStreamState::Error);
return Err(AppError::AudioError(
"Timed out waiting for audio capture to start".to_string(),
));
}
}
// Record start time
*self.stream_start_time.write().await = Some(Instant::now());
self.sequence.store(0, Ordering::SeqCst);
// Start encoding task
let capturer_for_task = capturer.clone();
let encoder = self.encoder.clone();
let opus_tx = self.opus_tx.clone();
let opus_subscribers = self.opus_subscribers.clone();
let state = self.state.clone();
let stop_flag = self.stop_flag.clone();
tokio::spawn(async move {
Self::stream_task(capturer_for_task, encoder, opus_tx, state, stop_flag).await;
Self::stream_task(
capturer_for_task,
encoder,
opus_subscribers,
state,
stop_flag,
)
.await;
});
Ok(())
}
/// Stop the audio stream
pub async fn stop(&self) -> Result<()> {
if self.state() == AudioStreamState::Stopped {
return Ok(());
@@ -218,82 +227,120 @@ impl AudioStreamer {
info!("Stopping audio stream");
// Signal stop
self.stop_flag.store(true, Ordering::SeqCst);
// Stop capturer
if let Some(ref capturer) = *self.capturer.read().await {
capturer.stop().await?;
}
// Clear resources
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
*self.stream_start_time.write().await = None;
self.opus_subscribers.lock().unwrap().clear();
let _ = self.state.send(AudioStreamState::Stopped);
info!("Audio stream stopped");
Ok(())
}
/// Check if streaming
pub fn is_running(&self) -> bool {
self.state() == AudioStreamState::Running
}
/// Internal streaming task
async fn fanout_opus(
subscribers: &Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
frame: Arc<OpusFrame>,
) {
let txs: Vec<_> = {
let g = subscribers.lock().unwrap();
if g.is_empty() {
return;
}
g.clone()
};
for tx in &txs {
let _ = tx.send(frame.clone()).await;
}
if txs.iter().any(|tx| tx.is_closed()) {
let mut g = subscribers.lock().unwrap();
g.retain(|tx| !tx.is_closed());
}
}
async fn stream_task(
capturer: Arc<AudioCapturer>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: watch::Sender<Option<Arc<OpusFrame>>>,
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
state: watch::Sender<AudioStreamState>,
stop_flag: Arc<AtomicBool>,
) {
let mut pcm_rx = capturer.subscribe();
let _ = state.send(AudioStreamState::Running);
info!("Audio stream task started");
debug!("Audio stream task started (48 kHz stereo → Opus, mpsc fan-out)");
let mut pending: Vec<i16> = Vec::new();
loop {
// Check stop flag (atomic, no async lock needed)
if stop_flag.load(Ordering::Relaxed) {
break;
}
// Check capturer state
if capturer.state() == CaptureState::Error {
error!("Audio capture error, stopping stream");
let _ = state.send(AudioStreamState::Error);
break;
}
// Receive PCM frame with timeout
let recv_result =
tokio::time::timeout(std::time::Duration::from_secs(2), pcm_rx.recv()).await;
match recv_result {
Ok(Ok(audio_frame)) => {
// Encode to Opus
let opus_result = {
let mut enc_guard = encoder.lock().await;
(*enc_guard)
.as_mut()
.map(|enc| enc.encode_frame(&audio_frame))
};
if audio_frame.sample_rate != 48_000 || audio_frame.channels != 2 {
warn!(
"Skip non48 kHz/stereo PCM ({} Hz, {} ch)",
audio_frame.sample_rate, audio_frame.channels
);
continue;
}
match opus_result {
Some(Ok(opus_frame)) => {
// Publish latest frame to subscribers
if opus_tx.receiver_count() > 0 {
let _ = opus_tx.send(Some(Arc::new(opus_frame)));
let samples: &[i16] = match bytemuck::try_cast_slice(&audio_frame.data) {
Ok(s) => s,
Err(_) => {
warn!("Audio frame size not multiple of 2; skipping");
continue;
}
};
if !samples.is_empty() {
pending.extend_from_slice(samples);
}
while pending.len() >= OPUS_STEREO_SAMPLES {
let pcm_20ms = Bytes::copy_from_slice(bytemuck::cast_slice(
&pending[..OPUS_STEREO_SAMPLES],
));
pending.drain(..OPUS_STEREO_SAMPLES);
let frame_48k = AudioFrame::new_interleaved(pcm_20ms, 2, 48_000, 0);
let opus_result = {
let mut enc_guard = encoder.lock().await;
(*enc_guard)
.as_mut()
.map(|enc| enc.encode_frame(&frame_48k))
};
match opus_result {
Some(Ok(opus_frame)) => {
Self::fanout_opus(&opus_subscribers, Arc::new(opus_frame)).await;
}
Some(Err(e)) => {
error!("Opus encode error: {}", e);
}
None => {
warn!("Encoder not available");
break;
}
}
Some(Err(e)) => {
error!("Opus encode error: {}", e);
}
None => {
warn!("Encoder not available");
break;
}
}
}
@@ -302,19 +349,23 @@ impl AudioStreamer {
break;
}
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
warn!("Audio receiver lagged by {} frames", n);
warn!("PCM receiver lagged by {} frames", n);
}
Err(_) => {
// Timeout - check if still capturing
if capturer.state() != CaptureState::Running {
info!("Audio capture stopped, ending stream task");
let _ = state.send(AudioStreamState::Error);
break;
}
}
}
}
let _ = state.send(AudioStreamState::Stopped);
if stop_flag.load(Ordering::Relaxed) {
let _ = state.send(AudioStreamState::Stopped);
} else {
opus_subscribers.lock().unwrap().clear();
}
info!("Audio stream task ended");
}
}

View File

@@ -8,20 +8,16 @@ use axum::{
use axum_extra::extract::CookieJar;
use std::sync::Arc;
use crate::error::ErrorResponse;
use crate::state::AppState;
use crate::web::ErrorResponse;
/// Session cookie name
pub const SESSION_COOKIE: &str = "one_kvm_session";
/// Extract session ID from request
pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option<String> {
// First try cookie
if let Some(cookie) = cookies.get(SESSION_COOKIE) {
return Some(cookie.value().to_string());
}
// Then try Authorization header (Bearer token)
if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
@@ -33,7 +29,6 @@ pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap)
None
}
/// Authentication middleware
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
cookies: CookieJar,
@@ -41,29 +36,23 @@ pub async fn auth_middleware(
next: Next,
) -> Result<Response, StatusCode> {
let raw_path = request.uri().path();
// When this middleware is mounted under /api, Axum strips the prefix for the inner router.
// Normalize the path so checks work whether it is mounted or not.
// Mounted under /api: inner path may lack prefix; normalize for whitelist checks.
let path = raw_path.strip_prefix("/api").unwrap_or(raw_path);
// Check if system is initialized
if !state.config.is_initialized() {
// Allow only setup-related endpoints when not initialized
if is_setup_public_endpoint(path) {
return Ok(next.run(request).await);
}
}
// Public endpoints that don't require auth
if is_public_endpoint(path) {
return Ok(next.run(request).await);
}
// Extract session ID
let session_id = extract_session_id(&cookies, request.headers());
if let Some(session_id) = session_id {
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
// Add session to request extensions
request.extensions_mut().insert(session);
return Ok(next.run(request).await);
}
@@ -87,9 +76,7 @@ fn unauthorized_response(message: &str) -> Response {
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
}
/// Check if endpoint is public (no auth required)
fn is_public_endpoint(path: &str) -> bool {
// Note: paths here are relative to /api since middleware is applied within the nested router
matches!(
path,
"/" | "/auth/login" | "/health" | "/setup" | "/setup/init"
@@ -102,7 +89,6 @@ fn is_public_endpoint(path: &str) -> bool {
|| path.ends_with(".svg")
}
/// Setup-only endpoints allowed before initialization.
fn is_setup_public_endpoint(path: &str) -> bool {
matches!(
path,

View File

@@ -5,7 +5,6 @@ use argon2::{
use crate::error::{AppError, Result};
/// Hash a password using Argon2
pub fn hash_password(password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
@@ -16,7 +15,6 @@ pub fn hash_password(password: &str) -> Result<String> {
.map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e)))
}
/// Verify a password against a hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?;

View File

@@ -1,147 +1,104 @@
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use std::collections::HashMap;
use std::sync::Arc;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::error::Result;
/// Session data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user_id: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub expires_at: OffsetDateTime,
pub data: Option<serde_json::Value>,
}
impl Session {
/// Check if session is expired
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
OffsetDateTime::now_utc() > self.expires_at
}
}
/// Session store backed by SQLite
#[derive(Clone)]
pub struct SessionStore {
pool: Pool<Sqlite>,
inner: Arc<RwLock<HashMap<String, Session>>>,
default_ttl: Duration,
}
impl SessionStore {
/// Create a new session store
pub fn new(pool: Pool<Sqlite>, ttl_secs: i64) -> Self {
pub fn new(ttl_secs: i64) -> Self {
Self {
pool,
inner: Arc::new(RwLock::new(HashMap::new())),
default_ttl: Duration::seconds(ttl_secs),
}
}
/// Create a new session
pub async fn create(&self, user_id: &str) -> Result<Session> {
let now = OffsetDateTime::now_utc();
let session = Session {
id: Uuid::new_v4().to_string(),
user_id: user_id.to_string(),
created_at: Utc::now(),
expires_at: Utc::now() + self.default_ttl,
created_at: now,
expires_at: now + self.default_ttl,
data: None,
};
sqlx::query(
r#"
INSERT INTO sessions (id, user_id, created_at, expires_at, data)
VALUES (?1, ?2, ?3, ?4, ?5)
"#,
)
.bind(&session.id)
.bind(&session.user_id)
.bind(session.created_at.to_rfc3339())
.bind(session.expires_at.to_rfc3339())
.bind(session.data.as_ref().map(|d| d.to_string()))
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
guard.insert(session.id.clone(), session.clone());
Ok(session)
}
/// Get a session by ID
pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
let row: Option<(String, String, String, String, Option<String>)> = sqlx::query_as(
"SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
match row {
Some((id, user_id, created_at, expires_at, data)) => {
let session = Session {
id,
user_id,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
expires_at: DateTime::parse_from_rfc3339(&expires_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
data: data.and_then(|d| serde_json::from_str(&d).ok()),
};
if session.is_expired() {
self.delete(&session.id).await?;
Ok(None)
} else {
Ok(Some(session))
}
}
None => Ok(None),
let mut guard = self.inner.write().await;
let Some(session) = guard.get(session_id).cloned() else {
return Ok(None);
};
if session.is_expired() {
guard.remove(session_id);
return Ok(None);
}
Ok(Some(session))
}
/// Delete a session
pub async fn delete(&self, session_id: &str) -> Result<()> {
sqlx::query("DELETE FROM sessions WHERE id = ?1")
.bind(session_id)
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
guard.remove(session_id);
Ok(())
}
/// Delete all expired sessions
pub async fn cleanup_expired(&self) -> Result<u64> {
let now = Utc::now().to_rfc3339();
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?1")
.bind(now)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
let mut guard = self.inner.write().await;
let before = guard.len();
guard.retain(|_, s| !s.is_expired());
Ok((before - guard.len()) as u64)
}
/// Delete all sessions
pub async fn delete_all(&self) -> Result<u64> {
let result = sqlx::query("DELETE FROM sessions")
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
let mut guard = self.inner.write().await;
let n = guard.len() as u64;
guard.clear();
Ok(n)
}
/// List all session IDs
pub async fn list_ids(&self) -> Result<Vec<String>> {
let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions")
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(|(id,)| id).collect())
let guard = self.inner.read().await;
Ok(guard.keys().cloned().collect())
}
/// Extend session expiration
pub async fn extend(&self, session_id: &str) -> Result<()> {
let new_expires = Utc::now() + self.default_ttl;
sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2")
.bind(new_expires.to_rfc3339())
.bind(session_id)
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
if let Some(session) = guard.get_mut(session_id) {
if session.is_expired() {
guard.remove(session_id);
} else {
session.expires_at = OffsetDateTime::now_utc() + self.default_ttl;
}
}
Ok(())
}
}

View File

@@ -1,123 +1,99 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;
use uuid::Uuid;
use super::password::{hash_password, verify_password};
use crate::error::{AppError, Result};
/// User row type from database
type UserRow = (String, String, String, String, String);
type UserRow = (String, String, String);
/// User data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
#[serde(skip_serializing)]
pub password_hash: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl User {
/// Convert from database row to User
fn from_row(row: UserRow) -> Self {
let (id, username, password_hash, created_at, updated_at) = row;
let (id, username, password_hash) = row;
Self {
id,
username,
password_hash,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
updated_at: DateTime::parse_from_rfc3339(&updated_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
}
}
}
/// User store backed by SQLite
#[derive(Clone)]
pub struct UserStore {
pool: Pool<Sqlite>,
}
impl UserStore {
/// Create a new user store
pub fn new(pool: Pool<Sqlite>) -> Self {
Self { pool }
}
/// Create a new user
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
// Check if username already exists
if self.get_by_username(username).await?.is_some() {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
username
)));
/// The single local user, or `None` if none exists. Errors if more than one row is present.
pub async fn single_user(&self) -> Result<Option<User>> {
let mut rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash FROM users ORDER BY rowid ASC LIMIT 2",
)
.fetch_all(&self.pool)
.await?;
match rows.len() {
0 => Ok(None),
1 => Ok(Some(User::from_row(rows.remove(0)))),
_ => Err(AppError::Internal(
"Multiple user accounts in database; this build supports only one".to_string(),
)),
}
}
pub async fn create_first_user(&self, username: &str, password: &str) -> Result<User> {
if self.single_user().await?.is_some() {
return Err(AppError::BadRequest(
"A user account already exists".to_string(),
));
}
let password_hash = hash_password(password)?;
let now = Utc::now();
let user = User {
id: Uuid::new_v4().to_string(),
username: username.to_string(),
password_hash,
created_at: now,
updated_at: now,
};
sqlx::query(
r#"
INSERT INTO users (id, username, password_hash, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)
INSERT INTO users (id, username, password_hash)
VALUES (?1, ?2, ?3)
"#,
)
.bind(&user.id)
.bind(&user.username)
.bind(&user.password_hash)
.bind(user.created_at.to_rfc3339())
.bind(user.updated_at.to_rfc3339())
.execute(&self.pool)
.await?;
Ok(user)
}
/// Get user by ID
pub async fn get(&self, user_id: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE id = ?1",
)
.bind(user_id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Get user by username
pub async fn get_by_username(&self, username: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE username = ?1",
)
.bind(username)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Verify user credentials
pub async fn verify(&self, username: &str, password: &str) -> Result<Option<User>> {
let user = match self.get_by_username(username).await? {
Some(user) => user,
let user = match self.single_user().await? {
Some(u) => u,
None => return Ok(None),
};
if user.username != username {
return Ok(None);
}
if verify_password(password, &user.password_hash)? {
Ok(Some(user))
} else {
@@ -125,15 +101,23 @@ impl UserStore {
}
}
/// Update user password
pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> {
let user = self
.single_user()
.await?
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
if user.id != user_id {
return Err(AppError::AuthError("Invalid session".to_string()));
}
let password_hash = hash_password(new_password)?;
let now = Utc::now();
let now = OffsetDateTime::now_utc();
let result =
sqlx::query("UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
.bind(&password_hash)
.bind(now.to_rfc3339())
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
.bind(user_id)
.execute(&self.pool)
.await?;
@@ -145,21 +129,24 @@ impl UserStore {
Ok(())
}
/// Update username
pub async fn update_username(&self, user_id: &str, new_username: &str) -> Result<()> {
if let Some(existing) = self.get_by_username(new_username).await? {
if existing.id != user_id {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
new_username
)));
}
let user = self
.single_user()
.await?
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
if user.id != user_id {
return Err(AppError::AuthError("Invalid session".to_string()));
}
let now = Utc::now();
if new_username == user.username {
return Ok(());
}
let now = OffsetDateTime::now_utc();
let result = sqlx::query("UPDATE users SET username = ?1, updated_at = ?2 WHERE id = ?3")
.bind(new_username)
.bind(now.to_rfc3339())
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
.bind(user_id)
.execute(&self.pool)
.await?;
@@ -170,37 +157,4 @@ impl UserStore {
Ok(())
}
/// List all users
pub async fn list(&self) -> Result<Vec<User>> {
let rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users ORDER BY created_at",
)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(User::from_row).collect())
}
/// Delete user by ID
pub async fn delete(&self, user_id: &str) -> Result<()> {
let result = sqlx::query("DELETE FROM users WHERE id = ?1")
.bind(user_id)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("User not found".to_string()));
}
Ok(())
}
/// Check if any users exist
pub async fn has_users(&self) -> Result<bool> {
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
.fetch_one(&self.pool)
.await?;
Ok(count.0 > 0)
}
}

View File

@@ -1,5 +1,7 @@
mod persistence;
mod schema;
mod store;
pub use persistence::ConfigChange;
pub use schema::*;
pub use store::ConfigStore;

View File

@@ -0,0 +1,5 @@
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}

View File

@@ -1,12 +1,85 @@
use crate::video::encoder::BitratePreset;
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
// Re-export ExtensionsConfig from extensions module
// Re-export domain config types that are embedded in AppConfig.
// These are simple data types defined in their respective modules;
// keeping the re-export here is acceptable since they flow inward.
pub use crate::extensions::ExtensionsConfig;
// Re-export RustDeskConfig from rustdesk module
pub use crate::rustdesk::config::RustDeskConfig;
/// Bitrate preset for video encoding
///
/// Simplifies bitrate configuration by providing three intuitive presets
/// plus a custom option for advanced users.
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
#[derive(Default)]
pub enum BitratePreset {
/// Speed priority: 1 Mbps, lowest latency, smaller GOP
Speed,
/// Balanced: 4 Mbps, good quality/latency tradeoff
#[default]
Balanced,
/// Quality priority: 8 Mbps, best visual quality
Quality,
/// Custom bitrate in kbps (for advanced users)
Custom(u32),
}
impl BitratePreset {
/// Get bitrate value in kbps
pub fn bitrate_kbps(&self) -> u32 {
match self {
Self::Speed => 1000,
Self::Balanced => 4000,
Self::Quality => 8000,
Self::Custom(kbps) => *kbps,
}
}
/// Get recommended GOP size based on preset
pub fn gop_size(&self, fps: u32) -> u32 {
match self {
Self::Speed => (fps / 2).max(15),
Self::Balanced => fps,
Self::Quality => fps * 2,
Self::Custom(_) => fps,
}
}
/// Get quality preset name for encoder configuration
pub fn quality_level(&self) -> &'static str {
match self {
Self::Speed => "low",
Self::Balanced => "medium",
Self::Quality => "high",
Self::Custom(_) => "medium",
}
}
/// Create from kbps value, mapping to nearest preset or Custom
pub fn from_kbps(kbps: u32) -> Self {
match kbps {
0..=1500 => Self::Speed,
1501..=6000 => Self::Balanced,
6001..=10000 => Self::Quality,
_ => Self::Custom(kbps),
}
}
}
impl std::fmt::Display for BitratePreset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Speed => write!(f, "Speed (1 Mbps)"),
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
Self::Quality => write!(f, "Quality (8 Mbps)"),
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
}
}
}
/// Main application configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -148,13 +221,11 @@ impl Default for OtgDescriptorConfig {
pub enum OtgHidProfile {
/// Full HID device set (keyboard + relative mouse + absolute mouse + consumer control)
#[default]
#[serde(alias = "full_no_msd")]
Full,
/// Full HID device set without MSD
FullNoMsd,
/// Full HID device set without consumer control
#[serde(alias = "full_no_consumer_no_msd")]
FullNoConsumer,
/// Full HID device set without consumer control and MSD
FullNoConsumerNoMsd,
/// Legacy profile: only keyboard
LegacyKeyboard,
/// Legacy profile: only relative mouse
@@ -163,9 +234,38 @@ pub enum OtgHidProfile {
Custom,
}
/// OTG endpoint budget policy.
#[typeshare]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OtgEndpointBudget {
/// Derive a safe default from the selected UDC.
#[default]
Auto,
/// Limit OTG gadget functions to 5 endpoints.
Five,
/// Limit OTG gadget functions to 6 endpoints.
Six,
/// Do not impose a software endpoint budget.
Unlimited,
}
impl OtgEndpointBudget {
/// Resolve endpoint limit assuming a known budget variant (not Auto).
pub fn endpoint_limit_raw(&self) -> Option<u8> {
match self {
Self::Five => Some(5),
Self::Six => Some(6),
Self::Unlimited => None,
Self::Auto => None, // resolved via `HidConfig::resolved_otg_endpoint_limit`
}
}
}
/// OTG HID function selection (used when profile is Custom)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(default)]
pub struct OtgHidFunctions {
pub keyboard: bool,
@@ -214,6 +314,26 @@ impl OtgHidFunctions {
pub fn is_empty(&self) -> bool {
!self.keyboard && !self.mouse_relative && !self.mouse_absolute && !self.consumer
}
pub fn endpoint_cost(&self, keyboard_leds: bool) -> u8 {
let mut endpoints = 0;
if self.keyboard {
endpoints += 1;
if keyboard_leds {
endpoints += 1;
}
}
if self.mouse_relative {
endpoints += 1;
}
if self.mouse_absolute {
endpoints += 1;
}
if self.consumer {
endpoints += 1;
}
endpoints
}
}
impl Default for OtgHidFunctions {
@@ -223,12 +343,21 @@ impl Default for OtgHidFunctions {
}
impl OtgHidProfile {
pub fn from_legacy_str(value: &str) -> Option<Self> {
match value {
"full" | "full_no_msd" => Some(Self::Full),
"full_no_consumer" | "full_no_consumer_no_msd" => Some(Self::FullNoConsumer),
"legacy_keyboard" => Some(Self::LegacyKeyboard),
"legacy_mouse_relative" => Some(Self::LegacyMouseRelative),
"custom" => Some(Self::Custom),
_ => None,
}
}
pub fn resolve_functions(&self, custom: &OtgHidFunctions) -> OtgHidFunctions {
match self {
Self::Full => OtgHidFunctions::full(),
Self::FullNoMsd => OtgHidFunctions::full(),
Self::FullNoConsumer => OtgHidFunctions::full_no_consumer(),
Self::FullNoConsumerNoMsd => OtgHidFunctions::full_no_consumer(),
Self::LegacyKeyboard => OtgHidFunctions::legacy_keyboard(),
Self::LegacyMouseRelative => OtgHidFunctions::legacy_mouse_relative(),
Self::Custom => custom.clone(),
@@ -243,10 +372,6 @@ impl OtgHidProfile {
pub struct HidConfig {
/// HID backend type
pub backend: HidBackend,
/// OTG keyboard device path
pub otg_keyboard: String,
/// OTG mouse device path
pub otg_mouse: String,
/// OTG UDC (USB Device Controller) name
pub otg_udc: Option<String>,
/// OTG USB device descriptor configuration
@@ -255,9 +380,15 @@ pub struct HidConfig {
/// OTG HID function profile
#[serde(default)]
pub otg_profile: OtgHidProfile,
/// OTG endpoint budget policy
#[serde(default)]
pub otg_endpoint_budget: OtgEndpointBudget,
/// OTG HID function selection (used when profile is Custom)
#[serde(default)]
pub otg_functions: OtgHidFunctions,
/// Enable keyboard LED/status feedback for OTG keyboard
#[serde(default)]
pub otg_keyboard_leds: bool,
/// CH9329 serial port
pub ch9329_port: String,
/// CH9329 baud rate
@@ -270,12 +401,12 @@ impl Default for HidConfig {
fn default() -> Self {
Self {
backend: HidBackend::None,
otg_keyboard: "/dev/hidg0".to_string(),
otg_mouse: "/dev/hidg1".to_string(),
otg_udc: None,
otg_descriptor: OtgDescriptorConfig::default(),
otg_profile: OtgHidProfile::default(),
otg_endpoint_budget: OtgEndpointBudget::default(),
otg_functions: OtgHidFunctions::default(),
otg_keyboard_leds: false,
ch9329_port: "/dev/ttyUSB0".to_string(),
ch9329_baudrate: 9600,
mouse_absolute: true,
@@ -284,9 +415,92 @@ impl Default for HidConfig {
}
impl HidConfig {
/// Resolve effective OTG HID functions from profile + custom selection.
/// Pure logic, no external dependency.
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
self.otg_profile.resolve_functions(&self.otg_functions)
}
/// Whether keyboard LED feedback is effectively enabled.
pub fn effective_otg_keyboard_leds(&self) -> bool {
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
}
/// Effective HID functions after applying all constraints.
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
self.effective_otg_functions()
}
/// Calculate required endpoint count for the current function selection.
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
let functions = self.effective_otg_functions();
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
if msd_enabled {
endpoints += 2;
}
endpoints
}
/// Validate endpoint budget for the current OTG configuration (UDC-aware when budget is Auto).
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
if self.backend != HidBackend::Otg {
return Ok(());
}
let functions = self.effective_otg_functions();
if functions.is_empty() {
return Err(crate::error::AppError::BadRequest(
"OTG HID functions cannot be empty".to_string(),
));
}
let resolved_limit = self.resolved_otg_endpoint_limit();
let required = self.effective_otg_required_endpoints(msd_enabled);
if let Some(limit) = resolved_limit {
if required > limit {
return Err(crate::error::AppError::BadRequest(format!(
"OTG selection requires {} endpoints, but the configured limit is {}",
required, limit
)));
}
}
Ok(())
}
/// Effective OTG UDC name (for change detection / service).
#[inline]
pub fn resolved_otg_udc(&self) -> Option<String> {
if self.backend != HidBackend::Otg {
return None;
}
self.otg_udc
.as_ref()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.or_else(|| crate::otg::OtgGadgetManager::find_udc())
}
/// Resolved endpoint limit used for OTG gadget allocator / validation.
#[inline]
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
if self.backend != HidBackend::Otg {
return None;
}
match self.otg_endpoint_budget {
OtgEndpointBudget::Five => Some(5),
OtgEndpointBudget::Six => Some(6),
OtgEndpointBudget::Unlimited => None,
OtgEndpointBudget::Auto => {
let udc = self.resolved_otg_udc().unwrap_or_default();
if crate::otg::configfs::is_low_endpoint_udc(&udc) {
Some(5)
} else {
Some(6)
}
}
}
}
}
/// MSD configuration
@@ -383,7 +597,7 @@ impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
device: String::new(),
quality: "balanced".to_string(),
}
}
@@ -478,21 +692,6 @@ pub enum EncoderType {
}
impl EncoderType {
/// Convert to EncoderBackend for registry queries
pub fn to_backend(&self) -> Option<crate::video::encoder::registry::EncoderBackend> {
use crate::video::encoder::registry::EncoderBackend;
match self {
EncoderType::Auto => None,
EncoderType::Software => Some(EncoderBackend::Software),
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
EncoderType::Qsv => Some(EncoderBackend::Qsv),
EncoderType::Amf => Some(EncoderBackend::Amf),
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
}
}
/// Get display name for UI
pub fn display_name(&self) -> &'static str {
match self {
@@ -559,23 +758,23 @@ impl Default for StreamConfig {
}
impl StreamConfig {
/// Check if using public ICE servers (user left fields empty)
/// Whether built-in / public ICE is used (no custom STUN or TURN URL configured).
pub fn is_using_public_ice_servers(&self) -> bool {
use crate::webrtc::config::public_ice;
self.stun_server
let no_custom_stun = self
.stun_server
.as_ref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& self
.turn_server
.as_ref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& public_ice::is_configured()
.map_or(true, |s| s.trim().is_empty());
let no_custom_turn = self
.turn_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
no_custom_stun && no_custom_turn
}
}
/// Web server configuration
/// Web server configuration persisted in the database (includes on-disk TLS paths).
///
/// The HTTP API for `/api/config/web` uses `WebConfigResponse` instead: no path fields, includes `has_custom_cert`.
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]

View File

@@ -1,10 +1,10 @@
use arc_swap::ArcSwap;
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use sqlx::{Pool, Sqlite};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use super::persistence::ConfigChange;
use super::AppConfig;
use crate::error::{AppError, Result};
@@ -18,128 +18,27 @@ pub struct ConfigStore {
/// Lock-free cache using ArcSwap for zero-cost reads
cache: Arc<ArcSwap<AppConfig>>,
change_tx: broadcast::Sender<ConfigChange>,
}
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
/// Serializes `set` / `update` so concurrent PATCH handlers cannot clobber each other
write_lock: Arc<Mutex<()>>,
}
impl ConfigStore {
/// Create a new configuration store
pub async fn new(db_path: &Path) -> Result<Self> {
// Ensure parent directory exists
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
// SQLite uses single-writer mode, 2 connections is sufficient for embedded devices
// One for reads, one for writes to avoid blocking
.max_connections(2)
// Set reasonable timeouts for embedded environments
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
// Initialize database schema
Self::init_schema(&pool).await?;
// Load or create default config
let config = Self::load_config(&pool).await?;
let cache = Arc::new(ArcSwap::from_pointee(config));
let (change_tx, _) = broadcast::channel(16);
pub fn new(pool: Pool<Sqlite>) -> Result<Self> {
// Load or create default config synchronously wrapper
// (actual DB load is async, handled in init())
Ok(Self {
pool,
cache,
change_tx,
cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())),
change_tx: broadcast::channel(16).0,
write_lock: Arc::new(Mutex::new(())),
})
}
/// Initialize database schema
async fn init_schema(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
expires_at TEXT NOT NULL,
data TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS wol_history (
mac_address TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
ON wol_history(updated_at DESC)
"#,
)
.execute(pool)
.await?;
/// Load configuration from database (call after new())
pub async fn load(&self) -> Result<()> {
let config = Self::load_config(&self.pool).await?;
self.cache.store(Arc::new(config));
Ok(())
}
@@ -191,6 +90,7 @@ impl ConfigStore {
/// Set entire configuration
pub async fn set(&self, config: AppConfig) -> Result<()> {
let _guard = self.write_lock.lock().await;
Self::save_config_to_db(&self.pool, &config).await?;
self.cache.store(Arc::new(config));
@@ -204,13 +104,13 @@ impl ConfigStore {
/// Update configuration with a closure
///
/// Note: This uses a read-modify-write pattern. For concurrent updates,
/// the last write wins. This is acceptable for configuration changes
/// which are infrequent and typically user-initiated.
/// Uses read-modify-write under a mutex so concurrent `update` / `set` calls are serialized
/// and merged correctly (each closure sees the latest stored config).
pub async fn update<F>(&self, f: F) -> Result<()>
where
F: FnOnce(&mut AppConfig),
{
let _guard = self.write_lock.lock().await;
// Load current config, clone it for modification
let current = self.cache.load();
let mut config = (**current).clone();
@@ -239,16 +139,12 @@ impl ConfigStore {
pub fn is_initialized(&self) -> bool {
self.cache.load().initialized
}
/// Get database pool for session management
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::DatabasePool;
use tempfile::tempdir;
#[tokio::test]
@@ -256,7 +152,11 @@ mod tests {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let store = ConfigStore::new(&db_path).await.unwrap();
let db = DatabasePool::new(&db_path).await.unwrap();
db.init_schema().await.unwrap();
let store = ConfigStore::new(db.clone_pool()).unwrap();
store.load().await.unwrap();
// Check default config (now lock-free, no await needed)
let config = store.get();
@@ -277,7 +177,8 @@ mod tests {
assert_eq!(config.web.http_port, 9000);
// Create new store instance and verify persistence
let store2 = ConfigStore::new(&db_path).await.unwrap();
let store2 = ConfigStore::new(db.clone_pool()).unwrap();
store2.load().await.unwrap();
let config = store2.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);

3
src/db/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
mod pool;
pub use pool::DatabasePool;

119
src/db/pool.rs Normal file
View File

@@ -0,0 +1,119 @@
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use std::time::Duration;
use crate::error::Result;
#[derive(Clone)]
pub struct DatabasePool {
pool: Pool<Sqlite>,
}
impl DatabasePool {
pub async fn new(db_path: &Path) -> Result<Self> {
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
.max_connections(4)
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
Ok(Self { pool })
}
pub async fn init_schema(&self) -> Result<()> {
self.create_config_table().await?;
self.create_users_table().await?;
self.create_api_tokens_table().await?;
self.create_wol_history_table().await?;
Ok(())
}
async fn create_config_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_users_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_api_tokens_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_wol_history_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS wol_history (
mac_address TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
ON wol_history(updated_at DESC)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
pub fn clone_pool(&self) -> Pool<Sqlite> {
self.pool.clone()
}
}

View File

@@ -1,12 +1,5 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use thiserror::Error;
/// Application-wide error type
#[derive(Error, Debug)]
pub enum AppError {
#[error("Authentication failed: {0}")]
@@ -15,17 +8,14 @@ pub enum AppError {
#[error("Not authenticated")]
Unauthorized,
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Persistence error: {0}")]
Persistence(String),
#[error("Internal error: {0}")]
Internal(String),
@@ -42,8 +32,9 @@ pub enum AppError {
#[error("Video error: {0}")]
VideoError(String),
#[error("Video device lost [{device}]: {reason}")]
VideoDeviceLost { device: String, reason: String },
/// No input signal while opening capture; `kind` is `SignalStatus` as string (`from_str`).
#[error("Capture has no valid signal: {kind}")]
CaptureNoSignal { kind: String },
#[error("Audio error: {0}")]
AudioError(String),
@@ -62,37 +53,10 @@ pub enum AppError {
ServiceUnavailable(String),
}
/// Error response body (unified success format)
#[derive(Serialize)]
pub struct ErrorResponse {
pub success: bool,
pub message: String,
}
impl AppError {
fn status_code(&self) -> StatusCode {
// Always return 200 OK - success/failure is indicated by the success field
StatusCode::OK
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let status = self.status_code();
let body = ErrorResponse {
success: false,
message: self.to_string(),
};
tracing::error!(
error_type = std::any::type_name_of_val(&self),
error_message = %body.message,
"Request failed"
);
(status, Json(body)).into_response()
}
}
/// Result type alias for handlers
pub type Result<T> = std::result::Result<T, AppError>;
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
AppError::Persistence(err.to_string())
}
}

View File

@@ -1,78 +1,108 @@
//! Event system for real-time state notifications
//!
//! This module provides a global event bus for broadcasting system events
//! to WebSocket clients and other subscribers.
//! Event bus: [`SystemEvent`] fan-out to WebSocket subscribers and internal tasks.
pub mod types;
use self::types::EXACT_EVENT_TOPICS;
pub use types::{
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent,
VideoDeviceInfo,
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, LedState, MsdDeviceInfo,
StreamDeviceLostKind, SystemEvent, TtydDeviceInfo, VideoDeviceInfo,
};
use tokio::sync::broadcast;
/// Event channel capacity (ring buffer size)
const EVENT_CHANNEL_CAPACITY: usize = 256;
/// Global event bus for broadcasting system events
///
/// The event bus uses tokio's broadcast channel to distribute events
/// to multiple subscribers. Events are delivered to all active subscribers.
///
/// # Example
///
/// ```no_run
/// use one_kvm::events::{EventBus, SystemEvent};
///
/// let bus = EventBus::new();
///
/// // Publish an event
/// bus.publish(SystemEvent::StreamStateChanged {
/// state: "streaming".to_string(),
/// device: Some("/dev/video0".to_string()),
/// });
///
/// // Subscribe to events
/// let mut rx = bus.subscribe();
/// tokio::spawn(async move {
/// while let Ok(event) = rx.recv().await {
/// println!("Received event: {:?}", event);
/// }
/// });
/// ```
fn collect_prefix_wildcards(exact: &[&'static str]) -> Vec<String> {
use std::collections::BTreeSet;
let mut segments = BTreeSet::new();
for name in exact {
if let Some((seg, _)) = name.split_once('.') {
segments.insert(seg);
}
}
segments.into_iter().map(|s| format!("{}.*", s)).collect()
}
fn make_sender() -> broadcast::Sender<SystemEvent> {
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
tx
}
fn topic_prefix(event_name: &str) -> Option<String> {
event_name
.split_once('.')
.map(|(prefix, _)| format!("{}.*", prefix))
}
pub struct EventBus {
tx: broadcast::Sender<SystemEvent>,
exact_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
prefix_topics: std::collections::HashMap<String, broadcast::Sender<SystemEvent>>,
device_info_dirty_tx: broadcast::Sender<()>,
}
impl EventBus {
/// Create a new event bus
pub fn new() -> Self {
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
Self { tx }
let tx = make_sender();
let exact_topics = EXACT_EVENT_TOPICS
.iter()
.map(|topic| (*topic, make_sender()))
.collect();
let prefix_topics = collect_prefix_wildcards(EXACT_EVENT_TOPICS)
.into_iter()
.map(|topic| (topic, make_sender()))
.collect();
let (device_info_dirty_tx, _dirty_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
Self {
tx,
exact_topics,
prefix_topics,
device_info_dirty_tx,
}
}
/// Publish an event to all subscribers
///
/// If there are no active subscribers, the event is silently dropped.
/// This is by design - events are fire-and-forget notifications.
pub fn publish(&self, event: SystemEvent) {
// If no subscribers, send returns Err which is normal
let event_name = event.event_name();
if let Some(tx) = self.exact_topics.get(event_name) {
let _ = tx.send(event.clone());
}
if let Some(prefix) = topic_prefix(event_name) {
if let Some(tx) = self.prefix_topics.get(&prefix) {
let _ = tx.send(event.clone());
}
}
let _ = self.tx.send(event);
}
/// Subscribe to events
///
/// Returns a receiver that will receive all future events.
/// The receiver uses a ring buffer, so if a subscriber falls too far
/// behind, it will receive a `Lagged` error and miss some events.
pub fn subscribe(&self) -> broadcast::Receiver<SystemEvent> {
self.tx.subscribe()
}
/// Get the current number of active subscribers
///
/// Useful for monitoring and debugging.
pub fn subscribe_topic(&self, topic: &str) -> Option<broadcast::Receiver<SystemEvent>> {
if topic == "*" {
return Some(self.tx.subscribe());
}
if topic.ends_with(".*") {
return self.prefix_topics.get(topic).map(|tx| tx.subscribe());
}
self.exact_topics.get(topic).map(|tx| tx.subscribe())
}
pub fn mark_device_info_dirty(&self) {
let _ = self.device_info_dirty_tx.send(());
}
pub fn subscribe_device_info_dirty(&self) -> broadcast::Receiver<()> {
self.device_info_dirty_tx.subscribe()
}
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
@@ -96,6 +126,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -110,17 +142,56 @@ mod tests {
assert_eq!(bus.subscriber_count(), 2);
bus.publish(SystemEvent::SystemError {
module: "test".to_string(),
severity: "info".to_string(),
message: "test message".to_string(),
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
});
let event1 = rx1.recv().await.unwrap();
let event2 = rx2.recv().await.unwrap();
assert!(matches!(event1, SystemEvent::SystemError { .. }));
assert!(matches!(event2, SystemEvent::SystemError { .. }));
assert!(matches!(event1, SystemEvent::StreamStateChanged { .. }));
assert!(matches!(event2, SystemEvent::StreamStateChanged { .. }));
}
#[tokio::test]
async fn test_subscribe_topic_exact() {
let bus = EventBus::new();
let mut rx = bus.subscribe_topic("stream.state_changed").unwrap();
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
assert!(matches!(event, SystemEvent::StreamStateChanged { .. }));
}
#[tokio::test]
async fn test_subscribe_topic_prefix() {
let bus = EventBus::new();
let mut rx = bus.subscribe_topic("stream.*").unwrap();
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
assert!(matches!(event, SystemEvent::StreamStateChanged { .. }));
}
#[test]
fn test_subscribe_topic_unknown() {
let bus = EventBus::new();
assert!(bus.subscribe_topic("unknown.topic").is_none());
}
#[test]
@@ -128,11 +199,11 @@ mod tests {
let bus = EventBus::new();
assert_eq!(bus.subscriber_count(), 0);
// Should not panic when publishing with no subscribers
bus.publish(SystemEvent::SystemError {
module: "test".to_string(),
severity: "info".to_string(),
message: "test".to_string(),
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
}
}

View File

@@ -1,548 +1,234 @@
//! System event types
//!
//! Defines all event types that can be broadcast through the event bus.
//! [`SystemEvent`] and device snapshot types (WebSocket / JSON).
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::atx::PowerStatus;
use crate::msd::MsdMode;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct LedState {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
pub compose: bool,
pub kana: bool,
}
// ============================================================================
// Device Info Structures (for system.device_info event)
// ============================================================================
/// Video device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoDeviceInfo {
/// Whether video device is available
pub available: bool,
/// Device path (e.g., /dev/video0)
pub device: Option<String>,
/// Pixel format (e.g., "MJPEG", "YUYV")
pub format: Option<String>,
/// Resolution (width, height)
pub resolution: Option<(u32, u32)>,
/// Frames per second
pub fps: u32,
/// Whether stream is currently active
pub online: bool,
/// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
pub stream_mode: String,
/// Whether video config is currently being changed (frontend should skip mode sync)
pub config_changing: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// HID device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HidDeviceInfo {
/// Whether HID backend is available
pub available: bool,
/// Backend type: "otg", "ch9329", "none"
pub backend: String,
/// Whether backend is initialized and ready
pub initialized: bool,
/// Whether absolute mouse positioning is supported
pub online: bool,
pub supports_absolute_mouse: bool,
/// Device path (e.g., serial port for CH9329)
pub keyboard_leds_enabled: bool,
pub led_state: LedState,
pub device: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
pub error_code: Option<String>,
}
/// MSD device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdDeviceInfo {
/// Whether MSD is available
pub available: bool,
/// Operating mode: "none", "image", "drive"
pub mode: String,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image ID
pub image_id: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// ATX device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDeviceInfo {
/// Whether ATX controller is available
pub available: bool,
/// Backend type: "gpio", "usb_relay", "none"
pub backend: String,
/// Whether backend is initialized
pub initialized: bool,
/// Whether power is currently on
pub power_on: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// Audio device information
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioDeviceInfo {
/// Whether audio is enabled/available
pub available: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Current audio device name
pub device: Option<String>,
/// Quality preset: "voice", "balanced", "high"
pub quality: String,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// Per-client statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydDeviceInfo {
pub available: bool,
pub running: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientStats {
/// Client ID
pub id: String,
/// Current FPS for this client (frames sent in last second)
pub fps: u32,
/// Connected duration (seconds)
pub connected_secs: u64,
}
/// System event enumeration
///
/// All events are tagged with their event name for serialization.
/// The `serde(tag = "event", content = "data")` attribute creates a
/// JSON structure like:
/// ```json
/// {
/// "event": "stream.state_changed",
/// "data": { "state": "streaming", "device": "/dev/video0" }
/// }
/// ```
/// Video vs audio source for [`SystemEvent::StreamDeviceLost`] (WebSocket `stream.device_lost`).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamDeviceLostKind {
Video,
Audio,
}
/// JSON: `{"event": "<name>", "data": { ... }}`.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event", content = "data")]
#[allow(clippy::large_enum_variant)]
pub enum SystemEvent {
// ============================================================================
// Video Stream Events
// ============================================================================
/// Stream mode switching started (transactional, correlates all following events)
///
/// Sent immediately after a mode switch request is accepted.
/// Clients can use `transition_id` to correlate subsequent `stream.*` events.
#[serde(rename = "stream.mode_switching")]
StreamModeSwitching {
/// Unique transition ID for this mode switch transaction
transition_id: String,
/// Target mode: "mjpeg", "h264", "h265", "vp8", "vp9"
to_mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", "vp9"
from_mode: String,
},
/// Stream state changed (e.g., started, stopped, error)
#[serde(rename = "stream.state_changed")]
StreamStateChanged {
/// Current state: "uninitialized", "ready", "streaming", "no_signal", "error"
state: String,
/// Device path if available
device: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
next_retry_ms: Option<u64>,
},
/// Stream configuration is being changed
///
/// Sent before applying new configuration to notify clients that
/// the stream will be interrupted temporarily.
#[serde(rename = "stream.config_changing")]
StreamConfigChanging {
/// Optional transition ID if this config change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Reason for change: "device_switch", "resolution_change", "format_change"
reason: String,
},
/// Stream configuration has been applied successfully
///
/// Sent after new configuration is active. Clients can reconnect now.
#[serde(rename = "stream.config_applied")]
StreamConfigApplied {
/// Optional transition ID if this config change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Device path
device: String,
/// Resolution (width, height)
resolution: (u32, u32),
/// Pixel format: "mjpeg", "yuyv", etc.
format: String,
/// Frames per second
fps: u32,
},
/// Stream device was lost (disconnected or error)
#[serde(rename = "stream.device_lost")]
StreamDeviceLost {
/// Device path that was lost
kind: StreamDeviceLostKind,
device: String,
/// Reason for loss
reason: String,
},
/// Stream device is reconnecting
#[serde(rename = "stream.reconnecting")]
StreamReconnecting {
/// Device path being reconnected
device: String,
/// Retry attempt number
attempt: u32,
},
StreamReconnecting { device: String, attempt: u32 },
/// Stream device has recovered
#[serde(rename = "stream.recovered")]
StreamRecovered {
/// Device path that was recovered
device: String,
},
StreamRecovered { device: String },
/// WebRTC is ready to accept connections
///
/// Sent after video frame source is connected to WebRTC pipeline.
/// Clients should wait for this event before attempting to create WebRTC sessions.
#[serde(rename = "stream.webrtc_ready")]
WebRTCReady {
/// Optional transition ID if this readiness is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Current video codec
codec: String,
/// Whether hardware encoding is being used
hardware: bool,
},
/// WebRTC ICE candidate (server -> client trickle)
#[serde(rename = "webrtc.ice_candidate")]
WebRTCIceCandidate {
/// WebRTC session ID
session_id: String,
/// ICE candidate data
candidate: crate::webrtc::signaling::IceCandidate,
candidate: serde_json::Value,
},
/// WebRTC ICE gathering complete (server -> client)
#[serde(rename = "webrtc.ice_complete")]
WebRTCIceComplete {
/// WebRTC session ID
session_id: String,
},
WebRTCIceComplete { session_id: String },
/// Stream statistics update (sent periodically for client stats)
#[serde(rename = "stream.stats_update")]
StreamStatsUpdate {
/// Number of connected clients
clients: u64,
/// Per-client statistics (client_id -> client stats)
/// Each client's FPS reflects the actual frames sent in the last second
clients_stat: HashMap<String, ClientStats>,
},
/// Stream mode changed (MJPEG <-> WebRTC)
///
/// Sent when the streaming mode is switched. Clients should disconnect
/// from the current stream and reconnect using the new mode.
#[serde(rename = "stream.mode_changed")]
StreamModeChanged {
/// Optional transition ID if this change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
previous_mode: String,
},
/// Stream mode switching completed (transactional end marker)
///
/// Sent when the backend considers the new mode ready for clients to connect.
#[serde(rename = "stream.mode_ready")]
StreamModeReady {
/// Unique transition ID for this mode switch transaction
transition_id: String,
/// Active mode after switch: "mjpeg", "h264", "h265", "vp8", "vp9"
mode: String,
},
StreamModeReady { transition_id: String, mode: String },
// ============================================================================
// HID Events
// ============================================================================
/// HID backend state changed
#[serde(rename = "hid.state_changed")]
HidStateChanged {
/// Backend type: "otg", "ch9329", "none"
backend: String,
/// Whether backend is initialized and ready
initialized: bool,
/// Error message if any, None if OK
error: Option<String>,
/// Error code for programmatic handling: "epipe", "eagain", "port_not_found", etc.
error_code: Option<String>,
},
/// HID backend is being switched
#[serde(rename = "hid.backend_switching")]
HidBackendSwitching {
/// Current backend
from: String,
/// New backend
to: String,
},
/// HID device lost (device file missing or I/O error)
#[serde(rename = "hid.device_lost")]
HidDeviceLost {
/// Backend type: "otg", "ch9329"
backend: String,
/// Device path that was lost (e.g., /dev/hidg0 or /dev/ttyUSB0)
device: Option<String>,
/// Human-readable reason for loss
reason: String,
/// Error code: "epipe", "eshutdown", "eagain", "enxio", "port_not_found", "io_error"
error_code: String,
},
/// HID device is reconnecting
#[serde(rename = "hid.reconnecting")]
HidReconnecting {
/// Backend type: "otg", "ch9329"
backend: String,
/// Current retry attempt number
attempt: u32,
},
/// HID device has recovered after error
#[serde(rename = "hid.recovered")]
HidRecovered {
/// Backend type: "otg", "ch9329"
backend: String,
},
// ============================================================================
// MSD (Mass Storage Device) Events
// ============================================================================
/// MSD state changed
#[serde(rename = "msd.state_changed")]
MsdStateChanged {
/// Operating mode
mode: MsdMode,
/// Whether storage is connected to target
connected: bool,
},
/// Image has been mounted
#[serde(rename = "msd.image_mounted")]
MsdImageMounted {
/// Image ID
image_id: String,
/// Image filename
image_name: String,
/// Image size in bytes
size: u64,
/// Mount as CD-ROM (read-only)
cdrom: bool,
},
/// Image has been unmounted
#[serde(rename = "msd.image_unmounted")]
MsdImageUnmounted,
/// File upload progress (for large file uploads)
#[serde(rename = "msd.upload_progress")]
MsdUploadProgress {
/// Upload operation ID
upload_id: String,
/// Filename being uploaded
filename: String,
/// Bytes uploaded so far
bytes_uploaded: u64,
/// Total file size
total_bytes: u64,
/// Progress percentage (0.0 - 100.0)
progress_pct: f32,
},
/// Image download progress (for URL downloads)
#[serde(rename = "msd.download_progress")]
MsdDownloadProgress {
/// Download operation ID
download_id: String,
/// Source URL
url: String,
/// Target filename
filename: String,
/// Bytes downloaded so far
bytes_downloaded: u64,
/// Total file size (None if unknown)
total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
progress_pct: Option<f32>,
/// Download status: "started", "in_progress", "completed", "failed"
status: String,
},
/// USB gadget connection status changed (host connected/disconnected)
#[serde(rename = "msd.usb_status_changed")]
MsdUsbStatusChanged {
/// Whether host is connected to USB device
connected: bool,
/// USB device state from kernel (e.g., "configured", "not attached")
device_state: String,
},
/// MSD operation error (configfs, image mount, etc.)
#[serde(rename = "msd.error")]
MsdError {
/// Human-readable reason for error
reason: String,
/// Error code: "configfs_error", "image_not_found", "mount_failed", "io_error"
error_code: String,
},
/// MSD has recovered after error
#[serde(rename = "msd.recovered")]
MsdRecovered,
// ============================================================================
// ATX (Power Control) Events
// ============================================================================
/// ATX power state changed
#[serde(rename = "atx.state_changed")]
AtxStateChanged {
/// Power status
power_status: PowerStatus,
},
/// ATX action was executed
#[serde(rename = "atx.action_executed")]
AtxActionExecuted {
/// Action: "short", "long", "reset"
action: String,
/// When the action was executed
timestamp: DateTime<Utc>,
},
// ============================================================================
// Audio Events
// ============================================================================
/// Audio state changed (streaming started/stopped)
#[serde(rename = "audio.state_changed")]
AudioStateChanged {
/// Whether audio is currently streaming
streaming: bool,
/// Current device (None if stopped)
device: Option<String>,
},
/// Audio device was selected
#[serde(rename = "audio.device_selected")]
AudioDeviceSelected {
/// Selected device name
device: String,
},
/// Audio quality was changed
#[serde(rename = "audio.quality_changed")]
AudioQualityChanged {
/// New quality setting: "voice", "balanced", "high"
quality: String,
},
/// Audio device lost (capture error or device disconnected)
#[serde(rename = "audio.device_lost")]
AudioDeviceLost {
/// Audio device name (e.g., "hw:0,0")
device: Option<String>,
/// Human-readable reason for loss
reason: String,
/// Error code: "device_busy", "device_disconnected", "capture_error", "io_error"
error_code: String,
},
/// Audio device is reconnecting
#[serde(rename = "audio.reconnecting")]
AudioReconnecting {
/// Current retry attempt number
attempt: u32,
},
/// Audio device has recovered after error
#[serde(rename = "audio.recovered")]
AudioRecovered {
/// Audio device name
device: Option<String>,
},
// ============================================================================
// System Events
// ============================================================================
/// A device was added (hot-plug)
#[serde(rename = "system.device_added")]
SystemDeviceAdded {
/// Device type: "video", "audio", "hid", etc.
device_type: String,
/// Device path
device_path: String,
/// Device name/description
device_name: String,
},
/// A device was removed (hot-unplug)
#[serde(rename = "system.device_removed")]
SystemDeviceRemoved {
/// Device type
device_type: String,
/// Device path that was removed
device_path: String,
},
/// System error or warning
#[serde(rename = "system.error")]
SystemError {
/// Module that generated the error: "stream", "hid", "msd", "atx"
module: String,
/// Severity: "warning", "error", "critical"
severity: String,
/// Error message
message: String,
},
/// Complete device information (sent on WebSocket connect and state changes)
#[serde(rename = "system.device_info")]
DeviceInfo {
/// Video device information
video: VideoDeviceInfo,
/// HID device information
hid: HidDeviceInfo,
/// MSD device information (None if MSD not enabled)
msd: Option<MsdDeviceInfo>,
/// ATX device information (None if ATX not enabled)
atx: Option<AtxDeviceInfo>,
/// Audio device information (None if audio not enabled)
audio: Option<AudioDeviceInfo>,
ttyd: TtydDeviceInfo,
},
/// WebSocket error notification (for connection-level errors like lag)
#[serde(rename = "error")]
Error {
/// Error message
message: String,
},
Error { message: String },
}
/// One entry per [`SystemEvent::event_name`]. `EventBus` builds `*.`-wildcard channels from the first segment; names without `.` (e.g. `error`) have no wildcard channel.
pub(crate) const EXACT_EVENT_TOPICS: &[&str] = &[
"stream.mode_switching",
"stream.state_changed",
"stream.config_changing",
"stream.config_applied",
"stream.device_lost",
"stream.reconnecting",
"stream.recovered",
"stream.webrtc_ready",
"stream.stats_update",
"stream.mode_changed",
"stream.mode_ready",
"webrtc.ice_candidate",
"webrtc.ice_complete",
"msd.upload_progress",
"msd.download_progress",
"system.device_info",
"error",
];
impl SystemEvent {
/// Get the event name (for filtering/routing)
pub fn event_name(&self) -> &'static str {
match self {
Self::StreamModeSwitching { .. } => "stream.mode_switching",
@@ -558,55 +244,12 @@ impl SystemEvent {
Self::StreamModeReady { .. } => "stream.mode_ready",
Self::WebRTCIceCandidate { .. } => "webrtc.ice_candidate",
Self::WebRTCIceComplete { .. } => "webrtc.ice_complete",
Self::HidStateChanged { .. } => "hid.state_changed",
Self::HidBackendSwitching { .. } => "hid.backend_switching",
Self::HidDeviceLost { .. } => "hid.device_lost",
Self::HidReconnecting { .. } => "hid.reconnecting",
Self::HidRecovered { .. } => "hid.recovered",
Self::MsdStateChanged { .. } => "msd.state_changed",
Self::MsdImageMounted { .. } => "msd.image_mounted",
Self::MsdImageUnmounted => "msd.image_unmounted",
Self::MsdUploadProgress { .. } => "msd.upload_progress",
Self::MsdDownloadProgress { .. } => "msd.download_progress",
Self::MsdUsbStatusChanged { .. } => "msd.usb_status_changed",
Self::MsdError { .. } => "msd.error",
Self::MsdRecovered => "msd.recovered",
Self::AtxStateChanged { .. } => "atx.state_changed",
Self::AtxActionExecuted { .. } => "atx.action_executed",
Self::AudioStateChanged { .. } => "audio.state_changed",
Self::AudioDeviceSelected { .. } => "audio.device_selected",
Self::AudioQualityChanged { .. } => "audio.quality_changed",
Self::AudioDeviceLost { .. } => "audio.device_lost",
Self::AudioReconnecting { .. } => "audio.reconnecting",
Self::AudioRecovered { .. } => "audio.recovered",
Self::SystemDeviceAdded { .. } => "system.device_added",
Self::SystemDeviceRemoved { .. } => "system.device_removed",
Self::SystemError { .. } => "system.error",
Self::DeviceInfo { .. } => "system.device_info",
Self::Error { .. } => "error",
}
}
/// Check if event name matches a topic pattern
///
/// Supports wildcards:
/// - `*` matches all events
/// - `stream.*` matches all stream events
/// - `stream.state_changed` matches exact event
pub fn matches_topic(&self, topic: &str) -> bool {
if topic == "*" {
return true;
}
let event_name = self.event_name();
if topic.ends_with(".*") {
let prefix = topic.trim_end_matches(".*");
event_name.starts_with(prefix)
} else {
event_name == topic
}
}
}
#[cfg(test)]
@@ -618,30 +261,145 @@ mod tests {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
};
assert_eq!(event.event_name(), "stream.state_changed");
let event = SystemEvent::MsdImageMounted {
image_id: "123".to_string(),
image_name: "ubuntu.iso".to_string(),
size: 1024,
cdrom: true,
};
assert_eq!(event.event_name(), "msd.image_mounted");
}
#[test]
fn test_matches_topic() {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: None,
fn stream_device_lost_json_snake_case_kind() {
let event = SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: "hw:0,0".to_string(),
reason: "test".to_string(),
};
let v = serde_json::to_value(&event).unwrap();
let data = v.get("data").unwrap();
assert_eq!(data.get("kind").and_then(|x| x.as_str()), Some("audio"));
assert_eq!(data.get("device").and_then(|x| x.as_str()), Some("hw:0,0"));
}
assert!(event.matches_topic("*"));
assert!(event.matches_topic("stream.*"));
assert!(event.matches_topic("stream.state_changed"));
assert!(!event.matches_topic("msd.*"));
assert!(!event.matches_topic("stream.config_changed"));
#[test]
fn exact_topics_covers_all_variants() {
use std::collections::HashSet;
let samples = vec![
SystemEvent::StreamModeSwitching {
transition_id: String::new(),
to_mode: String::new(),
from_mode: String::new(),
},
SystemEvent::StreamStateChanged {
state: String::new(),
device: None,
reason: None,
next_retry_ms: None,
},
SystemEvent::StreamConfigChanging {
transition_id: None,
reason: String::new(),
},
SystemEvent::StreamConfigApplied {
transition_id: None,
device: String::new(),
resolution: (0, 0),
format: String::new(),
fps: 0,
},
SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Video,
device: String::new(),
reason: String::new(),
},
SystemEvent::StreamReconnecting {
device: String::new(),
attempt: 0,
},
SystemEvent::StreamRecovered {
device: String::new(),
},
SystemEvent::WebRTCReady {
transition_id: None,
codec: String::new(),
hardware: false,
},
SystemEvent::StreamStatsUpdate {
clients: 0,
clients_stat: HashMap::new(),
},
SystemEvent::StreamModeChanged {
transition_id: None,
mode: String::new(),
previous_mode: String::new(),
},
SystemEvent::StreamModeReady {
transition_id: String::new(),
mode: String::new(),
},
SystemEvent::WebRTCIceCandidate {
session_id: String::new(),
candidate: serde_json::Value::Null,
},
SystemEvent::WebRTCIceComplete {
session_id: String::new(),
},
SystemEvent::MsdUploadProgress {
upload_id: String::new(),
filename: String::new(),
bytes_uploaded: 0,
total_bytes: 0,
progress_pct: 0.0,
},
SystemEvent::MsdDownloadProgress {
download_id: String::new(),
url: String::new(),
filename: String::new(),
bytes_downloaded: 0,
total_bytes: None,
progress_pct: None,
status: String::new(),
},
SystemEvent::DeviceInfo {
video: VideoDeviceInfo {
available: false,
device: None,
format: None,
resolution: None,
fps: 0,
online: false,
stream_mode: String::new(),
config_changing: false,
error: None,
},
hid: HidDeviceInfo {
available: false,
backend: String::new(),
initialized: false,
online: false,
supports_absolute_mouse: false,
keyboard_leds_enabled: false,
led_state: LedState::default(),
device: None,
error: None,
error_code: None,
},
msd: None,
atx: None,
audio: None,
ttyd: TtydDeviceInfo {
available: false,
running: false,
},
},
SystemEvent::Error {
message: String::new(),
},
];
let from_enum: HashSet<_> = samples.iter().map(|e| e.event_name()).collect();
let from_const: HashSet<_> = super::EXACT_EVENT_TOPICS.iter().copied().collect();
assert_eq!(from_enum, from_const);
}
#[test]

View File

@@ -1,5 +1,3 @@
//! Extension process manager
use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::process::Stdio;
@@ -10,27 +8,22 @@ use tokio::process::{Child, Command};
use tokio::sync::RwLock;
use super::types::*;
use crate::events::EventBus;
/// Maximum number of log lines to keep per extension
const LOG_BUFFER_SIZE: usize = 200;
/// Number of log lines to buffer before flushing to shared storage
const LOG_BATCH_SIZE: usize = 16;
/// Unix socket path for ttyd
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
/// Extension process with log buffer
struct ExtensionProcess {
child: Child,
logs: Arc<RwLock<VecDeque<String>>>,
}
/// Extension manager handles lifecycle of external processes
pub struct ExtensionManager {
processes: RwLock<HashMap<ExtensionId, ExtensionProcess>>,
/// Cached availability status (checked once at startup)
availability: HashMap<ExtensionId, bool>,
event_bus: RwLock<Option<Arc<EventBus>>>,
}
impl Default for ExtensionManager {
@@ -40,9 +33,7 @@ impl Default for ExtensionManager {
}
impl ExtensionManager {
/// Create a new extension manager with cached availability
pub fn new() -> Self {
// Check availability once at startup
let availability = ExtensionId::all()
.iter()
.map(|id| (*id, Path::new(id.binary_path()).exists()))
@@ -51,54 +42,81 @@ impl ExtensionManager {
Self {
processes: RwLock::new(HashMap::new()),
availability,
event_bus: RwLock::new(None),
}
}
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus);
}
async fn mark_ttyd_status_dirty(&self, id: ExtensionId) {
if id != ExtensionId::Ttyd {
return;
}
if let Some(ref event_bus) = *self.event_bus.read().await {
event_bus.mark_device_info_dirty();
}
}
/// Check if the binary for an extension is available (cached)
pub fn check_available(&self, id: ExtensionId) -> bool {
*self.availability.get(&id).unwrap_or(&false)
}
/// Get the current status of an extension
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
if !self.check_available(id) {
return ExtensionStatus::Unavailable;
}
let processes = self.processes.read().await;
match processes.get(&id) {
Some(proc) => {
if let Some(pid) = proc.child.id() {
ExtensionStatus::Running { pid }
} else {
ExtensionStatus::Stopped
let mut processes = self.processes.write().await;
let exited = {
let Some(proc) = processes.get_mut(&id) else {
return ExtensionStatus::Stopped;
};
match proc.child.try_wait() {
Ok(Some(status)) => {
tracing::info!("Extension {} exited with status {}", id, status);
true
}
Ok(None) => {
return match proc.child.id() {
Some(pid) => ExtensionStatus::Running { pid },
None => ExtensionStatus::Stopped,
};
}
Err(e) => {
tracing::warn!("Failed to query status for {}: {}", id, e);
return match proc.child.id() {
Some(pid) => ExtensionStatus::Running { pid },
None => ExtensionStatus::Stopped,
};
}
}
None => ExtensionStatus::Stopped,
};
if exited {
processes.remove(&id);
}
ExtensionStatus::Stopped
}
/// Start an extension with the given configuration
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
if !self.check_available(id) {
return Err(format!(
"{} not found at {}",
id.display_name(),
id.binary_path()
));
return Err(format!("{} not found at {}", id, id.binary_path()));
}
// Stop existing process first
self.stop(id).await.ok();
// Build command arguments
let args = self.build_args(id, config).await?;
tracing::info!(
"Starting extension {}: {} {}",
id,
id.binary_path(),
args.join(" ")
Self::redact_args_for_log(&args).join(" ")
);
let mut child = Command::new(id.binary_path())
@@ -107,11 +125,10 @@ impl ExtensionManager {
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?;
.map_err(|e| format!("Failed to start {}: {}", id, e))?;
let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE)));
// Spawn log collector for stdout
if let Some(stdout) = child.stdout.take() {
let logs_clone = logs.clone();
let id_clone = id;
@@ -120,7 +137,6 @@ impl ExtensionManager {
});
}
// Spawn log collector for stderr
if let Some(stderr) = child.stderr.take() {
let logs_clone = logs.clone();
let id_clone = id;
@@ -134,11 +150,12 @@ impl ExtensionManager {
let mut processes = self.processes.write().await;
processes.insert(id, ExtensionProcess { child, logs });
drop(processes);
self.mark_ttyd_status_dirty(id).await;
Ok(())
}
/// Stop an extension
pub async fn stop(&self, id: ExtensionId) -> Result<(), String> {
let mut processes = self.processes.write().await;
if let Some(mut proc) = processes.remove(&id) {
@@ -146,11 +163,12 @@ impl ExtensionManager {
if let Err(e) = proc.child.kill().await {
tracing::warn!("Failed to kill {}: {}", id, e);
}
drop(processes);
self.mark_ttyd_status_dirty(id).await;
}
Ok(())
}
/// Get recent logs for an extension
pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec<String> {
let processes = self.processes.read().await;
if let Some(proc) = processes.get(&id) {
@@ -162,7 +180,6 @@ impl ExtensionManager {
}
}
/// Collect logs from a stream with batched writes to reduce lock contention
async fn collect_logs<R: tokio::io::AsyncRead + Unpin>(
id: ExtensionId,
reader: R,
@@ -175,16 +192,14 @@ impl ExtensionManager {
loop {
match lines.next_line().await {
Ok(Some(line)) => {
tracing::debug!("[{}] {}", id, line);
tracing::info!("[{}] {}", id, line);
local_buffer.push(line);
// Flush when batch is full
if local_buffer.len() >= LOG_BATCH_SIZE {
Self::flush_logs(&logs, &mut local_buffer).await;
}
}
Ok(None) => {
// Stream ended, flush remaining logs
if !local_buffer.is_empty() {
Self::flush_logs(&logs, &mut local_buffer).await;
}
@@ -198,7 +213,6 @@ impl ExtensionManager {
}
}
/// Flush buffered logs to shared storage
async fn flush_logs(logs: &RwLock<VecDeque<String>>, buffer: &mut Vec<String>) {
let mut logs = logs.write().await;
for line in buffer.drain(..) {
@@ -209,7 +223,6 @@ impl ExtensionManager {
}
}
/// Build command arguments for an extension
async fn build_args(
&self,
id: ExtensionId,
@@ -219,41 +232,37 @@ impl ExtensionManager {
ExtensionId::Ttyd => {
let c = &config.ttyd;
// Prepare socket directory and clean up old socket (async)
Self::prepare_ttyd_socket().await?;
let mut args = vec![
"-i".to_string(),
TTYD_SOCKET_PATH.to_string(), // Unix socket
TTYD_SOCKET_PATH.to_string(),
"-b".to_string(),
"/api/terminal".to_string(), // Base path for reverse proxy
"-W".to_string(), // Writable (allow input)
"/api/terminal".to_string(),
"-W".to_string(),
];
// Add shell as last argument
args.push(c.shell.clone());
Ok(args)
}
ExtensionId::Gostc => {
let c = &config.gostc;
if c.addr.trim().is_empty() {
return Err("GOSTC server address is required".into());
}
if c.key.is_empty() {
return Err("GOSTC client key is required".into());
}
let mut args = Vec::new();
// Add TLS flag
if c.tls {
args.push("--tls=true".to_string());
}
// Add server address
if !c.addr.is_empty() {
args.extend(["-addr".to_string(), c.addr.clone()]);
}
args.extend(["-addr".to_string(), c.addr.trim().to_string()]);
// Add client key
args.extend(["-key".to_string(), c.key.clone()]);
Ok(args)
@@ -272,24 +281,19 @@ impl ExtensionManager {
c.network_secret.clone(),
];
// Add peer URLs
for peer in &c.peer_urls {
if !peer.is_empty() {
args.extend(["--peers".to_string(), peer.clone()]);
}
}
// Add virtual IP: use -d for DHCP if empty, or -i for specific IP
if let Some(ref ip) = c.virtual_ip {
if !ip.is_empty() {
// Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24)
args.extend(["-i".to_string(), ip.clone()]);
} else {
// Empty string means use DHCP
args.push("-d".to_string());
}
} else {
// None means use DHCP
args.push("-d".to_string());
}
@@ -298,11 +302,37 @@ impl ExtensionManager {
}
}
/// Prepare ttyd socket directory and clean up old socket file
fn redact_args_for_log(args: &[String]) -> Vec<String> {
let mut redacted = Vec::with_capacity(args.len());
let mut redact_next = false;
for arg in args {
if redact_next {
redacted.push("****".to_string());
redact_next = false;
continue;
}
if arg == "-key" || arg == "--key" {
redacted.push(arg.clone());
redact_next = true;
} else if let Some((flag, _)) = arg.split_once('=') {
if flag == "-key" || flag == "--key" {
redacted.push(format!("{}=****", flag));
} else {
redacted.push(arg.clone());
}
} else {
redacted.push(arg.clone());
}
}
redacted
}
async fn prepare_ttyd_socket() -> Result<(), String> {
let socket_path = Path::new(TTYD_SOCKET_PATH);
// Ensure socket directory exists
if let Some(socket_dir) = socket_path.parent() {
if !socket_dir.exists() {
tokio::fs::create_dir_all(socket_dir)
@@ -311,7 +341,6 @@ impl ExtensionManager {
}
}
// Remove old socket file if exists
if tokio::fs::try_exists(TTYD_SOCKET_PATH)
.await
.unwrap_or(false)
@@ -324,15 +353,17 @@ impl ExtensionManager {
Ok(())
}
/// Health check - restart crashed processes that should be running
pub async fn health_check(&self, config: &ExtensionsConfig) {
// Collect extensions that need restart check
let checks: Vec<_> = ExtensionId::all()
.iter()
.filter_map(|id| {
let should_run = match id {
ExtensionId::Ttyd => config.ttyd.enabled,
ExtensionId::Gostc => config.gostc.enabled && !config.gostc.key.is_empty(),
ExtensionId::Gostc => {
config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
}
ExtensionId::Easytier => {
config.easytier.enabled && !config.easytier.network_name.is_empty()
}
@@ -345,7 +376,6 @@ impl ExtensionManager {
})
.collect();
// Check which ones need restart (single read lock)
let needs_restart: Vec<_> = {
let processes = self.processes.read().await;
checks
@@ -360,7 +390,6 @@ impl ExtensionManager {
.collect()
};
// Restart all crashed extensions in parallel
let restart_futures: Vec<_> = needs_restart
.into_iter()
.map(|id| async move {
@@ -374,14 +403,12 @@ impl ExtensionManager {
futures::future::join_all(restart_futures).await;
}
/// Start all enabled extensions in parallel
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
use futures::Future;
use std::pin::Pin;
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
// Collect enabled extensions
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
@@ -392,6 +419,7 @@ impl ExtensionManager {
if config.gostc.enabled
&& !config.gostc.key.is_empty()
&& !config.gostc.addr.trim().is_empty()
&& self.check_available(ExtensionId::Gostc)
{
start_futures.push(Box::pin(async {
@@ -412,11 +440,9 @@ impl ExtensionManager {
}));
}
// Start all in parallel
futures::future::join_all(start_futures).await;
}
/// Stop all running extensions in parallel
pub async fn stop_all(&self) {
let stop_futures: Vec<_> = ExtensionId::all().iter().map(|id| self.stop(*id)).collect();
futures::future::join_all(stop_futures).await;

View File

@@ -1,5 +1,3 @@
//! Extensions module - manage external processes like ttyd, gostc, easytier
mod manager;
mod types;

View File

@@ -1,23 +1,16 @@
//! Extension types and configurations
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Extension identifier (fixed set of supported extensions)
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ExtensionId {
/// Web terminal (ttyd)
Ttyd,
/// NAT traversal client (gostc)
Gostc,
/// P2P VPN (easytier)
Easytier,
}
impl ExtensionId {
/// Get the binary path for this extension
pub fn binary_path(&self) -> &'static str {
match self {
Self::Ttyd => "/usr/bin/ttyd",
@@ -26,16 +19,6 @@ impl ExtensionId {
}
}
/// Get the display name for this extension
pub fn display_name(&self) -> &'static str {
match self {
Self::Ttyd => "Web Terminal",
Self::Gostc => "GOSTC Tunnel",
Self::Easytier => "EasyTier VPN",
}
}
/// Get all extension IDs
pub fn all() -> &'static [ExtensionId] {
&[Self::Ttyd, Self::Gostc, Self::Easytier]
}
@@ -64,25 +47,14 @@ impl std::str::FromStr for ExtensionId {
}
}
/// Extension running status
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "state", content = "data", rename_all = "lowercase")]
pub enum ExtensionStatus {
/// Binary not found at expected path
Unavailable,
/// Extension is stopped
Stopped,
/// Extension is running
Running {
/// Process ID
pid: u32,
},
/// Extension failed to start
Failed {
/// Error message
error: String,
},
Running { pid: u32 },
Failed { error: String },
}
impl ExtensionStatus {
@@ -91,16 +63,11 @@ impl ExtensionStatus {
}
}
/// ttyd configuration (Web Terminal)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TtydConfig {
/// Enable auto-start
pub enabled: bool,
/// Port to listen on
pub port: u16,
/// Shell to execute
pub shell: String,
}
@@ -108,25 +75,19 @@ impl Default for TtydConfig {
fn default() -> Self {
Self {
enabled: false,
port: 7681,
shell: "/bin/bash".to_string(),
}
}
}
/// gostc configuration (NAT traversal based on FRP)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct GostcConfig {
/// Enable auto-start
pub enabled: bool,
/// Server address (e.g., gostc.mofeng.run)
pub addr: String,
/// Client key from GOSTC management panel
#[serde(skip_serializing_if = "String::is_empty")]
pub key: String,
/// Enable TLS
pub tls: bool,
}
@@ -134,35 +95,28 @@ impl Default for GostcConfig {
fn default() -> Self {
Self {
enabled: false,
addr: "gostc.mofeng.run".to_string(),
addr: String::new(),
key: String::new(),
tls: true,
}
}
}
/// EasyTier configuration (P2P VPN)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct EasytierConfig {
/// Enable auto-start
pub enabled: bool,
/// Network name
pub network_name: String,
/// Network secret/password
#[serde(skip_serializing_if = "String::is_empty")]
pub network_secret: String,
/// Peer node URLs
#[serde(skip_serializing_if = "Vec::is_empty")]
pub peer_urls: Vec<String>,
/// Virtual IP address (optional, auto-assigned if not set)
#[serde(skip_serializing_if = "Option::is_none")]
pub virtual_ip: Option<String>,
}
/// Combined extensions configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
@@ -172,53 +126,37 @@ pub struct ExtensionsConfig {
pub easytier: EasytierConfig,
}
/// Extension info with status and config
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
}
/// ttyd extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: TtydConfig,
}
/// gostc extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GostcInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: GostcConfig,
}
/// easytier extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EasytierInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: EasytierConfig,
}
/// All extensions status response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionsStatus {
@@ -227,7 +165,6 @@ pub struct ExtensionsStatus {
pub easytier: EasytierInfo,
}
/// Extension logs response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionLogs {

View File

@@ -1,71 +1,32 @@
//! HID backend trait definition
//! `HidBackend` trait plus serde `HidBackendType` (OTG | CH9329 | disabled).
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use super::types::{ConsumerEvent, KeyboardEvent, MouseEvent};
use crate::error::Result;
use crate::events::LedState;
/// Default CH9329 baud rate
fn default_ch9329_baud_rate() -> u32 {
9600
}
/// HID backend type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackendType {
/// USB OTG gadget mode
Otg,
/// CH9329 serial HID controller
Ch9329 {
/// Serial port path
port: String,
/// Baud rate (default: 9600)
#[serde(default = "default_ch9329_baud_rate")]
baud_rate: u32,
},
/// No HID backend (disabled)
#[default]
None,
}
impl HidBackendType {
/// Check if OTG backend is available on this system
pub fn otg_available() -> bool {
// Check for USB gadget support
std::path::Path::new("/sys/class/udc").exists()
}
/// Detect the best available backend
pub fn detect() -> Self {
// Check for OTG gadget support
if Self::otg_available() {
return Self::Otg;
}
// Check for common CH9329 serial ports
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
];
for port in &common_ports {
if std::path::Path::new(port).exists() {
return Self::Ch9329 {
port: port.to_string(),
baud_rate: 9600, // Use default baud rate for auto-detection
};
}
}
Self::None
}
/// Get backend name as string
pub fn name_str(&self) -> &str {
match self {
Self::Otg => "otg",
@@ -75,67 +36,40 @@ impl HidBackendType {
}
}
/// HID backend trait
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct HidBackendRuntimeSnapshot {
pub initialized: bool,
pub online: bool,
pub supports_absolute_mouse: bool,
pub keyboard_leds_enabled: bool,
pub led_state: LedState,
pub screen_resolution: Option<(u32, u32)>,
pub device: Option<String>,
pub error: Option<String>,
pub error_code: Option<String>,
}
#[async_trait]
pub trait HidBackend: Send + Sync {
/// Get backend name
fn name(&self) -> &'static str;
/// Initialize the backend
async fn init(&self) -> Result<()>;
/// Send a keyboard event
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>;
/// Send a mouse event
async fn send_mouse(&self, event: MouseEvent) -> Result<()>;
/// Send a consumer control event (multimedia keys)
/// Default implementation returns an error (not supported)
async fn send_consumer(&self, _event: ConsumerEvent) -> Result<()> {
Err(crate::error::AppError::BadRequest(
"Consumer control not supported by this backend".to_string(),
))
}
/// Reset all inputs (release all keys/buttons)
async fn reset(&self) -> Result<()>;
/// Shutdown the backend
async fn shutdown(&self) -> Result<()>;
/// Perform backend health check.
///
/// Default implementation assumes backend is healthy.
fn health_check(&self) -> Result<()> {
Ok(())
}
fn runtime_snapshot(&self) -> HidBackendRuntimeSnapshot;
/// Check if backend supports absolute mouse positioning
fn supports_absolute_mouse(&self) -> bool {
false
}
fn subscribe_runtime(&self) -> watch::Receiver<()>;
/// Get screen resolution (for absolute mouse)
fn screen_resolution(&self) -> Option<(u32, u32)> {
None
}
/// Set screen resolution (for absolute mouse)
fn set_screen_resolution(&mut self, _width: u32, _height: u32) {}
}
/// HID backend information
#[derive(Debug, Clone, Serialize)]
pub struct HidBackendInfo {
/// Backend name
pub name: String,
/// Backend type
pub backend_type: String,
/// Is initialized
pub initialized: bool,
/// Supports absolute mouse
pub absolute_mouse: bool,
/// Screen resolution (if absolute mouse)
pub resolution: Option<(u32, u32)>,
fn set_screen_resolution(&self, _width: u32, _height: u32) {}
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,21 +2,17 @@
//!
//! Reference: USB HID Usage Tables 1.12, Section 15 (Consumer Page 0x0C)
/// Consumer Control Usage codes for multimedia keys
pub mod usage {
// Transport Controls
pub const PLAY_PAUSE: u16 = 0x00CD;
pub const STOP: u16 = 0x00B7;
pub const NEXT_TRACK: u16 = 0x00B5;
pub const PREV_TRACK: u16 = 0x00B6;
// Volume Controls
pub const MUTE: u16 = 0x00E2;
pub const VOLUME_UP: u16 = 0x00E9;
pub const VOLUME_DOWN: u16 = 0x00EA;
}
/// Check if a usage code is valid
pub fn is_valid_usage(usage: u16) -> bool {
matches!(
usage,

View File

@@ -9,7 +9,7 @@
//!
//! Keyboard event (type 0x01):
//! - Byte 1: Event type (0x00 = down, 0x01 = up)
//! - Byte 2: Key code (USB HID usage code)
//! - Byte 2: Canonical key code (stable One-KVM key id aligned with HID usage)
//! - Byte 3: Modifiers bitmask
//! - Bit 0: Left Ctrl
//! - Bit 1: Left Shift
@@ -38,26 +38,23 @@ use tracing::warn;
use super::types::ConsumerEvent;
use super::{
KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent, MouseEventType,
CanonicalKey, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent,
MouseEventType,
};
/// Message types
pub const MSG_KEYBOARD: u8 = 0x01;
pub const MSG_MOUSE: u8 = 0x02;
pub const MSG_CONSUMER: u8 = 0x03;
/// Keyboard event types
pub const KB_EVENT_DOWN: u8 = 0x00;
pub const KB_EVENT_UP: u8 = 0x01;
/// Mouse event types
pub const MS_EVENT_MOVE: u8 = 0x00;
pub const MS_EVENT_MOVE_ABS: u8 = 0x01;
pub const MS_EVENT_DOWN: u8 = 0x02;
pub const MS_EVENT_UP: u8 = 0x03;
pub const MS_EVENT_SCROLL: u8 = 0x04;
/// Parsed HID event from DataChannel
#[derive(Debug, Clone)]
pub enum HidChannelEvent {
Keyboard(KeyboardEvent),
@@ -65,7 +62,6 @@ pub enum HidChannelEvent {
Consumer(ConsumerEvent),
}
/// Parse a binary HID message from DataChannel
pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.is_empty() {
warn!("Empty HID message");
@@ -85,7 +81,6 @@ pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
}
}
/// Parse keyboard message payload
fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 3 {
warn!("Keyboard message too short: {} bytes", data.len());
@@ -101,7 +96,13 @@ fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
}
};
let key = data[1];
let key = match CanonicalKey::from_hid_usage(data[1]) {
Some(key) => key,
None => {
warn!("Unknown canonical keyboard key code: 0x{:02X}", data[1]);
return None;
}
};
let modifiers_byte = data[2];
let modifiers = KeyboardModifiers {
@@ -119,11 +120,9 @@ fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
event_type,
key,
modifiers,
is_usb_hid: true, // WebRTC/WebSocket HID channel sends USB HID usages
}))
}
/// Parse mouse message payload
fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 6 {
warn!("Mouse message too short: {} bytes", data.len());
@@ -142,11 +141,9 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
}
};
// Parse coordinates as i16 LE (works for both relative and absolute)
let x = i16::from_le_bytes([data[1], data[2]]) as i32;
let y = i16::from_le_bytes([data[3], data[4]]) as i32;
// Button or scroll delta
let (button, scroll) = match event_type {
MouseEventType::Down | MouseEventType::Up => {
let btn = match data[5] {
@@ -172,7 +169,6 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
}))
}
/// Parse consumer control message payload
fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 2 {
warn!("Consumer message too short: {} bytes", data.len());
@@ -184,7 +180,6 @@ fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
Some(HidChannelEvent::Consumer(ConsumerEvent { usage }))
}
/// Encode a keyboard event to binary format (for sending to client if needed)
pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
let event_type = match event.event_type {
KeyEventType::Down => KB_EVENT_DOWN,
@@ -193,40 +188,11 @@ pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
let modifiers = event.modifiers.to_hid_byte();
vec![MSG_KEYBOARD, event_type, event.key, modifiers]
}
/// Encode a mouse event to binary format (for sending to client if needed)
pub fn encode_mouse_event(event: &MouseEvent) -> Vec<u8> {
let event_type = match event.event_type {
MouseEventType::Move => MS_EVENT_MOVE,
MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS,
MouseEventType::Down => MS_EVENT_DOWN,
MouseEventType::Up => MS_EVENT_UP,
MouseEventType::Scroll => MS_EVENT_SCROLL,
};
let x_bytes = (event.x as i16).to_le_bytes();
let y_bytes = (event.y as i16).to_le_bytes();
let extra = match event.event_type {
MouseEventType::Down | MouseEventType::Up => event
.button
.as_ref()
.map(|b| match b {
MouseButton::Left => 0u8,
MouseButton::Middle => 1u8,
MouseButton::Right => 2u8,
MouseButton::Back => 3u8,
MouseButton::Forward => 4u8,
})
.unwrap_or(0),
MouseEventType::Scroll => event.scroll as u8,
_ => 0,
};
vec![
MSG_MOUSE, event_type, x_bytes[0], x_bytes[1], y_bytes[0], y_bytes[1], extra,
MSG_KEYBOARD,
event_type,
event.key.to_hid_usage(),
modifiers,
]
}
@@ -242,10 +208,9 @@ mod tests {
match event {
HidChannelEvent::Keyboard(kb) => {
assert!(matches!(kb.event_type, KeyEventType::Down));
assert_eq!(kb.key, 0x04);
assert_eq!(kb.key, CanonicalKey::KeyA);
assert!(kb.modifiers.left_ctrl);
assert!(!kb.modifiers.left_shift);
assert!(kb.is_usb_hid);
}
_ => panic!("Expected keyboard event"),
}
@@ -270,7 +235,7 @@ mod tests {
fn test_encode_keyboard() {
let event = KeyboardEvent {
event_type: KeyEventType::Down,
key: 0x04,
key: CanonicalKey::KeyA,
modifiers: KeyboardModifiers {
left_ctrl: true,
left_shift: false,
@@ -281,7 +246,6 @@ mod tests {
right_alt: false,
right_meta: false,
},
is_usb_hid: true,
};
let encoded = encode_keyboard_event(&event);

399
src/hid/keyboard.rs Normal file
View File

@@ -0,0 +1,399 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CanonicalKey {
KeyA,
KeyB,
KeyC,
KeyD,
KeyE,
KeyF,
KeyG,
KeyH,
KeyI,
KeyJ,
KeyK,
KeyL,
KeyM,
KeyN,
KeyO,
KeyP,
KeyQ,
KeyR,
KeyS,
KeyT,
KeyU,
KeyV,
KeyW,
KeyX,
KeyY,
KeyZ,
Digit1,
Digit2,
Digit3,
Digit4,
Digit5,
Digit6,
Digit7,
Digit8,
Digit9,
Digit0,
Enter,
Escape,
Backspace,
Tab,
Space,
Minus,
Equal,
BracketLeft,
BracketRight,
Backslash,
Semicolon,
Quote,
Backquote,
Comma,
Period,
Slash,
CapsLock,
F1,
F2,
F3,
F4,
F5,
F6,
F7,
F8,
F9,
F10,
F11,
F12,
PrintScreen,
ScrollLock,
Pause,
Insert,
Home,
PageUp,
Delete,
End,
PageDown,
ArrowRight,
ArrowLeft,
ArrowDown,
ArrowUp,
NumLock,
NumpadDivide,
NumpadMultiply,
NumpadSubtract,
NumpadAdd,
NumpadEnter,
Numpad1,
Numpad2,
Numpad3,
Numpad4,
Numpad5,
Numpad6,
Numpad7,
Numpad8,
Numpad9,
Numpad0,
NumpadDecimal,
IntlBackslash,
ContextMenu,
F13,
F14,
F15,
F16,
F17,
F18,
F19,
F20,
F21,
F22,
F23,
F24,
ControlLeft,
ShiftLeft,
AltLeft,
MetaLeft,
ControlRight,
ShiftRight,
AltRight,
MetaRight,
}
impl CanonicalKey {
pub const fn to_hid_usage(self) -> u8 {
match self {
Self::KeyA => 0x04,
Self::KeyB => 0x05,
Self::KeyC => 0x06,
Self::KeyD => 0x07,
Self::KeyE => 0x08,
Self::KeyF => 0x09,
Self::KeyG => 0x0A,
Self::KeyH => 0x0B,
Self::KeyI => 0x0C,
Self::KeyJ => 0x0D,
Self::KeyK => 0x0E,
Self::KeyL => 0x0F,
Self::KeyM => 0x10,
Self::KeyN => 0x11,
Self::KeyO => 0x12,
Self::KeyP => 0x13,
Self::KeyQ => 0x14,
Self::KeyR => 0x15,
Self::KeyS => 0x16,
Self::KeyT => 0x17,
Self::KeyU => 0x18,
Self::KeyV => 0x19,
Self::KeyW => 0x1A,
Self::KeyX => 0x1B,
Self::KeyY => 0x1C,
Self::KeyZ => 0x1D,
Self::Digit1 => 0x1E,
Self::Digit2 => 0x1F,
Self::Digit3 => 0x20,
Self::Digit4 => 0x21,
Self::Digit5 => 0x22,
Self::Digit6 => 0x23,
Self::Digit7 => 0x24,
Self::Digit8 => 0x25,
Self::Digit9 => 0x26,
Self::Digit0 => 0x27,
Self::Enter => 0x28,
Self::Escape => 0x29,
Self::Backspace => 0x2A,
Self::Tab => 0x2B,
Self::Space => 0x2C,
Self::Minus => 0x2D,
Self::Equal => 0x2E,
Self::BracketLeft => 0x2F,
Self::BracketRight => 0x30,
Self::Backslash => 0x31,
Self::Semicolon => 0x33,
Self::Quote => 0x34,
Self::Backquote => 0x35,
Self::Comma => 0x36,
Self::Period => 0x37,
Self::Slash => 0x38,
Self::CapsLock => 0x39,
Self::F1 => 0x3A,
Self::F2 => 0x3B,
Self::F3 => 0x3C,
Self::F4 => 0x3D,
Self::F5 => 0x3E,
Self::F6 => 0x3F,
Self::F7 => 0x40,
Self::F8 => 0x41,
Self::F9 => 0x42,
Self::F10 => 0x43,
Self::F11 => 0x44,
Self::F12 => 0x45,
Self::PrintScreen => 0x46,
Self::ScrollLock => 0x47,
Self::Pause => 0x48,
Self::Insert => 0x49,
Self::Home => 0x4A,
Self::PageUp => 0x4B,
Self::Delete => 0x4C,
Self::End => 0x4D,
Self::PageDown => 0x4E,
Self::ArrowRight => 0x4F,
Self::ArrowLeft => 0x50,
Self::ArrowDown => 0x51,
Self::ArrowUp => 0x52,
Self::NumLock => 0x53,
Self::NumpadDivide => 0x54,
Self::NumpadMultiply => 0x55,
Self::NumpadSubtract => 0x56,
Self::NumpadAdd => 0x57,
Self::NumpadEnter => 0x58,
Self::Numpad1 => 0x59,
Self::Numpad2 => 0x5A,
Self::Numpad3 => 0x5B,
Self::Numpad4 => 0x5C,
Self::Numpad5 => 0x5D,
Self::Numpad6 => 0x5E,
Self::Numpad7 => 0x5F,
Self::Numpad8 => 0x60,
Self::Numpad9 => 0x61,
Self::Numpad0 => 0x62,
Self::NumpadDecimal => 0x63,
Self::IntlBackslash => 0x64,
Self::ContextMenu => 0x65,
Self::F13 => 0x68,
Self::F14 => 0x69,
Self::F15 => 0x6A,
Self::F16 => 0x6B,
Self::F17 => 0x6C,
Self::F18 => 0x6D,
Self::F19 => 0x6E,
Self::F20 => 0x6F,
Self::F21 => 0x70,
Self::F22 => 0x71,
Self::F23 => 0x72,
Self::F24 => 0x73,
Self::ControlLeft => 0xE0,
Self::ShiftLeft => 0xE1,
Self::AltLeft => 0xE2,
Self::MetaLeft => 0xE3,
Self::ControlRight => 0xE4,
Self::ShiftRight => 0xE5,
Self::AltRight => 0xE6,
Self::MetaRight => 0xE7,
}
}
pub const fn from_hid_usage(usage: u8) -> Option<Self> {
match usage {
0x04 => Some(Self::KeyA),
0x05 => Some(Self::KeyB),
0x06 => Some(Self::KeyC),
0x07 => Some(Self::KeyD),
0x08 => Some(Self::KeyE),
0x09 => Some(Self::KeyF),
0x0A => Some(Self::KeyG),
0x0B => Some(Self::KeyH),
0x0C => Some(Self::KeyI),
0x0D => Some(Self::KeyJ),
0x0E => Some(Self::KeyK),
0x0F => Some(Self::KeyL),
0x10 => Some(Self::KeyM),
0x11 => Some(Self::KeyN),
0x12 => Some(Self::KeyO),
0x13 => Some(Self::KeyP),
0x14 => Some(Self::KeyQ),
0x15 => Some(Self::KeyR),
0x16 => Some(Self::KeyS),
0x17 => Some(Self::KeyT),
0x18 => Some(Self::KeyU),
0x19 => Some(Self::KeyV),
0x1A => Some(Self::KeyW),
0x1B => Some(Self::KeyX),
0x1C => Some(Self::KeyY),
0x1D => Some(Self::KeyZ),
0x1E => Some(Self::Digit1),
0x1F => Some(Self::Digit2),
0x20 => Some(Self::Digit3),
0x21 => Some(Self::Digit4),
0x22 => Some(Self::Digit5),
0x23 => Some(Self::Digit6),
0x24 => Some(Self::Digit7),
0x25 => Some(Self::Digit8),
0x26 => Some(Self::Digit9),
0x27 => Some(Self::Digit0),
0x28 => Some(Self::Enter),
0x29 => Some(Self::Escape),
0x2A => Some(Self::Backspace),
0x2B => Some(Self::Tab),
0x2C => Some(Self::Space),
0x2D => Some(Self::Minus),
0x2E => Some(Self::Equal),
0x2F => Some(Self::BracketLeft),
0x30 => Some(Self::BracketRight),
0x31 => Some(Self::Backslash),
0x33 => Some(Self::Semicolon),
0x34 => Some(Self::Quote),
0x35 => Some(Self::Backquote),
0x36 => Some(Self::Comma),
0x37 => Some(Self::Period),
0x38 => Some(Self::Slash),
0x39 => Some(Self::CapsLock),
0x3A => Some(Self::F1),
0x3B => Some(Self::F2),
0x3C => Some(Self::F3),
0x3D => Some(Self::F4),
0x3E => Some(Self::F5),
0x3F => Some(Self::F6),
0x40 => Some(Self::F7),
0x41 => Some(Self::F8),
0x42 => Some(Self::F9),
0x43 => Some(Self::F10),
0x44 => Some(Self::F11),
0x45 => Some(Self::F12),
0x46 => Some(Self::PrintScreen),
0x47 => Some(Self::ScrollLock),
0x48 => Some(Self::Pause),
0x49 => Some(Self::Insert),
0x4A => Some(Self::Home),
0x4B => Some(Self::PageUp),
0x4C => Some(Self::Delete),
0x4D => Some(Self::End),
0x4E => Some(Self::PageDown),
0x4F => Some(Self::ArrowRight),
0x50 => Some(Self::ArrowLeft),
0x51 => Some(Self::ArrowDown),
0x52 => Some(Self::ArrowUp),
0x53 => Some(Self::NumLock),
0x54 => Some(Self::NumpadDivide),
0x55 => Some(Self::NumpadMultiply),
0x56 => Some(Self::NumpadSubtract),
0x57 => Some(Self::NumpadAdd),
0x58 => Some(Self::NumpadEnter),
0x59 => Some(Self::Numpad1),
0x5A => Some(Self::Numpad2),
0x5B => Some(Self::Numpad3),
0x5C => Some(Self::Numpad4),
0x5D => Some(Self::Numpad5),
0x5E => Some(Self::Numpad6),
0x5F => Some(Self::Numpad7),
0x60 => Some(Self::Numpad8),
0x61 => Some(Self::Numpad9),
0x62 => Some(Self::Numpad0),
0x63 => Some(Self::NumpadDecimal),
0x64 => Some(Self::IntlBackslash),
0x65 => Some(Self::ContextMenu),
0x68 => Some(Self::F13),
0x69 => Some(Self::F14),
0x6A => Some(Self::F15),
0x6B => Some(Self::F16),
0x6C => Some(Self::F17),
0x6D => Some(Self::F18),
0x6E => Some(Self::F19),
0x6F => Some(Self::F20),
0x70 => Some(Self::F21),
0x71 => Some(Self::F22),
0x72 => Some(Self::F23),
0x73 => Some(Self::F24),
0xE0 => Some(Self::ControlLeft),
0xE1 => Some(Self::ShiftLeft),
0xE2 => Some(Self::AltLeft),
0xE3 => Some(Self::MetaLeft),
0xE4 => Some(Self::ControlRight),
0xE5 => Some(Self::ShiftRight),
0xE6 => Some(Self::AltRight),
0xE7 => Some(Self::MetaRight),
_ => None,
}
}
pub const fn is_modifier(self) -> bool {
matches!(
self,
Self::ControlLeft
| Self::ShiftLeft
| Self::AltLeft
| Self::MetaLeft
| Self::ControlRight
| Self::ShiftRight
| Self::AltRight
| Self::MetaRight
)
}
pub const fn modifier_bit(self) -> Option<u8> {
match self {
Self::ControlLeft => Some(0x01),
Self::ShiftLeft => Some(0x02),
Self::AltLeft => Some(0x04),
Self::MetaLeft => Some(0x08),
Self::ControlRight => Some(0x10),
Self::ShiftRight => Some(0x20),
Self::AltRight => Some(0x40),
Self::MetaRight => Some(0x80),
_ => None,
}
}
}

View File

@@ -1,430 +0,0 @@
//! USB HID keyboard key codes mapping
//!
//! This module provides mapping between JavaScript key codes and USB HID usage codes.
//! Reference: USB HID Usage Tables 1.12, Section 10 (Keyboard/Keypad Page)
/// USB HID key codes (Usage Page 0x07)
#[allow(dead_code)]
pub mod usb {
// Letters A-Z (0x04 - 0x1D)
pub const KEY_A: u8 = 0x04;
pub const KEY_B: u8 = 0x05;
pub const KEY_C: u8 = 0x06;
pub const KEY_D: u8 = 0x07;
pub const KEY_E: u8 = 0x08;
pub const KEY_F: u8 = 0x09;
pub const KEY_G: u8 = 0x0A;
pub const KEY_H: u8 = 0x0B;
pub const KEY_I: u8 = 0x0C;
pub const KEY_J: u8 = 0x0D;
pub const KEY_K: u8 = 0x0E;
pub const KEY_L: u8 = 0x0F;
pub const KEY_M: u8 = 0x10;
pub const KEY_N: u8 = 0x11;
pub const KEY_O: u8 = 0x12;
pub const KEY_P: u8 = 0x13;
pub const KEY_Q: u8 = 0x14;
pub const KEY_R: u8 = 0x15;
pub const KEY_S: u8 = 0x16;
pub const KEY_T: u8 = 0x17;
pub const KEY_U: u8 = 0x18;
pub const KEY_V: u8 = 0x19;
pub const KEY_W: u8 = 0x1A;
pub const KEY_X: u8 = 0x1B;
pub const KEY_Y: u8 = 0x1C;
pub const KEY_Z: u8 = 0x1D;
// Numbers 1-9, 0 (0x1E - 0x27)
pub const KEY_1: u8 = 0x1E;
pub const KEY_2: u8 = 0x1F;
pub const KEY_3: u8 = 0x20;
pub const KEY_4: u8 = 0x21;
pub const KEY_5: u8 = 0x22;
pub const KEY_6: u8 = 0x23;
pub const KEY_7: u8 = 0x24;
pub const KEY_8: u8 = 0x25;
pub const KEY_9: u8 = 0x26;
pub const KEY_0: u8 = 0x27;
// Control keys
pub const KEY_ENTER: u8 = 0x28;
pub const KEY_ESCAPE: u8 = 0x29;
pub const KEY_BACKSPACE: u8 = 0x2A;
pub const KEY_TAB: u8 = 0x2B;
pub const KEY_SPACE: u8 = 0x2C;
pub const KEY_MINUS: u8 = 0x2D;
pub const KEY_EQUAL: u8 = 0x2E;
pub const KEY_LEFT_BRACKET: u8 = 0x2F;
pub const KEY_RIGHT_BRACKET: u8 = 0x30;
pub const KEY_BACKSLASH: u8 = 0x31;
pub const KEY_HASH: u8 = 0x32; // Non-US # and ~
pub const KEY_SEMICOLON: u8 = 0x33;
pub const KEY_APOSTROPHE: u8 = 0x34;
pub const KEY_GRAVE: u8 = 0x35;
pub const KEY_COMMA: u8 = 0x36;
pub const KEY_PERIOD: u8 = 0x37;
pub const KEY_SLASH: u8 = 0x38;
pub const KEY_CAPS_LOCK: u8 = 0x39;
// Function keys F1-F12
pub const KEY_F1: u8 = 0x3A;
pub const KEY_F2: u8 = 0x3B;
pub const KEY_F3: u8 = 0x3C;
pub const KEY_F4: u8 = 0x3D;
pub const KEY_F5: u8 = 0x3E;
pub const KEY_F6: u8 = 0x3F;
pub const KEY_F7: u8 = 0x40;
pub const KEY_F8: u8 = 0x41;
pub const KEY_F9: u8 = 0x42;
pub const KEY_F10: u8 = 0x43;
pub const KEY_F11: u8 = 0x44;
pub const KEY_F12: u8 = 0x45;
// Special keys
pub const KEY_PRINT_SCREEN: u8 = 0x46;
pub const KEY_SCROLL_LOCK: u8 = 0x47;
pub const KEY_PAUSE: u8 = 0x48;
pub const KEY_INSERT: u8 = 0x49;
pub const KEY_HOME: u8 = 0x4A;
pub const KEY_PAGE_UP: u8 = 0x4B;
pub const KEY_DELETE: u8 = 0x4C;
pub const KEY_END: u8 = 0x4D;
pub const KEY_PAGE_DOWN: u8 = 0x4E;
pub const KEY_RIGHT_ARROW: u8 = 0x4F;
pub const KEY_LEFT_ARROW: u8 = 0x50;
pub const KEY_DOWN_ARROW: u8 = 0x51;
pub const KEY_UP_ARROW: u8 = 0x52;
// Numpad
pub const KEY_NUM_LOCK: u8 = 0x53;
pub const KEY_NUMPAD_DIVIDE: u8 = 0x54;
pub const KEY_NUMPAD_MULTIPLY: u8 = 0x55;
pub const KEY_NUMPAD_MINUS: u8 = 0x56;
pub const KEY_NUMPAD_PLUS: u8 = 0x57;
pub const KEY_NUMPAD_ENTER: u8 = 0x58;
pub const KEY_NUMPAD_1: u8 = 0x59;
pub const KEY_NUMPAD_2: u8 = 0x5A;
pub const KEY_NUMPAD_3: u8 = 0x5B;
pub const KEY_NUMPAD_4: u8 = 0x5C;
pub const KEY_NUMPAD_5: u8 = 0x5D;
pub const KEY_NUMPAD_6: u8 = 0x5E;
pub const KEY_NUMPAD_7: u8 = 0x5F;
pub const KEY_NUMPAD_8: u8 = 0x60;
pub const KEY_NUMPAD_9: u8 = 0x61;
pub const KEY_NUMPAD_0: u8 = 0x62;
pub const KEY_NUMPAD_DECIMAL: u8 = 0x63;
// Additional keys
pub const KEY_NON_US_BACKSLASH: u8 = 0x64;
pub const KEY_APPLICATION: u8 = 0x65; // Context menu
pub const KEY_POWER: u8 = 0x66;
pub const KEY_NUMPAD_EQUAL: u8 = 0x67;
// F13-F24
pub const KEY_F13: u8 = 0x68;
pub const KEY_F14: u8 = 0x69;
pub const KEY_F15: u8 = 0x6A;
pub const KEY_F16: u8 = 0x6B;
pub const KEY_F17: u8 = 0x6C;
pub const KEY_F18: u8 = 0x6D;
pub const KEY_F19: u8 = 0x6E;
pub const KEY_F20: u8 = 0x6F;
pub const KEY_F21: u8 = 0x70;
pub const KEY_F22: u8 = 0x71;
pub const KEY_F23: u8 = 0x72;
pub const KEY_F24: u8 = 0x73;
// Modifier keys (these are handled separately in the modifier byte)
pub const KEY_LEFT_CTRL: u8 = 0xE0;
pub const KEY_LEFT_SHIFT: u8 = 0xE1;
pub const KEY_LEFT_ALT: u8 = 0xE2;
pub const KEY_LEFT_META: u8 = 0xE3;
pub const KEY_RIGHT_CTRL: u8 = 0xE4;
pub const KEY_RIGHT_SHIFT: u8 = 0xE5;
pub const KEY_RIGHT_ALT: u8 = 0xE6;
pub const KEY_RIGHT_META: u8 = 0xE7;
}
/// JavaScript key codes (event.keyCode / event.code)
#[allow(dead_code)]
pub mod js {
// Letters
pub const KEY_A: u8 = 65;
pub const KEY_B: u8 = 66;
pub const KEY_C: u8 = 67;
pub const KEY_D: u8 = 68;
pub const KEY_E: u8 = 69;
pub const KEY_F: u8 = 70;
pub const KEY_G: u8 = 71;
pub const KEY_H: u8 = 72;
pub const KEY_I: u8 = 73;
pub const KEY_J: u8 = 74;
pub const KEY_K: u8 = 75;
pub const KEY_L: u8 = 76;
pub const KEY_M: u8 = 77;
pub const KEY_N: u8 = 78;
pub const KEY_O: u8 = 79;
pub const KEY_P: u8 = 80;
pub const KEY_Q: u8 = 81;
pub const KEY_R: u8 = 82;
pub const KEY_S: u8 = 83;
pub const KEY_T: u8 = 84;
pub const KEY_U: u8 = 85;
pub const KEY_V: u8 = 86;
pub const KEY_W: u8 = 87;
pub const KEY_X: u8 = 88;
pub const KEY_Y: u8 = 89;
pub const KEY_Z: u8 = 90;
// Numbers (top row)
pub const KEY_0: u8 = 48;
pub const KEY_1: u8 = 49;
pub const KEY_2: u8 = 50;
pub const KEY_3: u8 = 51;
pub const KEY_4: u8 = 52;
pub const KEY_5: u8 = 53;
pub const KEY_6: u8 = 54;
pub const KEY_7: u8 = 55;
pub const KEY_8: u8 = 56;
pub const KEY_9: u8 = 57;
// Function keys
pub const KEY_F1: u8 = 112;
pub const KEY_F2: u8 = 113;
pub const KEY_F3: u8 = 114;
pub const KEY_F4: u8 = 115;
pub const KEY_F5: u8 = 116;
pub const KEY_F6: u8 = 117;
pub const KEY_F7: u8 = 118;
pub const KEY_F8: u8 = 119;
pub const KEY_F9: u8 = 120;
pub const KEY_F10: u8 = 121;
pub const KEY_F11: u8 = 122;
pub const KEY_F12: u8 = 123;
// Control keys
pub const KEY_BACKSPACE: u8 = 8;
pub const KEY_TAB: u8 = 9;
pub const KEY_ENTER: u8 = 13;
pub const KEY_SHIFT: u8 = 16;
pub const KEY_CTRL: u8 = 17;
pub const KEY_ALT: u8 = 18;
pub const KEY_PAUSE: u8 = 19;
pub const KEY_CAPS_LOCK: u8 = 20;
pub const KEY_ESCAPE: u8 = 27;
pub const KEY_SPACE: u8 = 32;
pub const KEY_PAGE_UP: u8 = 33;
pub const KEY_PAGE_DOWN: u8 = 34;
pub const KEY_END: u8 = 35;
pub const KEY_HOME: u8 = 36;
pub const KEY_LEFT: u8 = 37;
pub const KEY_UP: u8 = 38;
pub const KEY_RIGHT: u8 = 39;
pub const KEY_DOWN: u8 = 40;
pub const KEY_INSERT: u8 = 45;
pub const KEY_DELETE: u8 = 46;
// Punctuation
pub const KEY_SEMICOLON: u8 = 186;
pub const KEY_EQUAL: u8 = 187;
pub const KEY_COMMA: u8 = 188;
pub const KEY_MINUS: u8 = 189;
pub const KEY_PERIOD: u8 = 190;
pub const KEY_SLASH: u8 = 191;
pub const KEY_GRAVE: u8 = 192;
pub const KEY_LEFT_BRACKET: u8 = 219;
pub const KEY_BACKSLASH: u8 = 220;
pub const KEY_RIGHT_BRACKET: u8 = 221;
pub const KEY_APOSTROPHE: u8 = 222;
// Numpad
pub const KEY_NUMPAD_0: u8 = 96;
pub const KEY_NUMPAD_1: u8 = 97;
pub const KEY_NUMPAD_2: u8 = 98;
pub const KEY_NUMPAD_3: u8 = 99;
pub const KEY_NUMPAD_4: u8 = 100;
pub const KEY_NUMPAD_5: u8 = 101;
pub const KEY_NUMPAD_6: u8 = 102;
pub const KEY_NUMPAD_7: u8 = 103;
pub const KEY_NUMPAD_8: u8 = 104;
pub const KEY_NUMPAD_9: u8 = 105;
pub const KEY_NUMPAD_MULTIPLY: u8 = 106;
pub const KEY_NUMPAD_ADD: u8 = 107;
pub const KEY_NUMPAD_SUBTRACT: u8 = 109;
pub const KEY_NUMPAD_DECIMAL: u8 = 110;
pub const KEY_NUMPAD_DIVIDE: u8 = 111;
// Lock keys
pub const KEY_NUM_LOCK: u8 = 144;
pub const KEY_SCROLL_LOCK: u8 = 145;
// Windows keys
pub const KEY_META_LEFT: u8 = 91;
pub const KEY_META_RIGHT: u8 = 92;
pub const KEY_CONTEXT_MENU: u8 = 93;
}
/// JavaScript keyCode to USB HID keyCode mapping table
/// Using a fixed-size array for O(1) lookup instead of HashMap
/// Index = JavaScript keyCode, Value = USB HID keyCode (0 means unmapped)
static JS_TO_USB_TABLE: [u8; 256] = {
let mut table = [0u8; 256];
// Letters A-Z (JS 65-90 -> USB 0x04-0x1D)
let mut i = 0u8;
while i < 26 {
table[(65 + i) as usize] = usb::KEY_A + i;
i += 1;
}
// Numbers 1-9, 0 (JS 49-57, 48 -> USB 0x1E-0x27)
table[49] = usb::KEY_1; // 1
table[50] = usb::KEY_2; // 2
table[51] = usb::KEY_3; // 3
table[52] = usb::KEY_4; // 4
table[53] = usb::KEY_5; // 5
table[54] = usb::KEY_6; // 6
table[55] = usb::KEY_7; // 7
table[56] = usb::KEY_8; // 8
table[57] = usb::KEY_9; // 9
table[48] = usb::KEY_0; // 0
// Function keys F1-F12 (JS 112-123 -> USB 0x3A-0x45)
table[112] = usb::KEY_F1;
table[113] = usb::KEY_F2;
table[114] = usb::KEY_F3;
table[115] = usb::KEY_F4;
table[116] = usb::KEY_F5;
table[117] = usb::KEY_F6;
table[118] = usb::KEY_F7;
table[119] = usb::KEY_F8;
table[120] = usb::KEY_F9;
table[121] = usb::KEY_F10;
table[122] = usb::KEY_F11;
table[123] = usb::KEY_F12;
// Control keys
table[13] = usb::KEY_ENTER; // Enter
table[27] = usb::KEY_ESCAPE; // Escape
table[8] = usb::KEY_BACKSPACE; // Backspace
table[9] = usb::KEY_TAB; // Tab
table[32] = usb::KEY_SPACE; // Space
table[20] = usb::KEY_CAPS_LOCK; // Caps Lock
// Punctuation (JS codes vary by browser/layout)
table[189] = usb::KEY_MINUS; // -
table[187] = usb::KEY_EQUAL; // =
table[219] = usb::KEY_LEFT_BRACKET; // [
table[221] = usb::KEY_RIGHT_BRACKET; // ]
table[220] = usb::KEY_BACKSLASH; // \
table[186] = usb::KEY_SEMICOLON; // ;
table[222] = usb::KEY_APOSTROPHE; // '
table[192] = usb::KEY_GRAVE; // `
table[188] = usb::KEY_COMMA; // ,
table[190] = usb::KEY_PERIOD; // .
table[191] = usb::KEY_SLASH; // /
// Navigation keys
table[45] = usb::KEY_INSERT;
table[46] = usb::KEY_DELETE;
table[36] = usb::KEY_HOME;
table[35] = usb::KEY_END;
table[33] = usb::KEY_PAGE_UP;
table[34] = usb::KEY_PAGE_DOWN;
// Arrow keys
table[39] = usb::KEY_RIGHT_ARROW;
table[37] = usb::KEY_LEFT_ARROW;
table[40] = usb::KEY_DOWN_ARROW;
table[38] = usb::KEY_UP_ARROW;
// Numpad
table[144] = usb::KEY_NUM_LOCK;
table[111] = usb::KEY_NUMPAD_DIVIDE;
table[106] = usb::KEY_NUMPAD_MULTIPLY;
table[109] = usb::KEY_NUMPAD_MINUS;
table[107] = usb::KEY_NUMPAD_PLUS;
table[96] = usb::KEY_NUMPAD_0;
table[97] = usb::KEY_NUMPAD_1;
table[98] = usb::KEY_NUMPAD_2;
table[99] = usb::KEY_NUMPAD_3;
table[100] = usb::KEY_NUMPAD_4;
table[101] = usb::KEY_NUMPAD_5;
table[102] = usb::KEY_NUMPAD_6;
table[103] = usb::KEY_NUMPAD_7;
table[104] = usb::KEY_NUMPAD_8;
table[105] = usb::KEY_NUMPAD_9;
table[110] = usb::KEY_NUMPAD_DECIMAL;
// Special keys
table[19] = usb::KEY_PAUSE;
table[145] = usb::KEY_SCROLL_LOCK;
table[93] = usb::KEY_APPLICATION; // Context menu
// Modifier keys
table[17] = usb::KEY_LEFT_CTRL;
table[16] = usb::KEY_LEFT_SHIFT;
table[18] = usb::KEY_LEFT_ALT;
table[91] = usb::KEY_LEFT_META; // Left Windows/Command
table[92] = usb::KEY_RIGHT_META; // Right Windows/Command
table
};
/// Convert JavaScript keyCode to USB HID keyCode
///
/// Uses a fixed-size lookup table for O(1) performance.
/// Returns None if the key code is not mapped.
#[inline]
pub fn js_to_usb(js_code: u8) -> Option<u8> {
let usb_code = JS_TO_USB_TABLE[js_code as usize];
if usb_code != 0 {
Some(usb_code)
} else {
None
}
}
/// Check if a key code is a modifier key
pub fn is_modifier_key(usb_code: u8) -> bool {
(0xE0..=0xE7).contains(&usb_code)
}
/// Get modifier bit for a modifier key
pub fn modifier_bit(usb_code: u8) -> Option<u8> {
match usb_code {
usb::KEY_LEFT_CTRL => Some(0x01),
usb::KEY_LEFT_SHIFT => Some(0x02),
usb::KEY_LEFT_ALT => Some(0x04),
usb::KEY_LEFT_META => Some(0x08),
usb::KEY_RIGHT_CTRL => Some(0x10),
usb::KEY_RIGHT_SHIFT => Some(0x20),
usb::KEY_RIGHT_ALT => Some(0x40),
usb::KEY_RIGHT_META => Some(0x80),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_letter_mapping() {
assert_eq!(js_to_usb(65), Some(usb::KEY_A)); // A
assert_eq!(js_to_usb(90), Some(usb::KEY_Z)); // Z
}
#[test]
fn test_number_mapping() {
assert_eq!(js_to_usb(48), Some(usb::KEY_0));
assert_eq!(js_to_usb(49), Some(usb::KEY_1));
}
#[test]
fn test_modifier_key() {
assert!(is_modifier_key(usb::KEY_LEFT_CTRL));
assert!(is_modifier_key(usb::KEY_RIGHT_SHIFT));
assert!(!is_modifier_key(usb::KEY_A));
}
}

View File

@@ -1,143 +1,169 @@
//! HID (Human Interface Device) control module
//!
//! This module provides keyboard and mouse control for remote KVM:
//! - USB OTG gadget mode (native Linux USB gadget)
//! - CH9329 serial HID controller
//!
//! Architecture:
//! ```text
//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC
//! |
//! [OTG | CH9329]
//! ```
//! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329.
pub mod backend;
pub mod ch9329;
pub mod consumer;
pub mod datachannel;
pub mod keymap;
pub mod monitor;
pub mod keyboard;
pub mod otg;
pub mod types;
pub mod websocket;
pub use backend::{HidBackend, HidBackendType};
pub use monitor::{HidHealthMonitor, HidHealthStatus, HidMonitorConfig};
pub use otg::LedState;
pub use crate::events::LedState;
pub use backend::{HidBackend, HidBackendRuntimeSnapshot, HidBackendType};
pub use keyboard::CanonicalKey;
pub use types::{
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent,
MouseEventType,
};
/// HID backend information
#[derive(Debug, Clone)]
pub struct HidInfo {
/// Backend name
pub name: &'static str,
/// Whether backend is initialized
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HidRuntimeState {
pub available: bool,
pub backend: String,
pub initialized: bool,
/// Whether absolute mouse positioning is supported
pub online: bool,
pub supports_absolute_mouse: bool,
/// Screen resolution for absolute mouse
pub keyboard_leds_enabled: bool,
pub led_state: LedState,
pub screen_resolution: Option<(u32, u32)>,
pub device: Option<String>,
pub error: Option<String>,
pub error_code: Option<String>,
}
impl HidRuntimeState {
fn from_backend_type(backend_type: &HidBackendType) -> Self {
Self {
available: !matches!(backend_type, HidBackendType::None),
backend: backend_type.name_str().to_string(),
initialized: false,
online: false,
supports_absolute_mouse: false,
keyboard_leds_enabled: false,
led_state: LedState::default(),
screen_resolution: None,
device: device_for_backend_type(backend_type),
error: None,
error_code: None,
}
}
fn from_backend(backend_type: &HidBackendType, snapshot: HidBackendRuntimeSnapshot) -> Self {
Self {
available: !matches!(backend_type, HidBackendType::None),
backend: backend_type.name_str().to_string(),
initialized: snapshot.initialized,
online: snapshot.online,
supports_absolute_mouse: snapshot.supports_absolute_mouse,
keyboard_leds_enabled: snapshot.keyboard_leds_enabled,
led_state: snapshot.led_state,
screen_resolution: snapshot.screen_resolution,
device: snapshot
.device
.or_else(|| device_for_backend_type(backend_type)),
error: snapshot.error,
error_code: snapshot.error_code,
}
}
fn with_error(
backend_type: &HidBackendType,
current: &Self,
reason: impl Into<String>,
error_code: impl Into<String>,
) -> Self {
let mut next = current.clone();
next.available = !matches!(backend_type, HidBackendType::None);
next.backend = backend_type.name_str().to_string();
next.initialized = false;
next.online = false;
next.keyboard_leds_enabled = false;
next.led_state = LedState::default();
next.device = device_for_backend_type(backend_type);
next.error = Some(reason.into());
next.error_code = Some(error_code.into());
next
}
}
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::error::{AppError, Result};
use crate::events::EventBus;
use crate::otg::OtgService;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
const HID_EVENT_QUEUE_CAPACITY: usize = 64;
const HID_EVENT_SEND_TIMEOUT_MS: u64 = 30;
const HID_HEALTH_CHECK_INTERVAL_MS: u64 = 1000;
#[derive(Debug)]
enum HidEvent {
enum QueuedHidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
Consumer(ConsumerEvent),
Reset,
}
/// HID controller managing keyboard and mouse input
pub struct HidController {
/// OTG Service reference (only used when backend is OTG)
otg_service: Option<Arc<OtgService>>,
/// Active backend
backend: Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
/// Backend type (mutable for reload)
backend_type: Arc<RwLock<HidBackendType>>,
/// Event bus for broadcasting state changes (optional)
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
/// Health monitor for error tracking and recovery
monitor: Arc<HidHealthMonitor>,
/// HID event queue sender (non-blocking)
hid_tx: mpsc::Sender<HidEvent>,
/// HID event queue receiver (moved into worker on first start)
hid_rx: Mutex<Option<mpsc::Receiver<HidEvent>>>,
/// Coalesced mouse move (latest)
events: Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
runtime_state: Arc<RwLock<HidRuntimeState>>,
hid_tx: mpsc::Sender<QueuedHidEvent>,
hid_rx: Mutex<Option<mpsc::Receiver<QueuedHidEvent>>>,
pending_move: Arc<parking_lot::Mutex<Option<MouseEvent>>>,
/// Pending move flag (fast path)
pending_move_flag: Arc<AtomicBool>,
/// Worker task handle
hid_worker: Mutex<Option<JoinHandle<()>>>,
/// Health check task handle
hid_health_checker: Mutex<Option<JoinHandle<()>>>,
/// Backend availability fast flag
backend_available: AtomicBool,
runtime_worker: Mutex<Option<JoinHandle<()>>>,
backend_available: Arc<AtomicBool>,
}
impl HidController {
/// Create a new HID controller with specified backend
///
/// For OTG backend, otg_service should be provided to support hot-reload
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
Self {
otg_service,
backend: Arc::new(RwLock::new(None)),
backend_type: Arc::new(RwLock::new(backend_type)),
events: tokio::sync::RwLock::new(None),
monitor: Arc::new(HidHealthMonitor::with_defaults()),
backend_type: Arc::new(RwLock::new(backend_type.clone())),
events: Arc::new(tokio::sync::RwLock::new(None)),
runtime_state: Arc::new(RwLock::new(HidRuntimeState::from_backend_type(
&backend_type,
))),
hid_tx,
hid_rx: Mutex::new(Some(hid_rx)),
pending_move: Arc::new(parking_lot::Mutex::new(None)),
pending_move_flag: Arc::new(AtomicBool::new(false)),
hid_worker: Mutex::new(None),
hid_health_checker: Mutex::new(None),
backend_available: AtomicBool::new(false),
runtime_worker: Mutex::new(None),
backend_available: Arc::new(AtomicBool::new(false)),
}
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<crate::events::EventBus>) {
*self.events.write().await = Some(events.clone());
// Also set event bus on the monitor for health notifications
self.monitor.set_event_bus(events).await;
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Initialize the HID backend
pub async fn init(&self) -> Result<()> {
let backend_type = self.backend_type.read().await.clone();
let backend: Arc<dyn HidBackend> = match backend_type {
HidBackendType::Otg => {
// Request HID functions from OtgService
let otg_service = self
.otg_service
.as_ref()
.ok_or_else(|| AppError::Internal("OtgService not available".into()))?;
info!("Requesting HID functions from OtgService");
let handles = otg_service.enable_hid().await?;
let handles = otg_service.hid_device_paths().await.ok_or_else(|| {
AppError::Config("OTG HID paths are not available".to_string())
})?;
// Create OtgBackend from handles (no longer manages gadget itself)
info!("Creating OTG HID backend from device paths");
Arc::new(otg::OtgBackend::from_handles(handles)?)
}
@@ -157,51 +183,65 @@ impl HidController {
}
};
backend.init().await?;
*self.backend.write().await = Some(backend);
self.backend_available.store(true, Ordering::Release);
if let Err(e) = backend.init().await {
self.backend_available.store(false, Ordering::Release);
let error_state = {
let backend_type = self.backend_type.read().await.clone();
let current = self.runtime_state.read().await.clone();
HidRuntimeState::with_error(
&backend_type,
&current,
format!("Failed to initialize HID backend: {}", e),
"init_failed",
)
};
self.apply_runtime_state(error_state).await;
return Err(e);
}
*self.backend.write().await = Some(backend);
self.sync_runtime_state_from_backend().await;
// Start HID event worker (once)
self.start_event_worker().await;
self.start_health_checker().await;
self.restart_runtime_worker().await;
info!("HID backend initialized: {:?}", backend_type);
Ok(())
}
/// Shutdown the HID backend and release resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down HID controller");
self.stop_health_checker().await;
self.stop_runtime_worker().await;
// Close the backend
*self.backend.write().await = None;
self.backend_available.store(false, Ordering::Release);
// If OTG backend, notify OtgService to disable HID
let backend_type = self.backend_type.read().await.clone();
if matches!(backend_type, HidBackendType::Otg) {
if let Some(ref otg_service) = self.otg_service {
info!("Disabling HID functions in OtgService");
otg_service.disable_hid().await?;
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down HID backend: {}", e);
}
}
self.backend_available.store(false, Ordering::Release);
let backend_type = self.backend_type.read().await.clone();
let mut shutdown_state = HidRuntimeState::from_backend_type(&backend_type);
if matches!(backend_type, HidBackendType::None) {
shutdown_state.available = false;
} else {
shutdown_state.error = Some("HID backend stopped".to_string());
shutdown_state.error_code = Some("shutdown".to_string());
}
self.apply_runtime_state(shutdown_state).await;
info!("HID controller shutdown complete");
Ok(())
}
/// Send keyboard event
pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
"HID backend not available".to_string(),
));
}
self.enqueue_event(HidEvent::Keyboard(event)).await
self.enqueue_event(QueuedHidEvent::Keyboard(event)).await
}
/// Send mouse event
pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
@@ -213,112 +253,55 @@ impl HidController {
event.event_type,
MouseEventType::Move | MouseEventType::MoveAbs
) {
// Best-effort: drop/merge move events if queue is full
self.enqueue_mouse_move(event)
} else {
self.enqueue_event(HidEvent::Mouse(event)).await
self.enqueue_event(QueuedHidEvent::Mouse(event)).await
}
}
/// Send consumer control event (multimedia keys)
pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
"HID backend not available".to_string(),
));
}
self.enqueue_event(HidEvent::Consumer(event)).await
self.enqueue_event(QueuedHidEvent::Consumer(event)).await
}
/// Reset all keys (release all pressed keys)
pub async fn reset(&self) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Ok(());
}
// Reset is important but best-effort; enqueue to avoid blocking
self.enqueue_event(HidEvent::Reset).await
self.enqueue_event(QueuedHidEvent::Reset).await
}
/// Check if backend is available
pub async fn is_available(&self) -> bool {
self.backend.read().await.is_some()
self.backend_available.load(Ordering::Acquire)
}
/// Get backend type
pub async fn backend_type(&self) -> HidBackendType {
self.backend_type.read().await.clone()
}
/// Get backend info
pub async fn info(&self) -> Option<HidInfo> {
let backend = self.backend.read().await;
backend.as_ref().map(|b| HidInfo {
name: b.name(),
initialized: true,
supports_absolute_mouse: b.supports_absolute_mouse(),
screen_resolution: b.screen_resolution(),
})
pub async fn snapshot(&self) -> HidRuntimeState {
self.runtime_state.read().await.clone()
}
/// Get current state as SystemEvent
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
let backend = self.backend.read().await;
let backend_type = self.backend_type().await;
let (backend_name, initialized) = match backend.as_ref() {
Some(b) => (b.name(), true),
None => (backend_type.name_str(), false),
};
// Include error information from monitor
let (error, error_code) = match self.monitor.status().await {
HidHealthStatus::Error {
reason, error_code, ..
} => (Some(reason), Some(error_code)),
_ => (None, None),
};
crate::events::SystemEvent::HidStateChanged {
backend: backend_name.to_string(),
initialized,
error,
error_code,
}
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<HidHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> HidHealthStatus {
self.monitor.status().await
}
/// Check if the HID backend is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
/// Reload the HID backend with new type
pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> {
info!("Reloading HID backend: {:?}", new_backend_type);
self.backend_available.store(false, Ordering::Release);
self.stop_health_checker().await;
self.stop_runtime_worker().await;
// Shutdown existing backend first
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down old HID backend: {}", e);
}
}
// Create and initialize new backend
let new_backend: Option<Arc<dyn HidBackend>> = match new_backend_type {
HidBackendType::Otg => {
info!("Initializing OTG HID backend");
// Get OtgService reference
let otg_service = match self.otg_service.as_ref() {
Some(svc) => svc,
None => {
@@ -329,43 +312,28 @@ impl HidController {
}
};
// Request HID functions from OtgService
match otg_service.enable_hid().await {
Ok(handles) => {
// Create OtgBackend from handles
match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let backend = Arc::new(backend);
match backend.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(backend)
}
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
// Cleanup: disable HID in OtgService
if let Err(e2) = otg_service.disable_hid().await {
warn!(
"Failed to cleanup HID after init failure: {}",
e2
);
}
None
}
match otg_service.hid_device_paths().await {
Some(handles) => match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let backend = Arc::new(backend);
match backend.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(backend)
}
}
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
// Cleanup: disable HID in OtgService
if let Err(e2) = otg_service.disable_hid().await {
warn!("Failed to cleanup HID after creation failure: {}", e2);
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
None
}
None
}
}
}
Err(e) => {
warn!("Failed to enable HID in OtgService: {}", e);
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
None
}
},
None => {
warn!("OTG HID paths are not available");
None
}
}
@@ -403,44 +371,37 @@ impl HidController {
*self.backend.write().await = new_backend;
if matches!(new_backend_type, HidBackendType::None) {
*self.backend_type.write().await = HidBackendType::None;
self.apply_runtime_state(HidRuntimeState::from_backend_type(&HidBackendType::None))
.await;
return Ok(());
}
if self.backend.read().await.is_some() {
info!("HID backend reloaded successfully: {:?}", new_backend_type);
self.backend_available.store(true, Ordering::Release);
self.start_event_worker().await;
self.start_health_checker().await;
// Update backend_type on success
*self.backend_type.write().await = new_backend_type.clone();
// Reset monitor state on successful reload
self.monitor.reset().await;
// Publish HID state changed event
let backend_name = new_backend_type.name_str().to_string();
self.publish_event(crate::events::SystemEvent::HidStateChanged {
backend: backend_name,
initialized: true,
error: None,
error_code: None,
})
.await;
self.sync_runtime_state_from_backend().await;
self.restart_runtime_worker().await;
Ok(())
} else {
warn!("HID backend reload resulted in no active backend");
self.backend_available.store(false, Ordering::Release);
// Update backend_type even on failure (to reflect the attempted change)
*self.backend_type.write().await = new_backend_type.clone();
// Publish event with initialized=false
self.publish_event(crate::events::SystemEvent::HidStateChanged {
backend: new_backend_type.name_str().to_string(),
initialized: false,
error: Some("Failed to initialize HID backend".to_string()),
error_code: Some("init_failed".to_string()),
})
.await;
let current = self.runtime_state.read().await.clone();
let error_state = HidRuntimeState::with_error(
&new_backend_type,
&current,
"Failed to initialize HID backend",
"init_failed",
);
self.apply_runtime_state(error_state).await;
Err(AppError::Internal(
"Failed to reload HID backend".to_string(),
@@ -448,11 +409,20 @@ impl HidController {
}
}
/// Publish event to event bus if available
async fn publish_event(&self, event: crate::events::SystemEvent) {
if let Some(events) = self.events.read().await.as_ref() {
events.publish(event);
}
async fn apply_runtime_state(&self, next: HidRuntimeState) {
apply_runtime_state(&self.runtime_state, &self.events, next).await;
}
async fn sync_runtime_state_from_backend(&self) {
let backend_opt = self.backend.read().await.clone();
apply_backend_runtime_state(
&self.backend_type,
&self.runtime_state,
&self.events,
self.backend_available.as_ref(),
backend_opt.as_deref(),
)
.await;
}
async fn start_event_worker(&self) {
@@ -468,8 +438,6 @@ impl HidController {
};
let backend = self.backend.clone();
let monitor = self.monitor.clone();
let backend_type = self.backend_type.clone();
let pending_move = self.pending_move.clone();
let pending_move_flag = self.pending_move_flag.clone();
@@ -481,19 +449,12 @@ impl HidController {
None => break,
};
process_hid_event(event, &backend, &monitor, &backend_type).await;
process_hid_event(event, &backend).await;
// After each event, flush latest move if pending
if pending_move_flag.swap(false, Ordering::AcqRel) {
let move_event = { pending_move.lock().take() };
if let Some(move_event) = move_event {
process_hid_event(
HidEvent::Mouse(move_event),
&backend,
&monitor,
&backend_type,
)
.await;
process_hid_event(QueuedHidEvent::Mouse(move_event), &backend).await;
}
}
}
@@ -502,89 +463,48 @@ impl HidController {
*worker_guard = Some(handle);
}
async fn start_health_checker(&self) {
let mut checker_guard = self.hid_health_checker.lock().await;
if checker_guard.is_some() {
return;
}
async fn restart_runtime_worker(&self) {
self.stop_runtime_worker().await;
let backend = self.backend.clone();
let backend_opt = self.backend.read().await.clone();
let Some(backend) = backend_opt else {
return;
};
let mut runtime_rx = backend.subscribe_runtime();
let runtime_state = self.runtime_state.clone();
let events = self.events.clone();
let backend_available = self.backend_available.clone();
let backend_type = self.backend_type.clone();
let monitor = self.monitor.clone();
let handle = tokio::spawn(async move {
let mut ticker =
tokio::time::interval(Duration::from_millis(HID_HEALTH_CHECK_INTERVAL_MS));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
ticker.tick().await;
let backend_opt = backend.read().await.clone();
let Some(active_backend) = backend_opt else {
continue;
};
let backend_name = backend_type.read().await.name_str().to_string();
let result =
tokio::task::spawn_blocking(move || active_backend.health_check()).await;
match result {
Ok(Ok(())) => {
if monitor.is_error().await {
monitor.report_recovered(&backend_name).await;
}
}
Ok(Err(AppError::HidError {
backend,
reason,
error_code,
})) => {
monitor
.report_error(&backend, None, &reason, &error_code)
.await;
}
Ok(Err(e)) => {
monitor
.report_error(
&backend_name,
None,
&format!("HID health check failed: {}", e),
"health_check_failed",
)
.await;
}
Err(e) => {
monitor
.report_error(
&backend_name,
None,
&format!("HID health check task failed: {}", e),
"health_check_join_failed",
)
.await;
}
if runtime_rx.changed().await.is_err() {
break;
}
apply_backend_runtime_state(
&backend_type,
&runtime_state,
&events,
backend_available.as_ref(),
Some(backend.as_ref()),
)
.await;
}
});
*checker_guard = Some(handle);
*self.runtime_worker.lock().await = Some(handle);
}
async fn stop_health_checker(&self) {
let handle_opt = {
let mut checker_guard = self.hid_health_checker.lock().await;
checker_guard.take()
};
if let Some(handle) = handle_opt {
async fn stop_runtime_worker(&self) {
if let Some(handle) = self.runtime_worker.lock().await.take() {
handle.abort();
let _ = handle.await;
}
}
fn enqueue_mouse_move(&self, event: MouseEvent) -> Result<()> {
match self.hid_tx.try_send(HidEvent::Mouse(event.clone())) {
match self.hid_tx.try_send(QueuedHidEvent::Mouse(event.clone())) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(_)) => {
*self.pending_move.lock() = Some(event);
@@ -597,11 +517,10 @@ impl HidController {
}
}
async fn enqueue_event(&self, event: HidEvent) -> Result<()> {
async fn enqueue_event(&self, event: QueuedHidEvent) -> Result<()> {
match self.hid_tx.try_send(event) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(ev)) => {
// For non-move events, wait briefly to avoid dropping critical input
let tx = self.hid_tx.clone();
let send_result = tokio::time::timeout(
Duration::from_millis(HID_EVENT_SEND_TIMEOUT_MS),
@@ -622,11 +541,25 @@ impl HidController {
}
}
async fn process_hid_event(
event: HidEvent,
backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
monitor: &Arc<HidHealthMonitor>,
async fn apply_backend_runtime_state(
backend_type: &Arc<RwLock<HidBackendType>>,
runtime_state: &Arc<RwLock<HidRuntimeState>>,
events: &Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
backend_available: &AtomicBool,
backend: Option<&dyn HidBackend>,
) {
let backend_kind = backend_type.read().await.clone();
let next = match backend {
Some(backend) => HidRuntimeState::from_backend(&backend_kind, backend.runtime_snapshot()),
None => HidRuntimeState::from_backend_type(&backend_kind),
};
backend_available.store(next.initialized, Ordering::Release);
apply_runtime_state(runtime_state, events, next).await;
}
async fn process_hid_event(
event: QueuedHidEvent,
backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
) {
let backend_opt = backend.read().await.clone();
let backend = match backend_opt {
@@ -634,13 +567,14 @@ async fn process_hid_event(
None => return,
};
let backend_for_send = backend.clone();
let result = tokio::task::spawn_blocking(move || {
futures::executor::block_on(async move {
match event {
HidEvent::Keyboard(ev) => backend.send_keyboard(ev).await,
HidEvent::Mouse(ev) => backend.send_mouse(ev).await,
HidEvent::Consumer(ev) => backend.send_consumer(ev).await,
HidEvent::Reset => backend.reset().await,
QueuedHidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
QueuedHidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
QueuedHidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
QueuedHidEvent::Reset => backend_for_send.reset().await,
}
})
})
@@ -652,31 +586,40 @@ async fn process_hid_event(
};
match result {
Ok(_) => {
if monitor.is_error().await {
let backend_type = backend_type.read().await;
monitor.report_recovered(backend_type.name_str()).await;
}
}
Ok(_) => {}
Err(e) => {
if let AppError::HidError {
ref backend,
ref reason,
ref error_code,
} = e
{
if error_code != "eagain_retry" {
monitor
.report_error(backend, None, reason, error_code)
.await;
}
}
warn!("HID event processing failed: {}", e);
}
}
}
impl Default for HidController {
fn default() -> Self {
Self::new(HidBackendType::None, None)
fn device_for_backend_type(backend_type: &HidBackendType) -> Option<String> {
match backend_type {
HidBackendType::Ch9329 { port, .. } => Some(port.clone()),
_ => None,
}
}
async fn apply_runtime_state(
runtime_state: &Arc<RwLock<HidRuntimeState>>,
events: &Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
next: HidRuntimeState,
) {
let changed = {
let mut guard = runtime_state.write().await;
if *guard == next {
false
} else {
*guard = next.clone();
true
}
};
if !changed {
return;
}
if let Some(events) = events.read().await.as_ref() {
events.mark_device_info_dirty();
}
}

View File

@@ -1,416 +0,0 @@
//! HID device health monitoring
//!
//! This module provides health monitoring for HID devices, including:
//! - Device connectivity checks
//! - Automatic reconnection on failure
//! - Error tracking and notification
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, warn};
use crate::events::{EventBus, SystemEvent};
use crate::utils::LogThrottler;
/// HID health status
#[derive(Debug, Clone, PartialEq, Default)]
pub enum HidHealthStatus {
/// Device is healthy and operational
#[default]
Healthy,
/// Device has an error, attempting recovery
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
/// Number of recovery attempts made
retry_count: u32,
},
/// Device is disconnected
Disconnected,
}
/// HID health monitor configuration
#[derive(Debug, Clone)]
pub struct HidMonitorConfig {
/// Health check interval in milliseconds
pub check_interval_ms: u64,
/// Retry interval when device is lost (milliseconds)
pub retry_interval_ms: u64,
/// Maximum retry attempts before giving up (0 = infinite)
pub max_retries: u32,
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
/// Recovery cooldown in milliseconds (suppress logs after recovery)
pub recovery_cooldown_ms: u64,
}
impl Default for HidMonitorConfig {
fn default() -> Self {
Self {
check_interval_ms: 1000,
retry_interval_ms: 1000,
max_retries: 0, // infinite retry
log_throttle_secs: 5,
recovery_cooldown_ms: 1000, // 1 second cooldown after recovery
}
}
}
/// HID health monitor
///
/// Monitors HID device health and manages error recovery.
/// Publishes WebSocket events when device status changes.
pub struct HidHealthMonitor {
/// Current health status
status: RwLock<HidHealthStatus>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
config: HidMonitorConfig,
/// Whether monitoring is active (reserved for future use)
#[allow(dead_code)]
running: AtomicBool,
/// Current retry count
retry_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
/// Last recovery timestamp (milliseconds since start, for cooldown)
last_recovery_ms: AtomicU64,
/// Start instant for timing
start_instant: Instant,
}
impl HidHealthMonitor {
/// Create a new HID health monitor with the specified configuration
pub fn new(config: HidMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
Self {
status: RwLock::new(HidHealthStatus::Healthy),
events: RwLock::new(None),
throttler: LogThrottler::with_secs(throttle_secs),
config,
running: AtomicBool::new(false),
retry_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
last_recovery_ms: AtomicU64::new(0),
start_instant: Instant::now(),
}
}
/// Create a new HID health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(HidMonitorConfig::default())
}
/// Set the event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Report an error from HID operations
///
/// This method is called when an HID operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling and cooldown respect)
/// 3. Publishes a WebSocket event if the error is new or changed
///
/// # Arguments
///
/// * `backend` - The HID backend type ("otg" or "ch9329")
/// * `device` - The device path (if known)
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(
&self,
backend: &str,
device: Option<&str>,
reason: &str,
error_code: &str,
) {
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if we're in cooldown period after recent recovery
let current_ms = self.start_instant.elapsed().as_millis() as u64;
let last_recovery = self.last_recovery_ms.load(Ordering::Relaxed);
let in_cooldown =
last_recovery > 0 && current_ms < last_recovery + self.config.recovery_cooldown_ms;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (skip if in cooldown period unless error type changed)
let throttle_key = format!("hid_{}_{}", backend, error_code);
if !in_cooldown && (error_changed || self.throttler.should_log(&throttle_key)) {
warn!(
"HID {} error: {} (code: {}, attempt: {})",
backend, reason, error_code, count
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = HidHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
retry_count: count,
};
// Publish event (only if error changed or first occurrence, and not in cooldown)
if !in_cooldown && (error_changed || count == 1) {
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::HidDeviceLost {
backend: backend.to_string(),
device: device.map(|s| s.to_string()),
reason: reason.to_string(),
error_code: error_code.to_string(),
});
}
}
}
/// Report that a reconnection attempt is starting
///
/// Publishes a reconnecting event to notify clients.
///
/// # Arguments
///
/// * `backend` - The HID backend type
pub async fn report_reconnecting(&self, backend: &str) {
let attempt = self.retry_count.load(Ordering::Relaxed);
// Only publish every 5 attempts to avoid event spam
if attempt == 1 || attempt.is_multiple_of(5) {
debug!("HID {} reconnecting, attempt {}", backend, attempt);
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::HidReconnecting {
backend: backend.to_string(),
attempt,
});
}
}
}
/// Report that the device has recovered
///
/// This method is called when the HID device successfully reconnects.
/// It resets the error state and publishes a recovery event.
///
/// # Arguments
///
/// * `backend` - The HID backend type
pub async fn report_recovered(&self, backend: &str) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != HidHealthStatus::Healthy {
let retry_count = self.retry_count.load(Ordering::Relaxed);
// Set cooldown timestamp
let current_ms = self.start_instant.elapsed().as_millis() as u64;
self.last_recovery_ms.store(current_ms, Ordering::Relaxed);
// Only log and publish events if there were multiple retries
// (avoid log spam for transient single-retry recoveries)
if retry_count > 1 {
debug!("HID {} recovered after {} retries", backend, retry_count);
// Publish recovery event
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::HidRecovered {
backend: backend.to_string(),
});
// Also publish state changed to indicate healthy state
events.publish(SystemEvent::HidStateChanged {
backend: backend.to_string(),
initialized: true,
error: None,
error_code: None,
});
}
}
// Reset state (always reset, even for single-retry recoveries)
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = HidHealthStatus::Healthy;
}
}
/// Get the current health status
pub async fn status(&self) -> HidHealthStatus {
self.status.read().await.clone()
}
/// Get the current retry count
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, HidHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, HidHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = HidHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the configuration
pub fn config(&self) -> &HidMonitorConfig {
&self.config
}
/// Check if we should continue retrying
///
/// Returns `false` if max_retries is set and we've exceeded it.
pub fn should_retry(&self) -> bool {
if self.config.max_retries == 0 {
return true; // Infinite retry
}
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
}
/// Get the retry interval
pub fn retry_interval(&self) -> Duration {
Duration::from_millis(self.config.retry_interval_ms)
}
}
impl Default for HidHealthMonitor {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_status() {
let monitor = HidHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = HidHealthMonitor::with_defaults();
monitor
.report_error("otg", Some("/dev/hidg0"), "Device not found", "enoent")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.retry_count(), 1);
if let HidHealthStatus::Error {
reason,
error_code,
retry_count,
} = monitor.status().await
{
assert_eq!(reason, "Device not found");
assert_eq!(error_code, "enoent");
assert_eq!(retry_count, 1);
} else {
panic!("Expected Error status");
}
}
#[tokio::test]
async fn test_report_recovered() {
let monitor = HidHealthMonitor::with_defaults();
// First report an error
monitor
.report_error("ch9329", None, "Port not found", "port_not_found")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered("ch9329").await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_retry_count_increments() {
let monitor = HidHealthMonitor::with_defaults();
for i in 1..=5 {
monitor.report_error("otg", None, "Error", "io_error").await;
assert_eq!(monitor.retry_count(), i);
}
}
#[tokio::test]
async fn test_should_retry_infinite() {
let monitor = HidHealthMonitor::new(HidMonitorConfig {
max_retries: 0, // infinite
..Default::default()
});
for _ in 0..100 {
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(monitor.should_retry());
}
}
#[tokio::test]
async fn test_should_retry_limited() {
let monitor = HidHealthMonitor::new(HidMonitorConfig {
max_retries: 3,
..Default::default()
});
assert!(monitor.should_retry());
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(monitor.should_retry()); // 1 < 3
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(monitor.should_retry()); // 2 < 3
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(!monitor.should_retry()); // 3 >= 3
}
#[tokio::test]
async fn test_reset() {
let monitor = HidHealthMonitor::with_defaults();
monitor.report_error("otg", None, "Error", "io_error").await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.retry_count(), 0);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,48 +1,37 @@
//! HID event types for keyboard and mouse
//! Keyboard/mouse/consumer structs (`KeyboardEvent`, `MouseEvent`, …).
use serde::{Deserialize, Serialize};
/// Keyboard event type
use super::keyboard::CanonicalKey;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KeyEventType {
/// Key pressed down
Down,
/// Key released
Up,
}
/// Keyboard modifier flags
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeyboardModifiers {
/// Left Control
#[serde(default)]
pub left_ctrl: bool,
/// Left Shift
#[serde(default)]
pub left_shift: bool,
/// Left Alt
#[serde(default)]
pub left_alt: bool,
/// Left Meta (Windows/Super key)
#[serde(default)]
pub left_meta: bool,
/// Right Control
#[serde(default)]
pub right_ctrl: bool,
/// Right Shift
#[serde(default)]
pub right_shift: bool,
/// Right Alt (AltGr)
#[serde(default)]
pub right_alt: bool,
/// Right Meta
#[serde(default)]
pub right_meta: bool,
}
impl KeyboardModifiers {
/// Convert to USB HID modifier byte
pub fn to_hid_byte(&self) -> u8 {
let mut byte = 0u8;
if self.left_ctrl {
@@ -72,7 +61,6 @@ impl KeyboardModifiers {
byte
}
/// Create from USB HID modifier byte
pub fn from_hid_byte(byte: u8) -> Self {
Self {
left_ctrl: byte & 0x01 != 0,
@@ -86,7 +74,6 @@ impl KeyboardModifiers {
}
}
/// Check if any modifier is active
pub fn any(&self) -> bool {
self.left_ctrl
|| self.left_shift
@@ -99,45 +86,33 @@ impl KeyboardModifiers {
}
}
/// Keyboard event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyboardEvent {
/// Event type (down/up)
#[serde(rename = "type")]
pub event_type: KeyEventType,
/// Key code (USB HID usage code or JavaScript key code)
pub key: u8,
/// Modifier keys state
pub key: CanonicalKey,
#[serde(default)]
pub modifiers: KeyboardModifiers,
/// If true, key is already USB HID code (skip js_to_usb conversion)
#[serde(default)]
pub is_usb_hid: bool,
}
impl KeyboardEvent {
/// Create a key down event (JS keycode, needs conversion)
pub fn key_down(key: u8, modifiers: KeyboardModifiers) -> Self {
pub fn key_down(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Down,
key,
modifiers,
is_usb_hid: false,
}
}
/// Create a key up event (JS keycode, needs conversion)
pub fn key_up(key: u8, modifiers: KeyboardModifiers) -> Self {
pub fn key_up(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Up,
key,
modifiers,
is_usb_hid: false,
}
}
}
/// Mouse button
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseButton {
@@ -149,7 +124,6 @@ pub enum MouseButton {
}
impl MouseButton {
/// Convert to USB HID button bit
pub fn to_hid_bit(&self) -> u8 {
match self {
MouseButton::Left => 0x01,
@@ -161,44 +135,31 @@ impl MouseButton {
}
}
/// Mouse event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseEventType {
/// Mouse moved (relative movement)
Move,
/// Mouse moved (absolute position)
MoveAbs,
/// Button pressed
Down,
/// Button released
Up,
/// Mouse wheel scroll
Scroll,
}
/// Mouse event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MouseEvent {
/// Event type
#[serde(rename = "type")]
pub event_type: MouseEventType,
/// X coordinate or delta
#[serde(default)]
pub x: i32,
/// Y coordinate or delta
#[serde(default)]
pub y: i32,
/// Button (for down/up events)
#[serde(default)]
pub button: Option<MouseButton>,
/// Scroll delta (for scroll events)
#[serde(default)]
pub scroll: i8,
}
impl MouseEvent {
/// Create a relative move event
pub fn move_rel(dx: i32, dy: i32) -> Self {
Self {
event_type: MouseEventType::Move,
@@ -209,7 +170,6 @@ impl MouseEvent {
}
}
/// Create an absolute move event
pub fn move_abs(x: i32, y: i32) -> Self {
Self {
event_type: MouseEventType::MoveAbs,
@@ -220,7 +180,6 @@ impl MouseEvent {
}
}
/// Create a button down event
pub fn button_down(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Down,
@@ -231,7 +190,6 @@ impl MouseEvent {
}
}
/// Create a button up event
pub fn button_up(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Up,
@@ -242,7 +200,6 @@ impl MouseEvent {
}
}
/// Create a scroll event
pub fn scroll(delta: i8) -> Self {
Self {
event_type: MouseEventType::Scroll,
@@ -254,35 +211,19 @@ impl MouseEvent {
}
}
/// Combined HID event (keyboard or mouse)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "device", rename_all = "lowercase")]
pub enum HidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
Consumer(ConsumerEvent),
}
/// Consumer control event (multimedia keys)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsumerEvent {
/// Consumer control usage code (e.g., 0x00CD for Play/Pause)
pub usage: u16,
}
/// USB HID keyboard report (8 bytes)
#[derive(Debug, Clone, Default)]
pub struct KeyboardReport {
/// Modifier byte
pub modifiers: u8,
/// Reserved byte
pub reserved: u8,
/// Key codes (up to 6 simultaneous keys)
pub keys: [u8; 6],
}
impl KeyboardReport {
/// Convert to bytes for USB HID
pub fn to_bytes(&self) -> [u8; 8] {
[
self.modifiers,
@@ -296,7 +237,6 @@ impl KeyboardReport {
]
}
/// Add a key to the report
pub fn add_key(&mut self, key: u8) -> bool {
for slot in &mut self.keys {
if *slot == 0 {
@@ -307,56 +247,21 @@ impl KeyboardReport {
false // All slots full
}
/// Remove a key from the report
pub fn remove_key(&mut self, key: u8) {
for slot in &mut self.keys {
if *slot == key {
*slot = 0;
}
}
// Compact the array
self.keys.sort_by(|a, b| b.cmp(a));
}
/// Clear all keys
pub fn clear(&mut self) {
self.modifiers = 0;
self.keys = [0; 6];
}
}
/// USB HID mouse report
#[derive(Debug, Clone, Default)]
pub struct MouseReport {
/// Button state
pub buttons: u8,
/// X movement (-127 to 127)
pub x: i8,
/// Y movement (-127 to 127)
pub y: i8,
/// Wheel movement (-127 to 127)
pub wheel: i8,
}
impl MouseReport {
/// Convert to bytes for USB HID (relative mouse)
pub fn to_bytes_relative(&self) -> [u8; 4] {
[self.buttons, self.x as u8, self.y as u8, self.wheel as u8]
}
/// Convert to bytes for USB HID (absolute mouse)
pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] {
[
self.buttons,
(x & 0xFF) as u8,
(x >> 8) as u8,
(y & 0xFF) as u8,
(y >> 8) as u8,
self.wheel as u8,
]
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,13 +1,4 @@
//! WebSocket HID channel for HTTP/MJPEG mode
//!
//! This provides an alternative to WebRTC DataChannel for HID input
//! when using MJPEG streaming mode.
//!
//! Uses binary protocol only (same format as DataChannel):
//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes)
//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes)
//!
//! See datachannel.rs for detailed protocol specification.
//! MJPEG mode: HID over WebSocket — same binary framing as [`super::datachannel`] (`0x01`/`0x02`/`0x03`; layout detailed there).
use axum::{
extract::{
@@ -24,25 +15,20 @@ use super::datachannel::{parse_hid_message, HidChannelEvent};
use crate::state::AppState;
use crate::utils::LogThrottler;
/// Binary response codes
const RESP_OK: u8 = 0x00;
const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01;
const RESP_ERR_INVALID_MESSAGE: u8 = 0x02;
/// WebSocket HID upgrade handler
pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
ws.on_upgrade(move |socket| handle_hid_socket(socket, state))
}
/// Handle HID WebSocket connection
async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// Log throttler for error messages (5 second interval)
let log_throttler = LogThrottler::with_secs(5);
info!("WebSocket HID connection established (binary protocol)");
// Check if HID controller is available and send initial status
let hid_available = state.hid.is_available().await;
let initial_response = if hid_available {
vec![RESP_OK]
@@ -59,17 +45,14 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
return;
}
// Process incoming messages (binary only)
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Check HID availability before processing each message
let hid_available = state.hid.is_available().await;
if !hid_available {
if log_throttler.should_log("hid_unavailable") {
warn!("HID controller not available, ignoring message");
}
// Send error response (optional, for client awareness)
let _ = sender
.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into()))
.await;
@@ -77,15 +60,12 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
}
if let Err(e) = handle_binary_message(&data, &state).await {
// Log with throttling to avoid spam
if log_throttler.should_log("binary_hid_error") {
warn!("Binary HID message error: {}", e);
}
// Don't send error response for every failed message to reduce overhead
}
}
Ok(Message::Text(text)) => {
// Text messages are no longer supported
if log_throttler.should_log("text_message_rejected") {
debug!(
"Received text message (not supported): {} bytes",
@@ -111,7 +91,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
}
}
// Reset HID state to release any held keys/buttons
if let Err(e) = state.hid.reset().await {
warn!("Failed to reset HID on WebSocket disconnect: {}", e);
}
@@ -119,7 +98,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
info!("WebSocket HID connection ended");
}
/// Handle binary HID message (same format as DataChannel)
async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> {
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
@@ -160,12 +138,10 @@ mod tests {
assert_eq!(RESP_OK, 0x00);
assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01);
assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02);
// assert_eq!(RESP_ERR_SEND_FAILED, 0x03); // TODO: fix test
}
#[test]
fn test_keyboard_message_format() {
// Keyboard message: [0x01, event_type, key, modifiers]
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl
let event = parse_hid_message(&data);
assert!(event.is_some());
@@ -173,7 +149,6 @@ mod tests {
#[test]
fn test_mouse_message_format() {
// Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra]
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
let event = parse_hid_message(&data);
assert!(event.is_some());

View File

@@ -1,30 +1,27 @@
//! One-KVM - Lightweight IP-KVM solution
//!
//! This crate provides the core functionality for One-KVM,
//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust.
//! Core library for One-KVM (IPKVM: capture, HID, OTG, streaming, Web UI glue).
pub mod atx;
pub mod audio;
pub mod auth;
pub mod config;
pub mod db;
pub mod error;
pub mod events;
pub mod extensions;
pub mod hid;
pub mod modules;
pub mod msd;
pub mod otg;
pub mod rtsp;
pub mod rustdesk;
pub mod state;
pub mod stream;
pub mod stream_encoder;
pub mod update;
pub mod utils;
pub mod video;
pub mod web;
pub mod webrtc;
/// Auto-generated secrets module (from secrets.toml at compile time)
pub mod secrets {
include!(concat!(env!("OUT_DIR"), "/secrets_generated.rs"));
}

View File

@@ -1,24 +1,27 @@
use std::collections::HashSet;
use std::future::Future;
use std::io::Write;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use clap::{Parser, ValueEnum};
use clap::{Args, Parser, Subcommand, ValueEnum};
use futures::{stream::FuturesUnordered, StreamExt};
use rustls::crypto::{ring, CryptoProvider};
use tokio::sync::broadcast;
use tokio::sync::{broadcast, mpsc};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use one_kvm::atx::AtxController;
use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality};
use one_kvm::auth::{SessionStore, UserStore};
use one_kvm::config::{self, AppConfig, ConfigStore};
use one_kvm::db::DatabasePool;
use one_kvm::events::EventBus;
use one_kvm::extensions::ExtensionManager;
use one_kvm::hid::{HidBackendType, HidController};
use one_kvm::msd::MsdController;
use one_kvm::otg::{configfs, OtgService};
use one_kvm::otg::OtgService;
use one_kvm::rtsp::RtspService;
use one_kvm::rustdesk::RustDeskService;
use one_kvm::state::AppState;
@@ -32,7 +35,6 @@ use one_kvm::video::{Streamer, VideoStreamManager};
use one_kvm::web;
use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig};
/// Log level for the application
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
enum LogLevel {
Error,
@@ -44,11 +46,14 @@ enum LogLevel {
Trace,
}
/// One-KVM command line arguments
#[derive(Parser, Debug)]
#[command(name = "one-kvm")]
#[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)]
struct CliArgs {
/// User management commands
#[command(subcommand)]
command: Option<CliCommand>,
/// Listen address (overrides database config)
#[arg(short = 'a', long, value_name = "ADDRESS")]
address: Option<String>,
@@ -86,33 +91,50 @@ struct CliArgs {
verbose: u8,
}
#[derive(Subcommand, Debug)]
enum CliCommand {
/// Manage local users
User(UserCommand),
}
#[derive(Args, Debug)]
struct UserCommand {
#[command(subcommand)]
action: UserAction,
}
#[derive(Subcommand, Debug)]
enum UserAction {
/// Set password for the single local user (interactive terminal prompt)
SetPassword,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Parse command line arguments
let args = CliArgs::parse();
// Initialize logging with CLI arguments
init_logging(args.log_level, args.verbose);
// Install default crypto provider (required by rustls 0.23+)
CryptoProvider::install_default(ring::default_provider())
.expect("Failed to install rustls crypto provider");
tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION"));
// Determine data directory (CLI arg takes precedence)
let data_dir = args.data_dir.unwrap_or_else(get_data_dir);
let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir);
tracing::info!("Data directory: {}", data_dir.display());
// Ensure data directory exists
if let Some(command) = args.command {
run_cli_command(command, data_dir).await?;
return Ok(());
}
tokio::fs::create_dir_all(&data_dir).await?;
// Initialize configuration store
let db_path = data_dir.join("one-kvm.db");
let config_store = ConfigStore::new(&db_path).await?;
let db = open_database_pool(&data_dir).await?;
let config_store = ConfigStore::new(db.clone_pool())?;
config_store.load().await?;
let mut config = (*config_store.get()).clone();
// Normalize MSD directory (absolute path under data dir if empty/relative)
let mut msd_dir_updated = false;
if config.msd.msd_dir.trim().is_empty() {
let msd_dir = data_dir.join("msd");
@@ -130,8 +152,6 @@ async fn main() -> anyhow::Result<()> {
if msd_dir_updated {
config_store.set(config.clone()).await?;
}
// Ensure MSD directories exist (msd/images, msd/ventoy)
let msd_dir = PathBuf::from(&config.msd.msd_dir);
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
tracing::warn!("Failed to create MSD images directory: {}", e);
@@ -140,7 +160,6 @@ async fn main() -> anyhow::Result<()> {
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
}
// Apply CLI argument overrides to config (only if explicitly specified)
if let Some(addr) = args.address {
config.web.bind_address = addr.clone();
config.web.bind_addresses = vec![addr];
@@ -174,29 +193,20 @@ async fn main() -> anyhow::Result<()> {
config.web.http_port
};
// Log final configuration
for ip in &bind_ips {
let addr = SocketAddr::new(*ip, bind_port);
tracing::info!("Server will listen on: {}://{}", scheme, addr);
}
// Initialize session store
let session_store = SessionStore::new(
config_store.pool().clone(),
config.auth.session_timeout_secs as i64,
);
let session_store = SessionStore::new(config.auth.session_timeout_secs as i64);
// Initialize user store
let user_store = UserStore::new(config_store.pool().clone());
let user_store = UserStore::new(db.clone_pool());
// Create shutdown channel
let (shutdown_tx, _) = broadcast::channel::<()>(1);
// Create event bus for real-time notifications
let events = Arc::new(EventBus::new());
tracing::info!("Event bus initialized");
// Parse video configuration once (avoid duplication)
let (video_format, video_resolution) = parse_video_config(&config);
tracing::debug!(
"Parsed video config: {} @ {}x{}",
@@ -205,7 +215,6 @@ async fn main() -> anyhow::Result<()> {
video_resolution.height
);
// Create video streamer and initialize with config if device is set
let streamer = Streamer::new();
streamer.set_event_bus(events.clone()).await;
if let Some(ref device_path) = config.video.device {
@@ -233,19 +242,19 @@ async fn main() -> anyhow::Result<()> {
}
}
// Create WebRTC streamer
let webrtc_streamer = {
let webrtc_config = WebRtcStreamerConfig {
resolution: video_resolution,
input_format: video_format,
fps: config.video.fps,
bitrate_preset: config.stream.bitrate_preset,
encoder_backend: config.stream.encoder.to_backend(),
encoder_backend: one_kvm::stream_encoder::encoder_type_to_backend(
config.stream.encoder.clone(),
),
webrtc: {
let mut stun_servers = vec![];
let mut turn_servers = vec![];
// Check if user configured custom servers
let has_custom_stun = config
.stream
.stun_server
@@ -259,25 +268,12 @@ async fn main() -> anyhow::Result<()> {
.map(|s| !s.is_empty())
.unwrap_or(false);
// If no custom servers, use public ICE servers (like RustDesk)
if !has_custom_stun && !has_custom_turn {
use one_kvm::webrtc::config::public_ice;
if public_ice::is_configured() {
if let Some(stun) = public_ice::stun_server() {
stun_servers.push(stun.clone());
tracing::info!("Using public STUN server: {}", stun);
}
for turn in public_ice::turn_servers() {
tracing::info!("Using public TURN server: {:?}", turn.urls);
turn_servers.push(turn);
}
} else {
tracing::info!(
"No public ICE servers configured, using host candidates only"
);
}
let stun = public_ice::stun_server().to_string();
tracing::info!("Using public STUN server: {}", stun);
stun_servers.push(stun);
} else {
// Use custom servers
if let Some(ref stun) = config.stream.stun_server {
if !stun.is_empty() {
stun_servers.push(stun.clone());
@@ -313,41 +309,15 @@ async fn main() -> anyhow::Result<()> {
};
WebRtcStreamer::with_config(webrtc_config)
};
tracing::info!("WebRTC streamer created (supports H264, extensible to VP8/VP9/H265)");
tracing::info!("WebRTC streamer created");
// Create OTG Service (single instance for centralized USB gadget management)
let otg_service = Arc::new(OtgService::new());
tracing::info!("OTG Service created");
// Pre-enable OTG functions to avoid gadget recreation (prevents kernel crashes)
let will_use_otg_hid = matches!(config.hid.backend, config::HidBackend::Otg);
let will_use_msd = config.msd.enabled;
if will_use_otg_hid {
let mut hid_functions = config.hid.effective_otg_functions();
if let Some(udc) = configfs::resolve_udc_name(config.hid.otg_udc.as_deref()) {
if configfs::is_low_endpoint_udc(&udc) && hid_functions.consumer {
tracing::warn!(
"UDC {} has low endpoint resources, disabling consumer control",
udc
);
hid_functions.consumer = false;
}
}
if let Err(e) = otg_service.update_hid_functions(hid_functions).await {
tracing::warn!("Failed to apply HID functions: {}", e);
}
if let Err(e) = otg_service.enable_hid().await {
tracing::warn!("Failed to pre-enable HID: {}", e);
}
}
if will_use_msd {
if let Err(e) = otg_service.enable_msd().await {
tracing::warn!("Failed to pre-enable MSD: {}", e);
}
if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await {
tracing::warn!("Failed to apply OTG config: {}", e);
}
// Create HID controller based on config
let hid_backend = match config.hid.backend {
config::HidBackend::Otg => HidBackendType::Otg,
config::HidBackend::Ch9329 => HidBackendType::Ch9329 {
@@ -356,27 +326,17 @@ async fn main() -> anyhow::Result<()> {
},
config::HidBackend::None => HidBackendType::None,
};
let hid = Arc::new(HidController::new(
hid_backend,
Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG
));
let hid = Arc::new(HidController::new(hid_backend, Some(otg_service.clone())));
hid.set_event_bus(events.clone()).await;
if let Err(e) = hid.init().await {
tracing::warn!("Failed to initialize HID backend: {}", e);
}
// Create MSD controller (optional, based on config)
let msd = if config.msd.enabled {
// Initialize Ventoy resources from data directory
let ventoy_resource_dir = ventoy_img::get_resource_dir(&data_dir);
let ventoy_resource_dir = data_dir.join("ventoy");
if ventoy_resource_dir.exists() {
if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) {
tracing::warn!("Failed to initialize Ventoy resources: {}", e);
tracing::info!(
"Ventoy resource files should be placed in: {}",
ventoy_resource_dir.display()
);
tracing::info!("Required files: {:?}", ventoy_img::required_files());
} else {
tracing::info!(
"Ventoy resources initialized from {}",
@@ -388,10 +348,6 @@ async fn main() -> anyhow::Result<()> {
"Ventoy resource directory not found: {}",
ventoy_resource_dir.display()
);
tracing::info!(
"Create the directory and place the following files: {:?}",
ventoy_img::required_files()
);
}
let controller = MsdController::new(otg_service.clone(), config.msd.msd_dir_path());
@@ -407,7 +363,6 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create ATX controller (optional, based on config)
let atx = if config.atx.enabled {
let controller_config = config.atx.to_controller_config();
let controller = AtxController::new(controller_config);
@@ -423,12 +378,21 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create Audio controller
let audio = {
let audio_config = AudioControllerConfig {
enabled: config.audio.enabled,
device: config.audio.device.clone(),
quality: AudioQuality::from_str(&config.audio.quality),
quality: match config.audio.quality.parse::<AudioQuality>() {
Ok(q) => q,
Err(e) => {
tracing::warn!(
"Invalid audio quality in config (value={:?}): {}, using balanced",
config.audio.quality,
e
);
AudioQuality::Balanced
}
},
};
let controller = AudioController::new(audio_config);
@@ -440,7 +404,6 @@ async fn main() -> anyhow::Result<()> {
config.audio.device,
config.audio.quality
);
// Start audio streaming so WebRTC can subscribe to Opus frames
if let Err(e) = controller.start_streaming().await {
tracing::warn!("Failed to start audio streaming: {}", e);
}
@@ -451,29 +414,23 @@ async fn main() -> anyhow::Result<()> {
Arc::new(controller)
};
// Create Extension manager (ttyd, gostc, easytier)
let extensions = Arc::new(ExtensionManager::new());
tracing::info!("Extension manager initialized");
// Wire up WebRTC streamer with HID controller
// This enables WebRTC DataChannel to process HID events
webrtc_streamer.set_hid_controller(hid.clone()).await;
// Wire up WebRTC streamer with Audio controller
// This enables WebRTC audio track to receive Opus frames
webrtc_streamer.set_audio_controller(audio.clone()).await;
if config.audio.enabled {
if let Err(e) = webrtc_streamer.set_audio_enabled(true).await {
tracing::warn!("Failed to enable WebRTC audio: {}", e);
} else {
tracing::info!("WebRTC audio enabled");
tracing::debug!("WebRTC audio enabled");
}
}
// Configure direct capture for WebRTC encoder pipeline
let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) =
streamer.current_capture_config().await;
tracing::info!(
tracing::debug!(
"Initial video config: {}x{} {:?} @ {}fps",
actual_resolution.width,
actual_resolution.height,
@@ -484,22 +441,50 @@ async fn main() -> anyhow::Result<()> {
.update_video_config(actual_resolution, actual_format, actual_fps)
.await;
if let Some(device_path) = device_path {
let (subdev_path, bridge_kind, v4l2_driver) = streamer
.current_device()
.await
.map(|d| {
(
d.subdev_path.clone(),
d.bridge_kind.clone(),
Some(d.driver.clone()),
)
})
.unwrap_or((None, None, None));
webrtc_streamer
.set_capture_device(device_path, jpeg_quality)
.set_capture_device(
device_path,
jpeg_quality,
subdev_path,
bridge_kind,
v4l2_driver,
)
.await;
tracing::info!("WebRTC streamer configured for direct capture");
tracing::debug!("WebRTC streamer configured for direct capture");
} else {
tracing::warn!("No capture device configured for WebRTC");
}
// Create video stream manager (unified MJPEG/WebRTC management)
// Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance
let stream_manager =
VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone());
let stream_manager = VideoStreamManager::with_webrtc_streamer(
streamer.clone(),
webrtc_streamer.clone() as std::sync::Arc<dyn one_kvm::video::traits::VideoOutput>,
);
stream_manager.set_event_bus(events.clone()).await;
stream_manager.set_config_store(config_store.clone()).await;
{
let stream_manager_weak = Arc::downgrade(&stream_manager);
audio
.set_recovered_callback(Arc::new(move || {
if let Some(stream_manager) = stream_manager_weak.upgrade() {
tokio::spawn(async move {
stream_manager.reconnect_webrtc_audio_sources().await;
});
}
}))
.await;
}
// Initialize stream manager with configured mode
let initial_mode = config.stream.mode.clone();
if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await {
tracing::warn!(
@@ -514,7 +499,6 @@ async fn main() -> anyhow::Result<()> {
);
}
// Create RustDesk service (optional, based on config)
let rustdesk = if config.rustdesk.is_valid() {
tracing::info!(
"Initializing RustDesk service: ID={} -> {}",
@@ -539,7 +523,6 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create RTSP service (optional, based on config)
let rtsp = if config.rtsp.enabled {
tracing::info!(
"Initializing RTSP service: rtsp://{}:{}/{}",
@@ -554,15 +537,16 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create application state
let update_service = Arc::new(UpdateService::new(data_dir.join("updates")));
let state = AppState::new(
db.clone(),
config_store.clone(),
session_store,
user_store,
otg_service,
stream_manager,
webrtc_streamer.clone(),
hid,
msd,
atx,
@@ -576,12 +560,12 @@ async fn main() -> anyhow::Result<()> {
data_dir.clone(),
);
// Start RustDesk service if enabled
extensions.set_event_bus(events.clone()).await;
if let Some(ref service) = rustdesk {
if let Err(e) = service.start().await {
tracing::error!("Failed to start RustDesk service: {}", e);
} else {
// Save generated keypair and UUID to config
if let Some(updated_config) = service.save_credentials() {
if let Err(e) = config_store
.update(|cfg| {
@@ -601,7 +585,6 @@ async fn main() -> anyhow::Result<()> {
}
}
// Start RTSP service if enabled
if let Some(ref service) = rtsp {
if let Err(e) = service.start().await {
tracing::error!("Failed to start RTSP service: {}", e);
@@ -610,7 +593,6 @@ async fn main() -> anyhow::Result<()> {
}
}
// Enforce startup codec constraints (e.g. RTSP/RustDesk locks)
{
let runtime_config = state.config.get();
let constraints = StreamCodecConstraints::from_config(&runtime_config);
@@ -625,13 +607,11 @@ async fn main() -> anyhow::Result<()> {
}
}
// Start enabled extensions
{
let ext_config = config_store.get();
extensions.start_enabled(&ext_config.extensions).await;
}
// Start extension health check task (every 30 seconds)
{
let extensions_clone = extensions.clone();
let config_store_clone = config_store.clone();
@@ -646,17 +626,14 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("Extension health check task started");
}
// Start device info broadcast task
// This monitors state change events and broadcasts DeviceInfo to all clients
state.publish_device_info().await;
spawn_device_info_broadcaster(state.clone(), events);
// Create router
let app = web::create_router(state.clone());
// Bind sockets for configured addresses
let listeners = bind_tcp_listeners(&bind_ips, bind_port)?;
// Setup graceful shutdown
let shutdown_signal = async move {
tokio::signal::ctrl_c()
.await
@@ -665,9 +642,7 @@ async fn main() -> anyhow::Result<()> {
let _ = shutdown_tx.send(());
};
// Start server
if config.web.https_enabled {
// Generate self-signed certificate if no custom cert provided
let tls_config = if let (Some(cert_path), Some(key_path)) =
(&config.web.ssl_cert_path, &config.web.ssl_key_path)
{
@@ -677,7 +652,6 @@ async fn main() -> anyhow::Result<()> {
let cert_path = cert_dir.join("server.crt");
let key_path = cert_dir.join("server.key");
// Check if certificate already exists, only generate if missing
if !cert_path.exists() || !key_path.exists() {
tracing::info!("Generating new self-signed TLS certificate");
let cert = generate_self_signed_cert()?;
@@ -691,7 +665,7 @@ async fn main() -> anyhow::Result<()> {
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
};
let mut servers = FuturesUnordered::new();
let servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTPS server on {}", local_addr);
@@ -701,19 +675,9 @@ async fn main() -> anyhow::Result<()> {
servers.push(server);
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTPS server error: {}", e);
}
cleanup(&state).await;
}
}
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await;
} else {
let mut servers = FuturesUnordered::new();
let servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTP server on {}", local_addr);
@@ -723,26 +687,14 @@ async fn main() -> anyhow::Result<()> {
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTP server error: {}", e);
}
cleanup(&state).await;
}
}
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await;
}
tracing::info!("Server shutdown complete");
Ok(())
}
/// Initialize logging with tracing
fn init_logging(level: LogLevel, verbose_count: u8) {
// Verbose count overrides log level
let effective_level = match verbose_count {
0 => level,
1 => LogLevel::Verbose,
@@ -750,7 +702,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
_ => LogLevel::Trace,
};
// Build filter string based on effective level
let filter = match effective_level {
LogLevel::Error => "one_kvm=error,tower_http=error",
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
@@ -760,7 +711,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
};
// Environment variable takes highest priority
let env_filter =
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| filter.into());
@@ -773,18 +723,111 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
}
}
/// Get the application data directory
fn get_data_dir() -> PathBuf {
// Check environment variable first
if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") {
return PathBuf::from(path);
}
// Default to system configuration directory
PathBuf::from("/etc/one-kvm")
}
/// Resolve bind IPs from config, preferring bind_addresses when set.
async fn open_database_pool(data_dir: &Path) -> anyhow::Result<DatabasePool> {
let db_path = data_dir.join("one-kvm.db");
let db = DatabasePool::new(&db_path).await?;
db.init_schema().await?;
Ok(db)
}
async fn run_servers_until_shutdown<F, E>(
mut servers: FuturesUnordered<F>,
shutdown_signal: impl Future<Output = ()>,
state: &Arc<AppState>,
protocol: &'static str,
) where
F: Future<Output = Result<(), E>> + Send,
E: std::fmt::Display,
{
tokio::select! {
_ = shutdown_signal => {
cleanup(state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("{} server error: {}", protocol, e);
}
cleanup(state).await;
}
}
}
async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Result<()> {
tokio::fs::create_dir_all(&data_dir).await?;
let db = open_database_pool(&data_dir).await?;
let users = UserStore::new(db.clone_pool());
let sessions = SessionStore::new(0);
match command {
CliCommand::User(user) => run_user_action(user.action, &users, &sessions).await,
}
}
async fn run_user_action(
action: UserAction,
users: &UserStore,
sessions: &SessionStore,
) -> anyhow::Result<()> {
match action {
UserAction::SetPassword => set_user_password(users, sessions).await,
}
}
async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow::Result<()> {
let user = users.single_user().await?.ok_or_else(|| {
anyhow::anyhow!("No local user exists yet; complete setup in the web UI first.")
})?;
let new_password = read_new_password_interactive()?;
if new_password.len() < 4 {
anyhow::bail!("Password must be at least 4 characters");
}
users.update_password(&user.id, &new_password).await?;
let revoked = sessions.delete_all().await?;
tracing::info!(
"Password updated for user '{}' and {} sessions revoked",
user.username,
revoked
);
println!(
"Password updated for user '{}' (revoked {} sessions).",
user.username, revoked
);
Ok(())
}
fn read_new_password_interactive() -> anyhow::Result<String> {
let once = |label: &str| -> anyhow::Result<String> {
print!("{}", label);
std::io::stdout().flush()?;
let mut line = String::new();
std::io::stdin().read_line(&mut line)?;
let s = line.trim_end_matches(['\r', '\n']).to_string();
if s.is_empty() {
anyhow::bail!("Password cannot be empty");
}
Ok(s)
};
let a = once("New password: ")?;
let b = once("Confirm password: ")?;
if a != b {
anyhow::bail!("Passwords do not match");
}
Ok(a)
}
fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result<Vec<IpAddr>> {
let raw_addrs = if !web.bind_addresses.is_empty() {
web.bind_addresses.as_slice()
@@ -825,7 +868,6 @@ fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result<Vec<std::ne
Ok(listeners)
}
/// Parse video format and resolution from config (avoids code duplication)
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
let format = config
.video
@@ -837,7 +879,6 @@ fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
(format, resolution)
}
/// Generate a self-signed TLS certificate
fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyPair>> {
use rcgen::generate_simple_self_signed;
@@ -851,59 +892,119 @@ fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyP
Ok(certified_key)
}
/// Spawn a background task that monitors state change events
/// and broadcasts DeviceInfo to all WebSocket clients with debouncing
fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
use one_kvm::events::SystemEvent;
use std::time::{Duration, Instant};
let mut rx = events.subscribe();
enum DeviceInfoTrigger {
Event,
Lagged { topic: &'static str, count: u64 },
}
const DEVICE_INFO_TOPICS: &[&str] = &[
"stream.state_changed",
"stream.config_applied",
"stream.mode_ready",
];
const DEBOUNCE_MS: u64 = 100;
let (trigger_tx, mut trigger_rx) = mpsc::unbounded_channel();
for topic in DEVICE_INFO_TOPICS {
let Some(mut rx) = events.subscribe_topic(topic) else {
tracing::warn!(
"DeviceInfo broadcaster missing topic subscription: {}",
topic
);
continue;
};
let trigger_tx = trigger_tx.clone();
let topic_name = *topic;
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(_) => {
if trigger_tx.send(DeviceInfoTrigger::Event).is_err() {
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
if trigger_tx
.send(DeviceInfoTrigger::Lagged {
topic: topic_name,
count,
})
.is_err()
{
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
});
}
{
let mut dirty_rx = events.subscribe_device_info_dirty();
let trigger_tx = trigger_tx.clone();
tokio::spawn(async move {
loop {
match dirty_rx.recv().await {
Ok(()) => {
if trigger_tx.send(DeviceInfoTrigger::Event).is_err() {
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => {
if trigger_tx
.send(DeviceInfoTrigger::Lagged {
topic: "device_info_dirty",
count,
})
.is_err()
{
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
});
}
tokio::spawn(async move {
let mut last_broadcast = Instant::now() - Duration::from_millis(DEBOUNCE_MS);
let mut pending_broadcast = false;
loop {
// Use timeout to handle pending broadcasts
let recv_result = if pending_broadcast {
let remaining =
DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64);
tokio::time::timeout(Duration::from_millis(remaining), rx.recv()).await
tokio::time::timeout(Duration::from_millis(remaining), trigger_rx.recv()).await
} else {
Ok(rx.recv().await)
Ok(trigger_rx.recv().await)
};
match recv_result {
Ok(Ok(event)) => {
let should_broadcast = matches!(
event,
SystemEvent::StreamStateChanged { .. }
| SystemEvent::StreamConfigApplied { .. }
| SystemEvent::StreamModeReady { .. }
| SystemEvent::HidStateChanged { .. }
| SystemEvent::MsdStateChanged { .. }
| SystemEvent::AtxStateChanged { .. }
| SystemEvent::AudioStateChanged { .. }
);
if should_broadcast {
pending_broadcast = true;
}
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(n))) => {
tracing::warn!("DeviceInfo broadcaster lagged by {} events", n);
Ok(Some(DeviceInfoTrigger::Event)) => {
pending_broadcast = true;
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => {
Ok(Some(DeviceInfoTrigger::Lagged { topic, count })) => {
tracing::warn!(
"DeviceInfo broadcaster lagged by {} events on topic {}",
count,
topic
);
pending_broadcast = true;
}
Ok(None) => {
tracing::info!("Event bus closed, stopping DeviceInfo broadcaster");
break;
}
Err(_timeout) => {
// Debounce timeout reached, broadcast now
}
Err(_timeout) => {}
}
// Broadcast if pending and debounce time has passed
if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) {
state.publish_device_info().await;
tracing::trace!("Broadcasted DeviceInfo (debounced)");
@@ -919,13 +1020,10 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
);
}
/// Clean up subsystems on shutdown
async fn cleanup(state: &Arc<AppState>) {
// Stop all extensions
state.extensions.stop_all().await;
tracing::info!("Extensions stopped");
// Stop RustDesk service
if let Some(ref service) = *state.rustdesk.read().await {
if let Err(e) = service.stop().await {
tracing::warn!("Failed to stop RustDesk service: {}", e);
@@ -934,7 +1032,6 @@ async fn cleanup(state: &Arc<AppState>) {
}
}
// Stop RTSP service
if let Some(ref service) = *state.rtsp.read().await {
if let Err(e) = service.stop().await {
tracing::warn!("Failed to stop RTSP service: {}", e);
@@ -943,31 +1040,26 @@ async fn cleanup(state: &Arc<AppState>) {
}
}
// Stop video
if let Err(e) = state.stream_manager.stop().await {
tracing::warn!("Failed to stop streamer: {}", e);
}
// Shutdown HID
if let Err(e) = state.hid.shutdown().await {
tracing::warn!("Failed to shutdown HID: {}", e);
}
// Shutdown MSD
if let Some(msd) = state.msd.write().await.as_mut() {
if let Err(e) = msd.shutdown().await {
tracing::warn!("Failed to shutdown MSD: {}", e);
}
}
// Shutdown ATX
if let Some(atx) = state.atx.write().await.as_mut() {
if let Err(e) = atx.shutdown().await {
tracing::warn!("Failed to shutdown ATX: {}", e);
}
}
// Shutdown Audio
if let Err(e) = state.audio.shutdown().await {
tracing::warn!("Failed to shutdown audio: {}", e);
}

View File

@@ -1,49 +0,0 @@
//! Module management for One-KVM
//!
//! This module provides infrastructure for managing feature modules
//! (video streaming, HID control, MSD, ATX) as independent async tasks.
use std::future::Future;
use std::pin::Pin;
use tokio::sync::broadcast;
/// Module status
#[derive(Debug, Clone, PartialEq)]
pub enum ModuleStatus {
Stopped,
Starting,
Running,
Stopping,
Error(String),
}
/// Trait for feature modules
pub trait Module: Send + Sync {
/// Module name
fn name(&self) -> &'static str;
/// Current status
fn status(&self) -> ModuleStatus;
/// Start the module
fn start(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
/// Stop the module
fn stop(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
}
/// Module manager for coordinating feature modules
pub struct ModuleManager {
shutdown_rx: broadcast::Receiver<()>,
}
impl ModuleManager {
pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self {
Self { shutdown_rx }
}
/// Wait for shutdown signal
pub async fn wait_for_shutdown(&mut self) {
let _ = self.shutdown_rx.recv().await;
}
}

View File

@@ -1,11 +1,3 @@
//! MSD Controller
//!
//! Manages the mass storage device lifecycle including:
//! - Image mounting and unmounting
//! - Virtual drive management
//! - State tracking
//! - Image downloads from URL
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
@@ -14,44 +6,25 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::image::ImageManager;
use super::monitor::{MsdHealthMonitor, MsdHealthStatus};
use super::monitor::MsdHealthMonitor;
use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState};
use crate::error::{AppError, Result};
use crate::otg::{MsdFunction, MsdLunConfig, OtgService};
/// USB Gadget path (system constant)
const GADGET_PATH: &str = "/sys/kernel/config/usb_gadget/one-kvm";
/// MSD Controller
pub struct MsdController {
/// OTG Service reference
otg_service: Arc<OtgService>,
/// MSD function manager (provided by OtgService)
msd_function: RwLock<Option<MsdFunction>>,
/// Current state
state: RwLock<MsdState>,
/// Images storage path
images_path: PathBuf,
/// Ventoy directory path
ventoy_dir: PathBuf,
/// Virtual drive path
drive_path: PathBuf,
/// Event bus for broadcasting state changes (optional)
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
/// Active downloads (download_id -> CancellationToken)
downloads: Arc<RwLock<HashMap<String, CancellationToken>>>,
/// Operation mutex lock (prevents concurrent operations)
operation_lock: Arc<RwLock<()>>,
/// Health monitor for error tracking and recovery
monitor: Arc<MsdHealthMonitor>,
}
impl MsdController {
/// Create new MSD controller
///
/// # Parameters
/// * `otg_service` - OTG service for gadget management
/// * `msd_dir` - Base directory for MSD storage
pub fn new(otg_service: Arc<OtgService>, msd_dir: impl Into<PathBuf>) -> Self {
let msd_dir = msd_dir.into();
let images_path = msd_dir.join("images");
@@ -71,11 +44,9 @@ impl MsdController {
}
}
/// Initialize the MSD controller
pub async fn init(&self) -> Result<()> {
info!("Initializing MSD controller");
// 1. Ensure images directory exists
if let Err(e) = std::fs::create_dir_all(&self.images_path) {
warn!("Failed to create images directory: {}", e);
}
@@ -83,18 +54,16 @@ impl MsdController {
warn!("Failed to create ventoy directory: {}", e);
}
// 2. Request MSD function from OtgService
info!("Requesting MSD function from OtgService");
let msd_func = self.otg_service.enable_msd().await?;
info!("Fetching MSD function from OtgService");
let msd_func = self.otg_service.msd_function().await.ok_or_else(|| {
AppError::Internal("MSD function is not active in OtgService".to_string())
})?;
// 3. Store function handle
*self.msd_function.write().await = Some(msd_func);
// 4. Update state
let mut state = self.state.write().await;
state.available = true;
// 5. Check for existing virtual drive
if self.drive_path.exists() {
if let Ok(metadata) = std::fs::metadata(&self.drive_path) {
state.drive_info = Some(DriveInfo {
@@ -115,71 +84,41 @@ impl MsdController {
Ok(())
}
/// Get current state as SystemEvent
pub async fn current_state_event(&self) -> crate::events::SystemEvent {
let state = self.state.read().await;
crate::events::SystemEvent::MsdStateChanged {
mode: state.mode.clone(),
connected: state.connected,
}
}
/// Get current MSD state
pub async fn state(&self) -> MsdState {
self.state.read().await.clone()
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: std::sync::Arc<crate::events::EventBus>) {
*self.events.write().await = Some(events.clone());
// Also set event bus on the monitor for health notifications
self.monitor.set_event_bus(events).await;
*self.events.write().await = Some(events);
}
/// Publish an event to the event bus
async fn publish_event(&self, event: crate::events::SystemEvent) {
if let Some(ref bus) = *self.events.read().await {
bus.publish(event);
}
}
/// Check if MSD is available
async fn mark_device_info_dirty(&self) {
if let Some(ref bus) = *self.events.read().await {
bus.mark_device_info_dirty();
}
}
pub async fn is_available(&self) -> bool {
self.state.read().await.available
}
/// Connect an image file
///
/// # Parameters
/// * `image` - Image info to mount
/// * `cdrom` - Mount as CD-ROM (read-only, removable)
/// * `read_only` - Mount as read-only
pub async fn connect_image(
&self,
image: &ImageInfo,
cdrom: bool,
read_only: bool,
) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(err);
}
self.assert_can_connect(&state).await?;
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Verify image exists
if !image.path.exists() {
let error_msg = format!("Image file not found: {}", image.path.display());
self.monitor
@@ -188,29 +127,12 @@ impl MsdController {
return Err(AppError::Internal(error_msg));
}
// Configure LUN
let config = if cdrom {
MsdLunConfig::cdrom(image.path.clone())
} else {
MsdLunConfig::disk(image.path.clone(), read_only)
};
let gadget_path = PathBuf::from(GADGET_PATH);
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(err);
}
self.configure_lun_now(&config).await?;
state.connected = true;
state.mode = MsdMode::Image;
@@ -221,55 +143,19 @@ impl MsdController {
image.name, cdrom, read_only
);
// Release the lock before publishing events
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
// Publish events
self.publish_event(crate::events::SystemEvent::MsdImageMounted {
image_id: image.id.clone(),
image_name: image.name.clone(),
size: image.size,
cdrom,
})
.await;
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
mode: MsdMode::Image,
connected: true,
})
.await;
self.finish_connect_success().await;
Ok(())
}
/// Connect the virtual drive
pub async fn connect_drive(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(err);
}
self.assert_can_connect(&state).await?;
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Check drive exists
if !self.drive_path.exists() {
let err =
AppError::Internal("Virtual drive not initialized. Call init first.".to_string());
@@ -279,25 +165,8 @@ impl MsdController {
return Err(err);
}
// Configure LUN as read-write disk
let config = MsdLunConfig::disk(self.drive_path.clone(), false);
let gadget_path = PathBuf::from(GADGET_PATH);
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(err);
}
self.configure_lun_now(&config).await?;
state.connected = true;
state.mode = MsdMode::Drive;
@@ -305,28 +174,57 @@ impl MsdController {
info!("Connected virtual drive: {}", self.drive_path.display());
// Release the lock before publishing event
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
// Publish event
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
mode: MsdMode::Drive,
connected: true,
})
.await;
self.finish_connect_success().await;
Ok(())
}
/// Disconnect current storage
async fn assert_can_connect(&self, state: &MsdState) -> Result<()> {
if !state.available {
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(AppError::Internal("MSD not available".to_string()));
}
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
Ok(())
}
async fn configure_lun_now(&self, config: &MsdLunConfig) -> Result<()> {
let gadget_path = self.active_gadget_path().await?;
let msd_hold = self.msd_function.read().await;
let Some(ref msd) = *msd_hold else {
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(AppError::Internal(
"MSD function not initialized".to_string(),
));
};
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
Ok(())
}
async fn finish_connect_success(&self) {
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
}
pub async fn disconnect(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
@@ -336,7 +234,7 @@ impl MsdController {
return Ok(());
}
let gadget_path = PathBuf::from(GADGET_PATH);
let gadget_path = self.active_gadget_path().await?;
if let Some(ref msd) = *self.msd_function.read().await {
msd.disconnect_lun_async(&gadget_path, 0).await?;
}
@@ -347,58 +245,39 @@ impl MsdController {
info!("Disconnected storage");
// Release the lock before publishing events
drop(state);
drop(_op_guard);
// Publish events
self.publish_event(crate::events::SystemEvent::MsdImageUnmounted)
.await;
self.publish_event(crate::events::SystemEvent::MsdStateChanged {
mode: MsdMode::None,
connected: false,
})
.await;
self.mark_device_info_dirty().await;
Ok(())
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
/// Get ventoy directory path
pub fn ventoy_dir(&self) -> &PathBuf {
&self.ventoy_dir
}
/// Get virtual drive path
pub fn drive_path(&self) -> &PathBuf {
&self.drive_path
}
/// Check if currently connected
pub async fn is_connected(&self) -> bool {
self.state.read().await.connected
}
/// Get current mode
pub async fn mode(&self) -> MsdMode {
self.state.read().await.mode.clone()
}
/// Update drive info
pub async fn update_drive_info(&self, info: DriveInfo) {
let mut state = self.state.write().await;
state.drive_info = Some(info);
}
/// Start downloading an image from URL
///
/// Returns the download_id that can be used to track or cancel the download.
/// Progress is reported via MsdDownloadProgress events.
pub async fn download_image(
&self,
url: String,
@@ -407,18 +286,15 @@ impl MsdController {
let download_id = uuid::Uuid::new_v4().to_string();
let cancel_token = CancellationToken::new();
// Register download
{
let mut downloads = self.downloads.write().await;
downloads.insert(download_id.clone(), cancel_token.clone());
}
// Extract filename for initial response
let display_filename = filename
.clone()
.unwrap_or_else(|| url.rsplit('/').next().unwrap_or("download").to_string());
// Create initial progress
let initial_progress = DownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
@@ -430,7 +306,6 @@ impl MsdController {
error: None,
};
// Publish started event
self.publish_event(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
@@ -442,18 +317,15 @@ impl MsdController {
})
.await;
// Clone what we need for the spawned task
let images_path = self.images_path.clone();
let events = self.events.read().await.clone();
let downloads = self.downloads.clone();
let download_id_clone = download_id.clone();
let url_clone = url.clone();
// Spawn download task
tokio::spawn(async move {
let manager = ImageManager::new(images_path);
// Create progress callback
let events_for_callback = events.clone();
let download_id_for_callback = download_id_clone.clone();
let url_for_callback = url_clone.clone();
@@ -475,18 +347,15 @@ impl MsdController {
}
};
// Run download
let result = manager
.download_from_url(&url_clone, filename, progress_callback)
.await;
// Remove from active downloads
{
let mut downloads_guard = downloads.write().await;
downloads_guard.remove(&download_id_clone);
}
// Publish completion event
match result {
Ok(image_info) => {
if let Some(ref bus) = events {
@@ -521,7 +390,6 @@ impl MsdController {
Ok(initial_progress)
}
/// Cancel an active download
pub async fn cancel_download(&self, download_id: &str) -> Result<()> {
let mut downloads = self.downloads.write().await;
@@ -537,26 +405,20 @@ impl MsdController {
}
}
/// Get list of active download IDs
pub async fn active_downloads(&self) -> Vec<String> {
let downloads = self.downloads.read().await;
downloads.keys().cloned().collect()
async fn active_gadget_path(&self) -> Result<PathBuf> {
self.otg_service
.gadget_path()
.await
.ok_or_else(|| AppError::Internal("OTG gadget path is not available".to_string()))
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down MSD controller");
// 1. Disconnect if connected
if let Err(e) = self.disconnect().await {
warn!("Error disconnecting during shutdown: {}", e);
}
// 2. Notify OtgService to disable MSD
info!("Disabling MSD function in OtgService");
self.otg_service.disable_msd().await?;
// 3. Clear local state
*self.msd_function.write().await = None;
let mut state = self.state.write().await;
@@ -566,27 +428,9 @@ impl MsdController {
Ok(())
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<MsdHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> MsdHealthStatus {
self.monitor.status().await
}
/// Check if the MSD is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Drop for MsdController {
fn drop(&mut self) {
// Cleanup is handled by OtgGadgetManager when the gadget is torn down
// Individual controllers don't need to cleanup the ConfigFS
}
}
#[cfg(test)]
@@ -602,7 +446,6 @@ mod tests {
let controller = MsdController::new(otg_service, &msd_dir);
// Check that MSD is not initialized (msd_function is None)
let state = controller.state().await;
assert!(!state.available);
assert!(controller.images_path.ends_with("images"));

View File

@@ -1,53 +1,37 @@
//! Image file manager
//!
//! Handles ISO/IMG image file operations:
//! - List available images
//! - Upload new images
//! - Delete images
//! - Metadata management
//! - Download from URL
use chrono::Utc;
use futures::StreamExt;
use std::fs::{self, File};
use std::io::{self, Read, Write};
use std::fs;
#[cfg(test)]
use std::io::Write;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use time::OffsetDateTime;
use tokio::io::AsyncWriteExt;
use tracing::info;
use super::types::ImageInfo;
use crate::error::{AppError, Result};
/// Maximum image size (32 GB)
const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024;
/// Progress report throttle interval (milliseconds)
const PROGRESS_THROTTLE_MS: u64 = 200;
/// Progress report throttle bytes threshold (512 KB)
const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024;
/// Image Manager
pub struct ImageManager {
/// Images storage directory
images_path: PathBuf,
}
impl ImageManager {
/// Create a new image manager
pub fn new(images_path: PathBuf) -> Self {
Self { images_path }
}
/// Ensure images directory exists
pub fn ensure_dir(&self) -> Result<()> {
fs::create_dir_all(&self.images_path)
.map_err(|e| AppError::Internal(format!("Failed to create images directory: {}", e)))?;
Ok(())
}
/// List all available images
pub fn list(&self) -> Result<Vec<ImageInfo>> {
self.ensure_dir()?;
@@ -68,28 +52,26 @@ impl ImageManager {
}
}
// Sort by creation time (newest first)
images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(images)
}
/// Get image info from path
fn get_image_info(&self, path: &Path) -> Option<ImageInfo> {
let metadata = fs::metadata(path).ok()?;
let name = path.file_name()?.to_string_lossy().to_string();
// Use filename hash as ID (stable across restarts)
let id = format!("{:x}", md5_hash(&name));
let id = stable_image_id_from_filename(&name);
let created_at = metadata
.created()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| {
chrono::DateTime::from_timestamp(d.as_secs() as i64, 0).unwrap_or_else(Utc::now)
OffsetDateTime::from_unix_timestamp(d.as_secs() as i64)
.unwrap_or_else(|_| OffsetDateTime::now_utc())
})
.unwrap_or_else(Utc::now);
.unwrap_or_else(OffsetDateTime::now_utc);
Some(ImageInfo {
id,
@@ -100,7 +82,6 @@ impl ImageManager {
})
}
/// Get image by ID
pub fn get(&self, id: &str) -> Result<ImageInfo> {
for image in self.list()? {
if image.id == id {
@@ -110,24 +91,21 @@ impl ImageManager {
Err(AppError::NotFound(format!("Image not found: {}", id)))
}
/// Get image by name
pub fn get_by_name(&self, name: &str) -> Result<ImageInfo> {
let path = self.images_path.join(name);
self.get_image_info(&path)
.ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name)))
}
/// Create a new image from bytes
pub fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
#[cfg(test)]
fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
self.ensure_dir()?;
// Validate name
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Check size
if data.len() as u64 > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
@@ -135,7 +113,6 @@ impl ImageManager {
)));
}
// Write file
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
@@ -144,11 +121,10 @@ impl ImageManager {
)));
}
let mut file = File::create(&path)
let mut file = fs::File::create(&path)
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
file.write_all(data).map_err(|e| {
// Try to clean up on error
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
@@ -158,55 +134,6 @@ impl ImageManager {
self.get_by_name(&name)
}
/// Create a new image from a file stream (for chunked uploads)
pub fn create_from_stream<R: Read>(
&self,
name: &str,
reader: &mut R,
expected_size: Option<u64>,
) -> Result<ImageInfo> {
self.ensure_dir()?;
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
if let Some(size) = expected_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
}
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
name
)));
}
// Create file and copy data
let mut file = File::create(&path)
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
let bytes_written = io::copy(reader, &mut file).map_err(|e| {
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
info!("Created image: {} ({} bytes)", name, bytes_written);
self.get_by_name(&name)
}
/// Create a new image from an async multipart field (streaming, memory-efficient)
///
/// This method streams data directly to disk without buffering the entire file in memory,
/// making it suitable for large files (multi-GB ISOs).
pub async fn create_from_multipart_field(
&self,
name: &str,
@@ -219,12 +146,10 @@ impl ImageManager {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Use a temporary file during upload
let temp_name = format!(".upload_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_name);
let final_path = self.images_path.join(&name);
// Check if final file already exists
if final_path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
@@ -232,23 +157,19 @@ impl ImageManager {
)));
}
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
let mut bytes_written: u64 = 0;
// Stream chunks directly to disk
while let Some(chunk) = field
.chunk()
.await
.map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))?
{
// Check size limit
bytes_written += chunk.len() as u64;
if bytes_written > MAX_IMAGE_SIZE {
// Cleanup and return error
drop(file);
let _ = tokio::fs::remove_file(&temp_path).await;
return Err(AppError::Internal(format!(
@@ -257,19 +178,16 @@ impl ImageManager {
)));
}
// Write chunk to file
file.write_all(&chunk)
.await
.map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?;
}
// Flush and close file
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
@@ -285,7 +203,6 @@ impl ImageManager {
self.get_by_name(&name)
}
/// Delete an image by ID
pub fn delete(&self, id: &str) -> Result<()> {
let image = self.get(id)?;
@@ -296,45 +213,6 @@ impl ImageManager {
Ok(())
}
/// Delete an image by name
pub fn delete_by_name(&self, name: &str) -> Result<()> {
let path = self.images_path.join(name);
if !path.exists() {
return Err(AppError::NotFound(format!("Image not found: {}", name)));
}
fs::remove_file(&path)
.map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?;
info!("Deleted image: {}", name);
Ok(())
}
/// Get total storage used
pub fn used_space(&self) -> u64 {
self.list()
.map(|images| images.iter().map(|i| i.size).sum())
.unwrap_or(0)
}
/// Check if storage has space for new image
pub fn has_space(&self, size: u64) -> bool {
// For now, just check against max size
// In the future, could check disk space
size <= MAX_IMAGE_SIZE
}
/// Download image from URL with progress callback
///
/// # Arguments
/// * `url` - The URL to download from
/// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided)
/// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes)
///
/// # Returns
/// * `Ok(ImageInfo)` - The downloaded image info
/// * `Err(AppError)` - If download fails
pub async fn download_from_url<F>(
&self,
url: &str,
@@ -346,20 +224,17 @@ impl ImageManager {
{
self.ensure_dir()?;
// Validate URL
let parsed_url = reqwest::Url::parse(url)
.map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?;
info!("Starting download from: {}", url);
// Create HTTP client with timeout
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files
.timeout(std::time::Duration::from_secs(3600))
.connect_timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?;
// Send HEAD request first to get content info
let head_response = client
.head(url)
.send()
@@ -379,7 +254,6 @@ impl ImageManager {
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
// Check file size
if let Some(size) = total_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::BadRequest(format!(
@@ -390,11 +264,9 @@ impl ImageManager {
}
}
// Determine filename
let final_filename = if let Some(name) = filename {
sanitize_filename(&name)
} else {
// Try Content-Disposition header first
let from_header = head_response
.headers()
.get(reqwest::header::CONTENT_DISPOSITION)
@@ -404,7 +276,6 @@ impl ImageManager {
if let Some(name) = from_header {
sanitize_filename(&name)
} else {
// Fall back to URL path
let path = parsed_url.path();
let name = path.rsplit('/').next().unwrap_or("download");
let name = urlencoding::decode(name).unwrap_or_else(|_| name.into());
@@ -418,7 +289,6 @@ impl ImageManager {
));
}
// Check if file already exists
let final_path = self.images_path.join(&final_filename);
if final_path.exists() {
return Err(AppError::BadRequest(format!(
@@ -427,11 +297,9 @@ impl ImageManager {
)));
}
// Create temporary file for download
let temp_filename = format!(".download_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_filename);
// Start actual download
let response = client
.get(url)
.send()
@@ -445,7 +313,6 @@ impl ImageManager {
)));
}
// Get actual content length from response (may differ from HEAD)
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
@@ -453,19 +320,16 @@ impl ImageManager {
.and_then(|s| s.parse::<u64>().ok())
.or(total_size);
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
// Stream download with progress (throttled)
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
let mut last_report_time = Instant::now();
let mut last_reported_bytes: u64 = 0;
let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS);
// Report initial progress
progress_callback(0, content_length);
while let Some(chunk_result) = stream.next().await {
@@ -473,14 +337,12 @@ impl ImageManager {
chunk_result.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?;
file.write_all(&chunk).await.map_err(|e| {
// Cleanup on error
let _ = std::fs::remove_file(&temp_path);
AppError::Internal(format!("Failed to write data: {}", e))
})?;
downloaded += chunk.len() as u64;
// Throttled progress reporting: report if enough time or bytes have passed
let now = Instant::now();
let time_elapsed = now.duration_since(last_report_time) >= throttle_interval;
let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES;
@@ -492,18 +354,15 @@ impl ImageManager {
}
}
// Always report final progress
if downloaded != last_reported_bytes {
progress_callback(downloaded, content_length);
}
// Ensure all data is flushed
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Verify downloaded size
let metadata = tokio::fs::metadata(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?;
@@ -519,7 +378,6 @@ impl ImageManager {
}
}
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
@@ -533,35 +391,29 @@ impl ImageManager {
metadata.len()
);
// Return image info
self.get_by_name(&final_filename)
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
}
/// Simple hash function for generating stable IDs
fn md5_hash(s: &str) -> u64 {
fn stable_image_id_from_filename(name: &str) -> String {
let mut hash: u64 = 0;
for (i, byte) in s.bytes().enumerate() {
for (i, byte) in name.bytes().enumerate() {
hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1)));
hash = hash.wrapping_mul(31);
}
hash
format!("{:x}", hash)
}
/// Sanitize filename to prevent path traversal
fn sanitize_filename(name: &str) -> String {
let name = name.trim();
let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_");
// Remove leading dots (hidden files)
let name = name.trim_start_matches('.');
// Limit length
if name.len() > 255 {
name[..255].to_string()
} else {
@@ -569,17 +421,10 @@ fn sanitize_filename(name: &str) -> String {
}
}
/// Extract filename from Content-Disposition header
fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
// Handle both:
// Content-Disposition: attachment; filename="example.iso"
// Content-Disposition: attachment; filename*=UTF-8''example.iso
// Try filename* first (RFC 5987)
if let Some(pos) = header.find("filename*=") {
let start = pos + 10;
let value = &header[start..];
// Format: charset'language'value
if let Some(quote_start) = value.find("''") {
let encoded = value[quote_start + 2..].split(';').next()?;
let decoded = urlencoding::decode(encoded.trim()).ok()?;
@@ -590,7 +435,6 @@ fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
}
}
// Try filename next
if let Some(pos) = header.find("filename=") {
let start = pos + 9;
let value = &header[start..];
@@ -612,7 +456,7 @@ mod tests {
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("test.iso"), "test.iso");
assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.')
assert_eq!(sanitize_filename("../test.iso"), "_test.iso");
assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso");
assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso");
}

View File

@@ -1,19 +1,3 @@
//! MSD (Mass Storage Device) module
//!
//! Provides virtual USB storage functionality with two modes:
//! - Image mounting: Mount ISO/IMG files for system installation
//! - Ventoy drive: Bootable exFAT drive for multiple ISO files
//!
//! Architecture:
//! ```text
//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC
//! |
//! ┌──────┴──────┐
//! │ │
//! Image Manager Ventoy Drive
//! (ISO/IMG) (Bootable exFAT)
//! ```
pub mod controller;
pub mod image;
pub mod monitor;
@@ -22,12 +6,11 @@ pub mod ventoy_drive;
pub use controller::MsdController;
pub use image::ImageManager;
pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig};
pub use monitor::MsdHealthMonitor;
pub use types::{
DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest,
ImageInfo, MsdConnectRequest, MsdMode, MsdState,
};
pub use ventoy_drive::VentoyDrive;
// Re-export from otg module for backward compatibility
pub use crate::otg::{MsdFunction, MsdLunConfig};

View File

@@ -1,118 +1,46 @@
//! MSD (Mass Storage Device) health monitoring
//!
//! This module provides health monitoring for MSD operations, including:
//! - ConfigFS operation error tracking
//! - Image mount/unmount error tracking
//! - Error notification
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::events::{EventBus, SystemEvent};
use crate::utils::LogThrottler;
/// MSD health status
const LOG_THROTTLE_SECS: u64 = 5;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum MsdHealthStatus {
/// Device is healthy and operational
pub(crate) enum MsdHealthStatus {
#[default]
Healthy,
/// Device has an error
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
},
}
/// MSD health monitor configuration
#[derive(Debug, Clone)]
pub struct MsdMonitorConfig {
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for MsdMonitorConfig {
fn default() -> Self {
Self {
log_throttle_secs: 5,
}
}
}
/// MSD health monitor
///
/// Monitors MSD operation health and manages error notifications.
/// Publishes WebSocket events when operation status changes.
pub struct MsdHealthMonitor {
/// Current health status
status: RwLock<MsdHealthStatus>,
/// Event bus for notifications
events: RwLock<Option<Arc<EventBus>>>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
#[allow(dead_code)]
config: MsdMonitorConfig,
/// Whether monitoring is active (reserved for future use)
#[allow(dead_code)]
running: AtomicBool,
/// Error count (for tracking)
error_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
}
impl MsdHealthMonitor {
/// Create a new MSD health monitor with the specified configuration
pub fn new(config: MsdMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
pub fn with_defaults() -> Self {
Self {
status: RwLock::new(MsdHealthStatus::Healthy),
events: RwLock::new(None),
throttler: LogThrottler::with_secs(throttle_secs),
config,
running: AtomicBool::new(false),
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
error_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
}
}
/// Create a new MSD health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(MsdMonitorConfig::default())
}
/// Set the event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Report an error from MSD operations
///
/// This method is called when an MSD operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Publishes a WebSocket event if the error is new or changed
///
/// # Arguments
///
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, reason: &str, error_code: &str) {
let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("msd_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
@@ -121,74 +49,47 @@ impl MsdHealthMonitor {
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = MsdHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
};
// Publish event (only if error changed or first occurrence)
if error_changed || count == 1 {
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::MsdError {
reason: reason.to_string(),
error_code: error_code.to_string(),
});
}
}
}
/// Report that the MSD has recovered from error
///
/// This method is called when an MSD operation succeeds after errors.
/// It resets the error state and publishes a recovery event.
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != MsdHealthStatus::Healthy {
let error_count = self.error_count.load(Ordering::Relaxed);
info!("MSD recovered after {} errors", error_count);
// Reset state
self.error_count.store(0, Ordering::Relaxed);
self.throttler.clear_all();
*self.last_error_code.write().await = None;
*self.status.write().await = MsdHealthStatus::Healthy;
// Publish recovery event
if let Some(ref events) = *self.events.read().await {
events.publish(SystemEvent::MsdRecovered);
}
}
}
/// Get the current health status
pub async fn status(&self) -> MsdHealthStatus {
#[cfg(test)]
pub(crate) async fn status(&self) -> MsdHealthStatus {
self.status.read().await.clone()
}
/// Get the current error count
pub fn error_count(&self) -> u32 {
#[cfg(test)]
pub(crate) fn error_count(&self) -> u32 {
self.error_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
#[cfg(test)]
pub(crate) async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.error_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
@@ -196,7 +97,6 @@ impl MsdHealthMonitor {
self.throttler.clear_all();
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
match &*self.status.read().await {
MsdHealthStatus::Error { reason, .. } => Some(reason.clone()),
@@ -246,13 +146,11 @@ mod tests {
async fn test_report_recovered() {
let monitor = MsdHealthMonitor::with_defaults();
// First report an error
monitor
.report_error("Image not found", "image_not_found")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.error_count(), 0);

View File

@@ -1,52 +1,38 @@
//! MSD data types and structures
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use time::OffsetDateTime;
/// MSD operating mode
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum MsdMode {
/// No storage connected
#[default]
None,
/// Image file mounted (ISO/IMG)
Image,
/// Virtual drive (FAT32) connected
Drive,
}
/// Image file metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageInfo {
/// Unique image ID
pub id: String,
/// Display name
pub name: String,
/// File path on disk
#[serde(skip_serializing)]
pub path: PathBuf,
/// File size in bytes
pub size: u64,
/// Creation timestamp
pub created_at: DateTime<Utc>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
}
impl ImageInfo {
/// Create new image info
pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self {
Self {
id,
name,
path,
size,
created_at: Utc::now(),
created_at: OffsetDateTime::now_utc(),
}
}
/// Format size for display
pub fn size_display(&self) -> String {
const KB: u64 = 1024;
const MB: u64 = KB * 1024;
@@ -64,18 +50,12 @@ impl ImageInfo {
}
}
/// MSD state information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdState {
/// Whether MSD feature is available
pub available: bool,
/// Current mode
pub mode: MsdMode,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image (if mode is Image)
pub current_image: Option<ImageInfo>,
/// Virtual drive info (if mode is Drive)
pub drive_info: Option<DriveInfo>,
}
@@ -91,24 +71,17 @@ impl Default for MsdState {
}
}
/// Virtual drive information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveInfo {
/// Drive size in bytes
pub size: u64,
/// Used space in bytes
pub used: u64,
/// Free space in bytes
pub free: u64,
/// Whether drive is initialized
pub initialized: bool,
/// Drive file path
#[serde(skip_serializing)]
pub path: PathBuf,
}
impl DriveInfo {
/// Create new drive info
pub fn new(path: PathBuf, size: u64) -> Self {
Self {
size,
@@ -120,91 +93,60 @@ impl DriveInfo {
}
}
/// File entry in virtual drive
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveFile {
/// File name
pub name: String,
/// Relative path from drive root
pub path: String,
/// File size in bytes (0 for directories)
pub size: u64,
/// Whether this is a directory
pub is_dir: bool,
/// Last modified timestamp
pub modified: Option<DateTime<Utc>>,
#[serde(with = "time::serde::rfc3339::option")]
pub modified: Option<OffsetDateTime>,
}
/// MSD connect request
#[derive(Debug, Clone, Deserialize)]
pub struct MsdConnectRequest {
/// Connection mode: "image" or "drive"
pub mode: MsdMode,
/// Image ID to mount (required for image mode)
pub image_id: Option<String>,
/// Mount as CD-ROM (optional, defaults based on image type)
#[serde(default)]
pub cdrom: Option<bool>,
/// Mount as read-only
#[serde(default)]
pub read_only: Option<bool>,
}
/// Virtual drive init request
#[derive(Debug, Clone, Deserialize)]
pub struct DriveInitRequest {
/// Drive size in megabytes (defaults to 16GB)
#[serde(default = "default_drive_size")]
pub size_mb: u32,
/// Optional custom path for Ventoy installation
pub ventoy_path: Option<String>,
}
fn default_drive_size() -> u32 {
16 * 1024 // 16GB
16 * 1024
}
/// Image download request
#[derive(Debug, Clone, Deserialize)]
pub struct ImageDownloadRequest {
/// URL to download from
pub url: String,
/// Optional custom filename
pub filename: Option<String>,
}
/// Download status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DownloadStatus {
/// Download has started
Started,
/// Download is in progress
InProgress,
/// Download completed successfully
Completed,
/// Download failed
Failed,
}
/// Download progress information
#[derive(Debug, Clone, Serialize)]
pub struct DownloadProgress {
/// Unique download ID
pub download_id: String,
/// Source URL
pub url: String,
/// Target filename
pub filename: String,
/// Bytes downloaded so far
pub bytes_downloaded: u64,
/// Total file size (None if unknown)
pub total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
pub progress_pct: Option<f32>,
/// Download status
pub status: DownloadStatus,
/// Error message if failed
pub error: Option<String>,
}
@@ -218,7 +160,7 @@ mod tests {
"test".into(),
"test.iso".into(),
PathBuf::from("/tmp/test.iso"),
1024 * 1024 * 1024 * 2, // 2 GB
1024 * 1024 * 1024 * 2,
);
assert!(info.size_display().contains("GB"));
}

View File

@@ -1,8 +1,3 @@
//! Ventoy Virtual Drive
//!
//! Replaces FAT32 VirtualDrive with a Ventoy bootable image.
//! Provides a bootable USB with exFAT data partition for ISO files.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
@@ -13,33 +8,20 @@ use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage};
use super::types::{DriveFile, DriveInfo};
use crate::error::{AppError, Result};
/// Chunk size for streaming reads (64 KB)
const STREAM_CHUNK_SIZE: usize = 64 * 1024;
/// Minimum drive size (1 GB) - Ventoy requires space for boot partition
const MIN_DRIVE_SIZE_MB: u32 = 1024;
/// Maximum drive size (128 GB)
const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024;
/// Default drive label
const DEFAULT_LABEL: &str = "ONE-KVM";
/// Ventoy Drive Manager
///
/// Thread-safe wrapper around VentoyImage providing async file operations.
/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous.
/// Uses RwLock to allow concurrent read operations while serializing writes.
pub struct VentoyDrive {
/// Drive image path
path: PathBuf,
/// RwLock for concurrent reads, exclusive writes
/// (ventoy-img-rs operations are synchronous and not thread-safe)
lock: Arc<RwLock<()>>,
}
impl VentoyDrive {
/// Create new Ventoy drive manager
pub fn new(path: PathBuf) -> Self {
Self {
path,
@@ -47,40 +29,32 @@ impl VentoyDrive {
}
}
/// Check if drive image exists
pub fn exists(&self) -> bool {
self.path.exists()
}
/// Get drive path
pub fn path(&self) -> &PathBuf {
&self.path
}
/// Initialize a new Ventoy drive image
///
/// Creates a bootable Ventoy image with the specified size.
/// The image includes boot partitions and an exFAT data partition.
pub async fn init(&self, size_mb: u32) -> Result<DriveInfo> {
let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB);
let size_str = format!("{}M", size_mb);
let path = self.path.clone();
let _lock = self.lock.write().await; // Write lock for initialization
let _lock = self.lock.write().await;
info!("Creating {} MB Ventoy drive at {}", size_mb, path.display());
// Run Ventoy creation in blocking task
let info = tokio::task::spawn_blocking(move || {
VentoyImage::create(&path, &size_str, DEFAULT_LABEL).map_err(ventoy_to_app_error)?;
// Get file metadata for DriveInfo
let metadata = std::fs::metadata(&path)
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
Ok::<DriveInfo, AppError>(DriveInfo {
size: metadata.len(),
used: 0,
free: metadata.len(), // Approximate - exFAT overhead not calculated
free: metadata.len(),
initialized: true,
path,
})
@@ -92,20 +66,18 @@ impl VentoyDrive {
Ok(info)
}
/// Get drive information
pub async fn info(&self) -> Result<DriveInfo> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let _lock = self.lock.read().await; // Read lock for info query
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let metadata = std::fs::metadata(&path)
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
// Open image to get file list and calculate used space
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
let files = image.list_files_recursive().map_err(ventoy_to_app_error)?;
@@ -116,7 +88,6 @@ impl VentoyDrive {
.map(|f| f.size)
.sum();
// Note: This is approximate since we don't have exact exFAT overhead
let size = metadata.len();
let free = size.saturating_sub(used);
@@ -132,7 +103,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// List files at a given path (or root if empty/"/")
pub async fn list_files(&self, dir_path: &str) -> Result<Vec<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -140,7 +110,7 @@ impl VentoyDrive {
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.read().await; // Read lock for listing
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -161,9 +131,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Write a file to the drive from multipart upload (streaming)
///
/// Streams the file directly into the Ventoy image's exFAT partition.
pub async fn write_file_from_multipart_field(
&self,
file_path: &str,
@@ -173,12 +140,10 @@ impl VentoyDrive {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, stream to a temporary file (to get the size)
let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp"));
let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4());
let temp_path = temp_dir.join(&temp_name);
// Stream upload to temp file
let mut temp_file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
@@ -201,23 +166,16 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?;
drop(temp_file);
// Now copy from temp file to Ventoy image
let path = self.path.clone();
let file_path = file_path.to_string();
let temp_path_clone = temp_path.clone();
let _lock = self.lock.write().await; // Write lock for file write
let _lock = self.lock.write().await;
let result = tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use add_file_to_path which handles streaming internally
image
.add_file_to_path(
&temp_path_clone,
&file_path,
true, // create_parents
true, // overwrite
)
.add_file_to_path(&temp_path_clone, &file_path, true, true)
.map_err(ventoy_to_app_error)?;
Ok::<(), AppError>(())
@@ -225,14 +183,13 @@ impl VentoyDrive {
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?;
// Cleanup temp file
let _ = tokio::fs::remove_file(&temp_path).await;
result?;
Ok(bytes_written)
}
/// Read a file from the drive (for download)
#[cfg(test)]
pub async fn read_file(&self, file_path: &str) -> Result<Vec<u8>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -240,7 +197,7 @@ impl VentoyDrive {
let path = self.path.clone();
let file_path = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file read
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -251,10 +208,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Get file information without reading content
///
/// Returns file size, name, and other metadata.
/// Returns None if the file doesn't exist.
pub async fn get_file_info(&self, file_path: &str) -> Result<Option<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -262,7 +215,7 @@ impl VentoyDrive {
let path = self.path.clone();
let file_path_owned = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file info
let _lock = self.lock.read().await;
let info = tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -282,10 +235,6 @@ impl VentoyDrive {
}))
}
/// Read a file from the drive as a stream (for large file downloads)
///
/// Returns an async channel receiver that yields chunks of file data.
/// This avoids loading the entire file into memory.
pub async fn read_file_stream(
&self,
file_path: &str,
@@ -297,7 +246,6 @@ impl VentoyDrive {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, get the file size
let file_info = self
.get_file_info(file_path)
.await?
@@ -315,15 +263,12 @@ impl VentoyDrive {
let file_path_owned = file_path.to_string();
let lock = self.lock.clone();
// Create a channel for streaming data
let (tx, rx) =
tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, std::io::Error>>(8);
// Spawn blocking task to read and send chunks
tokio::task::spawn_blocking(move || {
// Hold read lock for the entire read operation
let rt = tokio::runtime::Handle::current();
let _lock = rt.block_on(lock.read()); // Read lock for streaming
let _lock = rt.block_on(lock.read());
let image = match VentoyImage::open(&path) {
Ok(img) => img,
@@ -333,10 +278,8 @@ impl VentoyDrive {
}
};
// Create a channel writer that sends chunks
let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone());
// Stream the file through the writer
if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) {
let _ = rt.block_on(tx.send(Err(std::io::Error::other(e.to_string()))));
}
@@ -345,7 +288,6 @@ impl VentoyDrive {
Ok((file_size, rx))
}
/// Create a directory
pub async fn mkdir(&self, dir_path: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -353,7 +295,7 @@ impl VentoyDrive {
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.write().await; // Write lock for mkdir
let _lock = self.lock.write().await;
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -366,7 +308,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Delete a file or directory
pub async fn delete(&self, path_to_delete: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -374,12 +315,11 @@ impl VentoyDrive {
let path = self.path.clone();
let path_to_delete = path_to_delete.to_string();
let _lock = self.lock.write().await; // Write lock for delete
let _lock = self.lock.write().await;
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use recursive delete to handle directories
image
.remove_recursive(&path_to_delete)
.map_err(ventoy_to_app_error)
@@ -389,7 +329,6 @@ impl VentoyDrive {
}
}
/// Convert VentoyError to AppError
fn ventoy_to_app_error(err: VentoyError) -> AppError {
match err {
VentoyError::Io(e) => AppError::Io(e),
@@ -405,7 +344,6 @@ fn ventoy_to_app_error(err: VentoyError) -> AppError {
}
}
/// Convert VentoyFileInfo to DriveFile
fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile {
let full_path = if parent_path.is_empty() || parent_path == "/" {
format!("/{}", info.name)
@@ -418,13 +356,10 @@ fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFi
path: full_path,
size: info.size,
is_dir: info.is_directory,
modified: None, // Ventoy FileInfo doesn't include timestamps
modified: None,
}
}
/// A writer that sends chunks to an async channel
///
/// This bridges the sync Write trait with async channels for streaming.
struct ChannelWriter {
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
rt: tokio::runtime::Handle,
@@ -484,7 +419,6 @@ impl std::io::Write for ChannelWriter {
impl Drop for ChannelWriter {
fn drop(&mut self) {
// Flush any remaining data when the writer is dropped
let _ = self.flush_buffer();
}
}
@@ -496,16 +430,13 @@ mod tests {
use std::sync::OnceLock;
use tempfile::TempDir;
/// Path to ventoy resources directory
static RESOURCE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../ventoy-img-rs/resources");
/// Initialize ventoy resources once
fn init_ventoy_resources() -> bool {
static INIT: OnceLock<bool> = OnceLock::new();
*INIT.get_or_init(|| {
let resource_path = std::path::Path::new(RESOURCE_DIR);
// Decompress xz files if needed
let core_xz = resource_path.join("core.img.xz");
let core_img = resource_path.join("core.img");
if core_xz.exists() && !core_img.exists() {
@@ -524,7 +455,6 @@ mod tests {
}
}
// Initialize resources
if let Err(e) = ventoy_img::resources::init_resources(resource_path) {
eprintln!("Failed to init ventoy resources: {}", e);
return false;
@@ -534,7 +464,6 @@ mod tests {
})
}
/// Decompress xz file using system command
fn decompress_xz(src: &std::path::Path, dst: &std::path::Path) -> std::io::Result<()> {
let output = Command::new("xz")
.args(["-d", "-k", "-c", src.to_str().unwrap()])
@@ -551,7 +480,6 @@ mod tests {
Ok(())
}
/// Ensure resources are initialized, skip test if failed
fn ensure_resources() -> bool {
if !init_ventoy_resources() {
eprintln!("Skipping test: ventoy resources not available");
@@ -602,15 +530,12 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Write a test file
let test_content = b"Hello, Ventoy!";
let test_file_path = temp_dir.path().join("test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive using ventoy-img directly
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -619,7 +544,6 @@ mod tests {
.await
.unwrap();
// Read file from drive
let read_data = drive.read_file("/test.txt").await.unwrap();
assert_eq!(read_data, test_content);
}
@@ -633,18 +557,14 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a directory
drive.mkdir("/mydir").await.unwrap();
// Write a test file
let test_content = b"Test file content for info check";
let test_file_path = temp_dir.path().join("info_test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -653,7 +573,6 @@ mod tests {
.await
.unwrap();
// Test get_file_info for file
let file_info = drive.get_file_info("/info_test.txt").await.unwrap();
assert!(file_info.is_some());
let file_info = file_info.unwrap();
@@ -661,14 +580,12 @@ mod tests {
assert_eq!(file_info.size, test_content.len() as u64);
assert!(!file_info.is_dir);
// Test get_file_info for directory
let dir_info = drive.get_file_info("/mydir").await.unwrap();
assert!(dir_info.is_some());
let dir_info = dir_info.unwrap();
assert_eq!(dir_info.name, "mydir");
assert!(dir_info.is_dir);
// Test get_file_info for non-existent file
let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap();
assert!(not_found.is_none());
}
@@ -682,16 +599,13 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create test data that spans multiple chunks (>64KB)
let test_size = 200 * 1024; // 200 KB
let test_size = 200 * 1024;
let test_content: Vec<u8> = (0..test_size).map(|i| (i % 256) as u8).collect();
let test_file_path = temp_dir.path().join("large_file.bin");
std::fs::write(&test_file_path, &test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
let file_path_clone = test_file_path.clone();
tokio::task::spawn_blocking(move || {
@@ -701,18 +615,15 @@ mod tests {
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap();
assert_eq!(file_size, test_size as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.len(), test_content.len());
assert_eq!(received_data, test_content);
}
@@ -726,15 +637,12 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a small test file
let test_content = b"Small file for streaming test";
let test_file_path = temp_dir.path().join("small.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -743,18 +651,15 @@ mod tests {
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap();
assert_eq!(file_size, test_content.len() as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.as_slice(), test_content);
}
}

View File

@@ -1,35 +1,53 @@
//! ConfigFS file operations for USB Gadget
use std::fs::{self, File, OpenOptions};
use std::io::Write;
use std::path::Path;
use std::process::Command;
use crate::error::{AppError, Result};
/// ConfigFS base path for USB gadgets
pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget";
/// Default gadget name
pub const DEFAULT_GADGET_NAME: &str = "one-kvm";
/// USB Vendor ID (Linux Foundation) - default value
pub const DEFAULT_USB_VENDOR_ID: u16 = 0x1d6b;
/// USB Product ID (Multifunction Composite Gadget) - default value
pub const DEFAULT_USB_PRODUCT_ID: u16 = 0x0104;
/// USB device version - default value
pub const DEFAULT_USB_BCD_DEVICE: u16 = 0x0100;
/// USB spec version (USB 2.0)
pub const USB_BCD_USB: u16 = 0x0200;
/// Check if ConfigFS is available
pub fn is_configfs_available() -> bool {
Path::new(CONFIGFS_PATH).exists()
}
/// Find available UDC (USB Device Controller)
/// Loads `libcomposite` if needed; does not mount configfs.
pub fn ensure_libcomposite_loaded() -> Result<()> {
if is_configfs_available() {
return Ok(());
}
if !Path::new("/sys/module/libcomposite").exists() {
let status = Command::new("modprobe")
.arg("libcomposite")
.status()
.map_err(|e| {
AppError::Internal(format!("Failed to run modprobe libcomposite: {}", e))
})?;
if !status.success() {
return Err(AppError::Internal(format!(
"modprobe libcomposite failed with status {}",
status
)));
}
}
if is_configfs_available() {
Ok(())
} else {
Err(AppError::Internal(
"libcomposite is not available after modprobe; check configfs mount and kernel support"
.to_string(),
))
}
}
pub fn find_udc() -> Option<String> {
let udc_path = Path::new("/sys/class/udc");
if !udc_path.exists() {
@@ -43,40 +61,17 @@ pub fn find_udc() -> Option<String> {
.next()
}
/// Check if UDC is known to have low endpoint resources
pub fn is_low_endpoint_udc(name: &str) -> bool {
let name = name.to_ascii_lowercase();
name.contains("musb") || name.contains("musb-hdrc")
}
/// Resolve preferred UDC name if available, otherwise auto-detect
pub fn resolve_udc_name(preferred: Option<&str>) -> Option<String> {
if let Some(name) = preferred {
let path = Path::new("/sys/class/udc").join(name);
if path.exists() {
return Some(name.to_string());
}
}
find_udc()
}
/// Write string content to a file
///
/// For sysfs files, this function appends a newline and flushes
/// to ensure the kernel processes the write immediately.
///
/// IMPORTANT: sysfs attributes require a single atomic write() syscall.
/// The kernel processes the value on the first write(), so we must
/// build the complete buffer (including newline) before writing.
/// Sysfs/configfs: one write syscall with final buffer (incl. newline when needed).
pub fn write_file(path: &Path, content: &str) -> Result<()> {
// For sysfs files (especially write-only ones like forced_eject),
// we need to use simple O_WRONLY without O_TRUNC
// O_TRUNC may fail on special files or require read permission
let mut file = OpenOptions::new()
.write(true)
.open(path)
.or_else(|e| {
// If open fails, try create (for regular files)
if path.exists() {
Err(e)
} else {
@@ -85,9 +80,6 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
})
.map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?;
// Build complete buffer with newline, then write in single syscall.
// This is critical for sysfs - multiple write() calls may cause
// the kernel to only process partial data or return EINVAL.
let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') {
content.as_bytes().into()
} else {
@@ -99,14 +91,12 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
file.write_all(&data)
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
// Explicitly flush to ensure sysfs processes the write
file.flush()
.map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?;
Ok(())
}
/// Write binary content to a file
pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
let mut file = File::create(path)
.map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?;
@@ -117,14 +107,6 @@ pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
Ok(())
}
/// Read string content from a file
pub fn read_file(path: &Path) -> Result<String> {
fs::read_to_string(path)
.map(|s| s.trim().to_string())
.map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e)))
}
/// Create directory if not exists
pub fn create_dir(path: &Path) -> Result<()> {
fs::create_dir_all(path).map_err(|e| {
AppError::Internal(format!(
@@ -135,7 +117,6 @@ pub fn create_dir(path: &Path) -> Result<()> {
})
}
/// Remove directory
pub fn remove_dir(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_dir(path).map_err(|e| {
@@ -149,7 +130,6 @@ pub fn remove_dir(path: &Path) -> Result<()> {
Ok(())
}
/// Remove file
pub fn remove_file(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_file(path).map_err(|e| {
@@ -159,7 +139,6 @@ pub fn remove_file(path: &Path) -> Result<()> {
Ok(())
}
/// Create symlink
pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> {
std::os::unix::fs::symlink(src, dest).map_err(|e| {
AppError::Internal(format!(

View File

@@ -1,11 +1,7 @@
//! USB Endpoint allocation management
use crate::error::{AppError, Result};
/// Default maximum endpoints for typical UDC
pub const DEFAULT_MAX_ENDPOINTS: u8 = 16;
/// Endpoint allocator - manages UDC endpoint resources
#[derive(Debug, Clone)]
pub struct EndpointAllocator {
max_endpoints: u8,
@@ -13,7 +9,6 @@ pub struct EndpointAllocator {
}
impl EndpointAllocator {
/// Create a new endpoint allocator
pub fn new(max_endpoints: u8) -> Self {
Self {
max_endpoints,
@@ -21,7 +16,6 @@ impl EndpointAllocator {
}
}
/// Allocate endpoints for a function
pub fn allocate(&mut self, count: u8) -> Result<()> {
if self.used_endpoints + count > self.max_endpoints {
return Err(AppError::Internal(format!(
@@ -34,27 +28,22 @@ impl EndpointAllocator {
Ok(())
}
/// Release endpoints
pub fn release(&mut self, count: u8) {
self.used_endpoints = self.used_endpoints.saturating_sub(count);
}
/// Get available endpoint count
pub fn available(&self) -> u8 {
self.max_endpoints.saturating_sub(self.used_endpoints)
}
/// Get used endpoint count
pub fn used(&self) -> u8 {
self.used_endpoints
}
/// Get maximum endpoint count
pub fn max(&self) -> u8 {
self.max_endpoints
}
/// Check if can allocate
pub fn can_allocate(&self, count: u8) -> bool {
self.available() >= count
}
@@ -82,7 +71,6 @@ mod tests {
alloc.allocate(4).unwrap();
assert_eq!(alloc.available(), 2);
// Should fail - not enough endpoints
assert!(alloc.allocate(3).is_err());
alloc.release(2);

View File

@@ -1,42 +1,17 @@
//! USB Gadget Function trait definition
use std::path::Path;
use crate::error::Result;
/// Function metadata
#[derive(Debug, Clone)]
pub struct FunctionMeta {
/// Function name (e.g., "hid.usb0")
pub name: String,
/// Human-readable description
pub description: String,
/// Number of endpoints used
pub endpoints: u8,
/// Whether the function is enabled
pub enabled: bool,
}
/// USB Gadget Function trait
pub trait GadgetFunction: Send + Sync {
/// Get function name (e.g., "hid.usb0", "mass_storage.usb0")
fn name(&self) -> &str;
/// Get number of endpoints required
fn endpoints_required(&self) -> u8;
/// Get function metadata
fn meta(&self) -> FunctionMeta;
/// Create function directory and configuration in ConfigFS
fn create(&self, gadget_path: &Path) -> Result<()>;
/// Link function to configuration
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()>;
/// Unlink function from configuration
fn unlink(&self, config_path: &Path) -> Result<()>;
/// Cleanup function directory
fn cleanup(&self, gadget_path: &Path) -> Result<()>;
}

View File

@@ -1,34 +1,24 @@
//! HID Function implementation for USB Gadget
use std::path::{Path, PathBuf};
use tracing::debug;
use super::configfs::{
create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file,
};
use super::function::{FunctionMeta, GadgetFunction};
use super::report_desc::{CONSUMER_CONTROL, KEYBOARD, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
use super::function::GadgetFunction;
use super::report_desc::{
CONSUMER_CONTROL, KEYBOARD, KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE,
};
use crate::error::Result;
/// HID function type
#[derive(Debug, Clone)]
pub enum HidFunctionType {
/// Keyboard (no LED feedback)
/// Uses 1 endpoint: IN
Keyboard,
/// Relative mouse (traditional mouse movement)
/// Uses 1 endpoint: IN
MouseRelative,
/// Absolute mouse (touchscreen-like positioning)
/// Uses 1 endpoint: IN
MouseAbsolute,
/// Consumer control (multimedia keys)
/// Uses 1 endpoint: IN
ConsumerControl,
}
impl HidFunctionType {
/// Get endpoints required for this function type
pub fn endpoints(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1,
@@ -38,28 +28,25 @@ impl HidFunctionType {
}
}
/// Get HID protocol
pub fn protocol(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1, // Keyboard
HidFunctionType::MouseRelative => 2, // Mouse
HidFunctionType::MouseAbsolute => 2, // Mouse
HidFunctionType::ConsumerControl => 0, // None
HidFunctionType::Keyboard => 1,
HidFunctionType::MouseRelative => 2,
HidFunctionType::MouseAbsolute => 2,
HidFunctionType::ConsumerControl => 0,
}
}
/// Get HID subclass
pub fn subclass(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1, // Boot interface
HidFunctionType::MouseRelative => 1, // Boot interface
HidFunctionType::MouseAbsolute => 0, // No boot interface
HidFunctionType::ConsumerControl => 0, // No boot interface
HidFunctionType::Keyboard => 1,
HidFunctionType::MouseRelative => 1,
HidFunctionType::MouseAbsolute => 0,
HidFunctionType::ConsumerControl => 0,
}
}
/// Get report length in bytes
pub fn report_length(&self) -> u8 {
pub fn report_length(&self, _keyboard_leds: bool) -> u8 {
match self {
HidFunctionType::Keyboard => 8,
HidFunctionType::MouseRelative => 4,
@@ -68,81 +55,71 @@ impl HidFunctionType {
}
}
/// Get report descriptor
pub fn report_desc(&self) -> &'static [u8] {
pub fn report_desc(&self, keyboard_leds: bool) -> &'static [u8] {
match self {
HidFunctionType::Keyboard => KEYBOARD,
HidFunctionType::Keyboard => {
if keyboard_leds {
KEYBOARD_WITH_LED
} else {
KEYBOARD
}
}
HidFunctionType::MouseRelative => MOUSE_RELATIVE,
HidFunctionType::MouseAbsolute => MOUSE_ABSOLUTE,
HidFunctionType::ConsumerControl => CONSUMER_CONTROL,
}
}
/// Get description
pub fn description(&self) -> &'static str {
match self {
HidFunctionType::Keyboard => "Keyboard",
HidFunctionType::MouseRelative => "Relative Mouse",
HidFunctionType::MouseAbsolute => "Absolute Mouse",
HidFunctionType::ConsumerControl => "Consumer Control",
}
}
}
/// HID Function for USB Gadget
#[derive(Debug, Clone)]
pub struct HidFunction {
/// Instance number (usb0, usb1, ...)
instance: u8,
/// Function type
func_type: HidFunctionType,
/// Cached function name (avoids repeated allocation)
name: String,
keyboard_leds: bool,
}
impl HidFunction {
/// Create a keyboard function
pub fn keyboard(instance: u8) -> Self {
pub fn keyboard(instance: u8, keyboard_leds: bool) -> Self {
Self {
instance,
func_type: HidFunctionType::Keyboard,
name: format!("hid.usb{}", instance),
keyboard_leds,
}
}
/// Create a relative mouse function
pub fn mouse_relative(instance: u8) -> Self {
Self {
instance,
func_type: HidFunctionType::MouseRelative,
name: format!("hid.usb{}", instance),
keyboard_leds: false,
}
}
/// Create an absolute mouse function
pub fn mouse_absolute(instance: u8) -> Self {
Self {
instance,
func_type: HidFunctionType::MouseAbsolute,
name: format!("hid.usb{}", instance),
keyboard_leds: false,
}
}
/// Create a consumer control function
pub fn consumer_control(instance: u8) -> Self {
Self {
instance,
func_type: HidFunctionType::ConsumerControl,
name: format!("hid.usb{}", instance),
keyboard_leds: false,
}
}
/// Get function path in gadget
fn function_path(&self, gadget_path: &Path) -> PathBuf {
gadget_path.join("functions").join(self.name())
}
/// Get expected device path (e.g., /dev/hidg0)
pub fn device_path(&self) -> PathBuf {
PathBuf::from(format!("/dev/hidg{}", self.instance))
}
@@ -157,20 +134,10 @@ impl GadgetFunction for HidFunction {
self.func_type.endpoints()
}
fn meta(&self) -> FunctionMeta {
FunctionMeta {
name: self.name().to_string(),
description: self.func_type.description().to_string(),
endpoints: self.endpoints_required(),
enabled: true,
}
}
fn create(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
create_dir(&func_path)?;
// Set HID parameters
write_file(
&func_path.join("protocol"),
&self.func_type.protocol().to_string(),
@@ -181,11 +148,13 @@ impl GadgetFunction for HidFunction {
)?;
write_file(
&func_path.join("report_length"),
&self.func_type.report_length().to_string(),
&self.func_type.report_length(self.keyboard_leds).to_string(),
)?;
// Write report descriptor
write_bytes(&func_path.join("report_desc"), self.func_type.report_desc())?;
write_bytes(
&func_path.join("report_desc"),
self.func_type.report_desc(self.keyboard_leds),
)?;
debug!(
"Created HID function: {} at {}",
@@ -232,14 +201,15 @@ mod tests {
assert_eq!(HidFunctionType::MouseRelative.endpoints(), 1);
assert_eq!(HidFunctionType::MouseAbsolute.endpoints(), 1);
assert_eq!(HidFunctionType::Keyboard.report_length(), 8);
assert_eq!(HidFunctionType::MouseRelative.report_length(), 4);
assert_eq!(HidFunctionType::MouseAbsolute.report_length(), 6);
assert_eq!(HidFunctionType::Keyboard.report_length(false), 8);
assert_eq!(HidFunctionType::Keyboard.report_length(true), 8);
assert_eq!(HidFunctionType::MouseRelative.report_length(false), 4);
assert_eq!(HidFunctionType::MouseAbsolute.report_length(false), 6);
}
#[test]
fn test_hid_function_names() {
let kb = HidFunction::keyboard(0);
let kb = HidFunction::keyboard(0, false);
assert_eq!(kb.name(), "hid.usb0");
assert_eq!(kb.device_path(), PathBuf::from("/dev/hidg0"));

View File

@@ -1,6 +1,3 @@
//! OTG Gadget Manager - unified management for USB Gadget functions
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tracing::{debug, error, info, warn};
@@ -11,15 +8,14 @@ use super::configfs::{
DEFAULT_USB_VENDOR_ID, USB_BCD_USB,
};
use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS};
use super::function::{FunctionMeta, GadgetFunction};
use super::function::GadgetFunction;
use super::hid::HidFunction;
use super::msd::MsdFunction;
use crate::error::{AppError, Result};
const REBIND_DELAY_MS: u64 = 300;
/// USB Gadget device descriptor configuration
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GadgetDescriptor {
pub vendor_id: u16,
pub product_id: u16,
@@ -42,44 +38,28 @@ impl Default for GadgetDescriptor {
}
}
/// OTG Gadget Manager - unified management for HID and MSD
pub struct OtgGadgetManager {
/// Gadget name
gadget_name: String,
/// Gadget path in ConfigFS
gadget_path: PathBuf,
/// Configuration path
config_path: PathBuf,
/// Device descriptor
descriptor: GadgetDescriptor,
/// Endpoint allocator
endpoint_allocator: EndpointAllocator,
/// HID instance counter
hid_instance: u8,
/// MSD instance counter
msd_instance: u8,
/// Registered functions
functions: Vec<Box<dyn GadgetFunction>>,
/// Function metadata
meta: HashMap<String, FunctionMeta>,
/// Bound UDC name
bound_udc: Option<String>,
/// Whether gadget was created by us
created_by_us: bool,
}
impl OtgGadgetManager {
/// Create a new gadget manager with default settings
pub fn new() -> Self {
Self::with_config(DEFAULT_GADGET_NAME, DEFAULT_MAX_ENDPOINTS)
}
/// Create a new gadget manager with custom configuration
pub fn with_config(gadget_name: &str, max_endpoints: u8) -> Self {
Self::with_descriptor(gadget_name, max_endpoints, GadgetDescriptor::default())
}
/// Create a new gadget manager with custom descriptor
pub fn with_descriptor(
gadget_name: &str,
max_endpoints: u8,
@@ -96,30 +76,24 @@ impl OtgGadgetManager {
endpoint_allocator: EndpointAllocator::new(max_endpoints),
hid_instance: 0,
msd_instance: 0,
// Pre-allocate for typical use: 3 HID (keyboard, rel mouse, abs mouse) + 1 MSD
functions: Vec::with_capacity(4),
meta: HashMap::with_capacity(4),
bound_udc: None,
created_by_us: false,
}
}
/// Check if ConfigFS is available
pub fn is_available() -> bool {
is_configfs_available()
}
/// Find available UDC
pub fn find_udc() -> Option<String> {
find_udc()
}
/// Check if gadget exists
pub fn gadget_exists(&self) -> bool {
self.gadget_path.exists()
}
/// Check if gadget is bound to UDC
pub fn is_bound(&self) -> bool {
let udc_file = self.gadget_path.join("UDC");
if let Ok(content) = fs::read_to_string(&udc_file) {
@@ -129,17 +103,14 @@ impl OtgGadgetManager {
}
}
/// Add keyboard function
/// Returns the expected device path (e.g., /dev/hidg0)
pub fn add_keyboard(&mut self) -> Result<PathBuf> {
let func = HidFunction::keyboard(self.hid_instance);
pub fn add_keyboard(&mut self, keyboard_leds: bool) -> Result<PathBuf> {
let func = HidFunction::keyboard(self.hid_instance, keyboard_leds);
let device_path = func.device_path();
self.add_function(Box::new(func))?;
self.hid_instance += 1;
Ok(device_path)
}
/// Add relative mouse function
pub fn add_mouse_relative(&mut self) -> Result<PathBuf> {
let func = HidFunction::mouse_relative(self.hid_instance);
let device_path = func.device_path();
@@ -148,7 +119,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add absolute mouse function
pub fn add_mouse_absolute(&mut self) -> Result<PathBuf> {
let func = HidFunction::mouse_absolute(self.hid_instance);
let device_path = func.device_path();
@@ -157,7 +127,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add consumer control function (multimedia keys)
pub fn add_consumer_control(&mut self) -> Result<PathBuf> {
let func = HidFunction::consumer_control(self.hid_instance);
let device_path = func.device_path();
@@ -166,7 +135,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add MSD function (returns MsdFunction handle for LUN configuration)
pub fn add_msd(&mut self) -> Result<MsdFunction> {
let func = MsdFunction::new(self.msd_instance);
let func_clone = func.clone();
@@ -175,11 +143,9 @@ impl OtgGadgetManager {
Ok(func_clone)
}
/// Add a generic function
fn add_function(&mut self, func: Box<dyn GadgetFunction>) -> Result<()> {
let endpoints = func.endpoints_required();
// Check endpoint availability
if !self.endpoint_allocator.can_allocate(endpoints) {
return Err(AppError::Internal(format!(
"Not enough endpoints for function {}: need {}, available {}",
@@ -189,82 +155,62 @@ impl OtgGadgetManager {
)));
}
// Allocate endpoints
self.endpoint_allocator.allocate(endpoints)?;
// Store metadata
self.meta.insert(func.name().to_string(), func.meta());
// Store function
self.functions.push(func);
Ok(())
}
/// Setup the gadget (create directories and configure)
pub fn setup(&mut self) -> Result<()> {
info!("Setting up OTG USB Gadget: {}", self.gadget_name);
debug!("Setting up OTG USB Gadget: {}", self.gadget_name);
// Check ConfigFS availability
if !Self::is_available() {
return Err(AppError::Internal(
"ConfigFS not available. Is it mounted at /sys/kernel/config?".to_string(),
));
}
// Check if gadget already exists and is bound
if self.gadget_exists() {
if self.is_bound() {
info!("Gadget already exists and is bound, skipping setup");
debug!("Gadget already exists and is bound, skipping setup");
return Ok(());
}
warn!("Gadget exists but not bound, will reconfigure");
self.cleanup()?;
}
// Create gadget directory
create_dir(&self.gadget_path)?;
self.created_by_us = true;
// Set device descriptors
self.set_device_descriptors()?;
// Create strings
self.create_strings()?;
// Create configuration
self.create_configuration()?;
// Create and link all functions
for func in &self.functions {
func.create(&self.gadget_path)?;
func.link(&self.config_path, &self.gadget_path)?;
}
info!("OTG USB Gadget setup complete");
debug!("OTG USB Gadget setup complete");
Ok(())
}
/// Bind gadget to UDC
pub fn bind(&mut self) -> Result<()> {
let udc = Self::find_udc().ok_or_else(|| {
AppError::Internal("No USB Device Controller (UDC) found".to_string())
})?;
// Recreate config symlinks before binding to avoid kernel gadget issues after rebind
pub fn bind(&mut self, udc: &str) -> Result<()> {
if let Err(e) = self.recreate_config_links() {
warn!("Failed to recreate gadget config links before bind: {}", e);
}
info!("Binding gadget to UDC: {}", udc);
debug!("Binding gadget to UDC: {}", udc);
write_file(&self.gadget_path.join("UDC"), &udc)?;
self.bound_udc = Some(udc);
self.bound_udc = Some(udc.to_string());
std::thread::sleep(std::time::Duration::from_millis(REBIND_DELAY_MS));
Ok(())
}
/// Unbind gadget from UDC
pub fn unbind(&mut self) -> Result<()> {
if self.is_bound() {
write_file(&self.gadget_path.join("UDC"), "")?;
@@ -275,7 +221,6 @@ impl OtgGadgetManager {
Ok(())
}
/// Cleanup all resources
pub fn cleanup(&mut self) -> Result<()> {
if !self.gadget_exists() {
return Ok(());
@@ -283,29 +228,23 @@ impl OtgGadgetManager {
info!("Cleaning up OTG USB Gadget: {}", self.gadget_name);
// Unbind from UDC first
let _ = self.unbind();
// Unlink and cleanup functions
for func in self.functions.iter().rev() {
let _ = func.unlink(&self.config_path);
}
// Remove config strings
let config_strings = self.config_path.join("strings/0x409");
let _ = remove_dir(&config_strings);
let _ = remove_dir(&self.config_path);
// Cleanup functions
for func in self.functions.iter().rev() {
let _ = func.cleanup(&self.gadget_path);
}
// Remove gadget strings
let gadget_strings = self.gadget_path.join("strings/0x409");
let _ = remove_dir(&gadget_strings);
// Remove gadget directory
if let Err(e) = remove_dir(&self.gadget_path) {
warn!("Could not remove gadget directory: {}", e);
}
@@ -315,7 +254,6 @@ impl OtgGadgetManager {
Ok(())
}
/// Set USB device descriptors
fn set_device_descriptors(&self) -> Result<()> {
write_file(
&self.gadget_path.join("idVendor"),
@@ -333,14 +271,13 @@ impl OtgGadgetManager {
&self.gadget_path.join("bcdUSB"),
&format!("0x{:04x}", USB_BCD_USB),
)?;
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?;
write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?;
write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?;
debug!("Set device descriptors");
Ok(())
}
/// Create USB strings
fn create_strings(&self) -> Result<()> {
let strings_path = self.gadget_path.join("strings/0x409");
create_dir(&strings_path)?;
@@ -358,41 +295,23 @@ impl OtgGadgetManager {
Ok(())
}
/// Create configuration
fn create_configuration(&self) -> Result<()> {
create_dir(&self.config_path)?;
// Create config strings
let strings_path = self.config_path.join("strings/0x409");
create_dir(&strings_path)?;
write_file(&strings_path.join("configuration"), "Config 1: HID + MSD")?;
// Set max power (500mA)
write_file(&self.config_path.join("MaxPower"), "500")?;
debug!("Created configuration c.1");
Ok(())
}
/// Get function metadata
pub fn get_meta(&self) -> &HashMap<String, FunctionMeta> {
&self.meta
}
/// Get endpoint usage info
pub fn endpoint_info(&self) -> (u8, u8) {
(
self.endpoint_allocator.used(),
self.endpoint_allocator.max(),
)
}
/// Get gadget path
pub fn gadget_path(&self) -> &PathBuf {
&self.gadget_path
}
/// Recreate config symlinks from functions directory
fn recreate_config_links(&self) -> Result<()> {
let functions_path = self.gadget_path.join("functions");
if !functions_path.exists() || !self.config_path.exists() {
@@ -454,15 +373,10 @@ impl Drop for OtgGadgetManager {
}
}
/// Wait for HID devices to become available
///
/// Uses exponential backoff starting from 10ms, capped at 100ms,
/// to reduce CPU usage while still providing fast response.
pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
// Exponential backoff: start at 10ms, double each time, cap at 100ms
let mut delay_ms = 10u64;
const MAX_DELAY_MS: u64 = 100;
@@ -471,7 +385,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) ->
return true;
}
// Calculate remaining time to avoid overshooting timeout
let remaining = timeout.saturating_sub(start.elapsed());
let sleep_duration = std::time::Duration::from_millis(delay_ms).min(remaining);
@@ -481,7 +394,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) ->
tokio::time::sleep(sleep_duration).await;
// Exponential backoff with cap
delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
}
@@ -496,18 +408,16 @@ mod tests {
fn test_manager_creation() {
let manager = OtgGadgetManager::new();
assert_eq!(manager.gadget_name, DEFAULT_GADGET_NAME);
assert!(!manager.gadget_exists()); // Won't exist in test environment
assert!(!manager.gadget_exists());
}
#[test]
fn test_endpoint_tracking() {
let mut manager = OtgGadgetManager::with_config("test", 8);
// Keyboard uses 1 endpoint
let _ = manager.add_keyboard();
let _ = manager.add_keyboard(false);
assert_eq!(manager.endpoint_allocator.used(), 1);
// Mouse uses 1 endpoint each
let _ = manager.add_mouse_relative();
let _ = manager.add_mouse_absolute();
assert_eq!(manager.endpoint_allocator.used(), 3);

View File

@@ -1,21 +1,4 @@
//! OTG USB Gadget unified management module
//!
//! This module provides unified management for USB Gadget functions:
//! - HID (Keyboard, Mouse)
//! - MSD (Mass Storage Device)
//!
//! Architecture:
//! ```text
//! OtgService (high-level coordination)
//! └── OtgGadgetManager (gadget lifecycle)
//! ├── EndpointAllocator (manages UDC endpoints)
//! ├── HidFunction (keyboard, mouse_rel, mouse_abs)
//! └── MsdFunction (mass storage)
//! ```
//!
//! The recommended way to use this module is through `OtgService`, which provides
//! a high-level interface for enabling/disabling HID and MSD functions independently.
//! Both `HidController` and `MsdController` should share the same `OtgService` instance.
//! USB OTG composite gadget (HID + MSD).
pub mod configfs;
pub mod endpoint;
@@ -26,10 +9,6 @@ pub mod msd;
pub mod report_desc;
pub mod service;
pub use endpoint::EndpointAllocator;
pub use function::{FunctionMeta, GadgetFunction};
pub use hid::{HidFunction, HidFunctionType};
pub use manager::{wait_for_hid_devices, OtgGadgetManager};
pub use msd::{MsdFunction, MsdLunConfig};
pub use report_desc::{KEYBOARD, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
pub use service::{HidDevicePaths, OtgService, OtgServiceState};
pub use service::{HidDevicePaths, OtgService};

View File

@@ -1,25 +1,17 @@
//! MSD (Mass Storage Device) Function implementation for USB Gadget
use std::fs;
use std::path::{Path, PathBuf};
use tracing::{debug, info, warn};
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_file};
use super::function::{FunctionMeta, GadgetFunction};
use super::function::GadgetFunction;
use crate::error::{AppError, Result};
/// MSD LUN configuration
#[derive(Debug, Clone)]
pub struct MsdLunConfig {
/// File/image path to expose
pub file: PathBuf,
/// Mount as CD-ROM
pub cdrom: bool,
/// Read-only mode
pub ro: bool,
/// Removable media
pub removable: bool,
/// Disable Force Unit Access
pub nofua: bool,
}
@@ -36,7 +28,6 @@ impl Default for MsdLunConfig {
}
impl MsdLunConfig {
/// Create CD-ROM configuration
pub fn cdrom(file: PathBuf) -> Self {
Self {
file,
@@ -47,7 +38,6 @@ impl MsdLunConfig {
}
}
/// Create disk configuration
pub fn disk(file: PathBuf, read_only: bool) -> Self {
Self {
file,
@@ -59,38 +49,26 @@ impl MsdLunConfig {
}
}
/// MSD Function for USB Gadget
#[derive(Debug, Clone)]
pub struct MsdFunction {
/// Instance number (usb0, usb1, ...)
instance: u8,
/// Cached function name (avoids repeated allocation)
name: String,
}
impl MsdFunction {
/// Create a new MSD function
pub fn new(instance: u8) -> Self {
Self {
instance,
name: format!("mass_storage.usb{}", instance),
}
}
/// Get function path in gadget
fn function_path(&self, gadget_path: &Path) -> PathBuf {
gadget_path.join("functions").join(self.name())
}
/// Get LUN path
fn lun_path(&self, gadget_path: &Path, lun: u8) -> PathBuf {
self.function_path(gadget_path).join(format!("lun.{}", lun))
}
/// Configure a LUN with specified settings (async version)
///
/// This is the preferred method for async contexts. It runs the blocking
/// file I/O and USB timing operations in a separate thread pool.
pub async fn configure_lun_async(
&self,
gadget_path: &Path,
@@ -106,17 +84,6 @@ impl MsdFunction {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Configure a LUN with specified settings
/// Note: This should be called after the gadget is set up
///
/// This implementation is based on PiKVM's MSD drive configuration.
/// Key improvements:
/// - Uses forced_eject when available (safer than clearing file directly)
/// - Reduced sleep times to minimize HID interference
/// - Better retry logic for EBUSY errors
///
/// **Note**: This is a blocking function. In async contexts, prefer
/// `configure_lun_async` to avoid blocking the runtime.
pub fn configure_lun(&self, gadget_path: &Path, lun: u8, config: &MsdLunConfig) -> Result<()> {
let lun_path = self.lun_path(gadget_path, lun);
@@ -124,7 +91,6 @@ impl MsdFunction {
create_dir(&lun_path)?;
}
// Batch read all current values to minimize syscalls
let read_attr = |attr: &str| -> String {
fs::read_to_string(lun_path.join(attr))
.unwrap_or_default()
@@ -137,28 +103,21 @@ impl MsdFunction {
let current_removable = read_attr("removable");
let current_nofua = read_attr("nofua");
// Prepare new values
let new_cdrom = if config.cdrom { "1" } else { "0" };
let new_ro = if config.ro { "1" } else { "0" };
let new_removable = if config.removable { "1" } else { "0" };
let new_nofua = if config.nofua { "1" } else { "0" };
// Disconnect current file first using forced_eject if available (PiKVM approach)
let forced_eject_path = lun_path.join("forced_eject");
if forced_eject_path.exists() {
// forced_eject is safer - it forcibly detaches regardless of host state
debug!("Using forced_eject to clear LUN {}", lun);
let _ = write_file(&forced_eject_path, "1");
} else {
// Fallback to clearing file directly
let _ = write_file(&lun_path.join("file"), "");
}
// Brief yield to allow USB stack to process the disconnect
// Reduced from 200ms to 50ms - let USB protocol handle timing
std::thread::sleep(std::time::Duration::from_millis(50));
// Write only changed attributes
let cdrom_changed = current_cdrom != new_cdrom;
if cdrom_changed {
debug!(
@@ -186,13 +145,11 @@ impl MsdFunction {
write_file(&lun_path.join("nofua"), new_nofua)?;
}
// If cdrom mode changed, brief yield for USB host
if cdrom_changed {
debug!("CDROM mode changed, brief yield for USB host");
std::thread::sleep(std::time::Duration::from_millis(50));
}
// Set file path (this triggers the actual mount) - with retry on EBUSY
if config.file.exists() {
let file_path = config.file.to_string_lossy();
let mut last_error = None;
@@ -210,7 +167,6 @@ impl MsdFunction {
return Ok(());
}
Err(e) => {
// Check if it's EBUSY (error code 16)
let is_busy = e.to_string().contains("Device or resource busy")
|| e.to_string().contains("os error 16");
@@ -220,7 +176,6 @@ impl MsdFunction {
lun,
attempt + 1
);
// Exponential backoff: 50, 100, 200, 400ms
std::thread::sleep(std::time::Duration::from_millis(50 << attempt));
last_error = Some(e);
continue;
@@ -231,7 +186,6 @@ impl MsdFunction {
}
}
// If we get here, all retries failed
if let Some(e) = last_error {
return Err(e);
}
@@ -242,9 +196,6 @@ impl MsdFunction {
Ok(())
}
/// Disconnect LUN (async version)
///
/// Preferred for async contexts.
pub async fn disconnect_lun_async(&self, gadget_path: &Path, lun: u8) -> Result<()> {
let gadget_path = gadget_path.to_path_buf();
let this = self.clone();
@@ -254,17 +205,10 @@ impl MsdFunction {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Disconnect LUN (clear file)
///
/// This method uses forced_eject when available, which is safer than
/// directly clearing the file path. Based on PiKVM's implementation.
/// See: https://docs.kernel.org/usb/mass-storage.html
pub fn disconnect_lun(&self, gadget_path: &Path, lun: u8) -> Result<()> {
let lun_path = self.lun_path(gadget_path, lun);
if lun_path.exists() {
// Prefer forced_eject if available (PiKVM approach)
// forced_eject forcibly detaches the backing file regardless of host state
let forced_eject_path = lun_path.join("forced_eject");
if forced_eject_path.exists() {
debug!(
@@ -282,7 +226,6 @@ impl MsdFunction {
}
}
} else {
// Fallback to clearing file directly
write_file(&lun_path.join("file"), "")?;
}
info!("LUN {} disconnected", lun);
@@ -291,7 +234,6 @@ impl MsdFunction {
Ok(())
}
/// Get current LUN file path
pub fn get_lun_file(&self, gadget_path: &Path, lun: u8) -> Option<PathBuf> {
let lun_path = self.lun_path(gadget_path, lun);
let file_path = lun_path.join("file");
@@ -306,7 +248,6 @@ impl MsdFunction {
None
}
/// Check if LUN is connected
pub fn is_lun_connected(&self, gadget_path: &Path, lun: u8) -> bool {
self.get_lun_file(gadget_path, lun).is_some()
}
@@ -318,39 +259,23 @@ impl GadgetFunction for MsdFunction {
}
fn endpoints_required(&self) -> u8 {
2 // IN + OUT for bulk transfers
}
fn meta(&self) -> FunctionMeta {
FunctionMeta {
name: self.name().to_string(),
description: if self.instance == 0 {
"Mass Storage Drive".to_string()
} else {
format!("Extra Drive #{}", self.instance)
},
endpoints: self.endpoints_required(),
enabled: true,
}
2
}
fn create(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
create_dir(&func_path)?;
// Set stall to 0 (workaround for some hosts)
let stall_path = func_path.join("stall");
if stall_path.exists() {
let _ = write_file(&stall_path, "0");
}
// LUN 0 is created automatically, but ensure it exists
let lun0_path = func_path.join("lun.0");
if !lun0_path.exists() {
create_dir(&lun0_path)?;
}
// Set default LUN 0 parameters
let _ = write_file(&lun0_path.join("cdrom"), "0");
let _ = write_file(&lun0_path.join("ro"), "0");
let _ = write_file(&lun0_path.join("removable"), "1");
@@ -382,12 +307,10 @@ impl GadgetFunction for MsdFunction {
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
// Disconnect all LUNs first
for lun in 0..8 {
let _ = self.disconnect_lun(gadget_path, lun);
}
// Remove function directory
if let Err(e) = remove_dir(&func_path) {
warn!("Could not remove MSD function directory: {}", e);
}

View File

@@ -1,10 +1,3 @@
//! HID Report Descriptors
/// Keyboard HID Report Descriptor (no LED output - saves 1 endpoint)
/// Report format (8 bytes input):
/// [0] Modifier keys (8 bits)
/// [1] Reserved
/// [2-7] Key codes (6 keys)
pub const KEYBOARD: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x06, // Usage (Keyboard)
@@ -34,12 +27,46 @@ pub const KEYBOARD: &[u8] = &[
0xC0, // End Collection
];
/// Relative Mouse HID Report Descriptor (4 bytes report)
/// Report format:
/// [0] Buttons (5 bits) + padding (3 bits)
/// [1] X movement (signed 8-bit)
/// [2] Y movement (signed 8-bit)
/// [3] Wheel (signed 8-bit)
pub const KEYBOARD_WITH_LED: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x06, // Usage (Keyboard)
0xA1, 0x01, // Collection (Application)
// Modifier keys input (8 bits)
0x05, 0x07, // Usage Page (Key Codes)
0x19, 0xE0, // Usage Minimum (224) - Left Control
0x29, 0xE7, // Usage Maximum (231) - Right GUI
0x15, 0x00, // Logical Minimum (0)
0x25, 0x01, // Logical Maximum (1)
0x75, 0x01, // Report Size (1)
0x95, 0x08, // Report Count (8)
0x81, 0x02, // Input (Data, Variable, Absolute) - Modifier byte
// Reserved byte
0x95, 0x01, // Report Count (1)
0x75, 0x08, // Report Size (8)
0x81, 0x01, // Input (Constant) - Reserved byte
// LED output bits
0x95, 0x05, // Report Count (5)
0x75, 0x01, // Report Size (1)
0x05, 0x08, // Usage Page (LEDs)
0x19, 0x01, // Usage Minimum (1)
0x29, 0x05, // Usage Maximum (5)
0x91, 0x02, // Output (Data, Variable, Absolute)
// LED padding
0x95, 0x01, // Report Count (1)
0x75, 0x03, // Report Size (3)
0x91, 0x01, // Output (Constant)
// Key array (6 bytes)
0x95, 0x06, // Report Count (6)
0x75, 0x08, // Report Size (8)
0x15, 0x00, // Logical Minimum (0)
0x26, 0xFF, 0x00, // Logical Maximum (255)
0x05, 0x07, // Usage Page (Key Codes)
0x19, 0x00, // Usage Minimum (0)
0x2A, 0xFF, 0x00, // Usage Maximum (255)
0x81, 0x00, // Input (Data, Array) - Key array (6 keys)
0xC0, // End Collection
];
pub const MOUSE_RELATIVE: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
@@ -79,12 +106,6 @@ pub const MOUSE_RELATIVE: &[u8] = &[
0xC0, // End Collection
];
/// Absolute Mouse HID Report Descriptor (6 bytes report)
/// Report format:
/// [0] Buttons (5 bits) + padding (3 bits)
/// [1-2] X position (16-bit, 0-32767)
/// [3-4] Y position (16-bit, 0-32767)
/// [5] Wheel (signed 8-bit)
pub const MOUSE_ABSOLUTE: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
@@ -130,10 +151,6 @@ pub const MOUSE_ABSOLUTE: &[u8] = &[
0xC0, // End Collection
];
/// Consumer Control HID Report Descriptor (2 bytes report)
/// Report format:
/// [0-1] Consumer Control Usage (16-bit little-endian)
/// Supports: Play/Pause, Stop, Next/Prev Track, Mute, Volume Up/Down, etc.
pub const CONSUMER_CONTROL: &[u8] = &[
0x05, 0x0C, // Usage Page (Consumer)
0x09, 0x01, // Usage (Consumer Control)
@@ -155,6 +172,7 @@ mod tests {
#[test]
fn test_report_descriptor_sizes() {
assert!(!KEYBOARD.is_empty());
assert!(!KEYBOARD_WITH_LED.is_empty());
assert!(!MOUSE_RELATIVE.is_empty());
assert!(!MOUSE_ABSOLUTE.is_empty());
assert!(!CONSUMER_CONTROL.is_empty());

View File

@@ -1,437 +1,242 @@
//! OTG Service - unified gadget lifecycle management
//!
//! This module provides centralized management for USB OTG gadget functions.
//! It solves the ownership problem where both HID and MSD need access to the
//! same USB gadget but should be independently configurable.
//!
//! Architecture:
//! ```text
//! ┌─────────────────────────┐
//! │ OtgService │
//! │ ┌───────────────────┐ │
//! │ │ OtgGadgetManager │ │
//! │ └───────────────────┘ │
//! │ ↓ ↓ │
//! │ ┌─────┐ ┌─────┐ │
//! │ │ HID │ │ MSD │ │
//! │ └─────┘ └─────┘ │
//! └─────────────────────────┘
//! ↑ ↑
//! HidController MsdController
//! ```
use std::path::PathBuf;
use std::sync::atomic::{AtomicU8, Ordering};
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, info, warn};
use super::manager::{wait_for_hid_devices, GadgetDescriptor, OtgGadgetManager};
use super::msd::MsdFunction;
use crate::config::{OtgDescriptorConfig, OtgHidFunctions};
use crate::config::{HidBackend, HidConfig, MsdConfig, OtgDescriptorConfig, OtgHidFunctions};
use crate::error::{AppError, Result};
/// Bitflags for requested functions (lock-free)
const FLAG_HID: u8 = 0b01;
const FLAG_MSD: u8 = 0b10;
/// HID device paths
#[derive(Debug, Clone, Default)]
pub struct HidDevicePaths {
pub keyboard: Option<PathBuf>,
pub mouse_relative: Option<PathBuf>,
pub mouse_absolute: Option<PathBuf>,
pub consumer: Option<PathBuf>,
pub udc: Option<String>,
pub keyboard_leds_enabled: bool,
}
impl HidDevicePaths {
pub fn existing_paths(&self) -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Some(ref p) = self.keyboard {
paths.push(p.clone());
}
if let Some(ref p) = self.mouse_relative {
paths.push(p.clone());
}
if let Some(ref p) = self.mouse_absolute {
paths.push(p.clone());
}
if let Some(ref p) = self.consumer {
paths.push(p.clone());
}
paths
[
&self.keyboard,
&self.mouse_relative,
&self.mouse_absolute,
&self.consumer,
]
.into_iter()
.filter_map(|p| p.as_ref().cloned())
.collect()
}
}
/// OTG Service state
#[derive(Debug, Clone, Default)]
pub struct OtgServiceState {
/// Whether the gadget is created and bound
pub gadget_active: bool,
/// Whether HID functions are enabled
pub hid_enabled: bool,
/// Whether MSD function is enabled
pub msd_enabled: bool,
/// HID device paths (set after gadget setup)
pub hid_paths: Option<HidDevicePaths>,
/// HID function selection (set after gadget setup)
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct OtgDesiredState {
pub udc: Option<String>,
pub descriptor: GadgetDescriptor,
pub hid_functions: Option<OtgHidFunctions>,
/// Error message if setup failed
pub keyboard_leds: bool,
pub msd_enabled: bool,
pub max_endpoints: u8,
}
impl Default for OtgDesiredState {
fn default() -> Self {
Self {
udc: None,
descriptor: GadgetDescriptor::default(),
hid_functions: None,
keyboard_leds: false,
msd_enabled: false,
max_endpoints: super::endpoint::DEFAULT_MAX_ENDPOINTS,
}
}
}
impl OtgDesiredState {
pub(crate) fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result<Self> {
let hid_functions = if hid.backend == HidBackend::Otg {
let functions = hid.constrained_otg_functions();
Some(functions)
} else {
None
};
hid.validate_otg_endpoint_budget(msd.enabled)?;
Ok(Self {
udc: hid.resolved_otg_udc(),
descriptor: GadgetDescriptor::from(&hid.otg_descriptor),
hid_functions,
keyboard_leds: hid.effective_otg_keyboard_leds(),
msd_enabled: msd.enabled,
max_endpoints: hid
.resolved_otg_endpoint_limit()
.unwrap_or(super::endpoint::DEFAULT_MAX_ENDPOINTS),
})
}
#[inline]
pub fn hid_enabled(&self) -> bool {
self.hid_functions.is_some()
}
}
#[derive(Debug, Clone, Default)]
struct OtgServiceState {
pub gadget_active: bool,
pub hid_enabled: bool,
pub msd_enabled: bool,
pub configured_udc: Option<String>,
pub hid_paths: Option<HidDevicePaths>,
pub hid_functions: Option<OtgHidFunctions>,
pub keyboard_leds_enabled: bool,
pub max_endpoints: u8,
pub descriptor: Option<GadgetDescriptor>,
pub error: Option<String>,
}
/// OTG Service - unified gadget lifecycle management
///
/// This service owns the OtgGadgetManager and provides a high-level interface
/// for enabling/disabling HID and MSD functions. It ensures proper coordination
/// between the two subsystems and handles gadget lifecycle management.
pub struct OtgService {
/// The underlying gadget manager
manager: Mutex<Option<OtgGadgetManager>>,
/// Current state
state: RwLock<OtgServiceState>,
/// MSD function handle (for runtime LUN configuration)
msd_function: RwLock<Option<MsdFunction>>,
/// Requested functions flags (atomic, lock-free read/write)
requested_flags: AtomicU8,
/// Requested HID function set
hid_functions: RwLock<OtgHidFunctions>,
/// Current descriptor configuration
current_descriptor: RwLock<GadgetDescriptor>,
desired: RwLock<OtgDesiredState>,
}
impl OtgService {
/// Create a new OTG service
pub fn new() -> Self {
Self {
manager: Mutex::new(None),
state: RwLock::new(OtgServiceState::default()),
msd_function: RwLock::new(None),
requested_flags: AtomicU8::new(0),
hid_functions: RwLock::new(OtgHidFunctions::default()),
current_descriptor: RwLock::new(GadgetDescriptor::default()),
desired: RwLock::new(OtgDesiredState::default()),
}
}
/// Check if HID is requested (lock-free)
#[inline]
fn is_hid_requested(&self) -> bool {
self.requested_flags.load(Ordering::Acquire) & FLAG_HID != 0
}
/// Check if MSD is requested (lock-free)
#[inline]
fn is_msd_requested(&self) -> bool {
self.requested_flags.load(Ordering::Acquire) & FLAG_MSD != 0
}
/// Set HID requested flag (lock-free)
#[inline]
fn set_hid_requested(&self, requested: bool) {
if requested {
self.requested_flags.fetch_or(FLAG_HID, Ordering::Release);
} else {
self.requested_flags.fetch_and(!FLAG_HID, Ordering::Release);
}
}
/// Set MSD requested flag (lock-free)
#[inline]
fn set_msd_requested(&self, requested: bool) {
if requested {
self.requested_flags.fetch_or(FLAG_MSD, Ordering::Release);
} else {
self.requested_flags.fetch_and(!FLAG_MSD, Ordering::Release);
}
}
/// Check if OTG is available on this system
pub fn is_available() -> bool {
OtgGadgetManager::is_available() && OtgGadgetManager::find_udc().is_some()
}
/// Get current service state
pub async fn state(&self) -> OtgServiceState {
self.state.read().await.clone()
}
/// Check if gadget is active
pub async fn is_gadget_active(&self) -> bool {
self.state.read().await.gadget_active
}
/// Check if HID is enabled
pub async fn is_hid_enabled(&self) -> bool {
self.state.read().await.hid_enabled
}
/// Check if MSD is enabled
pub async fn is_msd_enabled(&self) -> bool {
self.state.read().await.msd_enabled
}
/// Get gadget path (for MSD LUN configuration)
pub async fn gadget_path(&self) -> Option<PathBuf> {
let manager = self.manager.lock().await;
manager.as_ref().map(|m| m.gadget_path().clone())
}
/// Get HID device paths
pub async fn hid_device_paths(&self) -> Option<HidDevicePaths> {
self.state.read().await.hid_paths.clone()
}
/// Get current HID function selection
pub async fn hid_functions(&self) -> OtgHidFunctions {
self.hid_functions.read().await.clone()
}
/// Update HID function selection
pub async fn update_hid_functions(&self, functions: OtgHidFunctions) -> Result<()> {
if functions.is_empty() {
return Err(AppError::BadRequest(
"OTG HID functions cannot be empty".to_string(),
));
}
{
let mut current = self.hid_functions.write().await;
if *current == functions {
return Ok(());
}
*current = functions;
}
// If HID is active, recreate gadget with new function set
if self.is_hid_requested() {
self.recreate_gadget().await?;
}
Ok(())
}
/// Get MSD function handle (for LUN configuration)
pub async fn msd_function(&self) -> Option<MsdFunction> {
self.msd_function.read().await.clone()
}
/// Enable HID functions
///
/// This will create the gadget if not already created, add HID functions,
/// and bind the gadget to UDC.
pub async fn enable_hid(&self) -> Result<HidDevicePaths> {
info!("Enabling HID functions via OtgService");
// Mark HID as requested (lock-free)
self.set_hid_requested(true);
// Check if already enabled and function set unchanged
let requested_functions = self.hid_functions.read().await.clone();
{
let state = self.state.read().await;
if state.hid_enabled && state.hid_functions.as_ref() == Some(&requested_functions) {
if let Some(ref paths) = state.hid_paths {
info!("HID already enabled, returning existing paths");
return Ok(paths.clone());
}
}
}
// Recreate gadget with both HID and MSD if needed
self.recreate_gadget().await?;
// Get HID paths from state
let state = self.state.read().await;
state
.hid_paths
.clone()
.ok_or_else(|| AppError::Internal("HID paths not set after gadget setup".to_string()))
pub async fn apply_config(&self, hid: &HidConfig, msd: &MsdConfig) -> Result<()> {
let desired = OtgDesiredState::from_config(hid, msd)?;
self.apply_desired_state(desired).await
}
/// Disable HID functions
///
/// This will unbind the gadget, remove HID functions, and optionally
/// recreate the gadget with only MSD if MSD is still enabled.
pub async fn disable_hid(&self) -> Result<()> {
info!("Disabling HID functions via OtgService");
// Mark HID as not requested (lock-free)
self.set_hid_requested(false);
// Check if HID is enabled
pub(crate) async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> {
{
let state = self.state.read().await;
if !state.hid_enabled {
info!("HID already disabled");
return Ok(());
}
let mut current = self.desired.write().await;
*current = desired;
}
// Recreate gadget without HID (or destroy if MSD also disabled)
self.recreate_gadget().await
self.reconcile_gadget().await
}
/// Enable MSD function
///
/// This will create the gadget if not already created, add MSD function,
/// and bind the gadget to UDC.
pub async fn enable_msd(&self) -> Result<MsdFunction> {
info!("Enabling MSD function via OtgService");
async fn reconcile_gadget(&self) -> Result<()> {
let desired = self.desired.read().await.clone();
// Mark MSD as requested (lock-free)
self.set_msd_requested(true);
// Check if already enabled
{
let state = self.state.read().await;
if state.msd_enabled {
let msd = self.msd_function.read().await;
if let Some(ref func) = *msd {
info!("MSD already enabled, returning existing function");
return Ok(func.clone());
}
}
}
// Recreate gadget with both HID and MSD if needed
self.recreate_gadget().await?;
// Get MSD function
let msd = self.msd_function.read().await;
msd.clone().ok_or_else(|| {
AppError::Internal("MSD function not set after gadget setup".to_string())
})
}
/// Disable MSD function
///
/// This will unbind the gadget, remove MSD function, and optionally
/// recreate the gadget with only HID if HID is still enabled.
pub async fn disable_msd(&self) -> Result<()> {
info!("Disabling MSD function via OtgService");
// Mark MSD as not requested (lock-free)
self.set_msd_requested(false);
// Check if MSD is enabled
{
let state = self.state.read().await;
if !state.msd_enabled {
info!("MSD already disabled");
return Ok(());
}
}
// Recreate gadget without MSD (or destroy if HID also disabled)
self.recreate_gadget().await
}
/// Recreate the gadget with currently requested functions
///
/// This is called whenever the set of enabled functions changes.
/// It will:
/// 1. Check if recreation is needed (function set changed)
/// 2. If needed: cleanup existing gadget
/// 3. Create new gadget with requested functions
/// 4. Setup and bind
async fn recreate_gadget(&self) -> Result<()> {
// Read requested flags atomically (lock-free)
let hid_requested = self.is_hid_requested();
let msd_requested = self.is_msd_requested();
let hid_functions = if hid_requested {
self.hid_functions.read().await.clone()
} else {
OtgHidFunctions::default()
};
info!(
"Recreating gadget with: HID={}, MSD={}",
hid_requested, msd_requested
debug!(
"Reconciling OTG gadget: HID={}, MSD={}, UDC={:?}",
desired.hid_enabled(),
desired.msd_enabled,
desired.udc
);
// Check if gadget already matches requested state
{
let state = self.state.read().await;
let functions_match = if hid_requested {
state.hid_functions.as_ref() == Some(&hid_functions)
} else {
state.hid_functions.is_none()
};
if state.gadget_active
&& state.hid_enabled == hid_requested
&& state.msd_enabled == msd_requested
&& functions_match
&& state.hid_enabled == desired.hid_enabled()
&& state.msd_enabled == desired.msd_enabled
&& state.configured_udc == desired.udc
&& state.hid_functions == desired.hid_functions
&& state.keyboard_leds_enabled == desired.keyboard_leds
&& state.max_endpoints == desired.max_endpoints
&& state.descriptor.as_ref() == Some(&desired.descriptor)
{
info!("Gadget already has requested functions, skipping recreate");
debug!("OTG gadget already matches desired state");
return Ok(());
}
}
// Cleanup existing gadget
{
let mut manager = self.manager.lock().await;
if let Some(mut m) = manager.take() {
info!("Cleaning up existing gadget before recreate");
debug!("Cleaning up existing gadget before OTG reconcile");
if let Err(e) = m.cleanup() {
warn!("Error cleaning up existing gadget: {}", e);
}
}
}
// Clear MSD function
*self.msd_function.write().await = None;
// Update state to inactive
{
let mut state = self.state.write().await;
state.gadget_active = false;
state.hid_enabled = false;
state.msd_enabled = false;
state.configured_udc = None;
state.hid_paths = None;
state.hid_functions = None;
state.keyboard_leds_enabled = false;
state.max_endpoints = super::endpoint::DEFAULT_MAX_ENDPOINTS;
state.descriptor = None;
state.error = None;
}
// If nothing requested, we're done
if !hid_requested && !msd_requested {
info!("No functions requested, gadget destroyed");
if !desired.hid_enabled() && !desired.msd_enabled {
info!("OTG desired state is empty, gadget removed");
return Ok(());
}
// Check if OTG is available
if !Self::is_available() {
let error = "OTG not available: ConfigFS not mounted or no UDC found".to_string();
let mut state = self.state.write().await;
state.error = Some(error.clone());
if let Err(e) = super::configfs::ensure_libcomposite_loaded() {
warn!("Failed to ensure libcomposite is available: {}", e);
}
if !OtgGadgetManager::is_available() {
let error = "OTG not available: ConfigFS not mounted".to_string();
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
// Create new gadget manager with current descriptor
let descriptor = self.current_descriptor.read().await.clone();
let udc = desired.udc.clone().ok_or_else(|| {
let error = "OTG not available: no UDC found".to_string();
AppError::Internal(error)
})?;
let mut manager = OtgGadgetManager::with_descriptor(
super::configfs::DEFAULT_GADGET_NAME,
super::endpoint::DEFAULT_MAX_ENDPOINTS,
descriptor,
desired.max_endpoints,
desired.descriptor.clone(),
);
let mut hid_paths = None;
// Add HID functions if requested
if hid_requested {
if hid_functions.is_empty() {
let error = "HID functions set is empty".to_string();
let mut state = self.state.write().await;
state.error = Some(error.clone());
return Err(AppError::BadRequest(error));
}
let mut paths = HidDevicePaths::default();
if let Some(hid_functions) = desired.hid_functions.clone() {
let mut paths = HidDevicePaths {
udc: Some(udc.clone()),
keyboard_leds_enabled: desired.keyboard_leds,
..Default::default()
};
if hid_functions.keyboard {
match manager.add_keyboard() {
match manager.add_keyboard(desired.keyboard_leds) {
Ok(kb) => paths.keyboard = Some(kb),
Err(e) => {
let error = format!("Failed to add keyboard HID function: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
@@ -442,8 +247,7 @@ impl OtgService {
Ok(rel) => paths.mouse_relative = Some(rel),
Err(e) => {
let error = format!("Failed to add relative mouse HID function: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
@@ -454,8 +258,7 @@ impl OtgService {
Ok(abs) => paths.mouse_absolute = Some(abs),
Err(e) => {
let error = format!("Failed to add absolute mouse HID function: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
@@ -466,8 +269,7 @@ impl OtgService {
Ok(consumer) => paths.consumer = Some(consumer),
Err(e) => {
let error = format!("Failed to add consumer HID function: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
@@ -477,8 +279,7 @@ impl OtgService {
debug!("HID functions added to gadget");
}
// Add MSD function if requested
let msd_func = if msd_requested {
let msd_func = if desired.msd_enabled {
match manager.add_msd() {
Ok(func) => {
debug!("MSD function added to gadget");
@@ -486,8 +287,7 @@ impl OtgService {
}
Err(e) => {
let error = format!("Failed to add MSD function: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
}
@@ -495,25 +295,19 @@ impl OtgService {
None
};
// Setup gadget
if let Err(e) = manager.setup() {
let error = format!("Failed to setup gadget: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
self.state.write().await.error = Some(error.clone());
return Err(AppError::Internal(error));
}
// Bind to UDC
if let Err(e) = manager.bind() {
let error = format!("Failed to bind gadget to UDC: {}", e);
let mut state = self.state.write().await;
state.error = Some(error.clone());
// Cleanup on failure
if let Err(e) = manager.bind(&udc) {
let error = format!("Failed to bind gadget to UDC {}: {}", udc, e);
self.state.write().await.error = Some(error.clone());
let _ = manager.cleanup();
return Err(AppError::Internal(error));
}
// Wait for HID devices to appear
if let Some(ref paths) = hid_paths {
let device_paths = paths.existing_paths();
if !device_paths.is_empty() && !wait_for_hid_devices(&device_paths, 2000).await {
@@ -521,103 +315,35 @@ impl OtgService {
}
}
// Store manager and update state
{
*self.manager.lock().await = Some(manager);
}
{
*self.msd_function.write().await = msd_func;
}
*self.manager.lock().await = Some(manager);
*self.msd_function.write().await = msd_func;
{
let mut state = self.state.write().await;
state.gadget_active = true;
state.hid_enabled = hid_requested;
state.msd_enabled = msd_requested;
state.hid_enabled = desired.hid_enabled();
state.msd_enabled = desired.msd_enabled;
state.configured_udc = Some(udc);
state.hid_paths = hid_paths;
state.hid_functions = if hid_requested {
Some(hid_functions)
} else {
None
};
state.hid_functions = desired.hid_functions;
state.keyboard_leds_enabled = desired.keyboard_leds;
state.max_endpoints = desired.max_endpoints;
state.descriptor = Some(desired.descriptor);
state.error = None;
}
info!("Gadget created successfully");
info!("OTG gadget reconciled successfully");
Ok(())
}
/// Update the descriptor configuration
///
/// This updates the stored descriptor and triggers a gadget recreation
/// if the gadget is currently active.
pub async fn update_descriptor(&self, config: &OtgDescriptorConfig) -> Result<()> {
let new_descriptor = GadgetDescriptor {
vendor_id: config.vendor_id,
product_id: config.product_id,
device_version: super::configfs::DEFAULT_USB_BCD_DEVICE,
manufacturer: config.manufacturer.clone(),
product: config.product.clone(),
serial_number: config
.serial_number
.clone()
.unwrap_or_else(|| "0123456789".to_string()),
};
// Update stored descriptor
*self.current_descriptor.write().await = new_descriptor;
// If gadget is active, recreate it with new descriptor
let state = self.state.read().await;
if state.gadget_active {
drop(state); // Release read lock before calling recreate
info!("Descriptor changed, recreating gadget");
self.force_recreate_gadget().await?;
}
Ok(())
}
/// Force recreate the gadget (used when descriptor changes)
async fn force_recreate_gadget(&self) -> Result<()> {
// Cleanup existing gadget
{
let mut manager = self.manager.lock().await;
if let Some(mut m) = manager.take() {
info!("Cleaning up existing gadget for descriptor change");
if let Err(e) = m.cleanup() {
warn!("Error cleaning up existing gadget: {}", e);
}
}
}
// Clear MSD function
*self.msd_function.write().await = None;
// Update state to inactive
{
let mut state = self.state.write().await;
state.gadget_active = false;
state.hid_enabled = false;
state.msd_enabled = false;
state.hid_paths = None;
state.hid_functions = None;
state.error = None;
}
// Recreate with current requested functions
self.recreate_gadget().await
}
/// Shutdown the OTG service and cleanup all resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down OTG service");
// Mark nothing as requested (lock-free)
self.requested_flags.store(0, Ordering::Release);
{
let mut desired = self.desired.write().await;
*desired = OtgDesiredState::default();
}
// Cleanup gadget
let mut manager = self.manager.lock().await;
if let Some(mut m) = manager.take() {
if let Err(e) = m.cleanup() {
@@ -625,7 +351,6 @@ impl OtgService {
}
}
// Clear state
*self.msd_function.write().await = None;
{
let mut state = self.state.write().await;
@@ -643,10 +368,19 @@ impl Default for OtgService {
}
}
impl Drop for OtgService {
fn drop(&mut self) {
// Gadget cleanup is handled by OtgGadgetManager's Drop
debug!("OtgService dropping");
impl From<&OtgDescriptorConfig> for GadgetDescriptor {
fn from(config: &OtgDescriptorConfig) -> Self {
Self {
vendor_id: config.vendor_id,
product_id: config.product_id,
device_version: super::configfs::DEFAULT_USB_BCD_DEVICE,
manufacturer: config.manufacturer.clone(),
product: config.product.clone(),
serial_number: config
.serial_number
.clone()
.unwrap_or_else(|| "0123456789".to_string()),
}
}
}
@@ -655,18 +389,8 @@ mod tests {
use super::*;
#[test]
fn test_service_creation() {
fn service_new_and_availability_probe() {
let _service = OtgService::new();
// Just test that creation doesn't panic
let _ = OtgService::is_available(); // Depends on environment
}
#[tokio::test]
async fn test_initial_state() {
let service = OtgService::new();
let state = service.state().await;
assert!(!state.gadget_active);
assert!(!state.hid_enabled);
assert!(!state.msd_enabled);
let _ = OtgService::is_available();
}
}

73
src/rtsp/auth.rs Normal file
View File

@@ -0,0 +1,73 @@
use base64::Engine;
use crate::config::RtspConfig;
use super::types::RtspRequest;
pub(crate) fn extract_basic_auth(req: &RtspRequest) -> Option<(String, String)> {
let value = req.headers.get("authorization")?;
let mut parts = value.split_whitespace();
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("basic") {
return None;
}
let b64 = parts.next()?;
let decoded = base64::engine::general_purpose::STANDARD.decode(b64).ok()?;
let raw = String::from_utf8(decoded).ok()?;
let (user, pass) = raw.split_once(':')?;
Some((user.to_string(), pass.to_string()))
}
pub(crate) fn rtsp_auth_credentials(config: &RtspConfig) -> Option<(String, String)> {
let username = config.username.as_ref()?.trim();
if username.is_empty() {
return None;
}
Some((
username.to_string(),
config.password.clone().unwrap_or_default(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use rtsp_types as rtsp;
use std::collections::HashMap;
#[test]
fn rtsp_auth_requires_non_empty_username() {
let mut config = RtspConfig::default();
config.password = Some("secret".to_string());
assert!(rtsp_auth_credentials(&config).is_none());
config.username = Some("".to_string());
assert!(rtsp_auth_credentials(&config).is_none());
config.username = Some("user".to_string());
let credentials = rtsp_auth_credentials(&config).expect("expected credentials");
assert_eq!(credentials, ("user".to_string(), "secret".to_string()));
config.password = None;
let credentials = rtsp_auth_credentials(&config).expect("expected credentials");
assert_eq!(credentials, ("user".to_string(), "".to_string()));
}
#[test]
fn extract_basic_auth_roundtrip() {
let encoded = base64::engine::general_purpose::STANDARD.encode(b"alice:pwd");
let mut headers = HashMap::new();
headers.insert("authorization".to_string(), format!("Basic {}", encoded));
let req = RtspRequest {
method: rtsp::Method::Options,
uri: "*".to_string(),
version: rtsp::Version::V1_0,
headers,
};
assert_eq!(
extract_basic_auth(&req),
Some(("alice".to_string(), "pwd".to_string()))
);
}
}

96
src/rtsp/bitstream.rs Normal file
View File

@@ -0,0 +1,96 @@
use bytes::Bytes;
use crate::video::encoder::registry::VideoEncoderType;
use crate::video::shared_video_pipeline::EncodedVideoFrame;
use super::state::ParameterSets;
pub(crate) fn update_parameter_sets(params: &mut ParameterSets, frame: &EncodedVideoFrame) {
let nal_units = split_annexb_nal_units(frame.data.as_ref());
match frame.codec {
VideoEncoderType::H264 => {
for nal in nal_units {
match h264_nal_type(nal) {
Some(7) => params.h264_sps = Some(Bytes::copy_from_slice(nal)),
Some(8) => params.h264_pps = Some(Bytes::copy_from_slice(nal)),
_ => {}
}
}
}
VideoEncoderType::H265 => {
for nal in nal_units {
match h265_nal_type(nal) {
Some(32) => params.h265_vps = Some(Bytes::copy_from_slice(nal)),
Some(33) => params.h265_sps = Some(Bytes::copy_from_slice(nal)),
Some(34) => params.h265_pps = Some(Bytes::copy_from_slice(nal)),
_ => {}
}
}
}
_ => {}
}
}
fn split_annexb_nal_units(data: &[u8]) -> Vec<&[u8]> {
let mut nal_units = Vec::new();
let mut cursor = 0usize;
while let Some((start, start_code_len)) = find_annexb_start_code(data, cursor) {
let nal_start = start + start_code_len;
if nal_start >= data.len() {
break;
}
let next_start = find_annexb_start_code(data, nal_start)
.map(|(idx, _)| idx)
.unwrap_or(data.len());
let mut nal_end = next_start;
while nal_end > nal_start && data[nal_end - 1] == 0 {
nal_end -= 1;
}
if nal_end > nal_start {
nal_units.push(&data[nal_start..nal_end]);
}
cursor = next_start;
}
nal_units
}
fn find_annexb_start_code(data: &[u8], from: usize) -> Option<(usize, usize)> {
if from >= data.len() {
return None;
}
let mut i = from;
while i + 3 <= data.len() {
if i + 4 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 0
&& data[i + 3] == 1
{
return Some((i, 4));
}
if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
return Some((i, 3));
}
i += 1;
}
None
}
fn h264_nal_type(nal: &[u8]) -> Option<u8> {
nal.first().map(|value| value & 0x1f)
}
fn h265_nal_type(nal: &[u8]) -> Option<u8> {
nal.first().map(|value| (value >> 1) & 0x3f)
}

9
src/rtsp/codec.rs Normal file
View File

@@ -0,0 +1,9 @@
use crate::config::RtspCodec;
use crate::video::encoder::VideoCodecType;
pub(crate) fn rtsp_codec_to_video(codec: RtspCodec) -> VideoCodecType {
match codec {
RtspCodec::H264 => VideoCodecType::H264,
RtspCodec::H265 => VideoCodecType::H265,
}
}

View File

@@ -1,3 +1,14 @@
pub mod service;
//! RTSP TCP server exposing H.264/H.265 video from [`VideoStreamManager`](crate::video::VideoStreamManager).
mod auth;
mod bitstream;
mod codec;
mod protocol;
mod response;
mod sdp;
mod service;
mod state;
mod streaming;
mod types;
pub use service::{RtspService, RtspServiceStatus};

193
src/rtsp/protocol.rs Normal file
View File

@@ -0,0 +1,193 @@
use rtsp_types as rtsp;
use std::collections::HashMap;
use super::types::RtspRequest;
pub(crate) const OPTIONS_PUBLIC_CAPABILITIES: &str =
"OPTIONS, DESCRIBE, SETUP, PLAY, GET_PARAMETER, SET_PARAMETER, TEARDOWN";
pub(crate) fn strip_interleaved_frames_prefix(buffer: &mut Vec<u8>) -> bool {
if buffer.len() < 4 || buffer[0] != b'$' {
return false;
}
let payload_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize;
let frame_len = 4 + payload_len;
if buffer.len() < frame_len {
return false;
}
buffer.drain(0..frame_len);
true
}
pub(crate) fn take_rtsp_request_from_buffer(buffer: &mut Vec<u8>) -> Option<String> {
let delimiter = b"\r\n\r\n";
let pos = find_bytes(buffer, delimiter)?;
let req_end = pos + delimiter.len();
let req_bytes: Vec<u8> = buffer.drain(0..req_end).collect();
Some(String::from_utf8_lossy(&req_bytes).to_string())
}
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
pub(crate) fn parse_rtsp_request(raw: &str) -> Option<RtspRequest> {
let (message, consumed): (rtsp::Message<Vec<u8>>, usize) =
rtsp::Message::parse(raw.as_bytes()).ok()?;
if consumed != raw.len() {
return None;
}
let request = match message {
rtsp::Message::Request(req) => req,
_ => return None,
};
let uri = request
.request_uri()
.map(|value| value.as_str().to_string())
.unwrap_or_default();
let mut headers = HashMap::new();
for (name, value) in request.headers() {
headers.insert(name.to_string().to_ascii_lowercase(), value.to_string());
}
Some(RtspRequest {
method: request.method().clone(),
uri,
version: request.version(),
headers,
})
}
pub(crate) fn parse_interleaved_channel(transport: &str) -> Option<u8> {
let lower = transport.to_ascii_lowercase();
if let Some((_, v)) = lower.split_once("interleaved=") {
let head = v.split(';').next().unwrap_or(v);
let first = head.split('-').next().unwrap_or(head).trim();
return first.parse::<u8>().ok();
}
None
}
pub(crate) fn is_tcp_transport_request(transport: &str) -> bool {
transport
.split(',')
.map(str::trim)
.map(str::to_ascii_lowercase)
.any(|item| item.contains("rtp/avp/tcp") || item.contains("interleaved="))
}
pub(crate) fn is_valid_rtsp_path(method: &rtsp::Method, uri: &str, configured_path: &str) -> bool {
if matches!(method, rtsp::Method::Options) && uri.trim() == "*" {
return true;
}
let normalized_cfg = configured_path.trim_matches('/');
if normalized_cfg.is_empty() {
return false;
}
let request_path = extract_rtsp_path(uri);
if request_path == normalized_cfg {
return true;
}
if !matches!(method, rtsp::Method::Setup | rtsp::Method::Teardown) {
return false;
}
let control_track_path = format!("{}/trackID=0", normalized_cfg);
request_path == "trackID=0" || request_path == control_track_path
}
fn extract_rtsp_path(uri: &str) -> String {
let raw_path = if let Some((_, remainder)) = uri.split_once("://") {
match remainder.find('/') {
Some(idx) => &remainder[idx..],
None => "/",
}
} else {
uri
};
raw_path
.split('?')
.next()
.unwrap_or(raw_path)
.split('#')
.next()
.unwrap_or(raw_path)
.trim_matches('/')
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rtsp_path_matching_follows_sdp_control_rules() {
assert!(is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live",
"live"
));
assert!(is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live/?token=1",
"/live/"
));
assert!(!is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live2",
"live"
));
assert!(!is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/",
"/"
));
assert!(is_valid_rtsp_path(
&rtsp::Method::Setup,
"rtsp://127.0.0.1/live/trackID=0",
"live"
));
assert!(is_valid_rtsp_path(
&rtsp::Method::Setup,
"rtsp://127.0.0.1/trackID=0",
"live"
));
assert!(!is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live/trackID=0",
"live"
));
assert!(is_valid_rtsp_path(&rtsp::Method::Options, "*", "live"));
}
#[test]
fn transport_parsing_detects_tcp_interleaved_requests() {
assert!(is_tcp_transport_request(
"RTP/AVP/TCP;unicast;interleaved=0-1"
));
assert!(is_tcp_transport_request("RTP/AVP;unicast;interleaved=2-3"));
assert!(!is_tcp_transport_request(
"RTP/AVP;unicast;client_port=8000-8001"
));
}
#[test]
fn options_public_includes_standard_methods() {
assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("GET_PARAMETER"));
assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("TEARDOWN"));
}
}

81
src/rtsp/response.rs Normal file
View File

@@ -0,0 +1,81 @@
use rtsp_types as rtsp;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::error::{AppError, Result};
use super::types::RtspRequest;
async fn serialize_and_write<W: AsyncWrite + Unpin>(
stream: &mut W,
response: rtsp::Response<Vec<u8>>,
) -> Result<()> {
let mut data = Vec::new();
response
.write(&mut data)
.map_err(|e| AppError::BadRequest(format!("failed to serialize RTSP response: {}", e)))?;
stream.write_all(&data).await?;
Ok(())
}
pub(crate) async fn send_simple_response<W: AsyncWrite + Unpin>(
stream: &mut W,
code: u16,
_reason: &str,
cseq: Option<&str>,
body: &str,
) -> Result<()> {
let mut builder = rtsp::Response::builder(rtsp::Version::V1_0, status_code_from_u16(code));
if let Some(cseq) = cseq {
builder = builder.header(rtsp::headers::CSEQ, cseq);
}
let response = builder.build(body.as_bytes().to_vec());
serialize_and_write(stream, response).await
}
pub(crate) async fn send_response<W: AsyncWrite + Unpin>(
stream: &mut W,
req: &RtspRequest,
code: u16,
_reason: &str,
extra_headers: Vec<(String, String)>,
body: &str,
session_id: &str,
) -> Result<()> {
let cseq = req
.headers
.get("cseq")
.cloned()
.unwrap_or_else(|| "1".to_string());
let mut builder = rtsp::Response::builder(req.version, status_code_from_u16(code))
.header(rtsp::headers::CSEQ, cseq.as_str());
if !session_id.is_empty() {
builder = builder.header(rtsp::headers::SESSION, session_id);
}
for (name, value) in extra_headers {
let header_name = rtsp::HeaderName::try_from(name.as_str()).map_err(|e| {
AppError::BadRequest(format!("invalid RTSP header name {}: {}", name, e))
})?;
builder = builder.header(header_name, value);
}
let response = builder.build(body.as_bytes().to_vec());
serialize_and_write(stream, response).await
}
pub(crate) fn status_code_from_u16(code: u16) -> rtsp::StatusCode {
match code {
200 => rtsp::StatusCode::Ok,
400 => rtsp::StatusCode::BadRequest,
401 => rtsp::StatusCode::Unauthorized,
404 => rtsp::StatusCode::NotFound,
405 => rtsp::StatusCode::MethodNotAllowed,
453 => rtsp::StatusCode::NotEnoughBandwidth,
455 => rtsp::StatusCode::MethodNotValidInThisState,
461 => rtsp::StatusCode::UnsupportedTransport,
_ => rtsp::StatusCode::InternalServerError,
}
}

224
src/rtsp/sdp.rs Normal file
View File

@@ -0,0 +1,224 @@
use base64::Engine;
use sdp_types as sdp;
use crate::config::RtspConfig;
use crate::video::encoder::VideoCodecType;
use crate::webrtc::rtp::parse_profile_level_id_from_sps;
use super::state::ParameterSets;
pub(crate) fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> String {
let mut attrs = vec!["packetization-mode=1".to_string()];
if let Some(sps) = params.h264_sps.as_ref() {
if let Some(profile_level_id) = parse_profile_level_id_from_sps(sps) {
attrs.push(format!("profile-level-id={}", profile_level_id));
}
} else {
attrs.push("profile-level-id=42e01f".to_string());
}
if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) {
let sps_b64 = base64::engine::general_purpose::STANDARD.encode(sps.as_ref());
let pps_b64 = base64::engine::general_purpose::STANDARD.encode(pps.as_ref());
attrs.push(format!("sprop-parameter-sets={},{}", sps_b64, pps_b64));
}
format!("{} {}", payload_type, attrs.join(";"))
}
pub(crate) fn build_h265_fmtp(payload_type: u8, params: &ParameterSets) -> String {
let mut attrs = Vec::new();
if let Some(vps) = params.h265_vps.as_ref() {
attrs.push(format!(
"sprop-vps={}",
base64::engine::general_purpose::STANDARD.encode(vps.as_ref())
));
}
if let Some(sps) = params.h265_sps.as_ref() {
attrs.push(format!(
"sprop-sps={}",
base64::engine::general_purpose::STANDARD.encode(sps.as_ref())
));
}
if let Some(pps) = params.h265_pps.as_ref() {
attrs.push(format!(
"sprop-pps={}",
base64::engine::general_purpose::STANDARD.encode(pps.as_ref())
));
}
if attrs.is_empty() {
format!("{} profile-id=1", payload_type)
} else {
format!("{} {}", payload_type, attrs.join(";"))
}
}
pub(crate) fn build_sdp(
config: &RtspConfig,
codec: VideoCodecType,
params: &ParameterSets,
) -> String {
let (payload_type, codec_name, fmtp_value) = match codec {
VideoCodecType::H264 => (96u8, "H264", build_h264_fmtp(96, params)),
VideoCodecType::H265 => (99u8, "H265", build_h265_fmtp(99, params)),
_ => {
tracing::warn!("RTSP SDP: unexpected VideoCodecType, falling back to H264");
(96u8, "H264", build_h264_fmtp(96, params))
}
};
let session = sdp::Session {
origin: sdp::Origin {
username: Some("-".to_string()),
sess_id: "0".to_string(),
sess_version: 0,
nettype: "IN".to_string(),
addrtype: "IP4".to_string(),
unicast_address: config.bind.clone(),
},
session_name: "One-KVM RTSP Stream".to_string(),
session_description: None,
uri: None,
emails: Vec::new(),
phones: Vec::new(),
connection: Some(sdp::Connection {
nettype: "IN".to_string(),
addrtype: "IP4".to_string(),
connection_address: "0.0.0.0".to_string(),
}),
bandwidths: Vec::new(),
times: vec![sdp::Time {
start_time: 0,
stop_time: 0,
repeats: Vec::new(),
}],
time_zones: Vec::new(),
key: None,
attributes: vec![sdp::Attribute {
attribute: "control".to_string(),
value: Some("*".to_string()),
}],
medias: vec![sdp::Media {
media: "video".to_string(),
port: 0,
num_ports: None,
proto: "RTP/AVP".to_string(),
fmt: payload_type.to_string(),
media_title: None,
connections: Vec::new(),
bandwidths: Vec::new(),
key: None,
attributes: vec![
sdp::Attribute {
attribute: "rtpmap".to_string(),
value: Some(format!("{} {}/90000", payload_type, codec_name)),
},
sdp::Attribute {
attribute: "fmtp".to_string(),
value: Some(fmtp_value),
},
sdp::Attribute {
attribute: "control".to_string(),
value: Some("trackID=0".to_string()),
},
],
}],
};
let mut output = Vec::new();
if let Err(err) = session.write(&mut output) {
tracing::warn!("Failed to serialize SDP with sdp-types: {}", err);
return String::new();
}
match String::from_utf8(output) {
Ok(sdp_text) => sdp_text,
Err(err) => {
tracing::warn!("Failed to convert SDP bytes to UTF-8: {}", err);
String::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RtspConfig;
use bytes::Bytes;
#[test]
fn build_sdp_h264_is_parseable_with_expected_video_attributes() {
let config = RtspConfig::default();
let mut params = ParameterSets::default();
params.h264_sps = Some(Bytes::from_static(&[0x67, 0x42, 0xe0, 0x1f, 0x96, 0x54]));
params.h264_pps = Some(Bytes::from_static(&[0x68, 0xce, 0x06, 0xe2]));
let sdp_text = build_sdp(&config, VideoCodecType::H264, &params);
assert!(!sdp_text.is_empty());
let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed");
assert_eq!(session.session_name, "One-KVM RTSP Stream");
assert_eq!(session.medias.len(), 1);
let media = &session.medias[0];
assert_eq!(media.media, "video");
assert_eq!(media.proto, "RTP/AVP");
assert_eq!(media.fmt, "96");
let has_rtpmap = media.attributes.iter().any(|attr| {
attr.attribute == "rtpmap" && attr.value.as_deref() == Some("96 H264/90000")
});
assert!(has_rtpmap);
let fmtp_value = media
.attributes
.iter()
.find(|attr| attr.attribute == "fmtp")
.and_then(|attr| attr.value.as_deref())
.expect("missing fmtp value");
assert!(fmtp_value.starts_with("96 "));
assert!(fmtp_value.contains("packetization-mode=1"));
assert!(fmtp_value.contains("sprop-parameter-sets="));
}
#[test]
fn build_sdp_h265_is_parseable_with_expected_video_attributes() {
let config = RtspConfig::default();
let mut params = ParameterSets::default();
params.h265_vps = Some(Bytes::from_static(&[0x40, 0x01, 0x0c, 0x01]));
params.h265_sps = Some(Bytes::from_static(&[0x42, 0x01, 0x01, 0x60]));
params.h265_pps = Some(Bytes::from_static(&[0x44, 0x01, 0xc0, 0x73]));
let sdp_text = build_sdp(&config, VideoCodecType::H265, &params);
assert!(!sdp_text.is_empty());
let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed");
assert_eq!(session.medias.len(), 1);
let media = &session.medias[0];
assert_eq!(media.media, "video");
assert_eq!(media.proto, "RTP/AVP");
assert_eq!(media.fmt, "99");
let has_rtpmap = media.attributes.iter().any(|attr| {
attr.attribute == "rtpmap" && attr.value.as_deref() == Some("99 H265/90000")
});
assert!(has_rtpmap);
let fmtp_value = media
.attributes
.iter()
.find(|attr| attr.attribute == "fmtp")
.and_then(|attr| attr.value.as_deref())
.expect("missing fmtp value");
assert!(fmtp_value.starts_with("99 "));
assert!(fmtp_value.contains("sprop-vps="));
assert!(fmtp_value.contains("sprop-sps="));
assert!(fmtp_value.contains("sprop-pps="));
}
}

File diff suppressed because it is too large Load Diff

28
src/rtsp/state.rs Normal file
View File

@@ -0,0 +1,28 @@
use bytes::Bytes;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Default, Clone)]
pub(crate) struct ParameterSets {
pub h264_sps: Option<Bytes>,
pub h264_pps: Option<Bytes>,
pub h265_vps: Option<Bytes>,
pub h265_sps: Option<Bytes>,
pub h265_pps: Option<Bytes>,
}
#[derive(Clone)]
pub(crate) struct SharedRtspState {
pub active_client: Arc<Mutex<Option<SocketAddr>>>,
pub parameter_sets: Arc<RwLock<ParameterSets>>,
}
impl SharedRtspState {
pub fn new() -> Self {
Self {
active_client: Arc::new(Mutex::new(None)),
parameter_sets: Arc::new(RwLock::new(ParameterSets::default())),
}
}
}

367
src/rtsp/streaming.rs Normal file
View File

@@ -0,0 +1,367 @@
use bytes::Bytes;
use rand::Rng;
use rtp::packet::Packet;
use rtp::packetizer::Payloader;
use rtsp_types as rtsp;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::{sleep, Duration};
use webrtc::util::{Marshal, MarshalSize};
use crate::config::RtspCodec;
use crate::error::{AppError, Result};
use crate::video::encoder::registry::VideoEncoderType;
use crate::video::shared_video_pipeline::EncodedVideoFrame;
use crate::video::VideoStreamManager;
use crate::webrtc::h265_payloader::H265Payloader;
use super::bitstream::update_parameter_sets;
use super::protocol::{
parse_rtsp_request, strip_interleaved_frames_prefix, take_rtsp_request_from_buffer,
};
use super::response::send_response;
use super::state::SharedRtspState;
use super::types::RtspRequest;
pub(crate) const RTP_CLOCK_RATE: u32 = 90_000;
pub(crate) const RTP_MTU: usize = 1200;
pub(crate) const RTSP_BUF_SIZE: usize = 8192;
const RTSP_RESUBSCRIBE_DELAY_MS: u64 = 300;
pub(crate) async fn stream_video_interleaved(
stream: TcpStream,
video_manager: &Arc<VideoStreamManager>,
rtsp_codec: RtspCodec,
channel: u8,
shared: SharedRtspState,
session_id: String,
) -> Result<()> {
let (mut reader, mut writer) = stream.into_split();
let mut rx = video_manager
.subscribe_encoded_frames()
.await
.ok_or_else(|| {
AppError::VideoError("RTSP failed to subscribe encoded frames".to_string())
})?;
video_manager.request_keyframe().await.ok();
let payload_type = match rtsp_codec {
RtspCodec::H264 => 96,
RtspCodec::H265 => 99,
};
let mut sequence_number: u16 = rand::rng().random();
let ssrc: u32 = rand::rng().random();
let mut h264_payloader = rtp::codecs::h264::H264Payloader::default();
let mut h265_payloader = H265Payloader::new();
let mut ctrl_read_buf = [0u8; RTSP_BUF_SIZE];
let mut ctrl_buffer = Vec::with_capacity(RTSP_BUF_SIZE);
// 4-byte interleaved prefix + RTP header + payload shard (≤ RTP_MTU)
let mut interleaved_rtp_buf = Vec::with_capacity(4 + RTP_MTU + 96);
let mut last_rtp_timestamp: u32 = 0;
loop {
tokio::select! {
maybe_frame = rx.recv() => {
let Some(frame) = maybe_frame else {
tracing::warn!("RTSP encoded frame subscription ended, attempting to restart pipeline");
if let Some(new_rx) = video_manager.subscribe_encoded_frames().await {
rx = new_rx;
let _ = video_manager.request_keyframe().await;
tracing::info!("RTSP frame subscription recovered");
} else {
tracing::warn!(
"RTSP failed to resubscribe encoded frames, retrying in {}ms",
RTSP_RESUBSCRIBE_DELAY_MS
);
sleep(Duration::from_millis(RTSP_RESUBSCRIBE_DELAY_MS)).await;
}
continue;
};
if !is_frame_codec_match(&frame, &rtsp_codec) {
continue;
}
{
let mut params = shared.parameter_sets.write().await;
update_parameter_sets(&mut params, &frame);
}
let rtp_timestamp = monotonic_rtp_timestamp(
frame.pts_ms,
&mut last_rtp_timestamp,
frame.duration,
);
let payloads: Vec<Bytes> = match rtsp_codec {
RtspCodec::H264 => h264_payloader
.payload(RTP_MTU, &frame.data)
.map_err(|e| AppError::VideoError(format!("H264 payload failed: {}", e)))?,
RtspCodec::H265 => h265_payloader.payload(RTP_MTU, &frame.data),
};
if payloads.is_empty() {
continue;
}
let total_payloads = payloads.len();
for (idx, payload) in payloads.into_iter().enumerate() {
let marker = idx == total_payloads.saturating_sub(1);
let packet = Packet {
header: rtp::header::Header {
version: 2,
padding: false,
extension: false,
marker,
payload_type,
sequence_number,
timestamp: rtp_timestamp,
ssrc,
..Default::default()
},
payload,
};
sequence_number = sequence_number.wrapping_add(1);
send_interleaved_rtp(&mut writer, channel, &packet, &mut interleaved_rtp_buf)
.await?;
}
if frame.is_keyframe {
tracing::debug!("RTSP keyframe sent");
}
}
read_res = reader.read(&mut ctrl_read_buf) => {
let n = read_res?;
if n == 0 {
break;
}
ctrl_buffer.extend_from_slice(&ctrl_read_buf[..n]);
while strip_interleaved_frames_prefix(&mut ctrl_buffer) {}
while let Some(raw_req) = take_rtsp_request_from_buffer(&mut ctrl_buffer) {
let Some(req) = parse_rtsp_request(&raw_req) else {
continue;
};
if handle_play_control_request(&mut writer, &req, &session_id).await? {
return Ok(());
}
while strip_interleaved_frames_prefix(&mut ctrl_buffer) {}
}
}
}
}
Ok(())
}
pub(crate) async fn send_interleaved_rtp<W: AsyncWrite + Unpin>(
stream: &mut W,
channel: u8,
packet: &Packet,
marshal_buf: &mut Vec<u8>,
) -> Result<()> {
let rtp_len = packet.marshal_size();
let rtp_len_u16 = u16::try_from(rtp_len).map_err(|_| {
AppError::VideoError(format!(
"RTP packet too large for interleaved framing: {} bytes",
rtp_len
))
})?;
marshal_buf.clear();
marshal_buf.reserve(4 + rtp_len);
marshal_buf.extend_from_slice(&[b'$', channel, (rtp_len_u16 >> 8) as u8, rtp_len_u16 as u8]);
let body_off = marshal_buf.len();
marshal_buf.resize(body_off + rtp_len, 0);
let written = packet
.marshal_to(&mut marshal_buf[body_off..])
.map_err(|e| AppError::VideoError(format!("RTP marshal failed: {}", e)))?;
if written != rtp_len {
return Err(AppError::VideoError(format!(
"RTP marshal size mismatch: wrote {written}, expected {rtp_len}"
)));
}
stream.write_all(marshal_buf).await?;
Ok(())
}
pub(crate) async fn handle_play_control_request<W: AsyncWrite + Unpin>(
stream: &mut W,
req: &RtspRequest,
session_id: &str,
) -> Result<bool> {
use super::protocol::OPTIONS_PUBLIC_CAPABILITIES;
match &req.method {
rtsp::Method::Teardown => {
send_response(stream, req, 200, "OK", vec![], "", session_id).await?;
Ok(true)
}
rtsp::Method::Options => {
send_response(
stream,
req,
200,
"OK",
vec![(
"Public".to_string(),
OPTIONS_PUBLIC_CAPABILITIES.to_string(),
)],
"",
session_id,
)
.await?;
Ok(false)
}
rtsp::Method::GetParameter | rtsp::Method::SetParameter => {
send_response(stream, req, 200, "OK", vec![], "", session_id).await?;
Ok(false)
}
_ => {
send_response(
stream,
req,
405,
"Method Not Allowed",
vec![],
"",
session_id,
)
.await?;
Ok(false)
}
}
}
fn pts_to_rtp_timestamp(pts_ms: i64) -> u32 {
if pts_ms <= 0 {
return 0;
}
((pts_ms as u64 * RTP_CLOCK_RATE as u64) / 1000) as u32
}
fn rtp_timestamp_increment(frame_duration: Duration) -> u32 {
let inc = (frame_duration.as_secs_f64() * f64::from(RTP_CLOCK_RATE)).round() as u32;
inc.max(1)
}
fn monotonic_rtp_timestamp(pts_ms: i64, last: &mut u32, frame_duration: Duration) -> u32 {
let from_pts = pts_to_rtp_timestamp(pts_ms);
let inc = rtp_timestamp_increment(frame_duration);
let ts = if from_pts > *last {
from_pts
} else {
last.wrapping_add(inc)
};
*last = ts;
ts
}
fn is_frame_codec_match(frame: &EncodedVideoFrame, codec: &RtspCodec) -> bool {
matches!(
(frame.codec, codec),
(VideoEncoderType::H264, RtspCodec::H264) | (VideoEncoderType::H265, RtspCodec::H265)
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tokio::io::{duplex, AsyncReadExt};
fn make_test_request(method: rtsp::Method) -> RtspRequest {
let mut headers = HashMap::new();
headers.insert("cseq".to_string(), "7".to_string());
RtspRequest {
method,
uri: "rtsp://127.0.0.1/live".to_string(),
version: rtsp::Version::V1_0,
headers,
}
}
async fn read_response_from_duplex(
mut client: tokio::io::DuplexStream,
) -> rtsp::Response<Vec<u8>> {
let mut buf = vec![0u8; 4096];
let n = client
.read(&mut buf)
.await
.expect("failed to read rtsp response");
assert!(n > 0);
let (message, consumed): (rtsp::Message<Vec<u8>>, usize) =
rtsp::Message::parse(&buf[..n]).expect("failed to parse rtsp response");
assert_eq!(consumed, n);
match message {
rtsp::Message::Response(response) => response,
_ => panic!("expected RTSP response"),
}
}
#[tokio::test]
async fn play_control_teardown_returns_ok_and_stop() {
let req = make_test_request(rtsp::Method::Teardown);
let (client, mut server) = duplex(4096);
let should_stop = handle_play_control_request(&mut server, &req, "session-1")
.await
.expect("control handling failed");
assert!(should_stop);
drop(server);
let response = read_response_from_duplex(client).await;
assert_eq!(response.status(), rtsp::StatusCode::Ok);
}
#[tokio::test]
async fn play_control_pause_returns_method_not_allowed() {
let req = make_test_request(rtsp::Method::Pause);
let (client, mut server) = duplex(4096);
let should_stop = handle_play_control_request(&mut server, &req, "session-1")
.await
.expect("control handling failed");
assert!(!should_stop);
drop(server);
let response = read_response_from_duplex(client).await;
assert_eq!(response.status(), rtsp::StatusCode::MethodNotAllowed);
}
#[test]
fn monotonic_rtp_timestamp_steps_when_pts_stays_zero() {
let d = Duration::from_millis(33);
let mut last = 0u32;
let a = monotonic_rtp_timestamp(0, &mut last, d);
let b = monotonic_rtp_timestamp(0, &mut last, d);
let c = monotonic_rtp_timestamp(0, &mut last, d);
assert!(a > 0);
assert!(b > a);
assert!(c > b);
}
#[test]
fn monotonic_rtp_timestamp_uses_pts_when_it_advances() {
let d = Duration::from_millis(33);
let mut last = 0u32;
let a = monotonic_rtp_timestamp(1000, &mut last, d);
assert_eq!(a, 90_000);
let b = monotonic_rtp_timestamp(2000, &mut last, d);
assert_eq!(b, 180_000);
}
}

53
src/rtsp/types.rs Normal file
View File

@@ -0,0 +1,53 @@
use rand::Rng;
use rtsp_types as rtsp;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum RtspServiceStatus {
Stopped,
Starting,
Running,
Error(String),
}
impl fmt::Display for RtspServiceStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Stopped => write!(f, "stopped"),
Self::Starting => write!(f, "starting"),
Self::Running => write!(f, "running"),
Self::Error(err) => write!(f, "error: {}", err),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct RtspRequest {
pub method: rtsp::Method,
pub uri: String,
pub version: rtsp::Version,
pub headers: HashMap<String, String>,
}
pub(crate) struct RtspConnectionState {
pub session_id: String,
pub setup_done: bool,
pub interleaved_channel: u8,
}
impl RtspConnectionState {
pub fn new() -> Self {
Self {
session_id: generate_session_id(),
setup_done: false,
interleaved_channel: 0,
}
}
}
pub(crate) fn generate_session_id() -> String {
let mut rng = rand::rng();
let value: u64 = rng.random();
format!("{:016x}", value)
}

View File

@@ -1,21 +1,11 @@
//! RustDesk BytesCodec - Variable-length framing for TCP messages
//!
//! RustDesk uses a custom variable-length encoding for message framing:
//! - Length <= 0x3F (63): 1-byte header, format `(len << 2)`
//! - Length <= 0x3FFF (16383): 2-byte LE header, format `(len << 2) | 0x1`
//! - Length <= 0x3FFFFF (4194303): 3-byte LE header, format `(len << 2) | 0x2`
//! - Length <= 0x3FFFFFFF (1073741823): 4-byte LE header, format `(len << 2) | 0x3`
//!
//! The low 2 bits of the first byte indicate the header length (+1).
//! Variable-length TCP framing (RustDesk wire format).
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
/// Maximum packet length (1GB)
const MAX_PACKET_LENGTH: usize = 0x3FFFFFFF;
/// Encode a message with RustDesk's variable-length framing
pub fn encode_frame(data: &[u8]) -> io::Result<Vec<u8>> {
let len = data.len();
let mut buf = Vec::with_capacity(len + 4);
@@ -44,8 +34,6 @@ pub fn encode_frame(data: &[u8]) -> io::Result<Vec<u8>> {
Ok(buf)
}
/// Decode the header to get message length
/// Returns (header_length, message_length)
fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) {
let head_len = ((first_byte & 0x3) + 1) as usize;
@@ -64,21 +52,17 @@ fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) {
(head_len, msg_len)
}
/// Read a single framed message from an async reader
pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<BytesMut> {
// Read first byte to determine header length
let mut first_byte = [0u8; 1];
reader.read_exact(&mut first_byte).await?;
let head_len = ((first_byte[0] & 0x3) + 1) as usize;
// Read remaining header bytes if needed
let mut header_rest = [0u8; 3];
if head_len > 1 {
reader.read_exact(&mut header_rest[..head_len - 1]).await?;
}
// Calculate message length
let (_, msg_len) = decode_header(first_byte[0], &header_rest);
if msg_len > MAX_PACKET_LENGTH {
@@ -88,7 +72,6 @@ pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Byte
));
}
// Read message body
let mut buf = BytesMut::with_capacity(msg_len);
buf.resize(msg_len, 0);
reader.read_exact(&mut buf).await?;
@@ -96,7 +79,6 @@ pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Byte
Ok(buf)
}
/// Write a framed message to an async writer
pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, data: &[u8]) -> io::Result<()> {
let frame = encode_frame(data)?;
writer.write_all(&frame).await?;
@@ -104,10 +86,6 @@ pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, data: &[u8]) ->
Ok(())
}
/// Write a framed message using a reusable buffer (reduces allocations)
///
/// This version reuses the provided BytesMut buffer to avoid allocation on each call.
/// The buffer is cleared before use and will grow as needed.
pub async fn write_frame_buffered<W: AsyncWrite + Unpin>(
writer: &mut W,
data: &[u8],
@@ -120,11 +98,9 @@ pub async fn write_frame_buffered<W: AsyncWrite + Unpin>(
Ok(())
}
/// Encode a message with RustDesk's variable-length framing into an existing buffer
pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
let len = data.len();
// Reserve space for header (max 4 bytes) + data
buf.reserve(4 + len);
if len <= 0x3F {
@@ -149,7 +125,7 @@ pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
Ok(())
}
/// BytesCodec for stateful decoding (compatible with tokio-util codec)
/// Stateful decoder for `Framed`.
#[derive(Debug, Clone, Copy)]
pub struct BytesCodec {
state: DecodeState,
@@ -180,7 +156,6 @@ impl BytesCodec {
self.max_packet_length = n;
}
/// Decode from a BytesMut buffer (for use with Framed)
pub fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
let n = match self.state {
DecodeState::Head => match self.decode_head(src)? {
@@ -242,7 +217,6 @@ impl BytesCodec {
Ok(Some(src.split_to(n)))
}
/// Encode a message into a BytesMut buffer
pub fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> io::Result<()> {
let len = data.len();
@@ -276,7 +250,7 @@ mod tests {
fn test_encode_decode_small() {
let data = vec![1u8; 63];
let encoded = encode_frame(&data).unwrap();
assert_eq!(encoded.len(), 63 + 1); // 1 byte header
assert_eq!(encoded.len(), 63 + 1);
let mut codec = BytesCodec::new();
let mut buf = BytesMut::from(&encoded[..]);
@@ -288,7 +262,7 @@ mod tests {
fn test_encode_decode_medium() {
let data = vec![2u8; 1000];
let encoded = encode_frame(&data).unwrap();
assert_eq!(encoded.len(), 1000 + 2); // 2 byte header
assert_eq!(encoded.len(), 1000 + 2);
let mut codec = BytesCodec::new();
let mut buf = BytesMut::from(&encoded[..]);
@@ -300,7 +274,7 @@ mod tests {
fn test_encode_decode_large() {
let data = vec![3u8; 100000];
let encoded = encode_frame(&data).unwrap();
assert_eq!(encoded.len(), 100000 + 3); // 3 byte header
assert_eq!(encoded.len(), 100000 + 3);
let mut codec = BytesCodec::new();
let mut buf = BytesMut::from(&encoded[..]);

View File

@@ -1,57 +1,26 @@
//! RustDesk Configuration
//!
//! Configuration types for the RustDesk protocol integration.
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// RustDesk configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RustDeskConfig {
/// Enable RustDesk protocol
pub enabled: bool,
/// Rendezvous server address (hbbs), e.g., "rs.example.com" or "192.168.1.100:21116"
/// Required for RustDesk to function
pub rendezvous_server: String,
/// Relay server address (hbbr), if different from rendezvous server
/// Usually the same host as rendezvous server but different port (21117)
pub relay_server: Option<String>,
/// Relay server authentication key (licence_key)
/// Required if the relay server is configured with -k option
#[typeshare(skip)]
pub relay_key: Option<String>,
/// Device ID (9-digit number), auto-generated if empty
pub device_id: String,
/// Device password for client authentication
#[typeshare(skip)]
pub device_password: String,
/// Public key for encryption (Curve25519, base64 encoded), auto-generated
#[typeshare(skip)]
pub public_key: Option<String>,
/// Private key for encryption (Curve25519, base64 encoded), auto-generated
#[typeshare(skip)]
pub private_key: Option<String>,
/// Signing public key (Ed25519, base64 encoded), auto-generated
/// Used for SignedId verification by clients
#[typeshare(skip)]
pub signing_public_key: Option<String>,
/// Signing private key (Ed25519, base64 encoded), auto-generated
/// Used for signing SignedId messages
#[typeshare(skip)]
pub signing_private_key: Option<String>,
/// UUID for rendezvous server registration (persisted to avoid UUID_MISMATCH)
#[typeshare(skip)]
pub uuid: Option<String>,
}
@@ -75,8 +44,6 @@ impl Default for RustDeskConfig {
}
impl RustDeskConfig {
/// Check if the configuration is valid for starting the service
/// Returns true if enabled and has a valid server
pub fn is_valid(&self) -> bool {
self.enabled
&& !self.rendezvous_server.is_empty()
@@ -84,44 +51,35 @@ impl RustDeskConfig {
&& !self.device_password.is_empty()
}
/// Get the rendezvous server (user-configured)
pub fn effective_rendezvous_server(&self) -> &str {
&self.rendezvous_server
}
/// Generate a new random device ID
pub fn generate_device_id() -> String {
generate_device_id()
}
/// Generate a new random password
pub fn generate_password() -> String {
generate_random_password()
}
/// Get or generate the UUID for rendezvous registration
/// Returns (uuid_bytes, is_new) where is_new indicates if a new UUID was generated
pub fn ensure_uuid(&mut self) -> ([u8; 16], bool) {
if let Some(ref uuid_str) = self.uuid {
// Try to parse existing UUID
if let Ok(uuid) = uuid::Uuid::parse_str(uuid_str) {
return (*uuid.as_bytes(), false);
}
}
// Generate new UUID
let new_uuid = uuid::Uuid::new_v4();
self.uuid = Some(new_uuid.to_string());
(*new_uuid.as_bytes(), true)
}
/// Get the UUID bytes (returns None if not set)
pub fn get_uuid_bytes(&self) -> Option<[u8; 16]> {
self.uuid
.as_ref()
.and_then(|s| uuid::Uuid::parse_str(s).ok().map(|u| *u.as_bytes()))
}
/// Get the rendezvous server address with default port
pub fn rendezvous_addr(&self) -> String {
let server = &self.rendezvous_server;
if server.is_empty() {
@@ -133,7 +91,6 @@ impl RustDeskConfig {
}
}
/// Get the relay server address with default port
pub fn relay_addr(&self) -> Option<String> {
self.relay_server
.as_ref()
@@ -145,7 +102,6 @@ impl RustDeskConfig {
}
})
.or_else(|| {
// Default: same host as rendezvous server
let server = &self.rendezvous_server;
if !server.is_empty() {
let host = server.split(':').next().unwrap_or("");
@@ -161,7 +117,6 @@ impl RustDeskConfig {
}
}
/// Generate a random 9-digit device ID
pub fn generate_device_id() -> String {
use rand::Rng;
let mut rng = rand::rng();
@@ -169,7 +124,6 @@ pub fn generate_device_id() -> String {
id.to_string()
}
/// Generate a random 8-character password
pub fn generate_random_password() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
@@ -212,7 +166,6 @@ mod tests {
config.rendezvous_server = "example.com:21116".to_string();
assert_eq!(config.rendezvous_addr(), "example.com:21116");
// Empty server returns empty string
config.rendezvous_server = String::new();
assert_eq!(config.rendezvous_addr(), "");
}
@@ -224,17 +177,14 @@ mod tests {
..Default::default()
};
// Rendezvous server configured, relay defaults to same host
assert_eq!(config.relay_addr(), Some("example.com:21117".to_string()));
// Explicit relay server
config.relay_server = Some("relay.example.com".to_string());
assert_eq!(
config.relay_addr(),
Some("relay.example.com:21117".to_string())
);
// No rendezvous server, relay is None
config.rendezvous_server = String::new();
config.relay_server = None;
assert_eq!(config.relay_addr(), None);
@@ -247,10 +197,8 @@ mod tests {
..Default::default()
};
// When user sets a server, use it
assert_eq!(config.effective_rendezvous_server(), "custom.example.com");
// When empty, returns empty
config.rendezvous_server = String::new();
assert_eq!(config.effective_rendezvous_server(), "");
}

View File

@@ -1,12 +1,4 @@
//! RustDesk Connection Handler
//!
//! This module handles incoming connections from RustDesk clients.
//! It manages the connection lifecycle including:
//! - Connection establishment (P2P or via relay)
//! - Encrypted handshake
//! - Authentication
//! - Message routing (video, audio, input)
//! - Video frame streaming (shared with WebRTC)
//! Incoming RustDesk TCP sessions (handshake, AV, input).
use std::net::SocketAddr;
use std::sync::Arc;
@@ -22,7 +14,8 @@ use tokio::sync::{broadcast, mpsc, Mutex};
use tracing::{debug, error, info, warn};
use crate::audio::AudioController;
use crate::hid::{HidController, KeyEventType, KeyboardEvent, KeyboardModifiers};
use crate::hid::{CanonicalKey, HidController, KeyEventType, KeyboardEvent, KeyboardModifiers};
use crate::utils::hostname_from_etc;
use crate::video::codec_constraints::{
encoder_codec_to_id, encoder_codec_to_video_codec, video_codec_to_encoder_codec,
};
@@ -94,13 +87,6 @@ impl InputThrottler {
}
}
/// Get system hostname
fn get_hostname() -> String {
std::fs::read_to_string("/etc/hostname")
.map(|s| s.trim().to_string())
.unwrap_or_else(|_| "One-KVM".to_string())
}
/// Connection state
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
@@ -652,22 +638,22 @@ impl Connection {
// H264 is preferred because it has the best hardware encoder support (RKMPP, VAAPI, etc.)
// and most RustDesk clients support H264 hardware decoding
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H264)
&& registry.is_format_available(VideoEncoderType::H264, false)
&& registry.is_codec_available(VideoEncoderType::H264)
{
return VideoEncoderType::H264;
}
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H265)
&& registry.is_format_available(VideoEncoderType::H265, false)
&& registry.is_codec_available(VideoEncoderType::H265)
{
return VideoEncoderType::H265;
}
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP8)
&& registry.is_format_available(VideoEncoderType::VP8, false)
&& registry.is_codec_available(VideoEncoderType::VP8)
{
return VideoEncoderType::VP8;
}
if constraints.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP9)
&& registry.is_format_available(VideoEncoderType::VP9, false)
&& registry.is_codec_available(VideoEncoderType::VP9)
{
return VideoEncoderType::VP9;
}
@@ -784,7 +770,7 @@ impl Connection {
}
let registry = EncoderRegistry::global();
if registry.is_format_available(new_codec, false) {
if registry.is_codec_available(new_codec) {
info!(
"Client requested codec switch: {:?} -> {:?}",
self.negotiated_codec, new_codec
@@ -1121,16 +1107,16 @@ impl Connection {
// Check which encoders are available (include software fallback)
let h264_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H264)
&& registry.is_format_available(VideoEncoderType::H264, false);
&& registry.is_codec_available(VideoEncoderType::H264);
let h265_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::H265)
&& registry.is_format_available(VideoEncoderType::H265, false);
&& registry.is_codec_available(VideoEncoderType::H265);
let vp8_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP8)
&& registry.is_format_available(VideoEncoderType::VP8, false);
&& registry.is_codec_available(VideoEncoderType::VP8);
let vp9_available = constraints
.is_webrtc_codec_allowed(crate::video::encoder::VideoCodecType::VP9)
&& registry.is_format_available(VideoEncoderType::VP9, false);
&& registry.is_codec_available(VideoEncoderType::VP9);
info!(
"Server encoding capabilities: H264={}, H265={}, VP8={}, VP9={}",
@@ -1165,7 +1151,7 @@ impl Connection {
let mut peer_info = PeerInfo::new();
peer_info.username = "one-kvm".to_string();
peer_info.hostname = get_hostname();
peer_info.hostname = hostname_from_etc();
peer_info.platform = RUSTDESK_COMPAT_PLATFORM.to_string();
peer_info.displays.push(display_info);
peer_info.current_display = 0;
@@ -1328,15 +1314,13 @@ impl Connection {
);
let caps_down = KeyboardEvent {
event_type: KeyEventType::Down,
key: 0x39, // USB HID CapsLock
key: CanonicalKey::CapsLock,
modifiers: KeyboardModifiers::default(),
is_usb_hid: true,
};
let caps_up = KeyboardEvent {
event_type: KeyEventType::Up,
key: 0x39,
key: CanonicalKey::CapsLock,
modifiers: KeyboardModifiers::default(),
is_usb_hid: true,
};
if let Err(e) = hid.send_keyboard(caps_down).await {
warn!("Failed to send CapsLock down: {}", e);
@@ -1351,7 +1335,7 @@ impl Connection {
if let Some(kb_event) = convert_key_event(ke) {
debug!(
"Converted to HID: key=0x{:02X}, event_type={:?}, modifiers={:02X}",
kb_event.key,
kb_event.key.to_hid_usage(),
kb_event.event_type,
kb_event.modifiers.to_hid_byte()
);
@@ -1788,7 +1772,7 @@ async fn run_audio_streaming(
}
// Subscribe to the audio Opus stream
let mut opus_rx = match audio_controller.subscribe_opus_async().await {
let mut opus_rx = match audio_controller.subscribe_opus().await {
Some(rx) => rx,
None => {
// Audio not available, wait and retry
@@ -1833,18 +1817,18 @@ async fn run_audio_streaming(
break 'subscribe_loop;
}
result = opus_rx.changed() => {
if result.is_err() {
// Pipeline was restarted
info!("Audio pipeline closed for connection {}, re-subscribing...", conn_id);
audio_adapter.reset();
tokio::time::sleep(Duration::from_millis(100)).await;
continue 'subscribe_loop;
}
let opus_frame = match opus_rx.borrow().clone() {
result = opus_rx.recv() => {
let opus_frame = match result {
Some(frame) => frame,
None => continue,
None => {
info!(
"Audio pipeline closed for connection {}, re-subscribing...",
conn_id
);
audio_adapter.reset();
tokio::time::sleep(Duration::from_millis(100)).await;
continue 'subscribe_loop;
}
};
// Convert OpusFrame to RustDesk AudioFrame message

View File

@@ -1,10 +1,4 @@
//! RustDesk Cryptography
//!
//! This module implements the NaCl-based cryptography used by RustDesk:
//! - Curve25519 for key exchange
//! - XSalsa20-Poly1305 for authenticated encryption
//! - Ed25519 for signatures
//! - Ed25519 to Curve25519 key conversion for unified keypair usage
//! NaCl crypto (RustDesk-compatible).
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use sodiumoxide::crypto::box_::{self, Nonce, PublicKey, SecretKey};
@@ -12,7 +6,6 @@ use sodiumoxide::crypto::secretbox;
use sodiumoxide::crypto::sign::{self, ed25519};
use thiserror::Error;
/// Cryptography errors
#[derive(Debug, Error)]
pub enum CryptoError {
#[error("Failed to initialize sodiumoxide")]
@@ -31,13 +24,10 @@ pub enum CryptoError {
KeyConversionFailed,
}
/// Initialize the cryptography library
/// Must be called before using any crypto functions
pub fn init() -> Result<(), CryptoError> {
sodiumoxide::init().map_err(|_| CryptoError::InitError)
}
/// A keypair for asymmetric encryption
#[derive(Clone)]
pub struct KeyPair {
pub public_key: PublicKey,
@@ -45,7 +35,6 @@ pub struct KeyPair {
}
impl KeyPair {
/// Generate a new random keypair
pub fn generate() -> Self {
let (public_key, secret_key) = box_::gen_keypair();
Self {
@@ -54,7 +43,6 @@ impl KeyPair {
}
}
/// Create from existing keys
pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result<Self, CryptoError> {
let pk = PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?;
let sk = SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?;
@@ -64,27 +52,22 @@ impl KeyPair {
})
}
/// Get public key as bytes
pub fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_ref()
}
/// Get secret key as bytes
pub fn secret_key_bytes(&self) -> &[u8] {
self.secret_key.as_ref()
}
/// Encode public key as base64
pub fn public_key_base64(&self) -> String {
BASE64.encode(self.public_key_bytes())
}
/// Encode secret key as base64
pub fn secret_key_base64(&self) -> String {
BASE64.encode(self.secret_key_bytes())
}
/// Create from base64-encoded keys
pub fn from_base64(public_key: &str, secret_key: &str) -> Result<Self, CryptoError> {
let pk_bytes = BASE64
.decode(public_key)
@@ -96,15 +79,10 @@ impl KeyPair {
}
}
/// Generate a random nonce for box encryption
pub fn generate_nonce() -> Nonce {
box_::gen_nonce()
}
/// Encrypt data using public-key cryptography (NaCl box)
///
/// Uses the sender's secret key and receiver's public key for encryption.
/// Returns (nonce, ciphertext).
pub fn encrypt_box(
data: &[u8],
their_public_key: &PublicKey,
@@ -115,7 +93,6 @@ pub fn encrypt_box(
(nonce, ciphertext)
}
/// Decrypt data using public-key cryptography (NaCl box)
pub fn decrypt_box(
ciphertext: &[u8],
nonce: &Nonce,
@@ -126,14 +103,12 @@ pub fn decrypt_box(
.map_err(|_| CryptoError::DecryptionFailed)
}
/// Encrypt data with a precomputed shared key
pub fn encrypt_with_key(data: &[u8], key: &secretbox::Key) -> (secretbox::Nonce, Vec<u8>) {
let nonce = secretbox::gen_nonce();
let ciphertext = secretbox::seal(data, &nonce, key);
(nonce, ciphertext)
}
/// Decrypt data with a precomputed shared key
pub fn decrypt_with_key(
ciphertext: &[u8],
nonce: &secretbox::Nonce,
@@ -142,8 +117,6 @@ pub fn decrypt_with_key(
secretbox::open(ciphertext, nonce, key).map_err(|_| CryptoError::DecryptionFailed)
}
/// Compute a shared symmetric key from public/private keypair
/// This is the precomputed key for the NaCl box
pub fn precompute_key(
their_public_key: &PublicKey,
our_secret_key: &SecretKey,
@@ -151,23 +124,18 @@ pub fn precompute_key(
box_::precompute(their_public_key, our_secret_key)
}
/// Create a symmetric key from raw bytes
pub fn symmetric_key_from_slice(key: &[u8]) -> Result<secretbox::Key, CryptoError> {
secretbox::Key::from_slice(key).ok_or(CryptoError::InvalidKeyLength)
}
/// Parse a nonce from bytes
pub fn nonce_from_slice(bytes: &[u8]) -> Result<Nonce, CryptoError> {
Nonce::from_slice(bytes).ok_or(CryptoError::InvalidNonce)
}
/// Parse a public key from bytes
pub fn public_key_from_slice(bytes: &[u8]) -> Result<PublicKey, CryptoError> {
PublicKey::from_slice(bytes).ok_or(CryptoError::InvalidKeyLength)
}
/// Hash a password for storage/comparison
/// RustDesk uses simple SHA256 for password hashing
pub fn hash_password(password: &str, salt: &str) -> Vec<u8> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
@@ -176,35 +144,24 @@ pub fn hash_password(password: &str, salt: &str) -> Vec<u8> {
hasher.finalize().to_vec()
}
/// RustDesk double hash for password verification
/// Client calculates: SHA256(SHA256(password + salt) + challenge)
/// This matches what the client sends in LoginRequest
pub fn hash_password_double(password: &str, salt: &str, challenge: &str) -> Vec<u8> {
use sha2::{Digest, Sha256};
// First hash: SHA256(password + salt)
let mut hasher1 = Sha256::new();
hasher1.update(password.as_bytes());
hasher1.update(salt.as_bytes());
let first_hash = hasher1.finalize();
// Second hash: SHA256(first_hash + challenge)
let mut hasher2 = Sha256::new();
hasher2.update(first_hash);
hasher2.update(challenge.as_bytes());
hasher2.finalize().to_vec()
}
/// Verify a password hash
pub fn verify_password(password: &str, salt: &str, expected_hash: &[u8]) -> bool {
let computed = hash_password(password, salt);
// Constant-time comparison would be better, but for our use case this is acceptable
computed == expected_hash
}
/// Decrypt symmetric key using Curve25519 secret key directly
///
/// This is used when we have a fresh Curve25519 keypair for the connection
/// (as per RustDesk protocol - each connection generates a new keypair)
pub fn decrypt_symmetric_key(
their_temp_public_key: &[u8],
sealed_symmetric_key: &[u8],
@@ -217,7 +174,6 @@ pub fn decrypt_symmetric_key(
let their_pk =
PublicKey::from_slice(their_temp_public_key).ok_or(CryptoError::InvalidKeyLength)?;
// Use zero nonce as per RustDesk protocol
let nonce = box_::Nonce([0u8; box_::NONCEBYTES]);
let key_bytes = box_::open(sealed_symmetric_key, &nonce, &their_pk, our_secret_key)
@@ -226,11 +182,7 @@ pub fn decrypt_symmetric_key(
secretbox::Key::from_slice(&key_bytes).ok_or(CryptoError::InvalidKeyLength)
}
/// Encrypt a message using the negotiated symmetric key
///
/// RustDesk uses a specific nonce format for session encryption
pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) -> Vec<u8> {
// Create nonce from counter (little-endian, padded to 24 bytes)
let mut nonce_bytes = [0u8; secretbox::NONCEBYTES];
nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes());
let nonce = secretbox::Nonce(nonce_bytes);
@@ -238,13 +190,11 @@ pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) ->
secretbox::seal(data, &nonce, key)
}
/// Decrypt a message using the negotiated symmetric key
pub fn decrypt_message(
ciphertext: &[u8],
key: &secretbox::Key,
nonce_counter: u64,
) -> Result<Vec<u8>, CryptoError> {
// Create nonce from counter (little-endian, padded to 24 bytes)
let mut nonce_bytes = [0u8; secretbox::NONCEBYTES];
nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes());
let nonce = secretbox::Nonce(nonce_bytes);
@@ -252,7 +202,6 @@ pub fn decrypt_message(
secretbox::open(ciphertext, &nonce, key).map_err(|_| CryptoError::DecryptionFailed)
}
/// Ed25519 signing keypair for RustDesk SignedId messages
#[derive(Clone)]
pub struct SigningKeyPair {
pub public_key: sign::PublicKey,
@@ -260,7 +209,6 @@ pub struct SigningKeyPair {
}
impl SigningKeyPair {
/// Generate a new random signing keypair
pub fn generate() -> Self {
let (public_key, secret_key) = sign::gen_keypair();
Self {
@@ -269,7 +217,6 @@ impl SigningKeyPair {
}
}
/// Create from existing keys
pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result<Self, CryptoError> {
let pk = sign::PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?;
let sk = sign::SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?;
@@ -279,27 +226,22 @@ impl SigningKeyPair {
})
}
/// Get public key as bytes
pub fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_ref()
}
/// Get secret key as bytes
pub fn secret_key_bytes(&self) -> &[u8] {
self.secret_key.as_ref()
}
/// Encode public key as base64
pub fn public_key_base64(&self) -> String {
BASE64.encode(self.public_key_bytes())
}
/// Encode secret key as base64
pub fn secret_key_base64(&self) -> String {
BASE64.encode(self.secret_key_bytes())
}
/// Create from base64-encoded keys
pub fn from_base64(public_key: &str, secret_key: &str) -> Result<Self, CryptoError> {
let pk_bytes = BASE64
.decode(public_key)
@@ -310,42 +252,27 @@ impl SigningKeyPair {
Self::from_keys(&pk_bytes, &sk_bytes)
}
/// Sign a message
/// Returns the signature prepended to the message (as per RustDesk protocol)
pub fn sign(&self, message: &[u8]) -> Vec<u8> {
sign::sign(message, &self.secret_key)
}
/// Sign a message and return just the signature (64 bytes)
pub fn sign_detached(&self, message: &[u8]) -> [u8; 64] {
let sig = sign::sign_detached(message, &self.secret_key);
// Use as_ref() to access the signature bytes since the inner field is private
let sig_bytes: &[u8] = sig.as_ref();
let mut result = [0u8; 64];
result.copy_from_slice(sig_bytes);
result
}
/// Convert Ed25519 public key to Curve25519 public key for encryption
///
/// This allows using the same keypair for both signing and encryption,
/// which is required by RustDesk's protocol where clients encrypt the
/// symmetric key using the public key from IdPk.
pub fn to_curve25519_pk(&self) -> Result<PublicKey, CryptoError> {
ed25519::to_curve25519_pk(&self.public_key).map_err(|_| CryptoError::KeyConversionFailed)
}
/// Convert Ed25519 secret key to Curve25519 secret key for decryption
///
/// This allows decrypting messages that were encrypted using the
/// converted public key.
pub fn to_curve25519_sk(&self) -> Result<SecretKey, CryptoError> {
ed25519::to_curve25519_sk(&self.secret_key).map_err(|_| CryptoError::KeyConversionFailed)
}
}
/// Verify a signed message
/// Returns the original message if signature is valid
pub fn verify_signed(
signed_message: &[u8],
public_key: &sign::PublicKey,

View File

@@ -1,8 +1,3 @@
//! RustDesk Frame Adapters
//!
//! Converts One-KVM video/audio frames to RustDesk protocol format.
//! Optimized for zero-copy where possible and buffer reuse.
use bytes::Bytes;
use protobuf::Message as ProtobufMessage;
@@ -11,7 +6,6 @@ use super::protocol::hbb::message::{
CursorData, CursorPosition, EncodedVideoFrame, EncodedVideoFrames, Message, Misc, VideoFrame,
};
/// Video codec type for RustDesk
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VideoCodec {
H264,
@@ -22,7 +16,6 @@ pub enum VideoCodec {
}
impl VideoCodec {
/// Get the codec ID for the RustDesk protocol
pub fn to_codec_id(self) -> i32 {
match self {
VideoCodec::H264 => 0,
@@ -34,21 +27,15 @@ impl VideoCodec {
}
}
/// Video frame adapter for converting to RustDesk format
pub struct VideoFrameAdapter {
/// Current codec
codec: VideoCodec,
/// Frame sequence number
seq: u32,
/// Timestamp offset
timestamp_base: u64,
/// Cached H264 SPS/PPS (Annex B NAL without start code)
h264_sps: Option<Bytes>,
h264_pps: Option<Bytes>,
}
impl VideoFrameAdapter {
/// Create a new video frame adapter
pub fn new(codec: VideoCodec) -> Self {
Self {
codec,
@@ -59,14 +46,10 @@ impl VideoFrameAdapter {
}
}
/// Set codec type
pub fn set_codec(&mut self, codec: VideoCodec) {
self.codec = codec;
}
/// Convert encoded video data to RustDesk Message (zero-copy version)
///
/// This version takes Bytes directly to avoid copying the frame data.
pub fn encode_frame_from_bytes(
&mut self,
data: Bytes,
@@ -74,7 +57,6 @@ impl VideoFrameAdapter {
timestamp_ms: u64,
) -> Message {
let data = self.prepare_h264_frame(data, is_keyframe);
// Calculate relative timestamp
if self.seq == 0 {
self.timestamp_base = timestamp_ms;
}
@@ -87,11 +69,9 @@ impl VideoFrameAdapter {
self.seq = self.seq.wrapping_add(1);
// Wrap in EncodedVideoFrames container
let mut frames = EncodedVideoFrames::new();
frames.frames.push(frame);
// Create the appropriate VideoFrame variant based on codec
let mut video_frame = VideoFrame::new();
match self.codec {
VideoCodec::H264 => video_frame.union = Some(vf_union::Union::H264s(frames)),
@@ -111,7 +91,6 @@ impl VideoFrameAdapter {
return data;
}
// Parse SPS/PPS from Annex B data (without start codes)
let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data);
let mut has_sps = false;
let mut has_pps = false;
@@ -125,7 +104,6 @@ impl VideoFrameAdapter {
has_pps = true;
}
// Inject cached SPS/PPS before IDR when missing
if is_keyframe && (!has_sps || !has_pps) {
if let (Some(sps), Some(pps)) = (self.h264_sps.as_ref(), self.h264_pps.as_ref()) {
let mut out = Vec::with_capacity(8 + sps.len() + pps.len() + data.len());
@@ -141,14 +119,10 @@ impl VideoFrameAdapter {
data
}
/// Convert encoded video data to RustDesk Message
pub fn encode_frame(&mut self, data: &[u8], is_keyframe: bool, timestamp_ms: u64) -> Message {
self.encode_frame_from_bytes(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms)
}
/// Encode frame to bytes for sending (zero-copy version)
///
/// Takes Bytes directly to avoid copying the frame data.
pub fn encode_frame_bytes_zero_copy(
&mut self,
data: Bytes,
@@ -159,7 +133,6 @@ impl VideoFrameAdapter {
Bytes::from(msg.write_to_bytes().unwrap_or_default())
}
/// Encode frame to bytes for sending
pub fn encode_frame_bytes(
&mut self,
data: &[u8],
@@ -169,24 +142,18 @@ impl VideoFrameAdapter {
self.encode_frame_bytes_zero_copy(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms)
}
/// Get current sequence number
pub fn seq(&self) -> u32 {
self.seq
}
}
/// Audio frame adapter for converting to RustDesk format
pub struct AudioFrameAdapter {
/// Sample rate
sample_rate: u32,
/// Channels
channels: u8,
/// Format sent flag
format_sent: bool,
}
impl AudioFrameAdapter {
/// Create a new audio frame adapter
pub fn new(sample_rate: u32, channels: u8) -> Self {
Self {
sample_rate,
@@ -195,7 +162,6 @@ impl AudioFrameAdapter {
}
}
/// Create audio format message (should be sent once before audio frames)
pub fn create_format_message(&mut self) -> Message {
self.format_sent = true;
@@ -211,12 +177,10 @@ impl AudioFrameAdapter {
msg
}
/// Check if format message has been sent
pub fn format_sent(&self) -> bool {
self.format_sent
}
/// Convert Opus audio data to RustDesk Message
pub fn encode_opus_frame(&self, data: &[u8]) -> Message {
let mut frame = AudioFrame::new();
frame.data = Bytes::copy_from_slice(data);
@@ -226,23 +190,19 @@ impl AudioFrameAdapter {
msg
}
/// Encode Opus frame to bytes for sending
pub fn encode_opus_bytes(&self, data: &[u8]) -> Bytes {
let msg = self.encode_opus_frame(data);
Bytes::from(msg.write_to_bytes().unwrap_or_default())
}
/// Reset state (call when restarting audio stream)
pub fn reset(&mut self) {
self.format_sent = false;
}
}
/// Cursor data adapter
pub struct CursorAdapter;
impl CursorAdapter {
/// Create cursor data message
pub fn encode_cursor(
id: u64,
hotx: i32,
@@ -264,7 +224,6 @@ impl CursorAdapter {
msg
}
/// Create cursor position message
pub fn encode_position(x: i32, y: i32) -> Message {
let mut pos = CursorPosition::new();
pos.x = x;
@@ -284,7 +243,6 @@ mod tests {
fn test_video_frame_encoding() {
let mut adapter = VideoFrameAdapter::new(VideoCodec::H264);
// Encode a keyframe
let data = vec![0x00, 0x00, 0x00, 0x01, 0x67]; // H264 SPS NAL
let msg = adapter.encode_frame(&data, true, 0);
@@ -324,7 +282,6 @@ mod tests {
fn test_audio_frame_encoding() {
let adapter = AudioFrameAdapter::new(48000, 2);
// Encode an Opus frame
let opus_data = vec![0xFC, 0x01, 0x02]; // Fake Opus data
let msg = adapter.encode_opus_frame(&opus_data);

View File

@@ -1,17 +1,11 @@
//! RustDesk HID Adapter
//!
//! Converts RustDesk HID events (KeyEvent, MouseEvent) to One-KVM HID events.
use super::protocol::hbb::message::key_event as ke_union;
use super::protocol::{ControlKey, KeyEvent, MouseEvent};
use crate::hid::{
KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent as OneKvmMouseEvent,
MouseEventType,
CanonicalKey, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton,
MouseEvent as OneKvmMouseEvent, MouseEventType,
};
use protobuf::Enum;
/// Mouse event types from RustDesk protocol
/// mask = (button << 3) | event_type
pub mod mouse_type {
pub const MOVE: i32 = 0;
pub const DOWN: i32 = 1;
@@ -21,7 +15,6 @@ pub mod mouse_type {
pub const MOVE_RELATIVE: i32 = 5;
}
/// Mouse button IDs from RustDesk protocol (before left shift by 3)
pub mod mouse_button {
pub const LEFT: i32 = 0x01;
pub const RIGHT: i32 = 0x02;
@@ -30,9 +23,6 @@ pub mod mouse_button {
pub const FORWARD: i32 = 0x10;
}
/// Convert RustDesk MouseEvent to One-KVM MouseEvent(s)
/// Returns a Vec because a single RustDesk event may need multiple One-KVM events
/// (e.g., move + button + scroll)
pub fn convert_mouse_event(
event: &MouseEvent,
screen_width: u32,
@@ -41,23 +31,18 @@ pub fn convert_mouse_event(
) -> Vec<OneKvmMouseEvent> {
let mut events = Vec::new();
// Parse RustDesk mask format: (button << 3) | event_type
let event_type = event.mask & 0x07;
let button_id = event.mask >> 3;
let include_abs_move = !relative_mode;
match event_type {
mouse_type::MOVE => {
// RustDesk uses absolute coordinates
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
// Normalize to 0-32767 range for absolute mouse (USB HID standard)
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32;
// Move event - may have button held down (button_id > 0 means dragging)
// Just send move, button state is tracked separately by HID backend
events.push(OneKvmMouseEvent {
event_type: MouseEventType::MoveAbs,
x: abs_x,
@@ -67,7 +52,6 @@ pub fn convert_mouse_event(
});
}
mouse_type::MOVE_RELATIVE => {
// Relative movement uses delta values directly (dx, dy).
events.push(OneKvmMouseEvent {
event_type: MouseEventType::Move,
x: event.x,
@@ -78,7 +62,6 @@ pub fn convert_mouse_event(
}
mouse_type::DOWN => {
if include_abs_move {
// Button down - first move, then press
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -104,7 +87,6 @@ pub fn convert_mouse_event(
}
mouse_type::UP => {
if include_abs_move {
// Button up - first move, then release
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -130,7 +112,6 @@ pub fn convert_mouse_event(
}
mouse_type::WHEEL => {
if include_abs_move {
// Scroll event - move first, then scroll
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -144,9 +125,6 @@ pub fn convert_mouse_event(
});
}
// RustDesk encodes scroll direction in the y coordinate
// Positive y = scroll up, Negative y = scroll down
// The button_id field is not used for direction
let scroll = if event.y > 0 { 1i8 } else { -1i8 };
events.push(OneKvmMouseEvent {
event_type: MouseEventType::Scroll,
@@ -158,7 +136,6 @@ pub fn convert_mouse_event(
}
_ => {
if include_abs_move {
// Unknown event type, just move
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -177,7 +154,6 @@ pub fn convert_mouse_event(
events
}
/// Convert RustDesk button ID to One-KVM MouseButton
fn button_id_to_button(button_id: i32) -> Option<MouseButton> {
match button_id {
mouse_button::LEFT => Some(MouseButton::Left),
@@ -187,65 +163,44 @@ fn button_id_to_button(button_id: i32) -> Option<MouseButton> {
}
}
/// Convert RustDesk KeyEvent to One-KVM KeyboardEvent
///
/// RustDesk KeyEvent has two modes:
/// - down=true/false: Key state (pressed/released)
/// - press=true: Complete key press (down + up), used for typing
///
/// For press=true events, we only send Down and let the caller handle
/// the timing for Up event if needed. Most systems handle this correctly.
pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
// Determine if this is a key down or key up event
// press=true means "key was pressed" (down event)
// down=true means key is currently held down
// down=false with press=false means key was released
let event_type = if event.down || event.press {
KeyEventType::Down
} else {
KeyEventType::Up
};
// For modifier keys sent as ControlKey, don't include them in modifiers
// to avoid double-pressing. The modifier will be tracked by HID state.
let modifiers = if is_modifier_control_key(event) {
KeyboardModifiers::default()
} else {
parse_modifiers(event)
};
// Handle control keys
if let Some(ke_union::Union::ControlKey(ck)) = &event.union {
if let Some(key) = control_key_to_hid(ck.value()) {
let key = CanonicalKey::from_hid_usage(key)?;
return Some(KeyboardEvent {
event_type,
key,
modifiers,
is_usb_hid: true, // Already converted to USB HID code
});
}
}
// Handle character keys (chr field contains platform-specific keycode)
if let Some(ke_union::Union::Chr(chr)) = &event.union {
// chr contains USB HID scancode on Windows, X11 keycode on Linux
if let Some(key) = keycode_to_hid(*chr) {
let key = CanonicalKey::from_hid_usage(key)?;
return Some(KeyboardEvent {
event_type,
key,
modifiers,
is_usb_hid: true, // Already converted to USB HID code
});
}
}
// Handle unicode (for text input, we'd need to convert to scancodes)
// Unicode input requires more complex handling, skip for now
None
}
/// Check if the event is a modifier key sent as ControlKey
fn is_modifier_control_key(event: &KeyEvent) -> bool {
if let Some(ke_union::Union::ControlKey(ck)) = &event.union {
let val = ck.value();
@@ -260,7 +215,6 @@ fn is_modifier_control_key(event: &KeyEvent) -> bool {
false
}
/// Parse modifier keys from RustDesk KeyEvent into KeyboardModifiers
fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers {
let mut modifiers = KeyboardModifiers::default();
@@ -281,7 +235,6 @@ fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers {
modifiers
}
/// Convert RustDesk ControlKey to USB HID usage code
fn control_key_to_hid(key: i32) -> Option<u8> {
match key {
x if x == ControlKey::Alt as i32 => Some(0xE2), // Left Alt
@@ -342,67 +295,47 @@ fn control_key_to_hid(key: i32) -> Option<u8> {
}
}
/// Convert platform keycode to USB HID usage code
/// Handles Windows Virtual Key Codes, X11 keycodes, and ASCII codes
fn keycode_to_hid(keycode: u32) -> Option<u8> {
// First try ASCII code mapping (RustDesk often sends ASCII codes)
if let Some(hid) = ascii_to_hid(keycode) {
return Some(hid);
}
// Then try Windows Virtual Key Code mapping
if let Some(hid) = windows_vk_to_hid(keycode) {
return Some(hid);
}
// Fall back to X11 keycode mapping for Linux clients
x11_keycode_to_hid(keycode)
}
/// Convert ASCII code to USB HID usage code
fn ascii_to_hid(ascii: u32) -> Option<u8> {
match ascii {
// Lowercase letters a-z (ASCII 97-122)
97..=122 => {
// USB HID: a=0x04, b=0x05, ..., z=0x1D
Some((ascii - 97 + 0x04) as u8)
}
// Uppercase letters A-Z (ASCII 65-90)
65..=90 => {
// USB HID: A=0x04, B=0x05, ..., Z=0x1D (same as lowercase)
Some((ascii - 65 + 0x04) as u8)
}
// Numbers 0-9 (ASCII 48-57)
97..=122 => Some((ascii - 97 + 0x04) as u8),
65..=90 => Some((ascii - 65 + 0x04) as u8),
48 => Some(0x27), // 0
49..=57 => Some((ascii - 49 + 0x1E) as u8), // 1-9
// Common punctuation
32 => Some(0x2C), // Space
13 => Some(0x28), // Enter (CR)
10 => Some(0x28), // Enter (LF)
9 => Some(0x2B), // Tab
27 => Some(0x29), // Escape
8 => Some(0x2A), // Backspace
127 => Some(0x4C), // Delete
// Symbols (US keyboard layout)
45 => Some(0x2D), // -
61 => Some(0x2E), // =
91 => Some(0x2F), // [
93 => Some(0x30), // ]
92 => Some(0x31), // \
59 => Some(0x33), // ;
39 => Some(0x34), // '
96 => Some(0x35), // `
44 => Some(0x36), // ,
46 => Some(0x37), // .
47 => Some(0x38), // /
32 => Some(0x2C), // Space
13 => Some(0x28), // Enter (CR)
10 => Some(0x28), // Enter (LF)
9 => Some(0x2B), // Tab
27 => Some(0x29), // Escape
8 => Some(0x2A), // Backspace
127 => Some(0x4C), // Delete
45 => Some(0x2D), // -
61 => Some(0x2E), // =
91 => Some(0x2F), // [
93 => Some(0x30), // ]
92 => Some(0x31), // \
59 => Some(0x33), // ;
39 => Some(0x34), // '
96 => Some(0x35), // `
44 => Some(0x36), // ,
46 => Some(0x37), // .
47 => Some(0x38), // /
_ => None,
}
}
/// Convert Windows Virtual Key Code to USB HID usage code
fn windows_vk_to_hid(vk: u32) -> Option<u8> {
match vk {
// Letters A-Z (VK_A=0x41 to VK_Z=0x5A)
0x41..=0x5A => {
// USB HID: A=0x04, B=0x05, ..., Z=0x1D
let letter = (vk - 0x41) as u8;
Some(match letter {
0 => 0x04, // A
@@ -434,21 +367,16 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
_ => return None,
})
}
// Numbers 0-9 (VK_0=0x30 to VK_9=0x39)
0x30 => Some(0x27), // 0
0x31..=0x39 => Some((vk - 0x31 + 0x1E) as u8), // 1-9
// Numpad 0-9 (VK_NUMPAD0=0x60 to VK_NUMPAD9=0x69)
0x60 => Some(0x62), // Numpad 0
0x61..=0x69 => Some((vk - 0x61 + 0x59) as u8), // Numpad 1-9
// Numpad operators
0x6A => Some(0x55), // Numpad *
0x6B => Some(0x57), // Numpad +
0x6D => Some(0x56), // Numpad -
0x6E => Some(0x63), // Numpad .
0x6F => Some(0x54), // Numpad /
// Function keys F1-F12 (VK_F1=0x70 to VK_F12=0x7B)
0x6A => Some(0x55), // Numpad *
0x6B => Some(0x57), // Numpad +
0x6D => Some(0x56), // Numpad -
0x6E => Some(0x63), // Numpad .
0x6F => Some(0x54), // Numpad /
0x70..=0x7B => Some((vk - 0x70 + 0x3A) as u8),
// Special keys
0x08 => Some(0x2A), // Backspace
0x09 => Some(0x2B), // Tab
0x0D => Some(0x28), // Enter
@@ -464,7 +392,6 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
0x28 => Some(0x51), // Down Arrow
0x2D => Some(0x49), // Insert
0x2E => Some(0x4C), // Delete
// OEM keys (US keyboard layout)
0xBA => Some(0x33), // ; :
0xBB => Some(0x2E), // = +
0xBC => Some(0x36), // , <
@@ -476,66 +403,56 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
0xDC => Some(0x31), // \ |
0xDD => Some(0x30), // ] }
0xDE => Some(0x34), // ' "
// Lock keys
0x14 => Some(0x39), // Caps Lock
0x90 => Some(0x53), // Num Lock
0x91 => Some(0x47), // Scroll Lock
// Print Screen, Pause
0x2C => Some(0x46), // Print Screen
0x13 => Some(0x48), // Pause
_ => None,
}
}
/// Convert X11 keycode to USB HID usage code (for Linux clients)
fn x11_keycode_to_hid(keycode: u32) -> Option<u8> {
match keycode {
// Numbers: X11 keycode 10="1", 11="2", ..., 18="9", 19="0"
10..=18 => Some((keycode - 10 + 0x1E) as u8), // 1-9
19 => Some(0x27), // 0
// Punctuation
20 => Some(0x2D), // -
21 => Some(0x2E), // =
34 => Some(0x2F), // [
35 => Some(0x30), // ]
// Letters (X11 keycodes are row-based)
// Row 1: q(24) w(25) e(26) r(27) t(28) y(29) u(30) i(31) o(32) p(33)
24 => Some(0x14), // q
25 => Some(0x1A), // w
26 => Some(0x08), // e
27 => Some(0x15), // r
28 => Some(0x17), // t
29 => Some(0x1C), // y
30 => Some(0x18), // u
31 => Some(0x0C), // i
32 => Some(0x12), // o
33 => Some(0x13), // p
// Row 2: a(38) s(39) d(40) f(41) g(42) h(43) j(44) k(45) l(46)
38 => Some(0x04), // a
39 => Some(0x16), // s
40 => Some(0x07), // d
41 => Some(0x09), // f
42 => Some(0x0A), // g
43 => Some(0x0B), // h
44 => Some(0x0D), // j
45 => Some(0x0E), // k
46 => Some(0x0F), // l
47 => Some(0x33), // ;
48 => Some(0x34), // '
49 => Some(0x35), // `
51 => Some(0x31), // \
// Row 3: z(52) x(53) c(54) v(55) b(56) n(57) m(58)
52 => Some(0x1D), // z
53 => Some(0x1B), // x
54 => Some(0x06), // c
55 => Some(0x19), // v
56 => Some(0x05), // b
57 => Some(0x11), // n
58 => Some(0x10), // m
59 => Some(0x36), // ,
60 => Some(0x37), // .
61 => Some(0x38), // /
// Space
20 => Some(0x2D), // -
21 => Some(0x2E), // =
34 => Some(0x2F), // [
35 => Some(0x30), // ]
24 => Some(0x14), // q
25 => Some(0x1A), // w
26 => Some(0x08), // e
27 => Some(0x15), // r
28 => Some(0x17), // t
29 => Some(0x1C), // y
30 => Some(0x18), // u
31 => Some(0x0C), // i
32 => Some(0x12), // o
33 => Some(0x13), // p
38 => Some(0x04), // a
39 => Some(0x16), // s
40 => Some(0x07), // d
41 => Some(0x09), // f
42 => Some(0x0A), // g
43 => Some(0x0B), // h
44 => Some(0x0D), // j
45 => Some(0x0E), // k
46 => Some(0x0F), // l
47 => Some(0x33), // ;
48 => Some(0x34), // '
49 => Some(0x35), // `
51 => Some(0x31), // \
52 => Some(0x1D), // z
53 => Some(0x1B), // x
54 => Some(0x06), // c
55 => Some(0x19), // v
56 => Some(0x05), // b
57 => Some(0x11), // n
58 => Some(0x10), // m
59 => Some(0x36), // ,
60 => Some(0x37), // .
61 => Some(0x38), // /
65 => Some(0x2C),
_ => None,
}
@@ -573,7 +490,6 @@ mod tests {
let events = convert_mouse_event(&event, 1920, 1080, false);
assert!(events.len() >= 2);
// Should have a button down event
assert!(events
.iter()
.any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left)));
@@ -608,6 +524,6 @@ mod tests {
let kb_event = result.unwrap();
assert_eq!(kb_event.event_type, KeyEventType::Down);
assert_eq!(kb_event.key, 0x28); // Return key USB HID code
assert_eq!(kb_event.key, CanonicalKey::Enter);
}
}

View File

@@ -1,17 +1,4 @@
//! RustDesk Protocol Integration Module
//!
//! This module implements the RustDesk client protocol, enabling One-KVM devices
//! to be accessed via standard RustDesk clients through existing hbbs/hbbr servers.
//!
//! ## Architecture
//!
//! - `config`: Configuration types for RustDesk settings
//! - `protocol`: Protobuf message wrappers and serialization
//! - `crypto`: NaCl cryptography (key generation, encryption, signatures)
//! - `rendezvous`: Communication with hbbs rendezvous server
//! - `connection`: Client session handling
//! - `frame_adapters`: Video/audio frame conversion to RustDesk format
//! - `hid_adapter`: RustDesk HID events to One-KVM conversion
//! RustDesk peer protocol (hbbs / hbbr).
pub mod bytes_codec;
pub mod config;
@@ -44,19 +31,13 @@ use self::connection::ConnectionManager;
use self::protocol::{make_local_addr, make_relay_response, make_request_relay};
use self::rendezvous::{AddrMangle, RendezvousMediator, RendezvousStatus};
/// Relay connection timeout
const RELAY_CONNECT_TIMEOUT_MS: u64 = 10_000;
/// RustDesk service status
#[derive(Debug, Clone, PartialEq)]
pub enum ServiceStatus {
/// Service is stopped
Stopped,
/// Service is starting
Starting,
/// Service is running and registered with rendezvous server
Running,
/// Service encountered an error
Error(String),
}
@@ -71,15 +52,8 @@ impl std::fmt::Display for ServiceStatus {
}
}
/// Default port for direct TCP connections (same as RustDesk)
const DIRECT_LISTEN_PORT: u16 = 21118;
/// RustDesk Service
///
/// Manages the RustDesk protocol integration, including:
/// - Registration with hbbs rendezvous server
/// - Accepting connections from RustDesk clients
/// - Streaming video/audio and receiving HID input
pub struct RustDeskService {
config: Arc<RwLock<RustDeskConfig>>,
status: Arc<RwLock<ServiceStatus>>,
@@ -95,7 +69,6 @@ pub struct RustDeskService {
}
impl RustDeskService {
/// Create a new RustDesk service instance
pub fn new(
config: RustDeskConfig,
video_manager: Arc<VideoStreamManager>,
@@ -120,42 +93,34 @@ impl RustDeskService {
}
}
/// Get the port for direct TCP connections
pub fn listen_port(&self) -> u16 {
*self.listen_port.read()
}
/// Get current service status
pub fn status(&self) -> ServiceStatus {
self.status.read().clone()
}
/// Get current configuration
pub fn config(&self) -> RustDeskConfig {
self.config.read().clone()
}
/// Update configuration
pub fn update_config(&self, config: RustDeskConfig) {
*self.config.write() = config;
}
/// Get rendezvous status
pub fn rendezvous_status(&self) -> Option<RendezvousStatus> {
self.rendezvous.read().as_ref().map(|r| r.status())
}
/// Get device ID
pub fn device_id(&self) -> String {
self.config.read().device_id.clone()
}
/// Get connection count
pub fn connection_count(&self) -> usize {
self.connection_manager.connection_count()
}
/// Start the RustDesk service
pub async fn start(&self) -> anyhow::Result<()> {
let config = self.config.read().clone();
@@ -181,74 +146,44 @@ impl RustDeskService {
config.rendezvous_addr()
);
// Initialize crypto
if let Err(e) = crypto::init() {
error!("Failed to initialize crypto: {}", e);
*self.status.write() = ServiceStatus::Error(e.to_string());
return Err(e.into());
}
// Create and start rendezvous mediator with relay callback
let mediator = Arc::new(RendezvousMediator::new(config.clone()));
// Set the keypair on connection manager (Curve25519 for encryption)
let keypair = mediator.ensure_keypair();
self.connection_manager.set_keypair(keypair);
// Set the signing keypair on connection manager (Ed25519 for SignedId)
let signing_keypair = mediator.ensure_signing_keypair();
self.connection_manager.set_signing_keypair(signing_keypair);
// Set the HID controller on connection manager
self.connection_manager.set_hid(self.hid.clone());
// Set the audio controller on connection manager for audio streaming
self.connection_manager.set_audio(self.audio.clone());
// Set the video manager on connection manager for video streaming
self.connection_manager
.set_video_manager(self.video_manager.clone());
*self.rendezvous.write() = Some(mediator.clone());
// Start TCP listener BEFORE the rendezvous mediator to ensure port is set correctly
// This prevents race condition where mediator starts registration with wrong port
let (tcp_handles, listen_port) = self.start_tcp_listener_with_port().await?;
*self.tcp_listener_handle.write() = Some(tcp_handles);
// Set the listen port on mediator before starting the registration loop
mediator.set_listen_port(listen_port);
// Create relay request handler
let connection_manager = self.connection_manager.clone();
let video_manager = self.video_manager.clone();
let hid = self.hid.clone();
let audio = self.audio.clone();
let service_config = self.config.clone();
// Set the punch callback on the mediator (tries P2P first, then relay)
let connection_manager_punch = self.connection_manager.clone();
let video_manager_punch = self.video_manager.clone();
let hid_punch = self.hid.clone();
let audio_punch = self.audio.clone();
let service_config_punch = self.config.clone();
mediator.set_punch_callback(Arc::new(
mediator.set_punch_callback(Arc::new({
let connection_manager = connection_manager.clone();
let service_config = service_config.clone();
move |peer_addr, rendezvous_addr, relay_server, uuid, socket_addr, device_id| {
let conn_mgr = connection_manager_punch.clone();
let video = video_manager_punch.clone();
let hid = hid_punch.clone();
let audio = audio_punch.clone();
let config = service_config_punch.clone();
let conn_mgr = connection_manager.clone();
let config = service_config.clone();
tokio::spawn(async move {
// Get relay_key from config (no public server fallback)
let relay_key = {
let cfg = config.read();
cfg.relay_key.clone().unwrap_or_default()
};
// Try P2P direct connection first
if let Some(addr) = peer_addr {
info!("Attempting P2P direct connection to {}", addr);
match punch::try_direct_connection(addr).await {
@@ -265,7 +200,7 @@ impl RustDeskService {
}
}
// Fall back to relay
let relay_key = rustdesk_relay_key(&config);
if let Err(e) = handle_relay_request(
&rendezvous_addr,
&relay_server,
@@ -274,34 +209,23 @@ impl RustDeskService {
&device_id,
&relay_key,
conn_mgr,
video,
hid,
audio,
)
.await
{
error!("Failed to handle relay request: {}", e);
}
});
},
));
}
}));
// Set the relay callback on the mediator
mediator.set_relay_callback(Arc::new(
mediator.set_relay_callback(Arc::new({
let connection_manager = connection_manager.clone();
let service_config = service_config.clone();
move |rendezvous_addr, relay_server, uuid, socket_addr, device_id| {
let conn_mgr = connection_manager.clone();
let video = video_manager.clone();
let hid = hid.clone();
let audio = audio.clone();
let config = service_config.clone();
tokio::spawn(async move {
// Get relay_key from config (no public server fallback)
let relay_key = {
let cfg = config.read();
cfg.relay_key.clone().unwrap_or_default()
};
let relay_key = rustdesk_relay_key(&config);
if let Err(e) = handle_relay_request(
&rendezvous_addr,
&relay_server,
@@ -310,19 +234,15 @@ impl RustDeskService {
&device_id,
&relay_key,
conn_mgr,
video,
hid,
audio,
)
.await
{
error!("Failed to handle relay request: {}", e);
}
});
},
));
}
}));
// Set the intranet callback on the mediator for same-LAN connections
let connection_manager2 = self.connection_manager.clone();
mediator.set_intranet_callback(Arc::new(
move |rendezvous_addr, peer_socket_addr, local_addr, relay_server, device_id| {
@@ -345,7 +265,6 @@ impl RustDeskService {
},
));
// Spawn rendezvous task
let status = self.status.clone();
let handle = tokio::spawn(async move {
loop {
@@ -357,7 +276,6 @@ impl RustDeskService {
Err(e) => {
error!("Rendezvous mediator error: {}", e);
*status.write() = ServiceStatus::Error(e.to_string());
// Wait before retry
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
*status.write() = ServiceStatus::Starting;
}
@@ -372,10 +290,7 @@ impl RustDeskService {
Ok(())
}
/// Start TCP listener for direct peer connections
/// Returns the join handle and the port that was bound
async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(Vec<JoinHandle<()>>, u16)> {
// Try to bind to the default port, or find an available port
let (listeners, listen_port) = match self.bind_direct_listeners(DIRECT_LISTEN_PORT) {
Ok(result) => result,
Err(err) => {
@@ -453,7 +368,6 @@ impl RustDeskService {
Ok((listeners, listen_port))
}
/// Stop the RustDesk service
pub async fn stop(&self) -> anyhow::Result<()> {
if self.status() == ServiceStatus::Stopped {
return Ok(());
@@ -461,23 +375,18 @@ impl RustDeskService {
info!("Stopping RustDesk service");
// Send shutdown signal (this will stop the TCP listener)
let _ = self.shutdown_tx.send(());
// Close all connections
self.connection_manager.close_all();
// Stop rendezvous mediator
if let Some(mediator) = self.rendezvous.read().as_ref() {
mediator.stop();
}
// Wait for rendezvous task to finish
if let Some(handle) = self.rendezvous_handle.write().take() {
handle.abort();
}
// Wait for TCP listener task to finish
if let Some(handles) = self.tcp_listener_handle.write().take() {
for handle in handles {
handle.abort();
@@ -490,15 +399,12 @@ impl RustDeskService {
Ok(())
}
/// Restart the service with new configuration
pub async fn restart(&self, config: RustDeskConfig) -> anyhow::Result<()> {
self.stop().await?;
self.update_config(config);
self.start().await
}
/// Save keypair and UUID to config
/// Returns the updated config if changes were made
pub fn save_credentials(&self) -> Option<RustDeskConfig> {
if let Some(mediator) = self.rendezvous.read().as_ref() {
let kp = mediator.ensure_keypair();
@@ -506,7 +412,6 @@ impl RustDeskService {
let mut config = self.config.write();
let mut changed = false;
// Save encryption keypair (Curve25519)
let pk = kp.public_key_base64();
let sk = kp.secret_key_base64();
if config.public_key.as_ref() != Some(&pk) || config.private_key.as_ref() != Some(&sk) {
@@ -515,7 +420,6 @@ impl RustDeskService {
changed = true;
}
// Save signing keypair (Ed25519)
let signing_pk = skp.public_key_base64();
let signing_sk = skp.secret_key_base64();
if config.signing_public_key.as_ref() != Some(&signing_pk)
@@ -526,7 +430,6 @@ impl RustDeskService {
changed = true;
}
// Save UUID if it was newly generated
if mediator.uuid_needs_save() {
let mediator_config = mediator.config();
if let Some(uuid) = mediator_config.uuid {
@@ -545,21 +448,16 @@ impl RustDeskService {
None
}
/// Save keypair to config (deprecated, use save_credentials instead)
#[deprecated(note = "Use save_credentials instead")]
pub fn save_keypair(&self) {
let _ = self.save_credentials();
}
}
/// Handle relay request from rendezvous server
///
/// Correct flow based on RustDesk protocol:
/// 1. Connect to RENDEZVOUS server (not relay!)
/// 2. Send RelayResponse with client's socket_addr
/// 3. Connect to RELAY server
/// 4. Accept connection without waiting for response
#[allow(clippy::too_many_arguments)]
fn rustdesk_relay_key(config: &Arc<RwLock<RustDeskConfig>>) -> String {
config.read().relay_key.clone().unwrap_or_default()
}
async fn handle_relay_request(
rendezvous_addr: &str,
relay_server: &str,
@@ -568,16 +466,12 @@ async fn handle_relay_request(
device_id: &str,
relay_key: &str,
connection_manager: Arc<ConnectionManager>,
_video_manager: Arc<VideoStreamManager>,
_hid: Arc<HidController>,
_audio: Arc<AudioController>,
) -> anyhow::Result<()> {
info!(
"Handling relay request: rendezvous={}, relay={}, uuid={}",
rendezvous_addr, relay_server, uuid
);
// Step 1: Connect to RENDEZVOUS server and send RelayResponse
let rendezvous_socket_addr: SocketAddr = tokio::net::lookup_host(rendezvous_addr)
.await?
.next()
@@ -597,8 +491,7 @@ async fn handle_relay_request(
rendezvous_socket_addr
);
// Send RelayResponse to rendezvous server with client's socket_addr
// IMPORTANT: Include our device ID so rendezvous server can look up and sign our public key
// Rendezvous looks up our pk by device id (must set `id`, not raw pk on wire).
let relay_response = make_relay_response(uuid, socket_addr, relay_server, device_id);
let bytes = relay_response
.write_to_bytes()
@@ -606,10 +499,8 @@ async fn handle_relay_request(
bytes_codec::write_frame(&mut rendezvous_stream, &bytes).await?;
debug!("Sent RelayResponse to rendezvous server for uuid={}", uuid);
// Close rendezvous connection - we don't need to wait for response
drop(rendezvous_stream);
// Step 2: Connect to RELAY server and send RequestRelay to identify ourselves
let relay_addr: SocketAddr = tokio::net::lookup_host(relay_server)
.await?
.next()
@@ -624,9 +515,7 @@ async fn handle_relay_request(
info!("Connected to relay server at {}", relay_addr);
// Send RequestRelay to relay server with our uuid, licence_key, and peer's socket_addr
// The licence_key is required if the relay server is configured with -k option
// The socket_addr is CRITICAL - the relay server uses it to match us with the peer
// Relay pairs peers by uuid + mangled peer socket_addr (required when hbbr uses -k).
let request_relay = make_request_relay(uuid, relay_key, socket_addr);
let bytes = request_relay
.write_to_bytes()
@@ -634,10 +523,8 @@ async fn handle_relay_request(
bytes_codec::write_frame(&mut stream, &bytes).await?;
debug!("Sent RequestRelay to relay server for uuid={}", uuid);
// Decode peer address for logging
let peer_addr = rendezvous::AddrMangle::decode(socket_addr).unwrap_or(relay_addr);
// Step 3: Accept connection - relay server bridges the connection
connection_manager
.accept_connection(stream, peer_addr)
.await?;
@@ -649,14 +536,6 @@ async fn handle_relay_request(
Ok(())
}
/// Handle intranet/same-LAN connection request
///
/// When the server determines that the client and peer are on the same intranet
/// (same public IP or both on LAN), it sends FetchLocalAddr to the peer.
/// The peer must:
/// 1. Open a TCP connection to the rendezvous server
/// 2. Send LocalAddr with our local address
/// 3. Accept the peer connection over that same TCP stream
async fn handle_intranet_request(
rendezvous_addr: &str,
peer_socket_addr: &[u8],
@@ -670,11 +549,9 @@ async fn handle_intranet_request(
rendezvous_addr, local_addr, device_id
);
// Decode peer address for logging
let peer_addr = AddrMangle::decode(peer_socket_addr);
debug!("Peer address from FetchLocalAddr: {:?}", peer_addr);
// Connect to rendezvous server via TCP with timeout
let mut stream =
tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(rendezvous_addr))
.await
@@ -685,7 +562,6 @@ async fn handle_intranet_request(
rendezvous_addr
);
// Build LocalAddr message with our local address (mangled)
let local_addr_bytes = AddrMangle::encode(local_addr);
let msg = make_local_addr(
peer_socket_addr,
@@ -698,24 +574,16 @@ async fn handle_intranet_request(
.write_to_bytes()
.map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?;
// Send LocalAddr using RustDesk's variable-length framing
bytes_codec::write_frame(&mut stream, &bytes).await?;
info!("Sent LocalAddr to rendezvous server, waiting for peer connection");
// Now the rendezvous server will forward this to the client,
// and the client will connect to us through this same TCP stream.
// The server proxies the connection between client and peer.
// Get peer address for logging/connection tracking
let effective_peer_addr = peer_addr.unwrap_or_else(|| {
// If we can't decode the peer address, use the rendezvous server address
rendezvous_addr
.parse()
.unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap())
});
// Accept the connection - the stream is now a proxied connection to the client
connection_manager
.accept_connection(stream, effective_peer_addr)
.await?;

View File

@@ -1,18 +1,12 @@
//! RustDesk Protocol Messages
//!
//! This module provides the compiled protobuf messages for the RustDesk protocol.
//! Messages are generated from rendezvous.proto and message.proto at build time.
//! Uses protobuf-rust (same as RustDesk server) for full compatibility.
//! Protobuf wrappers (`protos/` → `OUT_DIR`).
use protobuf::Message;
// Include the generated protobuf code
#[path = ""]
pub mod hbb {
include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));
}
// Re-export commonly used types
pub use hbb::rendezvous::{
punch_hole_response, relay_response, rendezvous_message, ConfigUpdate, ConnType,
FetchLocalAddr, HealthCheck, KeyExchange, LocalAddr, NatType, OnlineRequest, OnlineResponse,
@@ -21,7 +15,6 @@ pub use hbb::rendezvous::{
RequestRelay, SoftwareUpdate, TestNatRequest, TestNatResponse,
};
// Re-export message.proto types
pub use hbb::message::{
key_event, login_response, message, misc, AudioFormat, AudioFrame, Auth2FA, Clipboard,
ControlKey, CursorData, CursorPosition, DisplayInfo, EncodedVideoFrame, EncodedVideoFrames,
@@ -30,7 +23,6 @@ pub use hbb::message::{
SupportedResolutions, TestDelay, VideoFrame, WindowsSessions,
};
/// Helper to create a RendezvousMessage with RegisterPeer
pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage {
let mut rp = RegisterPeer::new();
rp.id = id.to_string();
@@ -41,7 +33,6 @@ pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage {
msg
}
/// Helper to create a RendezvousMessage with RegisterPk
pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> RendezvousMessage {
let mut rpk = RegisterPk::new();
rpk.id = id.to_string();
@@ -54,7 +45,6 @@ pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> Rende
msg
}
/// Helper to create a PunchHoleSent message
pub fn make_punch_hole_sent(
socket_addr: &[u8],
id: &str,
@@ -74,10 +64,7 @@ pub fn make_punch_hole_sent(
msg
}
/// Helper to create a RelayResponse message (sent to rendezvous server)
/// IMPORTANT: The union field should be `Id` (our device ID), NOT `Pk`.
/// The rendezvous server will look up our registered public key using this ID,
/// sign it with the server's private key, and set the `pk` field before forwarding to client.
/// Use `id` (device id), not raw `pk`; hbbs fills `pk` when forwarding.
pub fn make_relay_response(
uuid: &str,
socket_addr: &[u8],
@@ -96,13 +83,7 @@ pub fn make_relay_response(
msg
}
/// Helper to create a RequestRelay message (sent to relay server to identify ourselves)
///
/// The `licence_key` is required if the relay server is configured with a key.
/// If the key doesn't match, the relay server will silently reject the connection.
///
/// IMPORTANT: `socket_addr` is the peer's encoded socket address (from FetchLocalAddr/RelayResponse).
/// The relay server uses this to match the two peers connecting to the same relay session.
/// `socket_addr` must be the peer's mangled addr; `licence_key` required if hbbr uses `-k`.
pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) -> RendezvousMessage {
let mut rr = RequestRelay::new();
rr.uuid = uuid.to_string();
@@ -114,8 +95,6 @@ pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) ->
msg
}
/// Helper to create a LocalAddr response message
/// This is sent in response to FetchLocalAddr when a peer on the same LAN wants to connect
pub fn make_local_addr(
socket_addr: &[u8],
local_addr: &[u8],
@@ -135,12 +114,10 @@ pub fn make_local_addr(
msg
}
/// Decode a RendezvousMessage from bytes
pub fn decode_rendezvous_message(buf: &[u8]) -> Result<RendezvousMessage, protobuf::Error> {
RendezvousMessage::parse_from_bytes(buf)
}
/// Decode a Message (session message) from bytes
pub fn decode_message(buf: &[u8]) -> Result<hbb::message::Message, protobuf::Error> {
hbb::message::Message::parse_from_bytes(buf)
}

Some files were not shown because too many files have changed in this diff Show More