Compare commits

...

21 Commits

Author SHA1 Message Date
mofeng-git
3ee3df77b8 chroe: 不再配置 iceCandidatePoolSize,沿用浏览器默认 2026-05-05 01:19:17 +08:00
mofeng-git
8ec2f25e82 chore: bump version to v0.2.0 2026-05-05 00:59:16 +08:00
mofeng-git
c27d3a6703 fix:改进atx usb 继电器适配;修复 webrtc 无法建立连接问题;网页样式优化 2026-05-05 00:52:16 +08:00
mofeng-git
6723f432a3 feat: 允许通过环境变量手动指定前端资源路径,删除 debug 分支默认资源路径 2026-05-04 17:53:27 +08:00
mofeng-git
12a3f1c947 feat: 增加设备丢失自恢复机制
增加音频设备丢失自恢复机制,完善视频设备丢失自恢复机制

降级部分日志级别,GOSTC key打印脱敏

代码格式化
2026-05-02 10:55:05 +08:00
mofeng-git
52754c862b feat: 优化网页消息提醒样式 2026-05-01 21:46:32 +08:00
mofeng-git
e51d243324 feat: 增加 MSD 虚拟盘文件路径编码 2026-05-01 21:27:03 +08:00
mofeng-git
a1ebd34083 feat: 外部扩展程序输出日志级别修改为 info 级别 2026-05-01 21:21:54 +08:00
mofeng-git
89b19ea7dd refactor: 修改为同步请求 2026-05-01 20:06:22 +08:00
mofeng-git
0d47d8395d refactor: 重构视频采集状态与错误分类公共逻辑 2026-05-01 17:56:56 +08:00
mofeng-git
d82c863f40 refactor: 精简依赖 2026-05-01 17:41:11 +08:00
mofeng-git
d8e7de74a6 refactor: 删除部分多余的代码和注释 2026-05-01 17:31:04 +08:00
SilentWind
74035f8e12 Merge pull request #247 from tedaimengtech/main
Update: Add CQU Mirror Information
2026-04-29 14:13:25 +08:00
tedaimeng
8d45186eba Add Mirror Download Services to README
Added a section for Mirror Download Services and updated sponsors.
2026-04-29 13:12:35 +08:00
tedaimeng
c484580b8f Update README with CQUMirror details
Added CQUMirror information and links to the README.
2026-04-29 13:09:23 +08:00
SilentWind
56bce7937c Add GNU General Public License v3 2026-04-29 10:38:30 +08:00
mofeng-git
07b982d1d2 feat: 完善 USB UVC 设备异常处理,添加 USB 设备复位功能 2026-04-27 16:37:04 +08:00
mofeng-git
9065e01225 feat: 优化控制台页面状态工具栏在不同宽度网页下的自适应能力 2026-04-25 20:32:44 +08:00
mofeng-git
cc3cc15774 refactor: 删除部分多余的 Ventoy 逻辑 2026-04-20 14:07:28 +08:00
mofeng-git
fcb39c73fc refactor: 删除未使用的公共 STUN/TURN 逻辑 2026-04-20 10:15:53 +08:00
mofeng-git
7c703b8b4b feat: 深入适配 RK628D CSI 采集卡的设备识别、参数读取、自恢复和音频采集 2026-04-19 11:26:21 +08:00
185 changed files with 10537 additions and 12074 deletions

View File

@@ -1,6 +1,6 @@
[package]
name = "one-kvm"
version = "0.1.9"
version = "0.2.0"
edition = "2021"
authors = ["SilentWind"]
description = "A open and lightweight IP-KVM solution written in Rust"
@@ -16,8 +16,8 @@ tokio-util = { version = "0.7", features = ["rt"] }
# Web framework
axum = { version = "0.8", features = ["ws", "multipart", "tokio"] }
axum-extra = { version = "0.12", features = ["typed-header", "cookie"] }
tower-http = { version = "0.6", features = ["fs", "cors", "trace", "compression-gzip"] }
axum-extra = { version = "0.12", features = ["cookie"] }
tower-http = { version = "0.6", features = ["cors", "trace"] }
# Database - Use bundled SQLite for static linking
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
@@ -29,7 +29,6 @@ serde_json = "1"
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "tracing-log"] }
tracing-log = "0.2"
# Error handling
thiserror = "2"
@@ -41,7 +40,6 @@ rand = "0.9"
# Utilities
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
base64 = "0.22"
nix = { version = "0.30", features = ["fs", "net", "hostname", "poll"] }
@@ -51,7 +49,7 @@ reqwest = { version = "0.13", features = ["stream", "rustls", "json"], default-f
urlencoding = "2"
# Static file embedding
rust-embed = { version = "8", features = ["compression"] }
rust-embed = { version = "8", features = ["compression", "debug-embed"] }
mime_guess = "2"
# TLS/HTTPS
@@ -62,8 +60,8 @@ axum-server = { version = "0.8", features = ["tls-rustls"] }
# CLI argument parsing
clap = { version = "4", features = ["derive"] }
# Time
time = "0.3"
# Time (cookie max_age + RFC3339 timestamps)
time = { version = "0.3", features = ["serde", "formatting", "parsing"] }
# Video capture (V4L2)
v4l2r = "0.0.7"
@@ -125,13 +123,10 @@ libyuv = { path = "res/vcpkg/libyuv" }
typeshare = "1.0"
[dev-dependencies]
tokio-test = "0.4"
tempfile = "3"
[build-dependencies]
protobuf-codegen = "3.7"
toml = "0.9"
cc = "1"
[profile.release]
opt-level = 3

674
LICENSE Normal file
View File

@@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

View File

@@ -222,6 +222,9 @@ One-KVM builds on many great open-source projects; a lot of time goes into testi
### Sponsors
**Mirror Download Services:**
- **[Chongqing University Open Source Software Mirror](https://mirrors.cqu.edu.cn/)** — provides mirror download services
**File hosting**
- **[Huang1111 public-interest program](https://pan.huang1111.cn/s/mxkx3T1)** — login-free downloads

View File

@@ -218,6 +218,9 @@ One-KVM 已上架飞牛 **应用市场**,在 NAS 上直接搜索安装即可
本项目得到以下赞助商的支持:
**镜像下载服务:**
- **[重庆大学开源软件镜像站](https://mirrors.cqu.edu.cn/)** - 提供镜像站下载服务
**文件存储服务:**
- **[Huang1111公益计划](https://pan.huang1111.cn/s/mxkx3T1)** - 提供免登录下载服务
@@ -227,4 +230,4 @@ One-KVM 已上架飞牛 **应用市场**,在 NAS 上直接搜索安装即可
![林枫云](https://docs.one-kvm.cn/img/36076FEFF0898A80EBD5756D28F4076C.png)
林枫云主营国内外地域的精品线路业务服务器、高主频游戏服务器和大带宽服务器。
林枫云主营国内外地域的精品线路业务服务器、高主频游戏服务器和大带宽服务器。

View File

@@ -64,25 +64,6 @@ fn generate_secrets() {
pub mod ice {
/// Google public STUN server URL (hardcoded)
pub const STUN_SERVER: &str = "stun:stun.l.google.com:19302";
/// TURN server URLs - not provided, users must configure their own
pub const TURN_URLS: &str = "";
/// TURN authentication username
pub const TURN_USERNAME: &str = "";
/// TURN authentication password
pub const TURN_PASSWORD: &str = "";
/// Always returns true since we have STUN
pub const fn is_configured() -> bool {
true
}
/// Always returns false since TURN is not provided
pub const fn has_turn() -> bool {
false
}
}
/// RustDesk public server configuration - NOT PROVIDED

View File

@@ -4,6 +4,9 @@ version = "0.8.0"
edition = "2021"
description = "Hardware video codec for IP-KVM (Windows/Linux)"
[package.metadata.cargo-machete]
ignored = ["serde"]
[features]
default = []
rkmpp = []
@@ -17,6 +20,3 @@ serde_json = "1.0"
[build-dependencies]
cc = "1.0"
bindgen = "0.59"
[dev-dependencies]
env_logger = "0.10"

View File

@@ -45,4 +45,4 @@ pub use error::{Result, VentoyError};
pub use exfat::FileInfo;
pub use image::VentoyImage;
pub use partition::{parse_size, PartitionLayout};
pub use resources::{get_resource_dir, init_resources, is_initialized, required_files};
pub use resources::{init_resources, is_initialized, required_files};

View File

@@ -5,7 +5,7 @@
use crate::error::{Result, VentoyError};
use std::fs;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::OnceLock;
/// Resource file names
@@ -151,13 +151,6 @@ pub fn get_ventoy_disk_img() -> Result<&'static [u8]> {
})
}
/// Get the resource directory path for a given data directory
///
/// Returns `{data_dir}/ventoy`
pub fn get_resource_dir(data_dir: &Path) -> PathBuf {
data_dir.join("ventoy")
}
/// List required resource files
pub fn required_files() -> &'static [&'static str] {
&[BOOT_IMG_NAME, CORE_IMG_NAME, VENTOY_DISK_IMG_NAME]
@@ -166,22 +159,6 @@ pub fn required_files() -> &'static [&'static str] {
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn create_test_resources(dir: &Path) {
// Create boot.img (512 bytes)
let mut boot = std::fs::File::create(dir.join(BOOT_IMG_NAME)).unwrap();
boot.write_all(&[0u8; 512]).unwrap();
// Create core.img (fake, 1KB)
let mut core = std::fs::File::create(dir.join(CORE_IMG_NAME)).unwrap();
core.write_all(&[0u8; 1024]).unwrap();
// Create ventoy.disk.img (fake, 1KB)
let mut ventoy = std::fs::File::create(dir.join(VENTOY_DISK_IMG_NAME)).unwrap();
ventoy.write_all(&[0u8; 1024]).unwrap();
}
#[test]
fn test_required_files() {
@@ -191,11 +168,4 @@ mod tests {
assert!(files.contains(&"core.img"));
assert!(files.contains(&"ventoy.disk.img"));
}
#[test]
fn test_get_resource_dir() {
let data_dir = Path::new("/var/lib/one-kvm");
let resource_dir = get_resource_dir(data_dir);
assert_eq!(resource_dir, PathBuf::from("/var/lib/one-kvm/ventoy"));
}
}

View File

@@ -8,5 +8,4 @@ license = "BSD-3-Clause"
[dependencies]
[build-dependencies]
cc = "1.0"
bindgen = "0.59"

View File

@@ -7,6 +7,7 @@ use gpio_cdev::{Chip, LineHandle, LineRequestFlags};
use serialport::SerialPort;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::os::fd::AsRawFd;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
@@ -18,6 +19,10 @@ use crate::error::{AppError, Result};
pub type SharedSerialHandle = Arc<Mutex<Box<dyn SerialPort>>>;
const USB_RELAY_MAX_CHANNEL: u8 = 8;
const USB_RELAY_REPORT_LEN: usize = 9;
const HIDIOCSFEATURE_9: libc::c_ulong = 0xC009_4806; // _IOC(_IOC_READ|_IOC_WRITE, 'H', 0x06, 9)
/// Timing constants for ATX operations
pub mod timing {
use std::time::Duration;
@@ -129,12 +134,23 @@ impl AtxKeyExecutor {
}
}
AtxDriverType::UsbRelay => {
if self.config.pin == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if self.config.pin > u8::MAX as u32 {
return Err(AppError::Config(format!(
"USB relay channel must be <= {}",
u8::MAX
)));
}
if self.config.pin > USB_RELAY_MAX_CHANNEL as u32 {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
}
AtxDriverType::Gpio | AtxDriverType::None => {}
}
@@ -292,26 +308,64 @@ impl AtxKeyExecutor {
u8::MAX
))
})?;
if channel == 0 {
return Err(AppError::Config(
"USB relay channel must be 1-based (>= 1)".to_string(),
));
}
if channel > USB_RELAY_MAX_CHANNEL {
return Err(AppError::Config(format!(
"USB HID relay channel must be <= {}",
USB_RELAY_MAX_CHANNEL
)));
}
// Standard HID relay command format
let cmd = if on {
[0x00, channel + 1, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00]
} else {
[0x00, channel + 1, 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00]
};
let cmd = Self::build_usb_relay_command(channel, on);
let mut guard = self.usb_relay_handle.lock().unwrap();
let device = guard
.as_mut()
.ok_or_else(|| AppError::Internal("USB relay not initialized".to_string()))?;
device
.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("USB relay write failed: {}", e)))?;
if let Err(feature_err) = Self::send_usb_relay_feature_report(device, &cmd) {
debug!(
"USB relay feature report failed ({}), falling back to hidraw write",
feature_err
);
device.write_all(&cmd).map_err(|write_err| {
AppError::Internal(format!(
"USB relay feature report failed: {}; raw write failed: {}",
feature_err, write_err
))
})?;
device
.flush()
.map_err(|e| AppError::Internal(format!("USB relay flush failed: {}", e)))?;
}
Ok(())
}
fn build_usb_relay_command(channel: u8, on: bool) -> [u8; USB_RELAY_REPORT_LEN] {
let mut cmd = [0x00; USB_RELAY_REPORT_LEN];
cmd[1] = if on { 0xFF } else { 0xFD };
cmd[2] = channel;
cmd
}
fn send_usb_relay_feature_report(
device: &File,
report: &[u8; USB_RELAY_REPORT_LEN],
) -> std::io::Result<()> {
// Linux hidraw feature reports include the report ID as the first byte.
let rc = unsafe { libc::ioctl(device.as_raw_fd(), HIDIOCSFEATURE_9, report.as_ptr()) };
if rc < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
/// Pulse Serial relay
async fn pulse_serial(&self, duration: Duration) -> Result<()> {
info!(
@@ -367,6 +421,8 @@ impl AtxKeyExecutor {
port.write_all(&cmd)
.map_err(|e| AppError::Internal(format!("Serial relay write failed: {}", e)))?;
port.flush()
.map_err(|e| AppError::Internal(format!("Serial relay flush failed: {}", e)))?;
Ok(())
}
@@ -453,7 +509,7 @@ mod tests {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
pin: 1,
active_level: ActiveLevel::High, // Ignored for USB relay
baud_rate: 9600,
};
@@ -481,6 +537,18 @@ mod tests {
assert_eq!(timing::RESET_PRESS.as_millis(), 500);
}
#[test]
fn test_usb_relay_command_format() {
assert_eq!(
AtxKeyExecutor::build_usb_relay_command(1, true),
[0x00, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
assert_eq!(
AtxKeyExecutor::build_usb_relay_command(1, false),
[0x00, 0xFD, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_zero() {
let config = AtxKeyConfig {
@@ -495,6 +563,34 @@ mod tests {
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_usb_relay_channel_zero() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: 0,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_usb_relay_channel_overflow() {
let config = AtxKeyConfig {
driver: AtxDriverType::UsbRelay,
device: "/dev/hidraw0".to_string(),
pin: USB_RELAY_MAX_CHANNEL as u32 + 1,
active_level: ActiveLevel::High,
baud_rate: 9600,
};
let mut executor = AtxKeyExecutor::new(config);
let err = executor.init().await.unwrap_err();
assert!(matches!(err, AppError::Config(_)));
}
#[tokio::test]
async fn test_executor_init_rejects_serial_channel_overflow() {
let config = AtxKeyConfig {

View File

@@ -15,6 +15,7 @@
//!
//! - **GPIO**: Uses Linux GPIO character device (/dev/gpiochipX) for direct hardware control
//! - **USB Relay**: Uses HID USB relay modules for isolated switching
//! - **Serial Relay**: Uses LCUS-style serial relay modules
//!
//! # Example
//!
@@ -59,9 +60,25 @@ pub use types::{
};
pub use wol::send_wol;
fn hidraw_uevent_is_usb_relay(uevent: &str) -> bool {
let upper = uevent.to_ascii_uppercase();
upper.contains("000016C0:000005DF")
|| upper.contains("16C0:05DF")
|| upper.contains("PRODUCT=16C0/5DF")
|| upper.contains("USBRELAY")
|| upper.contains("USB RELAY")
}
fn is_usb_relay_hidraw(name: &str) -> bool {
let uevent_path = format!("/sys/class/hidraw/{}/device/uevent", name);
std::fs::read_to_string(uevent_path)
.map(|uevent| hidraw_uevent_is_usb_relay(&uevent))
.unwrap_or(false)
}
/// Discover available ATX devices on the system
///
/// Scans for GPIO chips and USB HID relay devices in a single pass.
/// Scans for GPIO chips, LCUS USB HID relay devices, and serial relay ports.
pub fn discover_devices() -> AtxDevices {
let mut devices = AtxDevices::default();
@@ -72,7 +89,7 @@ pub fn discover_devices() -> AtxDevices {
let name_str = name.to_string_lossy();
if name_str.starts_with("gpiochip") {
devices.gpio_chips.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("hidraw") {
} else if name_str.starts_with("hidraw") && is_usb_relay_hidraw(&name_str) {
devices.usb_relays.push(format!("/dev/{}", name_str));
} else if name_str.starts_with("ttyUSB") || name_str.starts_with("ttyACM") {
devices.serial_ports.push(format!("/dev/{}", name_str));
@@ -96,6 +113,20 @@ mod tests {
let _devices = discover_devices();
}
#[test]
fn test_hidraw_uevent_detects_usb_relay_id() {
assert!(hidraw_uevent_is_usb_relay(
"HID_ID=0003:000016C0:000005DF\nHID_NAME=www.dcttech.com USBRelay2\n"
));
}
#[test]
fn test_hidraw_uevent_rejects_unrelated_hid() {
assert!(!hidraw_uevent_is_usb_relay(
"HID_ID=0003:0000046D:0000C534\nHID_NAME=Logitech USB Receiver\n"
));
}
#[test]
fn test_module_exports() {
// Verify all public exports are accessible

View File

@@ -61,7 +61,7 @@ pub struct AtxKeyConfig {
pub device: String,
/// Pin or channel number:
/// - For GPIO: GPIO pin number
/// - For USB Relay: relay channel (0-based)
/// - For USB Relay: relay channel (1-based)
/// - For Serial Relay (LCUS): relay channel (1-based)
pub pin: u32,
/// Active level (only applicable to GPIO, ignored for USB Relay)

View File

@@ -1,5 +1,3 @@
//! ALSA audio capture implementation
use alsa::pcm::{Access, Format, Frames, HwParams, State, IO};
use alsa::{Direction, ValueOr, PCM};
use bytes::Bytes;
@@ -14,30 +12,23 @@ use crate::error::{AppError, Result};
use crate::utils::LogThrottler;
use crate::{error_throttled, warn_throttled};
/// Audio capture configuration
#[derive(Debug, Clone)]
pub struct AudioConfig {
/// ALSA device name (e.g., "hw:0,0" or "default")
pub device_name: String,
/// Sample rate in Hz
pub sample_rate: u32,
/// Number of channels (1 = mono, 2 = stereo)
pub channels: u32,
/// Samples per frame (for Opus, typically 480 for 10ms at 48kHz)
pub frame_size: u32,
/// Buffer size in frames
pub buffer_frames: u32,
/// Period size in frames
pub period_frames: u32,
}
impl Default for AudioConfig {
fn default() -> Self {
Self {
device_name: "default".to_string(),
device_name: String::new(),
sample_rate: 48000,
channels: 2,
frame_size: 960, // 20ms at 48kHz (good for Opus)
frame_size: 960,
buffer_frames: 4096,
period_frames: 960,
}
@@ -45,59 +36,33 @@ impl Default for AudioConfig {
}
impl AudioConfig {
/// Create config for a specific device
pub fn for_device(device: &AudioDeviceInfo) -> Self {
let sample_rate = if device.sample_rates.contains(&48000) {
48000
} else {
*device.sample_rates.first().unwrap_or(&48000)
};
let channels = if device.channels.contains(&2) {
2
} else {
*device.channels.first().unwrap_or(&2)
};
Self {
device_name: device.name.clone(),
sample_rate,
channels,
frame_size: sample_rate / 50, // 20ms
..Default::default()
}
}
/// Bytes per sample (16-bit signed)
pub fn bytes_per_sample(&self) -> u32 {
2 * self.channels
}
/// Bytes per frame
pub fn bytes_per_frame(&self) -> usize {
(self.frame_size * self.bytes_per_sample()) as usize
}
}
/// Audio frame data
#[derive(Debug, Clone)]
pub struct AudioFrame {
/// Raw PCM data (S16LE interleaved)
pub data: Bytes,
/// Sample rate
pub sample_rate: u32,
/// Number of channels
pub channels: u32,
/// Number of samples per channel
pub samples: u32,
/// Frame sequence number
pub sequence: u64,
/// Capture timestamp
pub timestamp: Instant,
}
impl AudioFrame {
/// One capture block: `sample_rate` must be the **hardware** rate (e.g. ALSA `actual_rate`).
pub fn new_interleaved(data: Bytes, channels: u32, sample_rate: u32, sequence: u64) -> Self {
let bps = 2 * channels;
Self {
@@ -111,7 +76,6 @@ impl AudioFrame {
}
}
/// Audio capture state
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureState {
Stopped,
@@ -119,7 +83,6 @@ pub enum CaptureState {
Error,
}
/// ALSA audio capturer
pub struct AudioCapturer {
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
@@ -128,15 +91,13 @@ pub struct AudioCapturer {
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
capture_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
/// Log throttler to prevent log flooding
log_throttler: LogThrottler,
}
impl AudioCapturer {
/// Create a new audio capturer
pub fn new(config: AudioConfig) -> Self {
let (state_tx, state_rx) = watch::channel(CaptureState::Stopped);
let (frame_tx, _) = broadcast::channel(16); // Buffer size 16 for low latency
let (frame_tx, _) = broadcast::channel(16);
Self {
config,
@@ -150,28 +111,24 @@ impl AudioCapturer {
}
}
/// Get current state
pub fn state(&self) -> CaptureState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<CaptureState> {
self.state_rx.clone()
}
/// Subscribe to audio frames
pub fn subscribe(&self) -> broadcast::Receiver<AudioFrame> {
self.frame_tx.subscribe()
}
/// Start capturing
pub async fn start(&self) -> Result<()> {
if self.state() == CaptureState::Running {
return Ok(());
}
info!(
debug!(
"Starting audio capture on {} at {}Hz {}ch",
self.config.device_name, self.config.sample_rate, self.config.channels
);
@@ -186,14 +143,27 @@ impl AudioCapturer {
let log_throttler = self.log_throttler.clone();
let handle = tokio::task::spawn_blocking(move || {
capture_loop(config, state, frame_tx, stop_flag, sequence, log_throttler);
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
});
*self.capture_handle.lock().await = Some(handle);
Ok(())
}
/// Stop capturing
pub async fn stop(&self) -> Result<()> {
info!("Stopping audio capture");
self.stop_flag.store(true, Ordering::SeqCst);
@@ -206,38 +176,11 @@ impl AudioCapturer {
Ok(())
}
/// Check if running
pub fn is_running(&self) -> bool {
self.state() == CaptureState::Running
}
}
/// Main capture loop
fn capture_loop(
config: AudioConfig,
state: Arc<watch::Sender<CaptureState>>,
frame_tx: broadcast::Sender<AudioFrame>,
stop_flag: Arc<AtomicBool>,
sequence: Arc<AtomicU64>,
log_throttler: LogThrottler,
) {
let result = run_capture(
&config,
&state,
&frame_tx,
&stop_flag,
&sequence,
&log_throttler,
);
if let Err(e) = result {
error_throttled!(log_throttler, "capture_error", "Audio capture error: {}", e);
let _ = state.send(CaptureState::Error);
} else {
let _ = state.send(CaptureState::Stopped);
}
}
fn run_capture(
config: &AudioConfig,
state: &watch::Sender<CaptureState>,
@@ -246,7 +189,6 @@ fn run_capture(
sequence: &AtomicU64,
log_throttler: &LogThrottler,
) -> Result<()> {
// Open ALSA device
let pcm = PCM::new(&config.device_name, Direction::Capture, false).map_err(|e| {
AppError::AudioError(format!(
"Failed to open audio device {}: {}",
@@ -254,7 +196,6 @@ fn run_capture(
))
})?;
// Configure hardware parameters
{
let hwp = HwParams::any(&pcm)
.map_err(|e| AppError::AudioError(format!("Failed to get HwParams: {}", e)))?;
@@ -281,31 +222,34 @@ fn run_capture(
.map_err(|e| AppError::AudioError(format!("Failed to apply hw params: {}", e)))?;
}
// Get actual configuration
let actual_rate = pcm
.hw_params_current()
.map(|h| h.get_rate().unwrap_or(config.sample_rate))
.unwrap_or(config.sample_rate);
if actual_rate != config.sample_rate {
info!(
"ALSA sample rate differs from requested ({}Hz vs {}Hz); streamer will resample to 48000Hz for Opus",
actual_rate, config.sample_rate
);
} else {
info!(
"Audio capture configured: {}Hz {}ch (requested {}Hz)",
actual_rate, config.channels, config.sample_rate
);
let hw_now = pcm.hw_params_current().map_err(|e| {
AppError::AudioError(format!("Failed to read hw_params after apply: {}", e))
})?;
let actual_rate = hw_now
.get_rate()
.map_err(|e| AppError::AudioError(format!("Failed to read sample rate: {}", e)))?;
let actual_ch = hw_now
.get_channels()
.map_err(|e| AppError::AudioError(format!("Failed to read channels: {}", e)))?;
if actual_rate != 48_000 {
return Err(AppError::AudioError(format!(
"Audio capture requires 48000 Hz; device is {} Hz",
actual_rate
)));
}
if actual_ch != 2 {
return Err(AppError::AudioError(format!(
"Audio capture requires 2 channels (stereo); device has {}",
actual_ch
)));
}
debug!("Audio capture: 48000 Hz, 2 ch");
// Prepare for capture
pcm.prepare()
.map_err(|e| AppError::AudioError(format!("Failed to prepare PCM: {}", e)))?;
let _ = state.send(CaptureState::Running);
// Sized from actual period — `readi` may return up to ~one period of frames per call.
let period_frames = pcm
.hw_params_current()
.ok()
@@ -317,9 +261,7 @@ fn run_capture(
let bytes_per_frame = (config.channels as usize) * 2;
let mut buffer = vec![0u8; buf_frames * bytes_per_frame];
// Capture loop
while !stop_flag.load(Ordering::Relaxed) {
// Check PCM state
match pcm.state() {
State::XRun => {
warn_throttled!(log_throttler, "xrun", "Audio buffer overrun, recovering");
@@ -338,9 +280,7 @@ fn run_capture(
_ => {}
}
// Get IO handle and read audio data directly as bytes
// Note: Use io() instead of io_checked() because USB audio devices
// typically don't support mmap, which io_checked() requires
// io_bytes: USB capture often lacks mmap (io_checked requires it).
let io: IO<u8> = pcm.io_bytes();
match io.readi(&mut buffer) {
@@ -349,19 +289,16 @@ fn run_capture(
continue;
}
// Calculate actual byte count
let byte_count = frames_read * config.channels as usize * 2;
// Directly use the buffer slice (already in correct byte format)
let seq = sequence.fetch_add(1, Ordering::Relaxed);
let frame = AudioFrame::new_interleaved(
Bytes::copy_from_slice(&buffer[..byte_count]),
config.channels,
actual_rate,
48_000,
seq,
);
// Send to subscribers
if frame_tx.receiver_count() > 0 {
if let Err(e) = frame_tx.send(frame) {
debug!("No audio receivers: {}", e);
@@ -369,15 +306,15 @@ fn run_capture(
}
}
Err(e) => {
// Check for buffer overrun (EPIPE = 32 on Linux)
let desc = e.to_string();
if desc.contains("EPIPE") || desc.contains("Broken pipe") {
// Buffer overrun
if is_device_lost_error(&desc) {
return Err(AppError::AudioError(format!(
"Audio device lost while reading {}: {}",
config.device_name, e
)));
} else if desc.contains("EPIPE") || desc.contains("Broken pipe") {
warn_throttled!(log_throttler, "buffer_overrun", "Audio buffer overrun");
let _ = pcm.prepare();
} else if desc.contains("No such device") || desc.contains("ENODEV") {
// Device disconnected - use longer throttle for this
error_throttled!(log_throttler, "no_device", "Audio read error: {}", e);
} else {
error_throttled!(log_throttler, "read_error", "Audio read error: {}", e);
}
@@ -388,3 +325,10 @@ fn run_capture(
info!("Audio capture stopped");
Ok(())
}
fn is_device_lost_error(desc: &str) -> bool {
desc.contains("No such device")
|| desc.contains("ENODEV")
|| desc.contains("ENXIO")
|| desc.contains("ESHUTDOWN")
}

View File

@@ -1,35 +1,38 @@
//! Audio controller for high-level audio management
//!
//! Provides device enumeration, selection, quality control, and streaming management.
//! Device selection, quality presets, streaming.
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::info;
use tracing::{debug, info, warn};
use super::capture::AudioConfig;
use super::device::{enumerate_audio_devices_with_current, AudioDeviceInfo};
use super::device::{
enumerate_audio_devices, enumerate_audio_devices_with_current, find_best_audio_device,
AudioDeviceInfo,
};
use super::encoder::{OpusConfig, OpusFrame};
use super::monitor::{AudioHealthMonitor, AudioHealthStatus};
use super::streamer::{AudioStreamer, AudioStreamerConfig};
use super::monitor::AudioHealthMonitor;
use super::streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};
use crate::error::{AppError, Result};
use crate::events::EventBus;
use crate::events::{EventBus, StreamDeviceLostKind, SystemEvent};
const AUDIO_RECOVERY_RETRY_DELAY: Duration = Duration::from_secs(1);
type AudioRecoveredCallback = Arc<dyn Fn() + Send + Sync>;
/// Audio quality presets
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioQuality {
/// Low bandwidth voice (32kbps)
Voice,
/// Balanced quality (64kbps) - default
#[default]
Balanced,
/// High quality audio (128kbps)
High,
}
impl AudioQuality {
/// Get the bitrate for this quality level
pub fn bitrate(&self) -> u32 {
match self {
AudioQuality::Voice => 32000,
@@ -38,17 +41,6 @@ impl AudioQuality {
}
}
/// Parse from string
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"voice" | "low" => AudioQuality::Voice,
"high" | "music" => AudioQuality::High,
_ => AudioQuality::Balanced,
}
}
/// Convert to OpusConfig
pub fn to_opus_config(&self) -> OpusConfig {
match self {
AudioQuality::Voice => OpusConfig::voice(),
@@ -58,6 +50,22 @@ impl AudioQuality {
}
}
impl FromStr for AudioQuality {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"voice" => Ok(Self::Voice),
"balanced" => Ok(Self::Balanced),
"high" => Ok(Self::High),
_ => Err(AppError::BadRequest(format!(
"invalid audio quality {:?} (expected voice, balanced, or high)",
s.trim()
))),
}
}
}
impl std::fmt::Display for AudioQuality {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -68,17 +76,10 @@ impl std::fmt::Display for AudioQuality {
}
}
/// Audio controller configuration
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
/// These are optimal for Opus encoding and match WebRTC requirements.
#[derive(Debug, Clone)]
pub struct AudioControllerConfig {
/// Whether audio is enabled
pub enabled: bool,
/// Selected device name
pub device: String,
/// Audio quality preset
pub quality: AudioQuality,
}
@@ -86,74 +87,347 @@ impl Default for AudioControllerConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
device: String::new(),
quality: AudioQuality::Balanced,
}
}
}
/// Current audio status
#[derive(Debug, Clone, Serialize)]
pub struct AudioStatus {
/// Whether audio feature is enabled
pub enabled: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Currently selected device
pub device: Option<String>,
/// Current quality preset
pub quality: AudioQuality,
/// Number of connected subscribers
pub subscriber_count: usize,
/// Error message if any
pub error: Option<String>,
}
/// Audio controller
///
/// High-level interface for audio management, providing:
/// - Device enumeration and selection
/// - Quality control
/// - Stream start/stop
/// - Status reporting
pub struct AudioController {
config: RwLock<AudioControllerConfig>,
streamer: RwLock<Option<Arc<AudioStreamer>>>,
devices: RwLock<Vec<AudioDeviceInfo>>,
event_bus: RwLock<Option<Arc<EventBus>>>,
last_error: RwLock<Option<String>>,
/// Health monitor for error tracking and recovery
config: Arc<RwLock<AudioControllerConfig>>,
streamer: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
devices: Arc<RwLock<Vec<AudioDeviceInfo>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
}
impl AudioController {
/// Create a new audio controller with configuration
pub fn new(config: AudioControllerConfig) -> Self {
Self {
config: RwLock::new(config),
streamer: RwLock::new(None),
devices: RwLock::new(Vec::new()),
event_bus: RwLock::new(None),
last_error: RwLock::new(None),
monitor: Arc::new(AudioHealthMonitor::with_defaults()),
config: Arc::new(RwLock::new(config)),
streamer: Arc::new(RwLock::new(None)),
devices: Arc::new(RwLock::new(Vec::new())),
event_bus: Arc::new(RwLock::new(None)),
monitor: Arc::new(AudioHealthMonitor::new()),
recovery_in_progress: Arc::new(AtomicBool::new(false)),
recovered_callback: Arc::new(RwLock::new(None)),
}
}
/// Set event bus for internal state notifications.
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus);
}
/// Mark the device-info snapshot as stale.
pub async fn set_recovered_callback(&self, callback: Arc<dyn Fn() + Send + Sync>) {
*self.recovered_callback.write().await = Some(callback);
}
async fn mark_device_info_dirty(&self) {
if let Some(ref bus) = *self.event_bus.read().await {
bus.mark_device_info_dirty();
}
}
/// List available audio capture devices
async fn publish_state(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
state: &str,
device: Option<String>,
reason: Option<&str>,
next_retry_ms: Option<u64>,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamStateChanged {
state: state.to_string(),
device,
reason: reason.map(str::to_string),
next_retry_ms,
});
bus.mark_device_info_dirty();
}
}
async fn publish_device_lost(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
reason: &str,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: device.to_string(),
reason: reason.to_string(),
});
}
}
async fn publish_reconnecting(
event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>,
device: &str,
attempt: u32,
) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamReconnecting {
device: device.to_string(),
attempt,
});
}
}
async fn publish_recovered(event_bus: &Arc<RwLock<Option<Arc<EventBus>>>>, device: &str) {
if let Some(ref bus) = *event_bus.read().await {
bus.publish(SystemEvent::StreamRecovered {
device: device.to_string(),
});
}
}
fn select_recovery_device(
devices: &[AudioDeviceInfo],
preferred: &str,
) -> Option<AudioDeviceInfo> {
if !preferred.trim().is_empty() {
if let Some(device) = devices.iter().find(|d| d.name == preferred) {
return Some(device.clone());
}
}
devices
.iter()
.find(|d| d.is_hdmi && d.sample_rates.contains(&48_000) && d.channels.contains(&2))
.or_else(|| {
devices
.iter()
.find(|d| d.sample_rates.contains(&48_000) && d.channels.contains(&2))
})
.or_else(|| devices.first())
.cloned()
}
fn spawn_stream_monitor_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
streamer: Arc<AudioStreamer>,
device: String,
) {
let mut state_rx = streamer.state_watch();
tokio::spawn(async move {
loop {
if state_rx.changed().await.is_err() {
return;
}
if *state_rx.borrow() != AudioStreamState::Error {
continue;
}
{
let current = streamer_slot.read().await;
if !current
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &streamer))
{
return;
}
}
let reason = format!("Audio device lost: {}", device);
monitor.report_error(&reason, "device_lost").await;
Self::spawn_recovery_task_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
device,
reason,
);
return;
}
});
}
fn spawn_recovery_task_from_parts(
config: Arc<RwLock<AudioControllerConfig>>,
streamer_slot: Arc<RwLock<Option<Arc<AudioStreamer>>>>,
event_bus: Arc<RwLock<Option<Arc<EventBus>>>>,
monitor: Arc<AudioHealthMonitor>,
recovery_in_progress: Arc<AtomicBool>,
recovered_callback: Arc<RwLock<Option<AudioRecoveredCallback>>>,
lost_device: String,
reason: String,
) {
if recovery_in_progress.swap(true, Ordering::SeqCst) {
debug!("Audio recovery already in progress");
return;
}
tokio::spawn(async move {
warn!("Audio recovery started for {}: {}", lost_device, reason);
Self::publish_device_lost(&event_bus, &lost_device, &reason).await;
Self::publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_device_lost"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
let mut attempt = 0u32;
loop {
if !recovery_in_progress.load(Ordering::SeqCst) {
debug!("Audio recovery canceled");
return;
}
if streamer_slot
.read()
.await
.as_ref()
.is_some_and(|s| s.is_running())
{
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
let cfg = config.read().await.clone();
if !cfg.enabled {
recovery_in_progress.store(false, Ordering::SeqCst);
return;
}
attempt = attempt.saturating_add(1);
Self::publish_reconnecting(&event_bus, &lost_device, attempt).await;
Self::publish_state(
&event_bus,
"device_lost",
Some(lost_device.clone()),
Some("audio_reconnecting"),
Some(AUDIO_RECOVERY_RETRY_DELAY.as_millis() as u64),
)
.await;
tokio::time::sleep(AUDIO_RECOVERY_RETRY_DELAY).await;
let devices = match enumerate_audio_devices() {
Ok(devices) => devices,
Err(e) => {
debug!(
"Audio recovery enumerate failed (attempt {}): {}",
attempt, e
);
continue;
}
};
let Some(device) = Self::select_recovery_device(&devices, &cfg.device) else {
debug!("No audio devices found during recovery attempt {}", attempt);
continue;
};
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device.name.clone(),
..Default::default()
},
opus: cfg.quality.to_opus_config(),
};
let new_streamer = Arc::new(AudioStreamer::with_config(streamer_config));
match new_streamer.start().await {
Ok(()) => {
{
let mut cfg = config.write().await;
cfg.device = device.name.clone();
}
*streamer_slot.write().await = Some(new_streamer.clone());
monitor.report_recovered().await;
Self::publish_recovered(&event_bus, &device.name).await;
if let Some(callback) = recovered_callback.read().await.clone() {
callback();
}
Self::publish_state(
&event_bus,
"streaming",
Some(device.name.clone()),
None,
None,
)
.await;
recovery_in_progress.store(false, Ordering::SeqCst);
info!(
"Audio device recovered with {} after {} attempts",
device.name, attempt
);
Self::spawn_stream_monitor_from_parts(
config,
streamer_slot,
event_bus,
monitor,
recovery_in_progress,
recovered_callback,
new_streamer,
device.name,
);
return;
}
Err(e) => {
debug!(
"Audio recovery start failed with {} (attempt {}): {}",
device.name, attempt, e
);
}
}
}
});
}
fn spawn_recovery_task(&self, lost_device: String, reason: String) {
Self::spawn_recovery_task_from_parts(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
self.monitor.clone(),
self.recovery_in_progress.clone(),
self.recovered_callback.clone(),
lost_device,
reason,
);
}
fn spawn_stream_monitor(&self, streamer: Arc<AudioStreamer>, device: String) {
Self::spawn_stream_monitor_from_parts(
self.config.clone(),
self.streamer.clone(),
self.event_bus.clone(),
self.monitor.clone(),
self.recovery_in_progress.clone(),
self.recovered_callback.clone(),
streamer,
device,
);
}
pub async fn list_devices(&self) -> Result<Vec<AudioDeviceInfo>> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
@@ -165,41 +439,23 @@ impl AudioController {
Ok(devices)
}
/// Refresh device list and cache it
pub async fn refresh_devices(&self) -> Result<()> {
// Get current device if streaming (it may be busy and unable to be opened)
let current_device = if self.is_streaming().await {
Some(self.config.read().await.device.clone())
} else {
None
};
let devices = enumerate_audio_devices_with_current(current_device.as_deref())?;
*self.devices.write().await = devices;
Ok(())
}
/// Get cached device list
pub async fn get_cached_devices(&self) -> Vec<AudioDeviceInfo> {
self.devices.read().await.clone()
}
/// Select audio device
pub async fn select_device(&self, device: &str) -> Result<()> {
// Validate device exists
let devices = self.list_devices().await?;
let found = devices
.iter()
.any(|d| d.name == device || d.description.contains(device));
if !found && device != "default" {
if !found {
return Err(AppError::AudioError(format!(
"Audio device not found: {}",
device
)));
}
// Update config
{
let mut config = self.config.write().await;
config.device = device.to_string();
@@ -207,7 +463,6 @@ impl AudioController {
info!("Audio device selected: {}", device);
// If streaming, restart with new device
if self.is_streaming().await {
self.stop_streaming().await?;
self.start_streaming().await?;
@@ -216,15 +471,12 @@ impl AudioController {
Ok(())
}
/// Set audio quality
pub async fn set_quality(&self, quality: AudioQuality) -> Result<()> {
// Update config
{
let mut config = self.config.write().await;
config.quality = quality;
}
// Update streamer if running
if let Some(ref streamer) = *self.streamer.read().await {
streamer.set_bitrate(quality.bitrate()).await?;
}
@@ -237,76 +489,94 @@ impl AudioController {
Ok(())
}
/// Start audio streaming
pub async fn start_streaming(&self) -> Result<()> {
let config = self.config.read().await.clone();
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
{
let config = self.config.read().await;
if !config.enabled {
return Err(AppError::AudioError("Audio is disabled".to_string()));
}
}
// Check if already streaming
if self.is_streaming().await {
return Ok(());
}
info!("Starting audio streaming with device: {}", config.device);
// Clear any previous error
*self.last_error.write().await = None;
// Create streamer config (fixed 48kHz stereo)
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: config.device.clone(),
..Default::default()
},
opus: config.quality.to_opus_config(),
let mut select_error = None;
let (device_name, quality) = {
let mut cfg = self.config.write().await;
if cfg.device.trim().is_empty() {
match find_best_audio_device() {
Ok(best) => cfg.device = best.name,
Err(e) => {
select_error = Some(format!("Failed to select audio device: {}", e));
}
}
}
(cfg.device.clone(), cfg.quality)
};
if let Some(error_msg) = select_error {
self.monitor.report_error(&error_msg, "start_failed").await;
self.spawn_recovery_task("auto".to_string(), error_msg.clone());
self.mark_device_info_dirty().await;
return Err(AppError::AudioError(error_msg));
}
debug!("Starting audio streaming with device: {}", device_name);
self.monitor.prepare_retry_attempt();
let streamer_config = AudioStreamerConfig {
capture: AudioConfig {
device_name: device_name.clone(),
..Default::default()
},
opus: quality.to_opus_config(),
};
// Create and start streamer
let streamer = Arc::new(AudioStreamer::with_config(streamer_config));
if let Err(e) = streamer.start().await {
let error_msg = format!("Failed to start audio: {}", e);
*self.last_error.write().await = Some(error_msg.clone());
// Report error to health monitor
self.monitor
.report_error(Some(&config.device), &error_msg, "start_failed")
.await;
self.monitor.report_error(&error_msg, "start_failed").await;
self.spawn_recovery_task(device_name.clone(), error_msg.clone());
self.mark_device_info_dirty().await;
return Err(AppError::AudioError(error_msg));
}
let streamer_for_monitor = streamer.clone();
*self.streamer.write().await = Some(streamer);
self.spawn_stream_monitor(streamer_for_monitor, device_name.clone());
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered(Some(&config.device)).await;
self.monitor.report_recovered().await;
}
self.recovery_in_progress.store(false, Ordering::SeqCst);
self.mark_device_info_dirty().await;
info!("Audio streaming started");
Ok(())
}
/// Stop audio streaming
pub async fn stop_streaming(&self) -> Result<()> {
self.recovery_in_progress.store(false, Ordering::SeqCst);
if let Some(streamer) = self.streamer.write().await.take() {
streamer.stop().await?;
}
self.monitor.reset().await;
self.mark_device_info_dirty().await;
info!("Audio streaming stopped");
Ok(())
}
/// Check if currently streaming
pub async fn is_streaming(&self) -> bool {
if let Some(ref streamer) = *self.streamer.read().await {
streamer.is_running()
@@ -315,46 +585,37 @@ impl AudioController {
}
}
/// Get current status
pub async fn status(&self) -> AudioStatus {
let config = self.config.read().await;
let streaming = self.is_streaming().await;
let error = self.last_error.read().await.clone();
let (enabled, device_str, quality) = {
let c = self.config.read().await;
(c.enabled, c.device.clone(), c.quality)
};
let error = self.monitor.error_message().await;
let subscriber_count = if let Some(ref streamer) = *self.streamer.read().await {
streamer.stats().await.subscriber_count
let (streaming, subscriber_count) = if let Some(ref streamer) = *self.streamer.read().await
{
let streaming = streamer.is_running();
let subscriber_count = streamer.stats().subscriber_count;
(streaming, subscriber_count)
} else {
0
(false, 0)
};
AudioStatus {
enabled: config.enabled,
enabled,
streaming,
device: if streaming || config.enabled {
Some(config.device.clone())
device: if streaming || enabled {
Some(device_str)
} else {
None
},
quality: config.quality,
quality,
subscriber_count,
error,
}
}
/// Subscribe to Opus frames (for WebSocket clients)
pub fn subscribe_opus(&self) -> Option<tokio::sync::watch::Receiver<Option<Arc<OpusFrame>>>> {
// Use try_read to avoid blocking - this is called from sync context sometimes
if let Ok(guard) = self.streamer.try_read() {
guard.as_ref().map(|s| s.subscribe_opus())
} else {
None
}
}
/// Subscribe to Opus frames (async version)
pub async fn subscribe_opus_async(
&self,
) -> Option<tokio::sync::watch::Receiver<Option<Arc<OpusFrame>>>> {
pub async fn subscribe_opus(&self) -> Option<tokio::sync::mpsc::Receiver<Arc<OpusFrame>>> {
self.streamer
.read()
.await
@@ -362,7 +623,6 @@ impl AudioController {
.map(|s| s.subscribe_opus())
}
/// Enable or disable audio
pub async fn set_enabled(&self, enabled: bool) -> Result<()> {
{
let mut config = self.config.write().await;
@@ -377,21 +637,15 @@ impl AudioController {
Ok(())
}
/// Update full configuration
pub async fn update_config(&self, new_config: AudioControllerConfig) -> Result<()> {
let was_streaming = self.is_streaming().await;
// Stop streaming if running (device/quality/enabled may all change)
if was_streaming {
self.stop_streaming().await?;
}
// Update config
*self.config.write().await = new_config.clone();
// Start whenever audio is enabled — not only when we were already streaming.
// Otherwise PATCH /config/audio alone leaves enabled=true with no capture until
// POST /audio/start, which races WebRTC reconnect and matches "apply twice" reports.
if new_config.enabled {
self.start_streaming().await?;
}
@@ -399,25 +653,9 @@ impl AudioController {
Ok(())
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
self.stop_streaming().await
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<AudioHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> AudioHealthStatus {
self.monitor.status().await
}
/// Check if the audio is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Default for AudioController {
@@ -439,12 +677,23 @@ mod tests {
#[test]
fn test_audio_quality_from_str() {
assert_eq!(AudioQuality::from_str("voice"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("low"), AudioQuality::Voice);
assert_eq!(AudioQuality::from_str("balanced"), AudioQuality::Balanced);
assert_eq!(AudioQuality::from_str("high"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("music"), AudioQuality::High);
assert_eq!(AudioQuality::from_str("unknown"), AudioQuality::Balanced);
assert_eq!(
"voice".parse::<AudioQuality>().unwrap(),
AudioQuality::Voice
);
assert_eq!(
"balanced".parse::<AudioQuality>().unwrap(),
AudioQuality::Balanced
);
assert_eq!("high".parse::<AudioQuality>().unwrap(), AudioQuality::High);
}
#[test]
fn test_audio_quality_from_str_rejects_aliases_and_unknown() {
assert!("low".parse::<AudioQuality>().is_err());
assert!("music".parse::<AudioQuality>().is_err());
assert!("unknown".parse::<AudioQuality>().is_err());
assert!("".parse::<AudioQuality>().is_err());
}
#[tokio::test]

View File

@@ -1,5 +1,3 @@
//! Audio device enumeration using ALSA
use alsa::pcm::HwParams;
use alsa::{Direction, PCM};
use serde::Serialize;
@@ -7,54 +5,30 @@ use tracing::{debug, info, warn};
use crate::error::{AppError, Result};
/// Audio device information
#[derive(Debug, Clone, Serialize)]
pub struct AudioDeviceInfo {
/// Device name (e.g., "hw:0,0" or "default")
pub name: String,
/// Human-readable description
pub description: String,
/// Card index
pub card_index: i32,
/// Device index
pub device_index: i32,
/// Supported sample rates
pub sample_rates: Vec<u32>,
/// Supported channel counts
pub channels: Vec<u32>,
/// Is this a capture device
pub is_capture: bool,
/// Is this an HDMI audio device (likely from capture card)
pub is_hdmi: bool,
/// USB bus info for matching with video devices (e.g., "1-1" from USB path)
pub usb_bus: Option<String>,
}
impl AudioDeviceInfo {
/// Get ALSA device name
pub fn alsa_name(&self) -> String {
format!("hw:{},{}", self.card_index, self.device_index)
}
}
/// Get USB bus info for an audio card by reading sysfs
/// Returns the USB port path like "1-1" or "1-2.3"
fn get_usb_bus_info(card_index: i32) -> Option<String> {
if card_index < 0 {
return None;
}
// Read the device symlink: /sys/class/sound/cardX/device -> ../../usb1/1-1/1-1:1.0
let device_path = format!("/sys/class/sound/card{}/device", card_index);
let link_target = std::fs::read_link(&device_path).ok()?;
let link_str = link_target.to_string_lossy();
// Extract USB port from path like "../../usb1/1-1/1-1:1.0" or "../../1-1/1-1:1.0"
// We want the "1-1" part (USB bus-port)
for component in link_str.split('/') {
// Match patterns like "1-1", "1-2", "1-1.2", "2-1.3.1"
if component.contains('-') && !component.contains(':') {
// Verify it looks like a USB port (starts with digit)
if component
.chars()
.next()
@@ -69,22 +43,15 @@ fn get_usb_bus_info(card_index: i32) -> Option<String> {
None
}
/// Enumerate available audio capture devices
pub fn enumerate_audio_devices() -> Result<Vec<AudioDeviceInfo>> {
enumerate_audio_devices_with_current(None)
}
/// Enumerate available audio capture devices, with option to include a currently-in-use device
///
/// # Arguments
/// * `current_device` - Optional device name that is currently in use. This device will be
/// included in the list even if it cannot be opened (because it's already open by us).
pub fn enumerate_audio_devices_with_current(
current_device: Option<&str>,
) -> Result<Vec<AudioDeviceInfo>> {
let mut devices = Vec::new();
// Try to enumerate cards
let cards = alsa::card::Iter::new();
for card_result in cards {
@@ -102,104 +69,71 @@ pub fn enumerate_audio_devices_with_current(
debug!("Found audio card {}: {}", card_index, card_longname);
// Check if this looks like an HDMI capture device
let is_hdmi = card_longname.to_lowercase().contains("hdmi")
|| card_longname.to_lowercase().contains("capture")
|| card_longname.to_lowercase().contains("usb");
let long_lower = card_longname.to_lowercase();
let is_hdmi = long_lower.contains("hdmi")
|| long_lower.contains("capture")
|| long_lower.contains("usb");
// Get USB bus info for this card
let usb_bus = get_usb_bus_info(card_index);
// Try to open each device on this card for capture
for device_index in 0..8 {
let device_name = format!("hw:{},{}", card_index, device_index);
// Check if this is the currently-in-use device
let is_current_device = current_device == Some(device_name.as_str());
// Try to open for capture
let mut push_info =
|sample_rates: Vec<u32>, channels: Vec<u32>, description: String| {
devices.push(AudioDeviceInfo {
name: device_name.clone(),
description,
card_index,
device_index,
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
};
match PCM::new(&device_name, Direction::Capture, false) {
Ok(pcm) => {
// Query capabilities
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() && !channels.is_empty() {
devices.push(AudioDeviceInfo {
name: device_name,
description: format!("{} - Device {}", card_longname, device_index),
card_index,
device_index,
push_info(
sample_rates,
channels,
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
format!("{} - Device {}", card_longname, device_index),
);
}
}
Err(_) => {
// Device doesn't exist or can't be opened for capture
// But if it's the current device, include it anyway (it's busy because we're using it)
if is_current_device {
debug!(
"Device {} is busy (in use by us), adding with default caps",
device_name
);
devices.push(AudioDeviceInfo {
name: device_name,
description: format!(
"{} - Device {} (in use)",
card_longname, device_index
),
card_index,
device_index,
// Use common default capabilities for HDMI capture devices
sample_rates: vec![44100, 48000],
channels: vec![2],
is_capture: true,
is_hdmi,
usb_bus: usb_bus.clone(),
});
push_info(
vec![44100, 48000],
vec![2],
format!("{} - Device {} (in use)", card_longname, device_index),
);
}
continue;
}
}
}
}
// Also check for "default" device
if let Ok(pcm) = PCM::new("default", Direction::Capture, false) {
let (sample_rates, channels) = query_device_caps(&pcm);
if !sample_rates.is_empty() {
devices.insert(
0,
AudioDeviceInfo {
name: "default".to_string(),
description: "Default Audio Device".to_string(),
card_index: -1,
device_index: -1,
sample_rates,
channels,
is_capture: true,
is_hdmi: false,
usb_bus: None,
},
);
}
}
info!("Found {} audio capture devices", devices.len());
Ok(devices)
}
/// Query device capabilities
fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
let hwp = match HwParams::any(pcm) {
Ok(h) => h,
Err(_) => return (vec![], vec![]),
};
// Common sample rates to check
let common_rates = [8000, 16000, 22050, 44100, 48000, 96000];
let mut supported_rates = Vec::new();
@@ -209,7 +143,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
}
}
// Check channel counts
let mut supported_channels = Vec::new();
for ch in 1..=8 {
if hwp.test_channels(ch).is_ok() {
@@ -220,8 +153,6 @@ fn query_device_caps(pcm: &PCM) -> (Vec<u32>, Vec<u32>) {
(supported_rates, supported_channels)
}
/// Find the best audio device for capture
/// Prefers HDMI/capture devices over built-in microphones
pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
let devices = enumerate_audio_devices()?;
@@ -231,23 +162,24 @@ pub fn find_best_audio_device() -> Result<AudioDeviceInfo> {
));
}
// First, look for HDMI/capture card devices that support 48kHz stereo
let mut first_48k_stereo: Option<&AudioDeviceInfo> = None;
for device in &devices {
if device.is_hdmi && device.sample_rates.contains(&48000) && device.channels.contains(&2) {
if !device.sample_rates.contains(&48000) || !device.channels.contains(&2) {
continue;
}
if device.is_hdmi {
info!("Selected HDMI audio device: {}", device.description);
return Ok(device.clone());
}
}
// Then look for any device supporting 48kHz stereo
for device in &devices {
if device.sample_rates.contains(&48000) && device.channels.contains(&2) {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
if first_48k_stereo.is_none() {
first_48k_stereo = Some(device);
}
}
if let Some(device) = first_48k_stereo {
info!("Selected audio device: {}", device.description);
return Ok(device.clone());
}
// Fall back to first device
let device = devices.into_iter().next().unwrap();
warn!(
"Using fallback audio device: {} (may not support optimal settings)",
@@ -262,10 +194,8 @@ mod tests {
#[test]
fn test_enumerate_devices() {
// This test may not find devices in CI environment
let result = enumerate_audio_devices();
println!("Audio devices: {:?}", result);
// Just verify it doesn't panic
assert!(result.is_ok());
}
}

View File

@@ -1,26 +1,19 @@
//! Opus audio encoder for WebRTC
//! Opus encoder.
use audiopus::coder::GenericCtl;
use audiopus::{coder::Encoder, Application, Bitrate, Channels, SampleRate};
use bytes::Bytes;
use std::time::Instant;
use tracing::info;
use tracing::debug;
use super::capture::AudioFrame;
use crate::error::{AppError, Result};
/// Opus encoder configuration
#[derive(Debug, Clone)]
pub struct OpusConfig {
/// Sample rate (must be 8000, 12000, 16000, 24000, or 48000)
pub sample_rate: u32,
/// Channels (1 or 2)
pub channels: u32,
/// Target bitrate in bps
pub bitrate: u32,
/// Application mode
pub application: OpusApplication,
/// Enable forward error correction
pub fec: bool,
}
@@ -29,7 +22,7 @@ impl Default for OpusConfig {
Self {
sample_rate: 48000,
channels: 2,
bitrate: 64000, // 64 kbps
bitrate: 64000,
application: OpusApplication::Audio,
fec: true,
}
@@ -37,7 +30,6 @@ impl Default for OpusConfig {
}
impl OpusConfig {
/// Create config for voice (lower latency)
pub fn voice() -> Self {
Self {
application: OpusApplication::Voip,
@@ -46,7 +38,6 @@ impl OpusConfig {
}
}
/// Create config for music (higher quality)
pub fn music() -> Self {
Self {
application: OpusApplication::Audio,
@@ -82,30 +73,18 @@ impl OpusConfig {
}
}
/// Opus application mode
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpusApplication {
/// Voice over IP
Voip,
/// General audio
Audio,
/// Low delay mode
LowDelay,
}
/// Encoded Opus frame
#[derive(Debug, Clone)]
pub struct OpusFrame {
/// Encoded Opus data
pub data: Bytes,
/// Duration in milliseconds
pub duration_ms: u32,
/// Sequence number
pub sequence: u64,
/// Timestamp
pub timestamp: Instant,
/// RTP timestamp (samples)
pub rtp_timestamp: u32,
}
impl OpusFrame {
@@ -118,20 +97,14 @@ impl OpusFrame {
}
}
/// Opus encoder
pub struct OpusEncoder {
config: OpusConfig,
encoder: Encoder,
/// Output buffer
output_buffer: Vec<u8>,
/// Frame counter for RTP timestamp
frame_count: u64,
/// Samples per frame
samples_per_frame: u32,
}
impl OpusEncoder {
/// Create a new Opus encoder
pub fn new(config: OpusConfig) -> Result<Self> {
let sample_rate = config.to_audiopus_sample_rate();
let channels = config.to_audiopus_channels();
@@ -140,7 +113,6 @@ impl OpusEncoder {
let mut encoder = Encoder::new(sample_rate, channels, application)
.map_err(|e| AppError::AudioError(format!("Failed to create Opus encoder: {:?}", e)))?;
// Configure encoder
encoder
.set_bitrate(Bitrate::BitsPerSecond(config.bitrate as i32))
.map_err(|e| AppError::AudioError(format!("Failed to set bitrate: {:?}", e)))?;
@@ -151,10 +123,7 @@ impl OpusEncoder {
.map_err(|e| AppError::AudioError(format!("Failed to enable FEC: {:?}", e)))?;
}
// Calculate samples per frame (20ms at sample_rate)
let samples_per_frame = config.sample_rate / 50;
info!(
debug!(
"Opus encoder created: {}Hz {}ch {}bps",
config.sample_rate, config.channels, config.bitrate
);
@@ -162,18 +131,11 @@ impl OpusEncoder {
Ok(Self {
config,
encoder,
output_buffer: vec![0u8; 4000], // Max Opus frame size
output_buffer: vec![0u8; 4000],
frame_count: 0,
samples_per_frame,
})
}
/// Create with default configuration
pub fn default_config() -> Result<Self> {
Self::new(OpusConfig::default())
}
/// Encode PCM audio data (S16LE interleaved)
pub fn encode(&mut self, pcm_data: &[i16]) -> Result<OpusFrame> {
let encoded_len = self
.encoder
@@ -182,7 +144,6 @@ impl OpusEncoder {
let samples = pcm_data.len() as u32 / self.config.channels;
let duration_ms = (samples * 1000) / self.config.sample_rate;
let rtp_timestamp = (self.frame_count * self.samples_per_frame as u64) as u32;
self.frame_count += 1;
@@ -190,27 +151,18 @@ impl OpusEncoder {
data: Bytes::copy_from_slice(&self.output_buffer[..encoded_len]),
duration_ms,
sequence: self.frame_count - 1,
timestamp: Instant::now(),
rtp_timestamp,
})
}
/// Encode from AudioFrame
///
/// Uses zero-copy conversion from bytes to i16 samples via bytemuck.
pub fn encode_frame(&mut self, frame: &AudioFrame) -> Result<OpusFrame> {
// Zero-copy: directly cast bytes to i16 slice
// AudioFrame.data is S16LE format, which matches native little-endian i16
let samples: &[i16] = bytemuck::cast_slice(&frame.data);
self.encode(samples)
}
/// Get encoder configuration
pub fn config(&self) -> &OpusConfig {
&self.config
}
/// Reset encoder state
pub fn reset(&mut self) -> Result<()> {
self.encoder
.reset_state()
@@ -219,7 +171,6 @@ impl OpusEncoder {
Ok(())
}
/// Set bitrate dynamically
pub fn set_bitrate(&mut self, bitrate: u32) -> Result<()> {
self.encoder
.set_bitrate(Bitrate::BitsPerSecond(bitrate as i32))
@@ -228,15 +179,6 @@ impl OpusEncoder {
}
}
/// Audio encoder statistics
#[derive(Debug, Clone, Default)]
pub struct EncoderStats {
pub frames_encoded: u64,
pub bytes_output: u64,
pub avg_frame_size: usize,
pub current_bitrate: u32,
}
#[cfg(test)]
mod tests {
use super::*;
@@ -261,13 +203,12 @@ mod tests {
let config = OpusConfig::default();
let mut encoder = OpusEncoder::new(config).unwrap();
// 20ms of stereo silence at 48kHz
let silence = vec![0i16; 960 * 2];
let result = encoder.encode(&silence);
assert!(result.is_ok());
let frame = result.unwrap();
assert!(!frame.is_empty());
assert!(frame.len() < silence.len() * 2); // Should be compressed
assert!(frame.len() < silence.len() * 2);
}
}

View File

@@ -1,24 +1,15 @@
//! Audio capture and encoding module
//!
//! This module provides:
//! - ALSA audio capture
//! - Opus encoding for WebRTC
//! - Audio device enumeration
//! - Audio streaming pipeline
//! - High-level audio controller
//! - Device health monitoring
//! ALSA capture, Opus encode, device enumeration, streaming, controller, health monitor.
pub mod capture;
pub mod controller;
pub mod device;
pub mod encoder;
pub mod monitor;
pub mod resample;
pub mod streamer;
pub use capture::{AudioCapturer, AudioConfig, AudioFrame};
pub use controller::{AudioController, AudioControllerConfig, AudioQuality, AudioStatus};
pub use device::{enumerate_audio_devices, enumerate_audio_devices_with_current, AudioDeviceInfo};
pub use encoder::{OpusConfig, OpusEncoder, OpusFrame};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus, AudioMonitorConfig};
pub use monitor::{AudioHealthMonitor, AudioHealthStatus};
pub use streamer::{AudioStreamState, AudioStreamer, AudioStreamerConfig};

View File

@@ -1,114 +1,58 @@
//! Audio device health monitoring
//!
//! This module provides health monitoring for audio capture devices, including:
//! - Device connectivity checks
//! - Automatic reconnection on failure
//! - Error tracking
//! - Log throttling to prevent log flooding
//! Audio device health and logging throttle for repeated failures.
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::utils::LogThrottler;
/// Audio health status
const LOG_THROTTLE_SECS: u64 = 5;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum AudioHealthStatus {
/// Device is healthy and operational
#[default]
Healthy,
/// Device has an error, attempting recovery
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
/// Number of recovery attempts made
retry_count: u32,
},
/// Device is disconnected or not available
Disconnected,
}
/// Audio health monitor configuration
#[derive(Debug, Clone)]
pub struct AudioMonitorConfig {
/// Retry interval when device is lost (milliseconds)
pub retry_interval_ms: u64,
/// Maximum retry attempts before giving up (0 = infinite)
pub max_retries: u32,
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for AudioMonitorConfig {
fn default() -> Self {
Self {
retry_interval_ms: 1000,
max_retries: 0, // infinite retry
log_throttle_secs: 5,
}
}
}
/// Audio health monitor
///
/// Monitors audio device health and manages error recovery.
pub struct AudioHealthMonitor {
/// Current health status
status: RwLock<AudioHealthStatus>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Configuration
config: AudioMonitorConfig,
/// Current retry count
retry_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
/// Hide `error_message` while a new capture attempt is in flight (internal error state unchanged).
suppress_display: AtomicBool,
}
impl AudioHealthMonitor {
/// Create a new audio health monitor with the specified configuration
pub fn new(config: AudioMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
pub fn new() -> Self {
Self {
status: RwLock::new(AudioHealthStatus::Healthy),
throttler: LogThrottler::with_secs(throttle_secs),
config,
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
retry_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
suppress_display: AtomicBool::new(false),
}
}
/// Create a new audio health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(AudioMonitorConfig::default())
/// Clears the error string exposed via [`Self::error_message`] until the next outcome (`report_error` or recovery).
pub fn prepare_retry_attempt(&self) {
self.suppress_display.store(true, Ordering::Relaxed);
}
/// Report an error from audio operations
///
/// This method is called when an audio operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Updates in-memory error state
///
/// # Arguments
///
/// * `device` - The audio device name (if known)
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, _device: Option<&str>, reason: &str, error_code: &str) {
pub async fn report_error(&self, reason: &str, error_code: &str) {
self.suppress_display.store(false, Ordering::Relaxed);
let count = self.retry_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("audio_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
@@ -117,34 +61,22 @@ impl AudioHealthMonitor {
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = AudioHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
retry_count: count,
};
}
/// Report that the device has recovered
///
/// This method is called when the audio device successfully reconnects.
/// It resets the error state.
///
/// # Arguments
///
/// * `device` - The audio device name
pub async fn report_recovered(&self, _device: Option<&str>) {
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != AudioHealthStatus::Healthy {
let retry_count = self.retry_count.load(Ordering::Relaxed);
info!("Audio recovered after {} retries", retry_count);
// Reset state
self.suppress_display.store(false, Ordering::Relaxed);
self.retry_count.store(0, Ordering::Relaxed);
self.throttler.clear("audio_");
*self.last_error_code.write().await = None;
@@ -152,58 +84,30 @@ impl AudioHealthMonitor {
}
}
/// Get the current health status
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Get the current retry count
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.suppress_display.store(false, Ordering::Relaxed);
self.retry_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
*self.status.write().await = AudioHealthStatus::Healthy;
self.throttler.clear_all();
}
/// Get the configuration
pub fn config(&self) -> &AudioMonitorConfig {
&self.config
pub async fn status(&self) -> AudioHealthStatus {
self.status.read().await.clone()
}
/// Check if we should continue retrying
///
/// Returns `false` if max_retries is set and we've exceeded it.
pub fn should_retry(&self) -> bool {
if self.config.max_retries == 0 {
return true; // Infinite retry
}
self.retry_count.load(Ordering::Relaxed) < self.config.max_retries
pub fn retry_count(&self) -> u32 {
self.retry_count.load(Ordering::Relaxed)
}
/// Get the retry interval
pub fn retry_interval(&self) -> Duration {
Duration::from_millis(self.config.retry_interval_ms)
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, AudioHealthStatus::Error { .. })
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
if self.suppress_display.load(Ordering::Relaxed) {
return None;
}
match &*self.status.read().await {
AudioHealthStatus::Error { reason, .. } => Some(reason.clone()),
_ => None,
@@ -213,7 +117,7 @@ impl AudioHealthMonitor {
impl Default for AudioHealthMonitor {
fn default() -> Self {
Self::with_defaults()
Self::new()
}
}
@@ -223,32 +127,25 @@ mod tests {
#[tokio::test]
async fn test_initial_status() {
let monitor = AudioHealthMonitor::with_defaults();
assert!(monitor.is_healthy().await);
let monitor = AudioHealthMonitor::new();
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_report_error() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
monitor
.report_error(Some("hw:0,0"), "Device not found", "device_disconnected")
.report_error("Device not found", "device_disconnected")
.await;
assert!(monitor.is_error().await);
assert_eq!(monitor.retry_count(), 1);
if let AudioHealthStatus::Error {
reason,
error_code,
retry_count,
} = monitor.status().await
{
if let AudioHealthStatus::Error { reason, error_code } = monitor.status().await {
assert_eq!(reason, "Device not found");
assert_eq!(error_code, "device_disconnected");
assert_eq!(retry_count, 1);
} else {
panic!("Expected Error status");
}
@@ -256,39 +153,52 @@ mod tests {
#[tokio::test]
async fn test_report_recovered() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
// First report an error
monitor
.report_error(Some("default"), "Capture failed", "capture_error")
.report_error("Capture failed", "capture_error")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered(Some("default")).await;
assert!(monitor.is_healthy().await);
monitor.report_recovered().await;
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_retry_count_increments() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
for i in 1..=5 {
monitor.report_error(None, "Error", "io_error").await;
monitor.report_error("Error", "io_error").await;
assert_eq!(monitor.retry_count(), i);
}
}
#[tokio::test]
async fn test_reset() {
let monitor = AudioHealthMonitor::with_defaults();
let monitor = AudioHealthMonitor::new();
monitor.report_error(None, "Error", "io_error").await;
monitor.report_error("Error", "io_error").await;
assert!(monitor.is_error().await);
monitor.reset().await;
assert!(monitor.is_healthy().await);
assert!(!monitor.is_error().await);
assert_eq!(monitor.retry_count(), 0);
}
#[tokio::test]
async fn test_prepare_retry_hides_error_until_next_failure() {
let monitor = AudioHealthMonitor::new();
monitor.report_error("bad", "e").await;
assert_eq!(monitor.error_message().await.as_deref(), Some("bad"));
monitor.prepare_retry_attempt();
assert!(monitor.is_error().await);
assert!(monitor.error_message().await.is_none());
monitor.report_error("still bad", "e").await;
assert_eq!(monitor.error_message().await.as_deref(), Some("still bad"));
}
}

View File

@@ -1,202 +0,0 @@
//! Resample capture PCM to 48 kHz stereo for Opus (fixed 20 ms / 960×2 samples).
const OUT_RATE: f64 = 48000.0;
const OPUS_STEREO_SAMPLES: usize = 960 * 2;
enum PipelineState {
/// Native 48 kHz interleaved stereo: only buffer and slice into 20 ms blocks (no float work).
Stereo48kPassthrough,
/// Other rates / mono: linear interpolation to 48 kHz stereo.
Resample {
in_rate: u32,
in_channels: u32,
next_out_frame: u64,
buffer_start_frame: u64,
},
}
/// Converts incoming interleaved PCM to 48 kHz stereo, then exposes fixed 960×2-sample chunks.
pub struct Opus48kPcmBuffer {
state: PipelineState,
pending: Vec<i16>,
}
impl Opus48kPcmBuffer {
pub fn new(in_rate: u32, in_channels: u32) -> Self {
let ch = in_channels.max(1);
let rate = in_rate.max(1);
let state = if rate == 48000 && ch == 2 {
PipelineState::Stereo48kPassthrough
} else {
PipelineState::Resample {
in_rate: rate,
in_channels: ch,
next_out_frame: 0,
buffer_start_frame: 0,
}
};
Self {
state,
pending: Vec::new(),
}
}
/// True when input is already 48 kHz stereo (no interpolation loop).
#[cfg(test)]
pub fn is_passthrough(&self) -> bool {
matches!(self.state, PipelineState::Stereo48kPassthrough)
}
/// Append one capture block (`sample_rate` must match the rate this buffer was built for).
pub fn push_interleaved(&mut self, data: &[i16]) {
self.pending.extend_from_slice(data);
}
/// Drain as many 960×2 stereo S16LE samples (20 ms @ 48 kHz) as possible.
pub fn pop_opus_frames(&mut self, out: &mut Vec<i16>) {
match &mut self.state {
PipelineState::Stereo48kPassthrough => {
while self.pending.len() >= OPUS_STEREO_SAMPLES {
out.extend_from_slice(&self.pending[..OPUS_STEREO_SAMPLES]);
self.pending.drain(..OPUS_STEREO_SAMPLES);
}
}
PipelineState::Resample {
in_rate,
in_channels,
next_out_frame,
buffer_start_frame,
} => {
let ch = *in_channels as usize;
if ch == 0 {
return;
}
loop {
let batch_start = *next_out_frame;
let mut block = Vec::with_capacity(OPUS_STEREO_SAMPLES);
let mut complete = true;
for i in 0u64..960 {
let k = batch_start + i;
let p_abs = (k as f64) * (*in_rate as f64) / OUT_RATE;
let f_abs = p_abs.floor() as u64;
let frac = p_abs - f_abs as f64;
let f_rel = f_abs.saturating_sub(*buffer_start_frame) as usize;
if f_rel + 1 >= self.pending.len() / ch {
complete = false;
break;
}
let base0 = f_rel * ch;
let base1 = (f_rel + 1) * ch;
let (l, r) = if *in_channels >= 2 {
let l0 = self.pending[base0] as f64;
let l1 = self.pending[base1] as f64;
let r0 = self.pending[base0 + 1] as f64;
let r1 = self.pending[base1 + 1] as f64;
(l0 + frac * (l1 - l0), r0 + frac * (r1 - r0))
} else {
let m0 = self.pending[base0] as f64;
let m1 = self.pending[base1] as f64;
let v = m0 + frac * (m1 - m0);
(v, v)
};
block.push(clamp_f64_to_i16(l));
block.push(clamp_f64_to_i16(r));
}
if !complete || block.len() != OPUS_STEREO_SAMPLES {
break;
}
out.extend_from_slice(&block);
*next_out_frame = batch_start + 960;
trim_resample_prefix(
&mut self.pending,
*in_rate,
*next_out_frame,
buffer_start_frame,
ch,
);
}
}
}
}
}
fn trim_resample_prefix(
pending: &mut Vec<i16>,
in_rate: u32,
next_out_frame: u64,
buffer_start_frame: &mut u64,
ch: usize,
) {
if pending.is_empty() {
return;
}
let p_next = (next_out_frame as f64) * (in_rate as f64) / OUT_RATE;
let need_abs = p_next.floor() as u64;
let keep_from_abs = need_abs.saturating_sub(1);
if keep_from_abs <= *buffer_start_frame {
return;
}
let drop_frames = (keep_from_abs - *buffer_start_frame) as usize;
let drop_samples = drop_frames.saturating_mul(ch).min(pending.len());
if drop_samples > 0 {
pending.drain(0..drop_samples);
*buffer_start_frame += drop_frames as u64;
}
}
#[inline]
fn clamp_f64_to_i16(v: f64) -> i16 {
v.round().clamp(i16::MIN as f64, i16::MAX as f64) as i16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn passthrough_48k_identity_tone_length() {
let mut buf = Opus48kPcmBuffer::new(48000, 2);
assert!(buf.is_passthrough());
let mut chunk = vec![0i16; 960 * 2];
for i in 0..960 {
let s = (i as f32 * 0.1).sin() * 3000.0;
chunk[2 * i] = s as i16;
chunk[2 * i + 1] = s as i16;
}
buf.push_interleaved(&chunk);
let mut out = Vec::new();
buf.pop_opus_frames(&mut out);
assert_eq!(out.len(), 960 * 2);
}
#[test]
fn upsample_44k_to_48k_chunk() {
let mut buf = Opus48kPcmBuffer::new(44100, 2);
assert!(!buf.is_passthrough());
let mut chunk = vec![0i16; 882 * 2];
for i in 0..882 {
chunk[2 * i] = (i as i16).wrapping_mul(10);
chunk[2 * i + 1] = (i as i16).wrapping_mul(-7);
}
buf.push_interleaved(&chunk);
let mut out = Vec::new();
buf.pop_opus_frames(&mut out);
assert_eq!(out.len(), 960 * 2, "expected one 20ms Opus block");
}
#[test]
fn mono_48k_not_passthrough() {
let buf = Opus48kPcmBuffer::new(48000, 1);
assert!(!buf.is_passthrough());
}
}

View File

@@ -1,46 +1,36 @@
//! Audio streaming pipeline
//!
//! Coordinates audio capture and Opus encoding, distributing encoded
//! frames to multiple subscribers via broadcast channel.
//! ALSA 48 kHz stereo → Opus 20 ms frames, fan-out per subscriber.
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, watch, Mutex, RwLock};
use tracing::{error, info, warn};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::{broadcast, mpsc, watch, Mutex as AsyncMutex, RwLock};
use tracing::{debug, error, info, warn};
use super::capture::{AudioCapturer, AudioConfig, AudioFrame, CaptureState};
use super::encoder::{OpusConfig, OpusEncoder, OpusFrame};
use super::resample::Opus48kPcmBuffer;
use crate::error::{AppError, Result};
use bytemuck;
use bytes::Bytes;
use std::time::Duration;
/// 48 kHz stereo: 20 ms = 960 × 2 samples (S16LE).
const OPUS_STEREO_SAMPLES: usize = 960 * 2;
/// Audio stream state
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AudioStreamState {
/// Stream is stopped
#[default]
Stopped,
/// Stream is starting up
Starting,
/// Stream is running
Running,
/// Stream encountered an error
Error,
}
/// Audio streamer configuration
#[derive(Debug, Clone, Default)]
pub struct AudioStreamerConfig {
/// Audio capture configuration
pub capture: AudioConfig,
/// Opus encoder configuration
pub opus: OpusConfig,
}
impl AudioStreamerConfig {
/// Create config for a specific device with default quality
pub fn for_device(device_name: &str) -> Self {
Self {
capture: AudioConfig {
@@ -51,90 +41,75 @@ impl AudioStreamerConfig {
}
}
/// Create config with specified bitrate
pub fn with_bitrate(mut self, bitrate: u32) -> Self {
self.opus.bitrate = bitrate;
self
}
}
/// Audio stream statistics
#[derive(Debug, Clone, Default)]
pub struct AudioStreamStats {
/// Frames encoded to Opus
/// Number of active subscribers
pub subscriber_count: usize,
}
/// Audio streamer
///
/// Manages the audio capture -> encode -> broadcast pipeline.
pub struct AudioStreamer {
config: RwLock<AudioStreamerConfig>,
state: watch::Sender<AudioStreamState>,
state_rx: watch::Receiver<AudioStreamState>,
capturer: RwLock<Option<Arc<AudioCapturer>>>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: watch::Sender<Option<Arc<OpusFrame>>>,
stats: Arc<Mutex<AudioStreamStats>>,
sequence: AtomicU64,
stream_start_time: RwLock<Option<Instant>>,
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
stop_flag: Arc<AtomicBool>,
}
impl AudioStreamer {
/// Create a new audio streamer with default configuration
pub fn new() -> Self {
Self::with_config(AudioStreamerConfig::default())
}
/// Create a new audio streamer with specified configuration
pub fn with_config(config: AudioStreamerConfig) -> Self {
let (state_tx, state_rx) = watch::channel(AudioStreamState::Stopped);
let (opus_tx, _opus_rx) = watch::channel(None);
Self {
config: RwLock::new(config),
state: state_tx,
state_rx,
capturer: RwLock::new(None),
encoder: Arc::new(Mutex::new(None)),
opus_tx,
stats: Arc::new(Mutex::new(AudioStreamStats::default())),
sequence: AtomicU64::new(0),
stream_start_time: RwLock::new(None),
encoder: Arc::new(AsyncMutex::new(None)),
opus_subscribers: Arc::new(Mutex::new(Vec::new())),
stop_flag: Arc::new(AtomicBool::new(false)),
}
}
/// Get current state
pub fn state(&self) -> AudioStreamState {
*self.state_rx.borrow()
}
/// Subscribe to state changes
pub fn state_watch(&self) -> watch::Receiver<AudioStreamState> {
self.state_rx.clone()
}
/// Subscribe to Opus frames
pub fn subscribe_opus(&self) -> watch::Receiver<Option<Arc<OpusFrame>>> {
self.opus_tx.subscribe()
pub fn subscribe_opus(&self) -> mpsc::Receiver<Arc<OpusFrame>> {
let (tx, rx) = mpsc::channel::<Arc<OpusFrame>>(128);
self.opus_subscribers.lock().unwrap().push(tx);
rx
}
/// Get number of active subscribers
pub fn subscriber_count(&self) -> usize {
self.opus_tx.receiver_count()
self.opus_subscribers
.lock()
.unwrap()
.iter()
.filter(|s| !s.is_closed())
.count()
}
/// Get current statistics
pub async fn stats(&self) -> AudioStreamStats {
let mut stats = self.stats.lock().await.clone();
stats.subscriber_count = self.subscriber_count();
stats
pub fn stats(&self) -> AudioStreamStats {
AudioStreamStats {
subscriber_count: self.subscriber_count(),
}
}
/// Update configuration (only when stopped)
pub async fn set_config(&self, config: AudioStreamerConfig) -> Result<()> {
if self.state() != AudioStreamState::Stopped {
return Err(AppError::AudioError(
@@ -145,12 +120,9 @@ impl AudioStreamer {
Ok(())
}
/// Update bitrate dynamically (can be done while streaming)
pub async fn set_bitrate(&self, bitrate: u32) -> Result<()> {
// Update config
self.config.write().await.opus.bitrate = bitrate;
// Update encoder if running
if let Some(ref mut encoder) = *self.encoder.lock().await {
encoder.set_bitrate(bitrate)?;
}
@@ -159,7 +131,6 @@ impl AudioStreamer {
Ok(())
}
/// Start the audio stream
pub async fn start(&self) -> Result<()> {
if self.state() == AudioStreamState::Running {
return Ok(());
@@ -178,42 +149,77 @@ impl AudioStreamer {
config.opus.bitrate
);
// Create capturer
let capturer = Arc::new(AudioCapturer::new(config.capture.clone()));
*self.capturer.write().await = Some(capturer.clone());
// Create encoder
let encoder = OpusEncoder::new(config.opus.clone())?;
*self.encoder.lock().await = Some(encoder);
// Start capture
capturer.start().await?;
// Reset stats
{
let mut stats = self.stats.lock().await;
*stats = AudioStreamStats::default();
let mut capture_state = capturer.state_watch();
let startup_result = tokio::time::timeout(Duration::from_secs(2), async {
loop {
let current_state = *capture_state.borrow();
match current_state {
CaptureState::Running => return Ok(()),
CaptureState::Error => {
return Err(AppError::AudioError(
"Audio capture failed to start".to_string(),
))
}
CaptureState::Stopped => {
if capture_state.changed().await.is_err() {
return Err(AppError::AudioError(
"Audio capture stopped during startup".to_string(),
));
}
}
}
}
})
.await;
match startup_result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
let _ = capturer.stop().await;
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
let _ = self.state.send(AudioStreamState::Error);
return Err(e);
}
Err(_) => {
let _ = capturer.stop().await;
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
let _ = self.state.send(AudioStreamState::Error);
return Err(AppError::AudioError(
"Timed out waiting for audio capture to start".to_string(),
));
}
}
// Record start time
*self.stream_start_time.write().await = Some(Instant::now());
self.sequence.store(0, Ordering::SeqCst);
// Start encoding task
let capturer_for_task = capturer.clone();
let encoder = self.encoder.clone();
let opus_tx = self.opus_tx.clone();
let opus_subscribers = self.opus_subscribers.clone();
let state = self.state.clone();
let stop_flag = self.stop_flag.clone();
tokio::spawn(async move {
Self::stream_task(capturer_for_task, encoder, opus_tx, state, stop_flag).await;
Self::stream_task(
capturer_for_task,
encoder,
opus_subscribers,
state,
stop_flag,
)
.await;
});
Ok(())
}
/// Stop the audio stream
pub async fn stop(&self) -> Result<()> {
if self.state() == AudioStreamState::Stopped {
return Ok(());
@@ -221,74 +227,82 @@ impl AudioStreamer {
info!("Stopping audio stream");
// Signal stop
self.stop_flag.store(true, Ordering::SeqCst);
// Stop capturer
if let Some(ref capturer) = *self.capturer.read().await {
capturer.stop().await?;
}
// Clear resources
*self.capturer.write().await = None;
*self.encoder.lock().await = None;
*self.stream_start_time.write().await = None;
self.opus_subscribers.lock().unwrap().clear();
let _ = self.state.send(AudioStreamState::Stopped);
info!("Audio stream stopped");
Ok(())
}
/// Check if streaming
pub fn is_running(&self) -> bool {
self.state() == AudioStreamState::Running
}
/// Internal streaming task
async fn fanout_opus(
subscribers: &Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
frame: Arc<OpusFrame>,
) {
let txs: Vec<_> = {
let g = subscribers.lock().unwrap();
if g.is_empty() {
return;
}
g.clone()
};
for tx in &txs {
let _ = tx.send(frame.clone()).await;
}
if txs.iter().any(|tx| tx.is_closed()) {
let mut g = subscribers.lock().unwrap();
g.retain(|tx| !tx.is_closed());
}
}
async fn stream_task(
capturer: Arc<AudioCapturer>,
encoder: Arc<Mutex<Option<OpusEncoder>>>,
opus_tx: watch::Sender<Option<Arc<OpusFrame>>>,
encoder: Arc<AsyncMutex<Option<OpusEncoder>>>,
opus_subscribers: Arc<Mutex<Vec<mpsc::Sender<Arc<OpusFrame>>>>>,
state: watch::Sender<AudioStreamState>,
stop_flag: Arc<AtomicBool>,
) {
let mut pcm_rx = capturer.subscribe();
let _ = state.send(AudioStreamState::Running);
info!("Audio stream task started");
debug!("Audio stream task started (48 kHz stereo → Opus, mpsc fan-out)");
let mut to_48k: Option<Opus48kPcmBuffer> = None;
let mut queued_48k: Vec<i16> = Vec::new();
let mut pending: Vec<i16> = Vec::new();
loop {
// Check stop flag (atomic, no async lock needed)
if stop_flag.load(Ordering::Relaxed) {
break;
}
// Check capturer state
if capturer.state() == CaptureState::Error {
error!("Audio capture error, stopping stream");
let _ = state.send(AudioStreamState::Error);
break;
}
// Receive PCM frame with timeout
let recv_result =
tokio::time::timeout(std::time::Duration::from_secs(2), pcm_rx.recv()).await;
match recv_result {
Ok(Ok(audio_frame)) => {
if to_48k.is_none() {
to_48k = Some(Opus48kPcmBuffer::new(
audio_frame.sample_rate,
audio_frame.channels,
));
if audio_frame.sample_rate != 48_000 || audio_frame.channels != 2 {
warn!(
"Skip non48 kHz/stereo PCM ({} Hz, {} ch)",
audio_frame.sample_rate, audio_frame.channels
);
continue;
}
let pipeline = match to_48k.as_mut() {
Some(p) => p,
None => continue,
};
let samples: &[i16] = match bytemuck::try_cast_slice(&audio_frame.data) {
Ok(s) => s,
@@ -298,16 +312,16 @@ impl AudioStreamer {
}
};
if !samples.is_empty() {
pipeline.push_interleaved(samples);
pending.extend_from_slice(samples);
}
pipeline.pop_opus_frames(&mut queued_48k);
while queued_48k.len() >= 960 * 2 {
let pcm_20ms =
Bytes::copy_from_slice(bytemuck::cast_slice(&queued_48k[..960 * 2]));
queued_48k.drain(..960 * 2);
while pending.len() >= OPUS_STEREO_SAMPLES {
let pcm_20ms = Bytes::copy_from_slice(bytemuck::cast_slice(
&pending[..OPUS_STEREO_SAMPLES],
));
pending.drain(..OPUS_STEREO_SAMPLES);
let frame_48k = AudioFrame::new_interleaved(pcm_20ms, 2, 48000, 0);
let frame_48k = AudioFrame::new_interleaved(pcm_20ms, 2, 48_000, 0);
let opus_result = {
let mut enc_guard = encoder.lock().await;
@@ -318,9 +332,7 @@ impl AudioStreamer {
match opus_result {
Some(Ok(opus_frame)) => {
if opus_tx.receiver_count() > 0 {
let _ = opus_tx.send(Some(Arc::new(opus_frame)));
}
Self::fanout_opus(&opus_subscribers, Arc::new(opus_frame)).await;
}
Some(Err(e)) => {
error!("Opus encode error: {}", e);
@@ -337,19 +349,23 @@ impl AudioStreamer {
break;
}
Ok(Err(broadcast::error::RecvError::Lagged(n))) => {
warn!("Audio receiver lagged by {} frames", n);
warn!("PCM receiver lagged by {} frames", n);
}
Err(_) => {
// Timeout - check if still capturing
if capturer.state() != CaptureState::Running {
info!("Audio capture stopped, ending stream task");
let _ = state.send(AudioStreamState::Error);
break;
}
}
}
}
let _ = state.send(AudioStreamState::Stopped);
if stop_flag.load(Ordering::Relaxed) {
let _ = state.send(AudioStreamState::Stopped);
} else {
opus_subscribers.lock().unwrap().clear();
}
info!("Audio stream task ended");
}
}

View File

@@ -8,20 +8,16 @@ use axum::{
use axum_extra::extract::CookieJar;
use std::sync::Arc;
use crate::error::ErrorResponse;
use crate::state::AppState;
use crate::web::ErrorResponse;
/// Session cookie name
pub const SESSION_COOKIE: &str = "one_kvm_session";
/// Extract session ID from request
pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap) -> Option<String> {
// First try cookie
if let Some(cookie) = cookies.get(SESSION_COOKIE) {
return Some(cookie.value().to_string());
}
// Then try Authorization header (Bearer token)
if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
@@ -33,7 +29,6 @@ pub fn extract_session_id(cookies: &CookieJar, headers: &axum::http::HeaderMap)
None
}
/// Authentication middleware
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
cookies: CookieJar,
@@ -41,29 +36,23 @@ pub async fn auth_middleware(
next: Next,
) -> Result<Response, StatusCode> {
let raw_path = request.uri().path();
// When this middleware is mounted under /api, Axum strips the prefix for the inner router.
// Normalize the path so checks work whether it is mounted or not.
// Mounted under /api: inner path may lack prefix; normalize for whitelist checks.
let path = raw_path.strip_prefix("/api").unwrap_or(raw_path);
// Check if system is initialized
if !state.config.is_initialized() {
// Allow only setup-related endpoints when not initialized
if is_setup_public_endpoint(path) {
return Ok(next.run(request).await);
}
}
// Public endpoints that don't require auth
if is_public_endpoint(path) {
return Ok(next.run(request).await);
}
// Extract session ID
let session_id = extract_session_id(&cookies, request.headers());
if let Some(session_id) = session_id {
if let Ok(Some(session)) = state.sessions.get(&session_id).await {
// Add session to request extensions
request.extensions_mut().insert(session);
return Ok(next.run(request).await);
}
@@ -87,9 +76,7 @@ fn unauthorized_response(message: &str) -> Response {
(StatusCode::UNAUTHORIZED, Json(body)).into_response()
}
/// Check if endpoint is public (no auth required)
fn is_public_endpoint(path: &str) -> bool {
// Note: paths here are relative to /api since middleware is applied within the nested router
matches!(
path,
"/" | "/auth/login" | "/health" | "/setup" | "/setup/init"
@@ -102,7 +89,6 @@ fn is_public_endpoint(path: &str) -> bool {
|| path.ends_with(".svg")
}
/// Setup-only endpoints allowed before initialization.
fn is_setup_public_endpoint(path: &str) -> bool {
matches!(
path,

View File

@@ -5,7 +5,6 @@ use argon2::{
use crate::error::{AppError, Result};
/// Hash a password using Argon2
pub fn hash_password(password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
@@ -16,7 +15,6 @@ pub fn hash_password(password: &str) -> Result<String> {
.map_err(|e| AppError::Internal(format!("Password hashing failed: {}", e)))
}
/// Verify a password against a hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| AppError::Internal(format!("Invalid password hash: {}", e)))?;

View File

@@ -1,156 +1,104 @@
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use std::collections::HashMap;
use std::sync::Arc;
use time::{Duration, OffsetDateTime};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::error::Result;
/// Session data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user_id: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub expires_at: OffsetDateTime,
pub data: Option<serde_json::Value>,
}
impl Session {
/// Check if session is expired
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
OffsetDateTime::now_utc() > self.expires_at
}
}
/// Session store backed by SQLite
#[derive(Clone)]
pub struct SessionStore {
pool: Pool<Sqlite>,
inner: Arc<RwLock<HashMap<String, Session>>>,
default_ttl: Duration,
}
impl SessionStore {
/// Create a new session store
pub fn new(pool: Pool<Sqlite>, ttl_secs: i64) -> Self {
pub fn new(ttl_secs: i64) -> Self {
Self {
pool,
inner: Arc::new(RwLock::new(HashMap::new())),
default_ttl: Duration::seconds(ttl_secs),
}
}
/// Create a new session
pub async fn create(&self, user_id: &str) -> Result<Session> {
let now = OffsetDateTime::now_utc();
let session = Session {
id: Uuid::new_v4().to_string(),
user_id: user_id.to_string(),
created_at: Utc::now(),
expires_at: Utc::now() + self.default_ttl,
created_at: now,
expires_at: now + self.default_ttl,
data: None,
};
sqlx::query(
r#"
INSERT INTO sessions (id, user_id, created_at, expires_at, data)
VALUES (?1, ?2, ?3, ?4, ?5)
"#,
)
.bind(&session.id)
.bind(&session.user_id)
.bind(session.created_at.to_rfc3339())
.bind(session.expires_at.to_rfc3339())
.bind(session.data.as_ref().map(|d| d.to_string()))
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
guard.insert(session.id.clone(), session.clone());
Ok(session)
}
/// Get a session by ID
pub async fn get(&self, session_id: &str) -> Result<Option<Session>> {
let row: Option<(String, String, String, String, Option<String>)> = sqlx::query_as(
"SELECT id, user_id, created_at, expires_at, data FROM sessions WHERE id = ?1",
)
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
match row {
Some((id, user_id, created_at, expires_at, data)) => {
let session = Session {
id,
user_id,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
expires_at: DateTime::parse_from_rfc3339(&expires_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
data: data.and_then(|d| serde_json::from_str(&d).ok()),
};
if session.is_expired() {
self.delete(&session.id).await?;
Ok(None)
} else {
Ok(Some(session))
}
}
None => Ok(None),
let mut guard = self.inner.write().await;
let Some(session) = guard.get(session_id).cloned() else {
return Ok(None);
};
if session.is_expired() {
guard.remove(session_id);
return Ok(None);
}
Ok(Some(session))
}
/// Delete a session
pub async fn delete(&self, session_id: &str) -> Result<()> {
sqlx::query("DELETE FROM sessions WHERE id = ?1")
.bind(session_id)
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
guard.remove(session_id);
Ok(())
}
/// Delete all expired sessions
pub async fn cleanup_expired(&self) -> Result<u64> {
let now = Utc::now().to_rfc3339();
let result = sqlx::query("DELETE FROM sessions WHERE expires_at < ?1")
.bind(now)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
let mut guard = self.inner.write().await;
let before = guard.len();
guard.retain(|_, s| !s.is_expired());
Ok((before - guard.len()) as u64)
}
/// Delete all sessions
pub async fn delete_all(&self) -> Result<u64> {
let result = sqlx::query("DELETE FROM sessions")
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
let mut guard = self.inner.write().await;
let n = guard.len() as u64;
guard.clear();
Ok(n)
}
/// Delete all sessions for a specific user
pub async fn delete_by_user_id(&self, user_id: &str) -> Result<u64> {
let result = sqlx::query("DELETE FROM sessions WHERE user_id = ?1")
.bind(user_id)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
/// List all session IDs
pub async fn list_ids(&self) -> Result<Vec<String>> {
let rows: Vec<(String,)> = sqlx::query_as("SELECT id FROM sessions")
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(|(id,)| id).collect())
let guard = self.inner.read().await;
Ok(guard.keys().cloned().collect())
}
/// Extend session expiration
pub async fn extend(&self, session_id: &str) -> Result<()> {
let new_expires = Utc::now() + self.default_ttl;
sqlx::query("UPDATE sessions SET expires_at = ?1 WHERE id = ?2")
.bind(new_expires.to_rfc3339())
.bind(session_id)
.execute(&self.pool)
.await?;
let mut guard = self.inner.write().await;
if let Some(session) = guard.get_mut(session_id) {
if session.is_expired() {
guard.remove(session_id);
} else {
session.expires_at = OffsetDateTime::now_utc() + self.default_ttl;
}
}
Ok(())
}
}

View File

@@ -1,123 +1,99 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Sqlite};
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;
use uuid::Uuid;
use super::password::{hash_password, verify_password};
use crate::error::{AppError, Result};
/// User row type from database
type UserRow = (String, String, String, String, String);
type UserRow = (String, String, String);
/// User data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
#[serde(skip_serializing)]
pub password_hash: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl User {
/// Convert from database row to User
fn from_row(row: UserRow) -> Self {
let (id, username, password_hash, created_at, updated_at) = row;
let (id, username, password_hash) = row;
Self {
id,
username,
password_hash,
created_at: DateTime::parse_from_rfc3339(&created_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
updated_at: DateTime::parse_from_rfc3339(&updated_at)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
}
}
}
/// User store backed by SQLite
#[derive(Clone)]
pub struct UserStore {
pool: Pool<Sqlite>,
}
impl UserStore {
/// Create a new user store
pub fn new(pool: Pool<Sqlite>) -> Self {
Self { pool }
}
/// Create a new user
pub async fn create(&self, username: &str, password: &str) -> Result<User> {
// Check if username already exists
if self.get_by_username(username).await?.is_some() {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
username
)));
/// The single local user, or `None` if none exists. Errors if more than one row is present.
pub async fn single_user(&self) -> Result<Option<User>> {
let mut rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash FROM users ORDER BY rowid ASC LIMIT 2",
)
.fetch_all(&self.pool)
.await?;
match rows.len() {
0 => Ok(None),
1 => Ok(Some(User::from_row(rows.remove(0)))),
_ => Err(AppError::Internal(
"Multiple user accounts in database; this build supports only one".to_string(),
)),
}
}
pub async fn create_first_user(&self, username: &str, password: &str) -> Result<User> {
if self.single_user().await?.is_some() {
return Err(AppError::BadRequest(
"A user account already exists".to_string(),
));
}
let password_hash = hash_password(password)?;
let now = Utc::now();
let user = User {
id: Uuid::new_v4().to_string(),
username: username.to_string(),
password_hash,
created_at: now,
updated_at: now,
};
sqlx::query(
r#"
INSERT INTO users (id, username, password_hash, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5)
INSERT INTO users (id, username, password_hash)
VALUES (?1, ?2, ?3)
"#,
)
.bind(&user.id)
.bind(&user.username)
.bind(&user.password_hash)
.bind(user.created_at.to_rfc3339())
.bind(user.updated_at.to_rfc3339())
.execute(&self.pool)
.await?;
Ok(user)
}
/// Get user by ID
pub async fn get(&self, user_id: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE id = ?1",
)
.bind(user_id)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Get user by username
pub async fn get_by_username(&self, username: &str) -> Result<Option<User>> {
let row: Option<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users WHERE username = ?1",
)
.bind(username)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(User::from_row))
}
/// Verify user credentials
pub async fn verify(&self, username: &str, password: &str) -> Result<Option<User>> {
let user = match self.get_by_username(username).await? {
Some(user) => user,
let user = match self.single_user().await? {
Some(u) => u,
None => return Ok(None),
};
if user.username != username {
return Ok(None);
}
if verify_password(password, &user.password_hash)? {
Ok(Some(user))
} else {
@@ -125,15 +101,23 @@ impl UserStore {
}
}
/// Update user password
pub async fn update_password(&self, user_id: &str, new_password: &str) -> Result<()> {
let user = self
.single_user()
.await?
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
if user.id != user_id {
return Err(AppError::AuthError("Invalid session".to_string()));
}
let password_hash = hash_password(new_password)?;
let now = Utc::now();
let now = OffsetDateTime::now_utc();
let result =
sqlx::query("UPDATE users SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
.bind(&password_hash)
.bind(now.to_rfc3339())
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
.bind(user_id)
.execute(&self.pool)
.await?;
@@ -145,21 +129,24 @@ impl UserStore {
Ok(())
}
/// Update username
pub async fn update_username(&self, user_id: &str, new_username: &str) -> Result<()> {
if let Some(existing) = self.get_by_username(new_username).await? {
if existing.id != user_id {
return Err(AppError::BadRequest(format!(
"Username '{}' already exists",
new_username
)));
}
let user = self
.single_user()
.await?
.ok_or_else(|| AppError::NotFound("User not found".to_string()))?;
if user.id != user_id {
return Err(AppError::AuthError("Invalid session".to_string()));
}
let now = Utc::now();
if new_username == user.username {
return Ok(());
}
let now = OffsetDateTime::now_utc();
let result = sqlx::query("UPDATE users SET username = ?1, updated_at = ?2 WHERE id = ?3")
.bind(new_username)
.bind(now.to_rfc3339())
.bind(now.format(&Rfc3339).expect("RFC3339 format"))
.bind(user_id)
.execute(&self.pool)
.await?;
@@ -170,37 +157,4 @@ impl UserStore {
Ok(())
}
/// List all users
pub async fn list(&self) -> Result<Vec<User>> {
let rows: Vec<UserRow> = sqlx::query_as(
"SELECT id, username, password_hash, created_at, updated_at FROM users ORDER BY created_at",
)
.fetch_all(&self.pool)
.await?;
Ok(rows.into_iter().map(User::from_row).collect())
}
/// Delete user by ID
pub async fn delete(&self, user_id: &str) -> Result<()> {
let result = sqlx::query("DELETE FROM users WHERE id = ?1")
.bind(user_id)
.execute(&self.pool)
.await?;
if result.rows_affected() == 0 {
return Err(AppError::NotFound("User not found".to_string()));
}
Ok(())
}
/// Check if any users exist
pub async fn has_users(&self) -> Result<bool> {
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
.fetch_one(&self.pool)
.await?;
Ok(count.0 > 0)
}
}

View File

@@ -1,5 +1,7 @@
mod persistence;
mod schema;
mod store;
pub use persistence::ConfigChange;
pub use schema::*;
pub use store::ConfigStore;

View File

@@ -0,0 +1,5 @@
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}

View File

@@ -1,12 +1,85 @@
use crate::video::encoder::BitratePreset;
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
// Re-export ExtensionsConfig from extensions module
// Re-export domain config types that are embedded in AppConfig.
// These are simple data types defined in their respective modules;
// keeping the re-export here is acceptable since they flow inward.
pub use crate::extensions::ExtensionsConfig;
// Re-export RustDeskConfig from rustdesk module
pub use crate::rustdesk::config::RustDeskConfig;
/// Bitrate preset for video encoding
///
/// Simplifies bitrate configuration by providing three intuitive presets
/// plus a custom option for advanced users.
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
#[derive(Default)]
pub enum BitratePreset {
/// Speed priority: 1 Mbps, lowest latency, smaller GOP
Speed,
/// Balanced: 4 Mbps, good quality/latency tradeoff
#[default]
Balanced,
/// Quality priority: 8 Mbps, best visual quality
Quality,
/// Custom bitrate in kbps (for advanced users)
Custom(u32),
}
impl BitratePreset {
/// Get bitrate value in kbps
pub fn bitrate_kbps(&self) -> u32 {
match self {
Self::Speed => 1000,
Self::Balanced => 4000,
Self::Quality => 8000,
Self::Custom(kbps) => *kbps,
}
}
/// Get recommended GOP size based on preset
pub fn gop_size(&self, fps: u32) -> u32 {
match self {
Self::Speed => (fps / 2).max(15),
Self::Balanced => fps,
Self::Quality => fps * 2,
Self::Custom(_) => fps,
}
}
/// Get quality preset name for encoder configuration
pub fn quality_level(&self) -> &'static str {
match self {
Self::Speed => "low",
Self::Balanced => "medium",
Self::Quality => "high",
Self::Custom(_) => "medium",
}
}
/// Create from kbps value, mapping to nearest preset or Custom
pub fn from_kbps(kbps: u32) -> Self {
match kbps {
0..=1500 => Self::Speed,
1501..=6000 => Self::Balanced,
6001..=10000 => Self::Quality,
_ => Self::Custom(kbps),
}
}
}
impl std::fmt::Display for BitratePreset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Speed => write!(f, "Speed (1 Mbps)"),
Self::Balanced => write!(f, "Balanced (4 Mbps)"),
Self::Quality => write!(f, "Quality (8 Mbps)"),
Self::Custom(kbps) => write!(f, "Custom ({} kbps)", kbps),
}
}
}
/// Main application configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -179,27 +252,13 @@ pub enum OtgEndpointBudget {
}
impl OtgEndpointBudget {
pub fn default_for_udc_name(udc: Option<&str>) -> Self {
if udc.is_some_and(crate::otg::configfs::is_low_endpoint_udc) {
Self::Five
} else {
Self::Six
}
}
pub fn resolved(self, udc: Option<&str>) -> Self {
/// Resolve endpoint limit assuming a known budget variant (not Auto).
pub fn endpoint_limit_raw(&self) -> Option<u8> {
match self {
Self::Auto => Self::default_for_udc_name(udc),
other => other,
}
}
pub fn endpoint_limit(self, udc: Option<&str>) -> Option<u8> {
match self.resolved(udc) {
Self::Five => Some(5),
Self::Six => Some(6),
Self::Unlimited => None,
Self::Auto => unreachable!("auto budget must be resolved before use"),
Self::Auto => None, // resolved via `HidConfig::resolved_otg_endpoint_limit`
}
}
}
@@ -356,32 +415,23 @@ impl Default for HidConfig {
}
impl HidConfig {
/// Resolve effective OTG HID functions from profile + custom selection.
/// Pure logic, no external dependency.
pub fn effective_otg_functions(&self) -> OtgHidFunctions {
self.otg_profile.resolve_functions(&self.otg_functions)
}
pub fn resolved_otg_udc(&self) -> Option<String> {
crate::otg::configfs::resolve_udc_name(self.otg_udc.as_deref())
}
pub fn resolved_otg_endpoint_budget(&self) -> OtgEndpointBudget {
self.otg_endpoint_budget
.resolved(self.resolved_otg_udc().as_deref())
}
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
self.otg_endpoint_budget
.endpoint_limit(self.resolved_otg_udc().as_deref())
}
/// Whether keyboard LED feedback is effectively enabled.
pub fn effective_otg_keyboard_leds(&self) -> bool {
self.otg_keyboard_leds && self.effective_otg_functions().keyboard
}
/// Effective HID functions after applying all constraints.
pub fn constrained_otg_functions(&self) -> OtgHidFunctions {
self.effective_otg_functions()
}
/// Calculate required endpoint count for the current function selection.
pub fn effective_otg_required_endpoints(&self, msd_enabled: bool) -> u8 {
let functions = self.effective_otg_functions();
let mut endpoints = functions.endpoint_cost(self.effective_otg_keyboard_leds());
@@ -391,6 +441,7 @@ impl HidConfig {
endpoints
}
/// Validate endpoint budget for the current OTG configuration (UDC-aware when budget is Auto).
pub fn validate_otg_endpoint_budget(&self, msd_enabled: bool) -> crate::error::Result<()> {
if self.backend != HidBackend::Otg {
return Ok(());
@@ -403,8 +454,9 @@ impl HidConfig {
));
}
let resolved_limit = self.resolved_otg_endpoint_limit();
let required = self.effective_otg_required_endpoints(msd_enabled);
if let Some(limit) = self.resolved_otg_endpoint_limit() {
if let Some(limit) = resolved_limit {
if required > limit {
return Err(crate::error::AppError::BadRequest(format!(
"OTG selection requires {} endpoints, but the configured limit is {}",
@@ -415,6 +467,40 @@ impl HidConfig {
Ok(())
}
/// Effective OTG UDC name (for change detection / service).
#[inline]
pub fn resolved_otg_udc(&self) -> Option<String> {
if self.backend != HidBackend::Otg {
return None;
}
self.otg_udc
.as_ref()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.or_else(|| crate::otg::OtgGadgetManager::find_udc())
}
/// Resolved endpoint limit used for OTG gadget allocator / validation.
#[inline]
pub fn resolved_otg_endpoint_limit(&self) -> Option<u8> {
if self.backend != HidBackend::Otg {
return None;
}
match self.otg_endpoint_budget {
OtgEndpointBudget::Five => Some(5),
OtgEndpointBudget::Six => Some(6),
OtgEndpointBudget::Unlimited => None,
OtgEndpointBudget::Auto => {
let udc = self.resolved_otg_udc().unwrap_or_default();
if crate::otg::configfs::is_low_endpoint_udc(&udc) {
Some(5)
} else {
Some(6)
}
}
}
}
}
/// MSD configuration
@@ -511,7 +597,7 @@ impl Default for AudioConfig {
fn default() -> Self {
Self {
enabled: false,
device: "default".to_string(),
device: String::new(),
quality: "balanced".to_string(),
}
}
@@ -606,21 +692,6 @@ pub enum EncoderType {
}
impl EncoderType {
/// Convert to EncoderBackend for registry queries
pub fn to_backend(&self) -> Option<crate::video::encoder::registry::EncoderBackend> {
use crate::video::encoder::registry::EncoderBackend;
match self {
EncoderType::Auto => None,
EncoderType::Software => Some(EncoderBackend::Software),
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
EncoderType::Qsv => Some(EncoderBackend::Qsv),
EncoderType::Amf => Some(EncoderBackend::Amf),
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
}
}
/// Get display name for UI
pub fn display_name(&self) -> &'static str {
match self {
@@ -687,19 +758,17 @@ impl Default for StreamConfig {
}
impl StreamConfig {
/// Check if using public ICE servers (user left fields empty)
/// Whether built-in / public ICE is used (no custom STUN or TURN URL configured).
pub fn is_using_public_ice_servers(&self) -> bool {
use crate::webrtc::config::public_ice;
self.stun_server
let no_custom_stun = self
.stun_server
.as_ref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& self
.turn_server
.as_ref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& public_ice::is_configured()
.map_or(true, |s| s.trim().is_empty());
let no_custom_turn = self
.turn_server
.as_ref()
.map_or(true, |s| s.trim().is_empty());
no_custom_stun && no_custom_turn
}
}

View File

@@ -1,11 +1,10 @@
use arc_swap::ArcSwap;
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use sqlx::{Pool, Sqlite};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use super::persistence::ConfigChange;
use super::AppConfig;
use crate::error::{AppError, Result};
@@ -23,127 +22,23 @@ pub struct ConfigStore {
write_lock: Arc<Mutex<()>>,
}
/// Configuration change event
#[derive(Debug, Clone)]
pub struct ConfigChange {
pub key: String,
}
impl ConfigStore {
/// Create a new configuration store
pub async fn new(db_path: &Path) -> Result<Self> {
// Ensure parent directory exists
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
// SQLite uses single-writer mode, 2 connections is sufficient for embedded devices
// One for reads, one for writes to avoid blocking
.max_connections(2)
// Set reasonable timeouts for embedded environments
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
// Initialize database schema
Self::init_schema(&pool).await?;
// Load or create default config
let config = Self::load_config(&pool).await?;
let cache = Arc::new(ArcSwap::from_pointee(config));
let (change_tx, _) = broadcast::channel(16);
pub fn new(pool: Pool<Sqlite>) -> Result<Self> {
// Load or create default config synchronously wrapper
// (actual DB load is async, handled in init())
Ok(Self {
pool,
cache,
change_tx,
cache: Arc::new(ArcSwap::from_pointee(AppConfig::default())),
change_tx: broadcast::channel(16).0,
write_lock: Arc::new(Mutex::new(())),
})
}
/// Initialize database schema
async fn init_schema(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
expires_at TEXT NOT NULL,
data TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS wol_history (
mac_address TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
ON wol_history(updated_at DESC)
"#,
)
.execute(pool)
.await?;
/// Load configuration from database (call after new())
pub async fn load(&self) -> Result<()> {
let config = Self::load_config(&self.pool).await?;
self.cache.store(Arc::new(config));
Ok(())
}
@@ -244,16 +139,12 @@ impl ConfigStore {
pub fn is_initialized(&self) -> bool {
self.cache.load().initialized
}
/// Get database pool for session management
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::DatabasePool;
use tempfile::tempdir;
#[tokio::test]
@@ -261,7 +152,11 @@ mod tests {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test.db");
let store = ConfigStore::new(&db_path).await.unwrap();
let db = DatabasePool::new(&db_path).await.unwrap();
db.init_schema().await.unwrap();
let store = ConfigStore::new(db.clone_pool()).unwrap();
store.load().await.unwrap();
// Check default config (now lock-free, no await needed)
let config = store.get();
@@ -282,7 +177,8 @@ mod tests {
assert_eq!(config.web.http_port, 9000);
// Create new store instance and verify persistence
let store2 = ConfigStore::new(&db_path).await.unwrap();
let store2 = ConfigStore::new(db.clone_pool()).unwrap();
store2.load().await.unwrap();
let config = store2.get();
assert!(config.initialized);
assert_eq!(config.web.http_port, 9000);

3
src/db/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
mod pool;
pub use pool::DatabasePool;

119
src/db/pool.rs Normal file
View File

@@ -0,0 +1,119 @@
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite};
use std::path::Path;
use std::time::Duration;
use crate::error::Result;
#[derive(Clone)]
pub struct DatabasePool {
pool: Pool<Sqlite>,
}
impl DatabasePool {
pub async fn new(db_path: &Path) -> Result<Self> {
if let Some(parent) = db_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let db_url = format!("sqlite:{}?mode=rwc", db_path.display());
let pool = SqlitePoolOptions::new()
.max_connections(4)
.acquire_timeout(Duration::from_secs(5))
.idle_timeout(Duration::from_secs(300))
.connect(&db_url)
.await?;
Ok(Self { pool })
}
pub async fn init_schema(&self) -> Result<()> {
self.create_config_table().await?;
self.create_users_table().await?;
self.create_api_tokens_table().await?;
self.create_wol_history_table().await?;
Ok(())
}
async fn create_config_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_users_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_api_tokens_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
permissions TEXT NOT NULL,
expires_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
last_used TEXT
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
async fn create_wol_history_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS wol_history (
mac_address TEXT PRIMARY KEY,
updated_at INTEGER NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
sqlx::query(
r#"
CREATE INDEX IF NOT EXISTS idx_wol_history_updated_at
ON wol_history(updated_at DESC)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
pub fn clone_pool(&self) -> Pool<Sqlite> {
self.pool.clone()
}
}

View File

@@ -1,12 +1,5 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use thiserror::Error;
/// Application-wide error type
#[derive(Error, Debug)]
pub enum AppError {
#[error("Authentication failed: {0}")]
@@ -15,17 +8,14 @@ pub enum AppError {
#[error("Not authenticated")]
Unauthorized,
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Persistence error: {0}")]
Persistence(String),
#[error("Internal error: {0}")]
Internal(String),
@@ -42,8 +32,9 @@ pub enum AppError {
#[error("Video error: {0}")]
VideoError(String),
#[error("Video device lost [{device}]: {reason}")]
VideoDeviceLost { device: String, reason: String },
/// No input signal while opening capture; `kind` is `SignalStatus` as string (`from_str`).
#[error("Capture has no valid signal: {kind}")]
CaptureNoSignal { kind: String },
#[error("Audio error: {0}")]
AudioError(String),
@@ -62,37 +53,10 @@ pub enum AppError {
ServiceUnavailable(String),
}
/// Error response body (unified success format)
#[derive(Serialize)]
pub struct ErrorResponse {
pub success: bool,
pub message: String,
}
impl AppError {
fn status_code(&self) -> StatusCode {
// Always return 200 OK - success/failure is indicated by the success field
StatusCode::OK
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let status = self.status_code();
let body = ErrorResponse {
success: false,
message: self.to_string(),
};
tracing::error!(
error_type = std::any::type_name_of_val(&self),
error_message = %body.message,
"Request failed"
);
(status, Json(body)).into_response()
}
}
/// Result type alias for handlers
pub type Result<T> = std::result::Result<T, AppError>;
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
AppError::Persistence(err.to_string())
}
}

View File

@@ -1,41 +1,28 @@
//! Event system for real-time state notifications
//!
//! This module provides a global event bus for broadcasting system events
//! to WebSocket clients and other subscribers.
//! Event bus: [`SystemEvent`] fan-out to WebSocket subscribers and internal tasks.
pub mod types;
use self::types::EXACT_EVENT_TOPICS;
pub use types::{
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, MsdDeviceInfo, SystemEvent,
TtydDeviceInfo, VideoDeviceInfo,
AtxDeviceInfo, AudioDeviceInfo, ClientStats, HidDeviceInfo, LedState, MsdDeviceInfo,
StreamDeviceLostKind, SystemEvent, TtydDeviceInfo, VideoDeviceInfo,
};
use tokio::sync::broadcast;
/// Event channel capacity (ring buffer size)
const EVENT_CHANNEL_CAPACITY: usize = 256;
const EXACT_TOPICS: &[&str] = &[
"stream.mode_switching",
"stream.state_changed",
"stream.config_changing",
"stream.config_applied",
"stream.device_lost",
"stream.reconnecting",
"stream.recovered",
"stream.webrtc_ready",
"stream.stats_update",
"stream.mode_changed",
"stream.mode_ready",
"webrtc.ice_candidate",
"webrtc.ice_complete",
"msd.upload_progress",
"msd.download_progress",
"system.device_info",
"error",
];
const PREFIX_TOPICS: &[&str] = &["stream.*", "webrtc.*", "msd.*", "system.*"];
fn collect_prefix_wildcards(exact: &[&'static str]) -> Vec<String> {
use std::collections::BTreeSet;
let mut segments = BTreeSet::new();
for name in exact {
if let Some((seg, _)) = name.split_once('.') {
segments.insert(seg);
}
}
segments.into_iter().map(|s| format!("{}.*", s)).collect()
}
fn make_sender() -> broadcast::Sender<SystemEvent> {
let (tx, _rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
@@ -48,50 +35,23 @@ fn topic_prefix(event_name: &str) -> Option<String> {
.map(|(prefix, _)| format!("{}.*", prefix))
}
/// Global event bus for broadcasting system events
///
/// The event bus uses tokio's broadcast channel to distribute events
/// to multiple subscribers. Events are delivered to all active subscribers.
///
/// # Example
///
/// ```no_run
/// use one_kvm::events::{EventBus, SystemEvent};
///
/// let bus = EventBus::new();
///
/// // Publish an event
/// bus.publish(SystemEvent::StreamStateChanged {
/// state: "streaming".to_string(),
/// device: Some("/dev/video0".to_string()),
/// });
///
/// // Subscribe to events
/// let mut rx = bus.subscribe();
/// tokio::spawn(async move {
/// while let Ok(event) = rx.recv().await {
/// println!("Received event: {:?}", event);
/// }
/// });
/// ```
pub struct EventBus {
tx: broadcast::Sender<SystemEvent>,
exact_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
prefix_topics: std::collections::HashMap<&'static str, broadcast::Sender<SystemEvent>>,
prefix_topics: std::collections::HashMap<String, broadcast::Sender<SystemEvent>>,
device_info_dirty_tx: broadcast::Sender<()>,
}
impl EventBus {
/// Create a new event bus
pub fn new() -> Self {
let tx = make_sender();
let exact_topics = EXACT_TOPICS
let exact_topics = EXACT_EVENT_TOPICS
.iter()
.map(|topic| (*topic, make_sender()))
.collect();
let prefix_topics = PREFIX_TOPICS
.iter()
.map(|topic| (*topic, make_sender()))
let prefix_topics = collect_prefix_wildcards(EXACT_EVENT_TOPICS)
.into_iter()
.map(|topic| (topic, make_sender()))
.collect();
let (device_info_dirty_tx, _dirty_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
@@ -103,10 +63,6 @@ impl EventBus {
}
}
/// Publish an event to all subscribers
///
/// If there are no active subscribers, the event is silently dropped.
/// This is by design - events are fire-and-forget notifications.
pub fn publish(&self, event: SystemEvent) {
let event_name = event.event_name();
@@ -115,28 +71,18 @@ impl EventBus {
}
if let Some(prefix) = topic_prefix(event_name) {
if let Some(tx) = self.prefix_topics.get(prefix.as_str()) {
if let Some(tx) = self.prefix_topics.get(&prefix) {
let _ = tx.send(event.clone());
}
}
// If no subscribers, send returns Err which is normal
let _ = self.tx.send(event);
}
/// Subscribe to events
///
/// Returns a receiver that will receive all future events.
/// The receiver uses a ring buffer, so if a subscriber falls too far
/// behind, it will receive a `Lagged` error and miss some events.
pub fn subscribe(&self) -> broadcast::Receiver<SystemEvent> {
self.tx.subscribe()
}
/// Subscribe to a specific topic.
///
/// Supports exact event names, namespace wildcards like `stream.*`, and
/// `*` for the full event stream.
pub fn subscribe_topic(&self, topic: &str) -> Option<broadcast::Receiver<SystemEvent>> {
if topic == "*" {
return Some(self.tx.subscribe());
@@ -149,22 +95,14 @@ impl EventBus {
self.exact_topics.get(topic).map(|tx| tx.subscribe())
}
/// Mark the device-info snapshot as stale.
///
/// This is an internal trigger used to refresh the latest `system.device_info`
/// snapshot without exposing another public WebSocket event.
pub fn mark_device_info_dirty(&self) {
let _ = self.device_info_dirty_tx.send(());
}
/// Subscribe to internal device-info refresh triggers.
pub fn subscribe_device_info_dirty(&self) -> broadcast::Receiver<()> {
self.device_info_dirty_tx.subscribe()
}
/// Get the current number of active subscribers
///
/// Useful for monitoring and debugging.
pub fn subscriber_count(&self) -> usize {
self.tx.receiver_count()
}
@@ -188,6 +126,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -205,6 +145,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
});
let event1 = rx1.recv().await.unwrap();
@@ -222,6 +164,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -236,6 +180,8 @@ mod tests {
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
let event = rx.recv().await.unwrap();
@@ -253,10 +199,11 @@ mod tests {
let bus = EventBus::new();
assert_eq!(bus.subscriber_count(), 0);
// Should not panic when publishing with no subscribers
bus.publish(SystemEvent::StreamStateChanged {
state: "ready".to_string(),
device: None,
reason: None,
next_retry_ms: None,
});
}
}

View File

@@ -1,359 +1,234 @@
//! System event types
//!
//! Defines all event types that can be broadcast through the event bus.
//! [`SystemEvent`] and device snapshot types (WebSocket / JSON).
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::hid::LedState;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct LedState {
pub num_lock: bool,
pub caps_lock: bool,
pub scroll_lock: bool,
pub compose: bool,
pub kana: bool,
}
// ============================================================================
// Device Info Structures (for system.device_info event)
// ============================================================================
/// Video device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VideoDeviceInfo {
/// Whether video device is available
pub available: bool,
/// Device path (e.g., /dev/video0)
pub device: Option<String>,
/// Pixel format (e.g., "MJPEG", "YUYV")
pub format: Option<String>,
/// Resolution (width, height)
pub resolution: Option<(u32, u32)>,
/// Frames per second
pub fps: u32,
/// Whether stream is currently active
pub online: bool,
/// Current streaming mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
pub stream_mode: String,
/// Whether video config is currently being changed (frontend should skip mode sync)
pub config_changing: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// HID device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HidDeviceInfo {
/// Whether HID backend is available
pub available: bool,
/// Backend type: "otg", "ch9329", "none"
pub backend: String,
/// Whether backend is initialized and ready
pub initialized: bool,
/// Whether backend is currently online
pub online: bool,
/// Whether absolute mouse positioning is supported
pub supports_absolute_mouse: bool,
/// Whether keyboard LED/status feedback is enabled.
pub keyboard_leds_enabled: bool,
/// Last known keyboard LED state.
pub led_state: LedState,
/// Device path (e.g., serial port for CH9329)
pub device: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
/// Error code if any, None if OK
pub error_code: Option<String>,
}
/// MSD device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdDeviceInfo {
/// Whether MSD is available
pub available: bool,
/// Operating mode: "none", "image", "drive"
pub mode: String,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image ID
pub image_id: Option<String>,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// ATX device information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AtxDeviceInfo {
/// Whether ATX controller is available
pub available: bool,
/// Backend type: "gpio", "usb_relay", "none"
pub backend: String,
/// Whether backend is initialized
pub initialized: bool,
/// Whether power is currently on
pub power_on: bool,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// Audio device information
///
/// Note: Sample rate is fixed at 48000Hz and channels at 2 (stereo).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioDeviceInfo {
/// Whether audio is enabled/available
pub available: bool,
/// Whether audio is currently streaming
pub streaming: bool,
/// Current audio device name
pub device: Option<String>,
/// Quality preset: "voice", "balanced", "high"
pub quality: String,
/// Error message if any, None if OK
pub error: Option<String>,
}
/// ttyd status information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydDeviceInfo {
/// Whether ttyd binary is available
pub available: bool,
/// Whether ttyd is currently running
pub running: bool,
}
/// Per-client statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientStats {
/// Client ID
pub id: String,
/// Current FPS for this client (frames sent in last second)
pub fps: u32,
/// Connected duration (seconds)
pub connected_secs: u64,
}
/// System event enumeration
///
/// All events are tagged with their event name for serialization.
/// The `serde(tag = "event", content = "data")` attribute creates a
/// JSON structure like:
/// ```json
/// {
/// "event": "stream.state_changed",
/// "data": { "state": "streaming", "device": "/dev/video0" }
/// }
/// ```
/// Video vs audio source for [`SystemEvent::StreamDeviceLost`] (WebSocket `stream.device_lost`).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamDeviceLostKind {
Video,
Audio,
}
/// JSON: `{"event": "<name>", "data": { ... }}`.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event", content = "data")]
#[allow(clippy::large_enum_variant)]
pub enum SystemEvent {
// ============================================================================
// Video Stream Events
// ============================================================================
/// Stream mode switching started (transactional, correlates all following events)
///
/// Sent immediately after a mode switch request is accepted.
/// Clients can use `transition_id` to correlate subsequent `stream.*` events.
#[serde(rename = "stream.mode_switching")]
StreamModeSwitching {
/// Unique transition ID for this mode switch transaction
transition_id: String,
/// Target mode: "mjpeg", "h264", "h265", "vp8", "vp9"
to_mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", "vp9"
from_mode: String,
},
/// Stream state changed (e.g., started, stopped, error)
#[serde(rename = "stream.state_changed")]
StreamStateChanged {
/// Current state: "uninitialized", "ready", "streaming", "no_signal", "error"
state: String,
/// Device path if available
device: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
next_retry_ms: Option<u64>,
},
/// Stream configuration is being changed
///
/// Sent before applying new configuration to notify clients that
/// the stream will be interrupted temporarily.
#[serde(rename = "stream.config_changing")]
StreamConfigChanging {
/// Optional transition ID if this config change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Reason for change: "device_switch", "resolution_change", "format_change"
reason: String,
},
/// Stream configuration has been applied successfully
///
/// Sent after new configuration is active. Clients can reconnect now.
#[serde(rename = "stream.config_applied")]
StreamConfigApplied {
/// Optional transition ID if this config change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Device path
device: String,
/// Resolution (width, height)
resolution: (u32, u32),
/// Pixel format: "mjpeg", "yuyv", etc.
format: String,
/// Frames per second
fps: u32,
},
/// Stream device was lost (disconnected or error)
#[serde(rename = "stream.device_lost")]
StreamDeviceLost {
/// Device path that was lost
kind: StreamDeviceLostKind,
device: String,
/// Reason for loss
reason: String,
},
/// Stream device is reconnecting
#[serde(rename = "stream.reconnecting")]
StreamReconnecting {
/// Device path being reconnected
device: String,
/// Retry attempt number
attempt: u32,
},
StreamReconnecting { device: String, attempt: u32 },
/// Stream device has recovered
#[serde(rename = "stream.recovered")]
StreamRecovered {
/// Device path that was recovered
device: String,
},
StreamRecovered { device: String },
/// WebRTC is ready to accept connections
///
/// Sent after video frame source is connected to WebRTC pipeline.
/// Clients should wait for this event before attempting to create WebRTC sessions.
#[serde(rename = "stream.webrtc_ready")]
WebRTCReady {
/// Optional transition ID if this readiness is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// Current video codec
codec: String,
/// Whether hardware encoding is being used
hardware: bool,
},
/// WebRTC ICE candidate (server -> client trickle)
#[serde(rename = "webrtc.ice_candidate")]
WebRTCIceCandidate {
/// WebRTC session ID
session_id: String,
/// ICE candidate data
candidate: crate::webrtc::signaling::IceCandidate,
candidate: serde_json::Value,
},
/// WebRTC ICE gathering complete (server -> client)
#[serde(rename = "webrtc.ice_complete")]
WebRTCIceComplete {
/// WebRTC session ID
session_id: String,
},
WebRTCIceComplete { session_id: String },
/// Stream statistics update (sent periodically for client stats)
#[serde(rename = "stream.stats_update")]
StreamStatsUpdate {
/// Number of connected clients
clients: u64,
/// Per-client statistics (client_id -> client stats)
/// Each client's FPS reflects the actual frames sent in the last second
clients_stat: HashMap<String, ClientStats>,
},
/// Stream mode changed (MJPEG <-> WebRTC)
///
/// Sent when the streaming mode is switched. Clients should disconnect
/// from the current stream and reconnect using the new mode.
#[serde(rename = "stream.mode_changed")]
StreamModeChanged {
/// Optional transition ID if this change is part of a mode switch transaction
#[serde(skip_serializing_if = "Option::is_none")]
transition_id: Option<String>,
/// New mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
mode: String,
/// Previous mode: "mjpeg", "h264", "h265", "vp8", or "vp9"
previous_mode: String,
},
/// Stream mode switching completed (transactional end marker)
///
/// Sent when the backend considers the new mode ready for clients to connect.
#[serde(rename = "stream.mode_ready")]
StreamModeReady {
/// Unique transition ID for this mode switch transaction
transition_id: String,
/// Active mode after switch: "mjpeg", "h264", "h265", "vp8", "vp9"
mode: String,
},
StreamModeReady { transition_id: String, mode: String },
// ============================================================================
// MSD (Mass Storage Device) Events
// ============================================================================
/// File upload progress (for large file uploads)
#[serde(rename = "msd.upload_progress")]
MsdUploadProgress {
/// Upload operation ID
upload_id: String,
/// Filename being uploaded
filename: String,
/// Bytes uploaded so far
bytes_uploaded: u64,
/// Total file size
total_bytes: u64,
/// Progress percentage (0.0 - 100.0)
progress_pct: f32,
},
/// Image download progress (for URL downloads)
#[serde(rename = "msd.download_progress")]
MsdDownloadProgress {
/// Download operation ID
download_id: String,
/// Source URL
url: String,
/// Target filename
filename: String,
/// Bytes downloaded so far
bytes_downloaded: u64,
/// Total file size (None if unknown)
total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
progress_pct: Option<f32>,
/// Download status: "started", "in_progress", "completed", "failed"
status: String,
},
/// Complete device information (sent on WebSocket connect and state changes)
#[serde(rename = "system.device_info")]
DeviceInfo {
/// Video device information
video: VideoDeviceInfo,
/// HID device information
hid: HidDeviceInfo,
/// MSD device information (None if MSD not enabled)
msd: Option<MsdDeviceInfo>,
/// ATX device information (None if ATX not enabled)
atx: Option<AtxDeviceInfo>,
/// Audio device information (None if audio not enabled)
audio: Option<AudioDeviceInfo>,
/// ttyd status information
ttyd: TtydDeviceInfo,
},
/// WebSocket error notification (for connection-level errors like lag)
#[serde(rename = "error")]
Error {
/// Error message
message: String,
},
Error { message: String },
}
/// One entry per [`SystemEvent::event_name`]. `EventBus` builds `*.`-wildcard channels from the first segment; names without `.` (e.g. `error`) have no wildcard channel.
pub(crate) const EXACT_EVENT_TOPICS: &[&str] = &[
"stream.mode_switching",
"stream.state_changed",
"stream.config_changing",
"stream.config_applied",
"stream.device_lost",
"stream.reconnecting",
"stream.recovered",
"stream.webrtc_ready",
"stream.stats_update",
"stream.mode_changed",
"stream.mode_ready",
"webrtc.ice_candidate",
"webrtc.ice_complete",
"msd.upload_progress",
"msd.download_progress",
"system.device_info",
"error",
];
impl SystemEvent {
/// Get the event name (for filtering/routing)
pub fn event_name(&self) -> &'static str {
match self {
Self::StreamModeSwitching { .. } => "stream.mode_switching",
@@ -375,27 +250,6 @@ impl SystemEvent {
Self::Error { .. } => "error",
}
}
/// Check if event name matches a topic pattern
///
/// Supports wildcards:
/// - `*` matches all events
/// - `stream.*` matches all stream events
/// - `stream.state_changed` matches exact event
pub fn matches_topic(&self, topic: &str) -> bool {
if topic == "*" {
return true;
}
let event_name = self.event_name();
if topic.ends_with(".*") {
let prefix = topic.trim_end_matches(".*");
event_name.starts_with(prefix)
} else {
event_name == topic
}
}
}
#[cfg(test)]
@@ -407,22 +261,145 @@ mod tests {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: Some("/dev/video0".to_string()),
reason: None,
next_retry_ms: None,
};
assert_eq!(event.event_name(), "stream.state_changed");
}
#[test]
fn test_matches_topic() {
let event = SystemEvent::StreamStateChanged {
state: "streaming".to_string(),
device: None,
fn stream_device_lost_json_snake_case_kind() {
let event = SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Audio,
device: "hw:0,0".to_string(),
reason: "test".to_string(),
};
let v = serde_json::to_value(&event).unwrap();
let data = v.get("data").unwrap();
assert_eq!(data.get("kind").and_then(|x| x.as_str()), Some("audio"));
assert_eq!(data.get("device").and_then(|x| x.as_str()), Some("hw:0,0"));
}
assert!(event.matches_topic("*"));
assert!(event.matches_topic("stream.*"));
assert!(event.matches_topic("stream.state_changed"));
assert!(!event.matches_topic("msd.*"));
assert!(!event.matches_topic("stream.config_changed"));
#[test]
fn exact_topics_covers_all_variants() {
use std::collections::HashSet;
let samples = vec![
SystemEvent::StreamModeSwitching {
transition_id: String::new(),
to_mode: String::new(),
from_mode: String::new(),
},
SystemEvent::StreamStateChanged {
state: String::new(),
device: None,
reason: None,
next_retry_ms: None,
},
SystemEvent::StreamConfigChanging {
transition_id: None,
reason: String::new(),
},
SystemEvent::StreamConfigApplied {
transition_id: None,
device: String::new(),
resolution: (0, 0),
format: String::new(),
fps: 0,
},
SystemEvent::StreamDeviceLost {
kind: StreamDeviceLostKind::Video,
device: String::new(),
reason: String::new(),
},
SystemEvent::StreamReconnecting {
device: String::new(),
attempt: 0,
},
SystemEvent::StreamRecovered {
device: String::new(),
},
SystemEvent::WebRTCReady {
transition_id: None,
codec: String::new(),
hardware: false,
},
SystemEvent::StreamStatsUpdate {
clients: 0,
clients_stat: HashMap::new(),
},
SystemEvent::StreamModeChanged {
transition_id: None,
mode: String::new(),
previous_mode: String::new(),
},
SystemEvent::StreamModeReady {
transition_id: String::new(),
mode: String::new(),
},
SystemEvent::WebRTCIceCandidate {
session_id: String::new(),
candidate: serde_json::Value::Null,
},
SystemEvent::WebRTCIceComplete {
session_id: String::new(),
},
SystemEvent::MsdUploadProgress {
upload_id: String::new(),
filename: String::new(),
bytes_uploaded: 0,
total_bytes: 0,
progress_pct: 0.0,
},
SystemEvent::MsdDownloadProgress {
download_id: String::new(),
url: String::new(),
filename: String::new(),
bytes_downloaded: 0,
total_bytes: None,
progress_pct: None,
status: String::new(),
},
SystemEvent::DeviceInfo {
video: VideoDeviceInfo {
available: false,
device: None,
format: None,
resolution: None,
fps: 0,
online: false,
stream_mode: String::new(),
config_changing: false,
error: None,
},
hid: HidDeviceInfo {
available: false,
backend: String::new(),
initialized: false,
online: false,
supports_absolute_mouse: false,
keyboard_leds_enabled: false,
led_state: LedState::default(),
device: None,
error: None,
error_code: None,
},
msd: None,
atx: None,
audio: None,
ttyd: TtydDeviceInfo {
available: false,
running: false,
},
},
SystemEvent::Error {
message: String::new(),
},
];
let from_enum: HashSet<_> = samples.iter().map(|e| e.event_name()).collect();
let from_const: HashSet<_> = super::EXACT_EVENT_TOPICS.iter().copied().collect();
assert_eq!(from_enum, from_const);
}
#[test]

View File

@@ -1,5 +1,3 @@
//! Extension process manager
use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::process::Stdio;
@@ -12,25 +10,18 @@ use tokio::sync::RwLock;
use super::types::*;
use crate::events::EventBus;
/// Maximum number of log lines to keep per extension
const LOG_BUFFER_SIZE: usize = 200;
/// Number of log lines to buffer before flushing to shared storage
const LOG_BATCH_SIZE: usize = 16;
/// Unix socket path for ttyd
pub const TTYD_SOCKET_PATH: &str = "/var/run/one-kvm/ttyd.sock";
/// Extension process with log buffer
struct ExtensionProcess {
child: Child,
logs: Arc<RwLock<VecDeque<String>>>,
}
/// Extension manager handles lifecycle of external processes
pub struct ExtensionManager {
processes: RwLock<HashMap<ExtensionId, ExtensionProcess>>,
/// Cached availability status (checked once at startup)
availability: HashMap<ExtensionId, bool>,
event_bus: RwLock<Option<Arc<EventBus>>>,
}
@@ -42,9 +33,7 @@ impl Default for ExtensionManager {
}
impl ExtensionManager {
/// Create a new extension manager with cached availability
pub fn new() -> Self {
// Check availability once at startup
let availability = ExtensionId::all()
.iter()
.map(|id| (*id, Path::new(id.binary_path()).exists()))
@@ -57,7 +46,6 @@ impl ExtensionManager {
}
}
/// Set event bus for ttyd status notifications.
pub async fn set_event_bus(&self, event_bus: Arc<EventBus>) {
*self.event_bus.write().await = Some(event_bus);
}
@@ -72,12 +60,10 @@ impl ExtensionManager {
}
}
/// Check if the binary for an extension is available (cached)
pub fn check_available(&self, id: ExtensionId) -> bool {
*self.availability.get(&id).unwrap_or(&false)
}
/// Get the current status of an extension
pub async fn status(&self, id: ExtensionId) -> ExtensionStatus {
if !self.check_available(id) {
return ExtensionStatus::Unavailable;
@@ -117,27 +103,20 @@ impl ExtensionManager {
ExtensionStatus::Stopped
}
/// Start an extension with the given configuration
pub async fn start(&self, id: ExtensionId, config: &ExtensionsConfig) -> Result<(), String> {
if !self.check_available(id) {
return Err(format!(
"{} not found at {}",
id.display_name(),
id.binary_path()
));
return Err(format!("{} not found at {}", id, id.binary_path()));
}
// Stop existing process first
self.stop(id).await.ok();
// Build command arguments
let args = self.build_args(id, config).await?;
tracing::info!(
"Starting extension {}: {} {}",
id,
id.binary_path(),
args.join(" ")
Self::redact_args_for_log(&args).join(" ")
);
let mut child = Command::new(id.binary_path())
@@ -146,11 +125,10 @@ impl ExtensionManager {
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|e| format!("Failed to start {}: {}", id.display_name(), e))?;
.map_err(|e| format!("Failed to start {}: {}", id, e))?;
let logs = Arc::new(RwLock::new(VecDeque::with_capacity(LOG_BUFFER_SIZE)));
// Spawn log collector for stdout
if let Some(stdout) = child.stdout.take() {
let logs_clone = logs.clone();
let id_clone = id;
@@ -159,7 +137,6 @@ impl ExtensionManager {
});
}
// Spawn log collector for stderr
if let Some(stderr) = child.stderr.take() {
let logs_clone = logs.clone();
let id_clone = id;
@@ -179,7 +156,6 @@ impl ExtensionManager {
Ok(())
}
/// Stop an extension
pub async fn stop(&self, id: ExtensionId) -> Result<(), String> {
let mut processes = self.processes.write().await;
if let Some(mut proc) = processes.remove(&id) {
@@ -193,7 +169,6 @@ impl ExtensionManager {
Ok(())
}
/// Get recent logs for an extension
pub async fn logs(&self, id: ExtensionId, lines: usize) -> Vec<String> {
let processes = self.processes.read().await;
if let Some(proc) = processes.get(&id) {
@@ -205,7 +180,6 @@ impl ExtensionManager {
}
}
/// Collect logs from a stream with batched writes to reduce lock contention
async fn collect_logs<R: tokio::io::AsyncRead + Unpin>(
id: ExtensionId,
reader: R,
@@ -218,16 +192,14 @@ impl ExtensionManager {
loop {
match lines.next_line().await {
Ok(Some(line)) => {
tracing::debug!("[{}] {}", id, line);
tracing::info!("[{}] {}", id, line);
local_buffer.push(line);
// Flush when batch is full
if local_buffer.len() >= LOG_BATCH_SIZE {
Self::flush_logs(&logs, &mut local_buffer).await;
}
}
Ok(None) => {
// Stream ended, flush remaining logs
if !local_buffer.is_empty() {
Self::flush_logs(&logs, &mut local_buffer).await;
}
@@ -241,7 +213,6 @@ impl ExtensionManager {
}
}
/// Flush buffered logs to shared storage
async fn flush_logs(logs: &RwLock<VecDeque<String>>, buffer: &mut Vec<String>) {
let mut logs = logs.write().await;
for line in buffer.drain(..) {
@@ -252,7 +223,6 @@ impl ExtensionManager {
}
}
/// Build command arguments for an extension
async fn build_args(
&self,
id: ExtensionId,
@@ -262,18 +232,16 @@ impl ExtensionManager {
ExtensionId::Ttyd => {
let c = &config.ttyd;
// Prepare socket directory and clean up old socket (async)
Self::prepare_ttyd_socket().await?;
let mut args = vec![
"-i".to_string(),
TTYD_SOCKET_PATH.to_string(), // Unix socket
TTYD_SOCKET_PATH.to_string(),
"-b".to_string(),
"/api/terminal".to_string(), // Base path for reverse proxy
"-W".to_string(), // Writable (allow input)
"/api/terminal".to_string(),
"-W".to_string(),
];
// Add shell as last argument
args.push(c.shell.clone());
Ok(args)
}
@@ -289,15 +257,12 @@ impl ExtensionManager {
let mut args = Vec::new();
// Add TLS flag
if c.tls {
args.push("--tls=true".to_string());
}
// Server address (validated non-empty above)
args.extend(["-addr".to_string(), c.addr.trim().to_string()]);
// Add client key
args.extend(["-key".to_string(), c.key.clone()]);
Ok(args)
@@ -316,24 +281,19 @@ impl ExtensionManager {
c.network_secret.clone(),
];
// Add peer URLs
for peer in &c.peer_urls {
if !peer.is_empty() {
args.extend(["--peers".to_string(), peer.clone()]);
}
}
// Add virtual IP: use -d for DHCP if empty, or -i for specific IP
if let Some(ref ip) = c.virtual_ip {
if !ip.is_empty() {
// Use specific IP with -i (must include CIDR, e.g., 10.0.0.1/24)
args.extend(["-i".to_string(), ip.clone()]);
} else {
// Empty string means use DHCP
args.push("-d".to_string());
}
} else {
// None means use DHCP
args.push("-d".to_string());
}
@@ -342,11 +302,37 @@ impl ExtensionManager {
}
}
/// Prepare ttyd socket directory and clean up old socket file
fn redact_args_for_log(args: &[String]) -> Vec<String> {
let mut redacted = Vec::with_capacity(args.len());
let mut redact_next = false;
for arg in args {
if redact_next {
redacted.push("****".to_string());
redact_next = false;
continue;
}
if arg == "-key" || arg == "--key" {
redacted.push(arg.clone());
redact_next = true;
} else if let Some((flag, _)) = arg.split_once('=') {
if flag == "-key" || flag == "--key" {
redacted.push(format!("{}=****", flag));
} else {
redacted.push(arg.clone());
}
} else {
redacted.push(arg.clone());
}
}
redacted
}
async fn prepare_ttyd_socket() -> Result<(), String> {
let socket_path = Path::new(TTYD_SOCKET_PATH);
// Ensure socket directory exists
if let Some(socket_dir) = socket_path.parent() {
if !socket_dir.exists() {
tokio::fs::create_dir_all(socket_dir)
@@ -355,7 +341,6 @@ impl ExtensionManager {
}
}
// Remove old socket file if exists
if tokio::fs::try_exists(TTYD_SOCKET_PATH)
.await
.unwrap_or(false)
@@ -368,9 +353,7 @@ impl ExtensionManager {
Ok(())
}
/// Health check - restart crashed processes that should be running
pub async fn health_check(&self, config: &ExtensionsConfig) {
// Collect extensions that need restart check
let checks: Vec<_> = ExtensionId::all()
.iter()
.filter_map(|id| {
@@ -393,7 +376,6 @@ impl ExtensionManager {
})
.collect();
// Check which ones need restart (single read lock)
let needs_restart: Vec<_> = {
let processes = self.processes.read().await;
checks
@@ -408,7 +390,6 @@ impl ExtensionManager {
.collect()
};
// Restart all crashed extensions in parallel
let restart_futures: Vec<_> = needs_restart
.into_iter()
.map(|id| async move {
@@ -422,14 +403,12 @@ impl ExtensionManager {
futures::future::join_all(restart_futures).await;
}
/// Start all enabled extensions in parallel
pub async fn start_enabled(&self, config: &ExtensionsConfig) {
use futures::Future;
use std::pin::Pin;
let mut start_futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + '_>>> = Vec::new();
// Collect enabled extensions
if config.ttyd.enabled && self.check_available(ExtensionId::Ttyd) {
start_futures.push(Box::pin(async {
if let Err(e) = self.start(ExtensionId::Ttyd, config).await {
@@ -461,11 +440,9 @@ impl ExtensionManager {
}));
}
// Start all in parallel
futures::future::join_all(start_futures).await;
}
/// Stop all running extensions in parallel
pub async fn stop_all(&self) {
let stop_futures: Vec<_> = ExtensionId::all().iter().map(|id| self.stop(*id)).collect();
futures::future::join_all(stop_futures).await;

View File

@@ -1,5 +1,3 @@
//! Extensions module - manage external processes like ttyd, gostc, easytier
mod manager;
mod types;

View File

@@ -1,23 +1,16 @@
//! Extension types and configurations
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Extension identifier (fixed set of supported extensions)
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ExtensionId {
/// Web terminal (ttyd)
Ttyd,
/// NAT traversal client (gostc)
Gostc,
/// P2P VPN (easytier)
Easytier,
}
impl ExtensionId {
/// Get the binary path for this extension
pub fn binary_path(&self) -> &'static str {
match self {
Self::Ttyd => "/usr/bin/ttyd",
@@ -26,16 +19,6 @@ impl ExtensionId {
}
}
/// Get the display name for this extension
pub fn display_name(&self) -> &'static str {
match self {
Self::Ttyd => "Web Terminal",
Self::Gostc => "GOSTC Tunnel",
Self::Easytier => "EasyTier VPN",
}
}
/// Get all extension IDs
pub fn all() -> &'static [ExtensionId] {
&[Self::Ttyd, Self::Gostc, Self::Easytier]
}
@@ -64,25 +47,14 @@ impl std::str::FromStr for ExtensionId {
}
}
/// Extension running status
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "state", content = "data", rename_all = "lowercase")]
pub enum ExtensionStatus {
/// Binary not found at expected path
Unavailable,
/// Extension is stopped
Stopped,
/// Extension is running
Running {
/// Process ID
pid: u32,
},
/// Extension failed to start
Failed {
/// Error message
error: String,
},
Running { pid: u32 },
Failed { error: String },
}
impl ExtensionStatus {
@@ -91,16 +63,11 @@ impl ExtensionStatus {
}
}
/// ttyd configuration (Web Terminal)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TtydConfig {
/// Enable auto-start
pub enabled: bool,
/// Port to listen on
pub port: u16,
/// Shell to execute
pub shell: String,
}
@@ -108,25 +75,19 @@ impl Default for TtydConfig {
fn default() -> Self {
Self {
enabled: false,
port: 7681,
shell: "/bin/bash".to_string(),
}
}
}
/// gostc configuration (NAT traversal based on FRP)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct GostcConfig {
/// Enable auto-start
pub enabled: bool,
/// Server address (hostname or IP)
pub addr: String,
/// Client key from GOSTC management panel
#[serde(skip_serializing_if = "String::is_empty")]
pub key: String,
/// Enable TLS
pub tls: bool,
}
@@ -141,28 +102,21 @@ impl Default for GostcConfig {
}
}
/// EasyTier configuration (P2P VPN)
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct EasytierConfig {
/// Enable auto-start
pub enabled: bool,
/// Network name
pub network_name: String,
/// Network secret/password
#[serde(skip_serializing_if = "String::is_empty")]
pub network_secret: String,
/// Peer node URLs
#[serde(skip_serializing_if = "Vec::is_empty")]
pub peer_urls: Vec<String>,
/// Virtual IP address (optional, auto-assigned if not set)
#[serde(skip_serializing_if = "Option::is_none")]
pub virtual_ip: Option<String>,
}
/// Combined extensions configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
@@ -172,53 +126,37 @@ pub struct ExtensionsConfig {
pub easytier: EasytierConfig,
}
/// Extension info with status and config
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
}
/// ttyd extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TtydInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: TtydConfig,
}
/// gostc extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GostcInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: GostcConfig,
}
/// easytier extension info
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EasytierInfo {
/// Whether binary exists
pub available: bool,
/// Current status
pub status: ExtensionStatus,
/// Configuration
pub config: EasytierConfig,
}
/// All extensions status response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionsStatus {
@@ -227,7 +165,6 @@ pub struct ExtensionsStatus {
pub easytier: EasytierInfo,
}
/// Extension logs response
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionLogs {

View File

@@ -1,73 +1,32 @@
//! HID backend trait definition
//! `HidBackend` trait plus serde `HidBackendType` (OTG | CH9329 | disabled).
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use super::otg::LedState;
use super::types::{ConsumerEvent, KeyboardEvent, MouseEvent};
use crate::error::Result;
use crate::events::LedState;
/// Default CH9329 baud rate
fn default_ch9329_baud_rate() -> u32 {
9600
}
/// HID backend type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
#[derive(Default)]
pub enum HidBackendType {
/// USB OTG gadget mode
Otg,
/// CH9329 serial HID controller
Ch9329 {
/// Serial port path
port: String,
/// Baud rate (default: 9600)
#[serde(default = "default_ch9329_baud_rate")]
baud_rate: u32,
},
/// No HID backend (disabled)
#[default]
None,
}
impl HidBackendType {
/// Check if OTG backend is available on this system
pub fn otg_available() -> bool {
// Check for USB gadget support
std::path::Path::new("/sys/class/udc").exists()
}
/// Detect the best available backend
pub fn detect() -> Self {
// Check for OTG gadget support
if Self::otg_available() {
return Self::Otg;
}
// Check for common CH9329 serial ports
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
];
for port in &common_ports {
if std::path::Path::new(port).exists() {
return Self::Ch9329 {
port: port.to_string(),
baud_rate: 9600, // Use default baud rate for auto-detection
};
}
}
Self::None
}
/// Get backend name as string
pub fn name_str(&self) -> &str {
match self {
Self::Otg => "otg",
@@ -77,76 +36,40 @@ impl HidBackendType {
}
}
/// Current runtime status reported by a HID backend.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct HidBackendRuntimeSnapshot {
/// Whether the backend has been initialized and can accept requests.
pub initialized: bool,
/// Whether the backend is currently online and communicating successfully.
pub online: bool,
/// Whether absolute mouse positioning is supported.
pub supports_absolute_mouse: bool,
/// Whether keyboard LED/status feedback is currently enabled.
pub keyboard_leds_enabled: bool,
/// Last known keyboard LED state.
pub led_state: LedState,
/// Screen resolution for absolute mouse mode.
pub screen_resolution: Option<(u32, u32)>,
/// Device identifier associated with the backend, if any.
pub device: Option<String>,
/// Current user-facing error, if any.
pub error: Option<String>,
/// Current programmatic error code, if any.
pub error_code: Option<String>,
}
/// HID backend trait
#[async_trait]
pub trait HidBackend: Send + Sync {
/// Initialize the backend
async fn init(&self) -> Result<()>;
/// Send a keyboard event
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()>;
/// Send a mouse event
async fn send_mouse(&self, event: MouseEvent) -> Result<()>;
/// Send a consumer control event (multimedia keys)
/// Default implementation returns an error (not supported)
async fn send_consumer(&self, _event: ConsumerEvent) -> Result<()> {
Err(crate::error::AppError::BadRequest(
"Consumer control not supported by this backend".to_string(),
))
}
/// Reset all inputs (release all keys/buttons)
async fn reset(&self) -> Result<()>;
/// Shutdown the backend
async fn shutdown(&self) -> Result<()>;
/// Get the current backend runtime snapshot.
fn runtime_snapshot(&self) -> HidBackendRuntimeSnapshot;
/// Subscribe to backend runtime changes.
fn subscribe_runtime(&self) -> watch::Receiver<()>;
/// Set screen resolution (for absolute mouse)
fn set_screen_resolution(&mut self, _width: u32, _height: u32) {}
}
/// HID backend information
#[derive(Debug, Clone, Serialize)]
pub struct HidBackendInfo {
/// Backend name
pub name: String,
/// Backend type
pub backend_type: String,
/// Is initialized
pub initialized: bool,
/// Supports absolute mouse
pub absolute_mouse: bool,
/// Screen resolution (if absolute mouse)
pub resolution: Option<(u32, u32)>,
fn set_screen_resolution(&self, _width: u32, _height: u32) {}
}

View File

@@ -1,9 +1,4 @@
//! CH9329 Serial HID Controller backend
//!
//! CH9329 is a USB HID chip controlled via UART from WCH (沁恒).
//! It supports keyboard, mouse (absolute + relative), and custom HID device emulation.
//!
//! ## Protocol Format
//! CH9329 over UART — WCH *Serial Communication Protocol V1.0*.
//! ```text
//! ┌──────┬──────┬──────┬────────┬──────────────┬──────────┐
//! │Header│ ADDR │ CMD │ LEN │ DATA │ SUM │
@@ -11,16 +6,11 @@
//! │57 AB │ 00 │ xx │ N │ N bytes │Checksum │
//! └──────┴──────┴──────┴────────┴──────────────┴──────────┘
//! ```
//!
//! Checksum: Sum of ALL bytes including header (modulo 256)
//!
//! ## Reference
//! Based on WCH CH9329 Serial Communication Protocol V1.0
//! Sum of all octets modulo 256 (including header).
use async_trait::async_trait;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, Ordering};
use std::sync::{mpsc, Arc};
use std::thread;
@@ -29,82 +19,51 @@ use tokio::sync::watch;
use tracing::{info, trace, warn};
use super::backend::{HidBackend, HidBackendRuntimeSnapshot};
use super::otg::LedState;
use super::types::{KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType};
use crate::error::{AppError, Result};
use crate::events::LedState;
// ============================================================================
// Constants and Command Codes
// ============================================================================
/// CH9329 packet header
const PACKET_HEADER: [u8; 2] = [0x57, 0xAB];
/// Default address (accepts any address)
const DEFAULT_ADDR: u8 = 0x00;
/// Default baud rate for CH9329
pub const DEFAULT_BAUD_RATE: u32 = 9600;
/// Response timeout in milliseconds
const RESPONSE_TIMEOUT_MS: u64 = 500;
/// Maximum data length in a packet
const MAX_DATA_LEN: usize = 64;
/// CH9329 absolute mouse resolution
const CH9329_MOUSE_RESOLUTION: u32 = 4096;
/// How often the worker probes the chip when idle.
const PROBE_INTERVAL_MS: u64 = 100;
/// How long the worker waits before reopening the serial port after a failure.
const RECONNECT_DELAY_MS: u64 = 2000;
/// Initial startup wait for the worker to confirm CH9329 is reachable.
const INIT_WAIT_MS: u64 = 3000;
/// CH9329 command codes
pub mod cmd {
/// Get chip version, USB status, and LED status
pub const GET_INFO: u8 = 0x01;
/// Send standard keyboard data (8 bytes)
pub const SEND_KB_GENERAL_DATA: u8 = 0x02;
/// Send multimedia keyboard data
pub const SEND_KB_MEDIA_DATA: u8 = 0x03;
/// Send absolute mouse data
pub const SEND_MS_ABS_DATA: u8 = 0x04;
/// Send relative mouse data
pub const SEND_MS_REL_DATA: u8 = 0x05;
/// Send custom HID data
pub const SEND_MY_HID_DATA: u8 = 0x06;
/// Restore factory default configuration
pub const SET_DEFAULT_CFG: u8 = 0x0C;
/// Software reset
pub const RESET: u8 = 0x0F;
}
/// Response command mask (success = cmd | 0x80, error = cmd | 0xC0)
const RESPONSE_SUCCESS_MASK: u8 = 0x80;
const RESPONSE_ERROR_MASK: u8 = 0xC0;
/// CH9329 error codes
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Ch9329Error {
/// Command executed successfully
Success = 0x00,
/// Serial receive timeout
Timeout = 0xE1,
/// Invalid packet header
InvalidHeader = 0xE2,
/// Invalid command code
InvalidCommand = 0xE3,
/// Checksum mismatch
ChecksumError = 0xE4,
/// Parameter error
ParameterError = 0xE5,
/// Execution failed
OperationFailed = 0xE6,
}
@@ -137,29 +96,17 @@ impl std::fmt::Display for Ch9329Error {
}
}
// ============================================================================
// Chip Information
// ============================================================================
/// CH9329 chip information
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChipInfo {
/// Chip version (e.g., "V1.0", "V1.1")
pub version: String,
/// Raw version byte
pub version_raw: u8,
/// USB connection status
pub usb_connected: bool,
/// Num Lock LED state
pub num_lock: bool,
/// Caps Lock LED state
pub caps_lock: bool,
/// Scroll Lock LED state
pub scroll_lock: bool,
}
impl ChipInfo {
/// Parse chip info from response data (8 bytes)
pub fn from_response(data: &[u8]) -> Option<Self> {
if data.len() < 8 {
return None;
@@ -181,7 +128,6 @@ impl ChipInfo {
}
}
/// Keyboard LED status
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct LedStatus {
pub num_lock: bool,
@@ -199,98 +145,21 @@ impl From<u8> for LedStatus {
}
}
// ============================================================================
// Configuration
// ============================================================================
/// CH9329 work mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
#[derive(Default)]
pub enum WorkMode {
/// Mode 0: Standard USB Keyboard + Mouse (default)
#[default]
KeyboardMouse = 0x00,
/// Mode 1: Standard USB Keyboard only
KeyboardOnly = 0x01,
/// Mode 2: Standard USB Mouse only
MouseOnly = 0x02,
/// Mode 3: Custom HID device
CustomHid = 0x03,
}
/// CH9329 serial communication mode
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
#[derive(Default)]
pub enum SerialMode {
/// Mode 0: Protocol transmission mode (default)
#[default]
Protocol = 0x00,
/// Mode 1: ASCII mode
Ascii = 0x01,
/// Mode 2: Transparent mode
Transparent = 0x02,
}
/// CH9329 configuration parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Ch9329Config {
/// Work mode
pub work_mode: WorkMode,
/// Serial communication mode
pub serial_mode: SerialMode,
/// Device address (0x00-0xFE, 0xFF = broadcast)
pub address: u8,
/// Baud rate
pub baud_rate: u32,
/// USB VID
pub vid: u16,
/// USB PID
pub pid: u16,
}
impl Default for Ch9329Config {
fn default() -> Self {
Self {
work_mode: WorkMode::KeyboardMouse,
serial_mode: SerialMode::Protocol,
address: 0x00,
baud_rate: 9600,
vid: 0x1A86,
pid: 0xE129,
}
}
}
// ============================================================================
// Response Parsing
// ============================================================================
/// Parsed response from CH9329
#[derive(Debug)]
pub struct Response {
/// Address byte
pub address: u8,
/// Command code (with response bits)
pub cmd: u8,
/// Data payload
pub data: Vec<u8>,
/// Whether this is an error response
pub is_error: bool,
/// Error code (if is_error)
pub error_code: Option<Ch9329Error>,
}
impl Response {
/// Parse a response from raw bytes
pub fn parse(bytes: &[u8]) -> Option<Self> {
// Minimum: Header(2) + Addr(1) + Cmd(1) + Len(1) + Sum(1) = 6
if bytes.len() < 6 {
return None;
}
// Check header
if bytes[0] != PACKET_HEADER[0] || bytes[1] != PACKET_HEADER[1] {
return None;
}
@@ -299,12 +168,10 @@ impl Response {
let cmd = bytes[3];
let len = bytes[4] as usize;
// Check if we have enough bytes
if bytes.len() < 5 + len + 1 {
return None;
}
// Verify checksum
let expected_checksum = bytes[5 + len];
let calculated_checksum = bytes[..5 + len]
.iter()
@@ -335,19 +202,13 @@ impl Response {
})
}
/// Check if the response indicates success
pub fn is_success(&self) -> bool {
!self.is_error && (self.data.is_empty() || self.data[0] == Ch9329Error::Success as u8)
}
}
/// Maximum packet size (header 2 + addr 1 + cmd 1 + len 1 + data 64 + checksum 1 = 70)
const MAX_PACKET_SIZE: usize = 70;
// ============================================================================
// CH9329 Backend Implementation
// ============================================================================
struct Ch9329RuntimeState {
initialized: AtomicBool,
online: AtomicBool,
@@ -424,47 +285,28 @@ enum WorkerCommand {
Shutdown,
}
/// CH9329 HID backend
pub struct Ch9329Backend {
/// Serial port path
port_path: String,
/// Baud rate
baud_rate: u32,
/// Worker command sender
worker_tx: Mutex<Option<mpsc::Sender<WorkerCommand>>>,
/// Background worker thread
worker_handle: Mutex<Option<thread::JoinHandle<()>>>,
/// Current keyboard state
keyboard_state: Mutex<KeyboardReport>,
/// Current mouse button state
mouse_buttons: AtomicU8,
/// Screen width for absolute mouse coordinate conversion
screen_width: u32,
/// Screen height for absolute mouse coordinate conversion
screen_height: u32,
/// Cached chip information
screen_resolution: RwLock<(u32, u32)>,
chip_info: Arc<RwLock<Option<ChipInfo>>>,
/// LED status cache
led_status: Arc<RwLock<LedStatus>>,
/// Device address (default 0x00)
address: u8,
/// Last absolute mouse X position (CH9329 coordinate: 0-4095)
last_abs_x: AtomicU16,
/// Last absolute mouse Y position (CH9329 coordinate: 0-4095)
last_abs_y: AtomicU16,
/// Whether relative mouse mode is active (set by incoming events)
relative_mouse_active: AtomicBool,
/// Shared runtime status updated only by the worker.
runtime: Arc<Ch9329RuntimeState>,
}
impl Ch9329Backend {
/// Create a new CH9329 backend with default baud rate (9600)
pub fn new(port_path: &str) -> Result<Self> {
Self::with_baud_rate(port_path, DEFAULT_BAUD_RATE)
}
/// Create a new CH9329 backend with custom baud rate
pub fn with_baud_rate(port_path: &str, baud_rate: u32) -> Result<Self> {
Ok(Self {
port_path: port_path.to_string(),
@@ -473,8 +315,7 @@ impl Ch9329Backend {
worker_handle: Mutex::new(None),
keyboard_state: Mutex::new(KeyboardReport::default()),
mouse_buttons: AtomicU8::new(0),
screen_width: 1920,
screen_height: 1080,
screen_resolution: RwLock::new((1920, 1080)),
chip_info: Arc::new(RwLock::new(None)),
led_status: Arc::new(RwLock::new(LedStatus::default())),
address: DEFAULT_ADDR,
@@ -489,12 +330,10 @@ impl Ch9329Backend {
self.runtime.set_error(reason, error_code);
}
/// Check if the serial port device file exists
pub fn check_port_exists(&self) -> bool {
std::path::Path::new(&self.port_path).exists()
}
/// Convert serialport error to HidError
fn serial_error_to_hid_error(e: serialport::Error, operation: &str) -> AppError {
let error_code = match e.kind() {
serialport::ErrorKind::NoDevice => "port_not_found",
@@ -518,16 +357,11 @@ impl Ch9329Backend {
}
}
/// Calculate checksum for CH9329 packet (sum of ALL bytes including header)
#[inline]
fn calculate_checksum(data: &[u8]) -> u8 {
data.iter().fold(0u8, |acc, &x| acc.wrapping_add(x))
}
/// Build a CH9329 packet into a stack-allocated buffer
///
/// Packet format: `[Header 0x57 0xAB] [Address] [Command] [Length] [Data] [Checksum]`
/// Returns the packet buffer and the actual length
#[inline]
fn build_packet_buf(address: u8, cmd: u8, data: &[u8]) -> ([u8; MAX_PACKET_SIZE], usize) {
debug_assert!(
@@ -539,25 +373,18 @@ impl Ch9329Backend {
let packet_len = 6 + data.len();
let mut packet = [0u8; MAX_PACKET_SIZE];
// Header (2 bytes)
packet[0] = PACKET_HEADER[0];
packet[1] = PACKET_HEADER[1];
// Address (1 byte)
packet[2] = address;
// Command (1 byte)
packet[3] = cmd;
// Length (1 byte) - data length only
packet[4] = len;
// Data (N bytes)
packet[5..5 + data.len()].copy_from_slice(data);
// Checksum (1 byte) - sum of ALL bytes including header
let checksum = Self::calculate_checksum(&packet[..5 + data.len()]);
packet[5 + data.len()] = checksum;
(packet, packet_len)
}
/// Build a CH9329 packet (legacy Vec version for compatibility)
fn build_packet(address: u8, cmd: u8, data: &[u8]) -> Vec<u8> {
let (buf, len) = Self::build_packet_buf(address, cmd, data);
buf[..len].to_vec()
@@ -984,10 +811,6 @@ impl Ch9329Backend {
}
}
// ============================================================================
// HidBackend Trait Implementation
// ============================================================================
#[async_trait]
impl HidBackend for Ch9329Backend {
async fn init(&self) -> Result<()> {
@@ -1060,7 +883,6 @@ impl HidBackend for Ch9329Backend {
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
let usb_key = event.key.to_hid_usage();
// Handle modifier keys separately
if event.key.is_modifier() {
let mut state = self.keyboard_state.lock();
@@ -1078,7 +900,6 @@ impl HidBackend for Ch9329Backend {
} else {
let mut state = self.keyboard_state.lock();
// Update modifiers from event
state.modifiers = event.modifiers.to_hid_byte();
match event.event_type {
@@ -1104,19 +925,15 @@ impl HidBackend for Ch9329Backend {
match event.event_type {
MouseEventType::Move => {
// Relative movement - send delta directly without inversion
self.relative_mouse_active.store(true, Ordering::Relaxed);
let dx = event.x.clamp(-127, 127) as i8;
let dy = event.y.clamp(-127, 127) as i8;
self.send_mouse_relative(buttons, dx, dy, 0)?;
}
MouseEventType::MoveAbs => {
// Absolute movement
self.relative_mouse_active.store(false, Ordering::Relaxed);
// Frontend sends 0-32767 (HID standard), CH9329 expects 0-4095
let x = ((event.x.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16;
let y = ((event.y.clamp(0, 32767) as u32) * CH9329_MOUSE_RESOLUTION / 32768) as u16;
// Store last absolute position for click events
self.last_abs_x.store(x, Ordering::Relaxed);
self.last_abs_y.store(y, Ordering::Relaxed);
self.send_mouse_absolute(buttons, x, y, 0)?;
@@ -1153,7 +970,6 @@ impl HidBackend for Ch9329Backend {
if self.relative_mouse_active.load(Ordering::Relaxed) {
self.send_mouse_relative(buttons, 0, 0, event.scroll)?;
} else {
// Use absolute mouse for scroll with last position
let x = self.last_abs_x.load(Ordering::Relaxed);
let y = self.last_abs_y.load(Ordering::Relaxed);
self.send_mouse_absolute(buttons, x, y, event.scroll)?;
@@ -1165,7 +981,6 @@ impl HidBackend for Ch9329Backend {
}
async fn reset(&self) -> Result<()> {
// Reset keyboard
{
let mut state = self.keyboard_state.lock();
state.clear();
@@ -1174,14 +989,12 @@ impl HidBackend for Ch9329Backend {
self.send_keyboard_report(&report)?;
}
// Reset mouse
self.mouse_buttons.store(0, Ordering::Relaxed);
self.last_abs_x.store(0, Ordering::Relaxed);
self.last_abs_y.store(0, Ordering::Relaxed);
self.relative_mouse_active.store(false, Ordering::Relaxed);
self.send_mouse_absolute(0, 0, 0, 0)?;
// Reset media keys
let _ = self.release_media_keys();
info!("CH9329 HID state reset");
@@ -1233,7 +1046,7 @@ impl HidBackend for Ch9329Backend {
kana: false,
}
},
screen_resolution: Some((self.screen_width, self.screen_height)),
screen_resolution: Some(*self.screen_resolution.read()),
device: Some(self.port_path.clone()),
error: error.as_ref().map(|(reason, _)| reason.clone()),
error_code: error.as_ref().map(|(_, code)| code.clone()),
@@ -1244,125 +1057,21 @@ impl HidBackend for Ch9329Backend {
self.runtime.subscribe()
}
fn set_screen_resolution(&mut self, width: u32, height: u32) {
self.screen_width = width;
self.screen_height = height;
fn set_screen_resolution(&self, width: u32, height: u32) {
*self.screen_resolution.write() = (width, height);
self.runtime.notify();
}
}
// ============================================================================
// Detection and Helpers
// ============================================================================
/// Detect CH9329 on common serial ports
pub fn detect_ch9329() -> Option<String> {
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
"/dev/ttyS0",
];
// Try multiple baud rates
let baud_rates = [9600, 115200];
for port_path in &common_ports {
if !std::path::Path::new(port_path).exists() {
continue;
}
for &baud_rate in &baud_rates {
if let Ok(mut port) = serialport::new(*port_path, baud_rate)
.timeout(Duration::from_millis(200))
.open()
{
// Build GET_INFO packet manually (address = 0x00)
let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03];
if port.write_all(&packet).is_ok() {
std::thread::sleep(Duration::from_millis(50));
let mut response = [0u8; 16];
if let Ok(n) = port.read(&mut response) {
// Check for valid CH9329 response header
if n >= 6
&& response[0] == PACKET_HEADER[0]
&& response[1] == PACKET_HEADER[1]
{
info!("CH9329 detected on {} @ {} baud", port_path, baud_rate);
return Some(port_path.to_string());
}
}
}
}
}
}
None
}
/// Detect CH9329 and return both path and working baud rate
pub fn detect_ch9329_with_baud() -> Option<(String, u32)> {
let common_ports = [
"/dev/ttyUSB0",
"/dev/ttyUSB1",
"/dev/ttyAMA0",
"/dev/serial0",
"/dev/ttyS0",
];
let baud_rates = [9600, 115200, 57600, 38400, 19200];
for port_path in &common_ports {
if !std::path::Path::new(port_path).exists() {
continue;
}
for &baud_rate in &baud_rates {
if let Ok(mut port) = serialport::new(*port_path, baud_rate)
.timeout(Duration::from_millis(200))
.open()
{
let packet = [0x57, 0xAB, 0x00, cmd::GET_INFO, 0x00, 0x03];
if port.write_all(&packet).is_ok() {
std::thread::sleep(Duration::from_millis(50));
let mut response = [0u8; 16];
if let Ok(n) = port.read(&mut response) {
if n >= 6
&& response[0] == PACKET_HEADER[0]
&& response[1] == PACKET_HEADER[1]
{
info!("CH9329 detected on {} @ {} baud", port_path, baud_rate);
return Some((port_path.to_string(), baud_rate));
}
}
}
}
}
}
None
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_packet_building() {
// Test GET_INFO packet (no data)
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::GET_INFO, &[]);
assert_eq!(packet, vec![0x57, 0xAB, 0x00, 0x01, 0x00, 0x03]);
// Test keyboard packet (8 bytes data)
let data = [0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]; // 'A' key
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_KB_GENERAL_DATA, &data);
@@ -1372,7 +1081,6 @@ mod tests {
assert_eq!(packet[3], cmd::SEND_KB_GENERAL_DATA); // Command
assert_eq!(packet[4], 8); // Length (8 data bytes)
assert_eq!(&packet[5..13], &data); // Data
// Checksum = 0x57 + 0xAB + 0x00 + 0x02 + 0x08 + 0x00 + 0x00 + 0x04 + ... = 0x10
let expected_checksum: u8 = packet[..13]
.iter()
.fold(0u8, |acc: u8, &x| acc.wrapping_add(x));
@@ -1381,7 +1089,6 @@ mod tests {
#[test]
fn test_relative_mouse_packet() {
// Test relative mouse: move right 50 pixels
let data = [0x01, 0x00, 50u8, 0x00, 0x00];
let packet = Ch9329Backend::build_packet(DEFAULT_ADDR, cmd::SEND_MS_REL_DATA, &data);
@@ -1397,12 +1104,10 @@ mod tests {
#[test]
fn test_checksum_calculation() {
// Known packet: GET_INFO
let packet = [0x57u8, 0xAB, 0x00, 0x01, 0x00];
let checksum = Ch9329Backend::calculate_checksum(&packet);
assert_eq!(checksum, 0x03);
// Known packet: Keyboard 'A' press
let packet = [
0x57u8, 0xAB, 0x00, 0x02, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
];
@@ -1412,7 +1117,6 @@ mod tests {
#[test]
fn test_response_parsing() {
// Valid GET_INFO response
let response_bytes = [
0x57, 0xAB, // Header
0x00, // Address
@@ -1422,9 +1126,7 @@ mod tests {
0xE0, // Checksum (calculated)
];
// Note: checksum in test is just placeholder, parse will validate
let _result = Response::parse(&response_bytes);
// This will fail because checksum doesn't match, but structure is tested
}
#[test]

View File

@@ -2,21 +2,17 @@
//!
//! Reference: USB HID Usage Tables 1.12, Section 15 (Consumer Page 0x0C)
/// Consumer Control Usage codes for multimedia keys
pub mod usage {
// Transport Controls
pub const PLAY_PAUSE: u16 = 0x00CD;
pub const STOP: u16 = 0x00B7;
pub const NEXT_TRACK: u16 = 0x00B5;
pub const PREV_TRACK: u16 = 0x00B6;
// Volume Controls
pub const MUTE: u16 = 0x00E2;
pub const VOLUME_UP: u16 = 0x00E9;
pub const VOLUME_DOWN: u16 = 0x00EA;
}
/// Check if a usage code is valid
pub fn is_valid_usage(usage: u16) -> bool {
matches!(
usage,

View File

@@ -42,23 +42,19 @@ use super::{
MouseEventType,
};
/// Message types
pub const MSG_KEYBOARD: u8 = 0x01;
pub const MSG_MOUSE: u8 = 0x02;
pub const MSG_CONSUMER: u8 = 0x03;
/// Keyboard event types
pub const KB_EVENT_DOWN: u8 = 0x00;
pub const KB_EVENT_UP: u8 = 0x01;
/// Mouse event types
pub const MS_EVENT_MOVE: u8 = 0x00;
pub const MS_EVENT_MOVE_ABS: u8 = 0x01;
pub const MS_EVENT_DOWN: u8 = 0x02;
pub const MS_EVENT_UP: u8 = 0x03;
pub const MS_EVENT_SCROLL: u8 = 0x04;
/// Parsed HID event from DataChannel
#[derive(Debug, Clone)]
pub enum HidChannelEvent {
Keyboard(KeyboardEvent),
@@ -66,7 +62,6 @@ pub enum HidChannelEvent {
Consumer(ConsumerEvent),
}
/// Parse a binary HID message from DataChannel
pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.is_empty() {
warn!("Empty HID message");
@@ -86,7 +81,6 @@ pub fn parse_hid_message(data: &[u8]) -> Option<HidChannelEvent> {
}
}
/// Parse keyboard message payload
fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 3 {
warn!("Keyboard message too short: {} bytes", data.len());
@@ -129,7 +123,6 @@ fn parse_keyboard_message(data: &[u8]) -> Option<HidChannelEvent> {
}))
}
/// Parse mouse message payload
fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 6 {
warn!("Mouse message too short: {} bytes", data.len());
@@ -148,11 +141,9 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
}
};
// Parse coordinates as i16 LE (works for both relative and absolute)
let x = i16::from_le_bytes([data[1], data[2]]) as i32;
let y = i16::from_le_bytes([data[3], data[4]]) as i32;
// Button or scroll delta
let (button, scroll) = match event_type {
MouseEventType::Down | MouseEventType::Up => {
let btn = match data[5] {
@@ -178,7 +169,6 @@ fn parse_mouse_message(data: &[u8]) -> Option<HidChannelEvent> {
}))
}
/// Parse consumer control message payload
fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
if data.len() < 2 {
warn!("Consumer message too short: {} bytes", data.len());
@@ -190,7 +180,6 @@ fn parse_consumer_message(data: &[u8]) -> Option<HidChannelEvent> {
Some(HidChannelEvent::Consumer(ConsumerEvent { usage }))
}
/// Encode a keyboard event to binary format (for sending to client if needed)
pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
let event_type = match event.event_type {
KeyEventType::Down => KB_EVENT_DOWN,
@@ -207,40 +196,6 @@ pub fn encode_keyboard_event(event: &KeyboardEvent) -> Vec<u8> {
]
}
/// Encode a mouse event to binary format (for sending to client if needed)
pub fn encode_mouse_event(event: &MouseEvent) -> Vec<u8> {
let event_type = match event.event_type {
MouseEventType::Move => MS_EVENT_MOVE,
MouseEventType::MoveAbs => MS_EVENT_MOVE_ABS,
MouseEventType::Down => MS_EVENT_DOWN,
MouseEventType::Up => MS_EVENT_UP,
MouseEventType::Scroll => MS_EVENT_SCROLL,
};
let x_bytes = (event.x as i16).to_le_bytes();
let y_bytes = (event.y as i16).to_le_bytes();
let extra = match event.event_type {
MouseEventType::Down | MouseEventType::Up => event
.button
.as_ref()
.map(|b| match b {
MouseButton::Left => 0u8,
MouseButton::Middle => 1u8,
MouseButton::Right => 2u8,
MouseButton::Back => 3u8,
MouseButton::Forward => 4u8,
})
.unwrap_or(0),
MouseEventType::Scroll => event.scroll as u8,
_ => 0,
};
vec![
MSG_MOUSE, event_type, x_bytes[0], x_bytes[1], y_bytes[0], y_bytes[1], extra,
]
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,10 +1,6 @@
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// Shared canonical keyboard key identifiers used across frontend and backend.
///
/// The enum names intentionally mirror `KeyboardEvent.code` style values so the
/// browser, virtual keyboard, and HID backend can all speak the same language.
#[typeshare]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CanonicalKey {
@@ -128,11 +124,6 @@ pub enum CanonicalKey {
}
impl CanonicalKey {
/// Convert the canonical key to a stable wire code.
///
/// The wire code intentionally matches the USB HID usage for keyboard page
/// keys so existing low-level behavior stays intact while the semantic type
/// becomes explicit.
pub const fn to_hid_usage(self) -> u8 {
match self {
Self::KeyA => 0x04,
@@ -255,7 +246,6 @@ impl CanonicalKey {
}
}
/// Convert a wire code / USB HID usage to its canonical key.
pub const fn from_hid_usage(usage: u8) -> Option<Self> {
match usage {
0x04 => Some(Self::KeyA),

View File

@@ -1,15 +1,4 @@
//! HID (Human Interface Device) control module
//!
//! This module provides keyboard and mouse control for remote KVM:
//! - USB OTG gadget mode (native Linux USB gadget)
//! - CH9329 serial HID controller
//!
//! Architecture:
//! ```text
//! Web Client --> WebSocket/DataChannel --> HID Events --> Backend --> Target PC
//! |
//! [OTG | CH9329]
//! ```
//! HID path: browser (WebSocket or WebRTC DataChannel) → queue → OTG gadget or CH9329.
pub mod backend;
pub mod ch9329;
@@ -20,51 +9,26 @@ pub mod otg;
pub mod types;
pub mod websocket;
pub use crate::events::LedState;
pub use backend::{HidBackend, HidBackendRuntimeSnapshot, HidBackendType};
pub use keyboard::CanonicalKey;
pub use otg::LedState;
pub use types::{
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardModifiers, MouseButton, MouseEvent,
MouseEventType,
};
/// HID backend information
#[derive(Debug, Clone)]
pub struct HidInfo {
/// Backend name
pub name: String,
/// Whether backend is initialized
pub initialized: bool,
/// Whether absolute mouse positioning is supported
pub supports_absolute_mouse: bool,
/// Screen resolution for absolute mouse
pub screen_resolution: Option<(u32, u32)>,
}
/// Unified HID runtime state used by snapshots and events.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HidRuntimeState {
/// Whether a backend is configured and expected to exist.
pub available: bool,
/// Stable backend key: "otg", "ch9329", "none".
pub backend: String,
/// Whether the backend is currently initialized and operational.
pub initialized: bool,
/// Whether the backend is currently online.
pub online: bool,
/// Whether absolute mouse positioning is supported.
pub supports_absolute_mouse: bool,
/// Whether keyboard LED/status feedback is enabled.
pub keyboard_leds_enabled: bool,
/// Last known keyboard LED state.
pub led_state: LedState,
/// Screen resolution for absolute mouse mode.
pub screen_resolution: Option<(u32, u32)>,
/// Device path associated with the backend, if any.
pub device: Option<String>,
/// Current user-facing error, if any.
pub error: Option<String>,
/// Current programmatic error code, if any.
pub error_code: Option<String>,
}
@@ -140,45 +104,29 @@ const HID_EVENT_QUEUE_CAPACITY: usize = 64;
const HID_EVENT_SEND_TIMEOUT_MS: u64 = 30;
#[derive(Debug)]
enum HidEvent {
enum QueuedHidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
Consumer(ConsumerEvent),
Reset,
}
/// HID controller managing keyboard and mouse input
pub struct HidController {
/// OTG Service reference (only used when backend is OTG)
otg_service: Option<Arc<OtgService>>,
/// Active backend
backend: Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
/// Backend type (mutable for reload)
backend_type: Arc<RwLock<HidBackendType>>,
/// Event bus for broadcasting state changes (optional)
events: Arc<tokio::sync::RwLock<Option<Arc<EventBus>>>>,
/// Unified HID runtime state.
runtime_state: Arc<RwLock<HidRuntimeState>>,
/// HID event queue sender (non-blocking)
hid_tx: mpsc::Sender<HidEvent>,
/// HID event queue receiver (moved into worker on first start)
hid_rx: Mutex<Option<mpsc::Receiver<HidEvent>>>,
/// Coalesced mouse move (latest)
hid_tx: mpsc::Sender<QueuedHidEvent>,
hid_rx: Mutex<Option<mpsc::Receiver<QueuedHidEvent>>>,
pending_move: Arc<parking_lot::Mutex<Option<MouseEvent>>>,
/// Pending move flag (fast path)
pending_move_flag: Arc<AtomicBool>,
/// Worker task handle
hid_worker: Mutex<Option<JoinHandle<()>>>,
/// Backend runtime subscription task handle
runtime_worker: Mutex<Option<JoinHandle<()>>>,
/// Backend initialization fast flag
backend_available: Arc<AtomicBool>,
}
impl HidController {
/// Create a new HID controller with specified backend
///
/// For OTG backend, otg_service should be provided to support hot-reload
pub fn new(backend_type: HidBackendType, otg_service: Option<Arc<OtgService>>) -> Self {
let (hid_tx, hid_rx) = mpsc::channel(HID_EVENT_QUEUE_CAPACITY);
Self {
@@ -199,12 +147,10 @@ impl HidController {
}
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: Arc<EventBus>) {
*self.events.write().await = Some(events);
}
/// Initialize the HID backend
pub async fn init(&self) -> Result<()> {
let backend_type = self.backend_type.read().await.clone();
let backend: Arc<dyn HidBackend> = match backend_type {
@@ -256,7 +202,6 @@ impl HidController {
*self.backend.write().await = Some(backend);
self.sync_runtime_state_from_backend().await;
// Start HID event worker (once)
self.start_event_worker().await;
self.restart_runtime_worker().await;
@@ -264,12 +209,10 @@ impl HidController {
Ok(())
}
/// Shutdown the HID backend and release resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down HID controller");
self.stop_runtime_worker().await;
// Close the backend
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down HID backend: {}", e);
@@ -290,17 +233,15 @@ impl HidController {
Ok(())
}
/// Send keyboard event
pub async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
"HID backend not available".to_string(),
));
}
self.enqueue_event(HidEvent::Keyboard(event)).await
self.enqueue_event(QueuedHidEvent::Keyboard(event)).await
}
/// Send mouse event
pub async fn send_mouse(&self, event: MouseEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
@@ -312,81 +253,55 @@ impl HidController {
event.event_type,
MouseEventType::Move | MouseEventType::MoveAbs
) {
// Best-effort: drop/merge move events if queue is full
self.enqueue_mouse_move(event)
} else {
self.enqueue_event(HidEvent::Mouse(event)).await
self.enqueue_event(QueuedHidEvent::Mouse(event)).await
}
}
/// Send consumer control event (multimedia keys)
pub async fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Err(AppError::BadRequest(
"HID backend not available".to_string(),
));
}
self.enqueue_event(HidEvent::Consumer(event)).await
self.enqueue_event(QueuedHidEvent::Consumer(event)).await
}
/// Reset all keys (release all pressed keys)
pub async fn reset(&self) -> Result<()> {
if !self.backend_available.load(Ordering::Acquire) {
return Ok(());
}
// Reset is important but best-effort; enqueue to avoid blocking
self.enqueue_event(HidEvent::Reset).await
self.enqueue_event(QueuedHidEvent::Reset).await
}
/// Check if backend is available
pub async fn is_available(&self) -> bool {
self.backend_available.load(Ordering::Acquire)
}
/// Get backend type
pub async fn backend_type(&self) -> HidBackendType {
self.backend_type.read().await.clone()
}
/// Get backend info
pub async fn info(&self) -> Option<HidInfo> {
let state = self.runtime_state.read().await.clone();
if !state.available {
return None;
}
Some(HidInfo {
name: state.backend,
initialized: state.initialized,
supports_absolute_mouse: state.supports_absolute_mouse,
screen_resolution: state.screen_resolution,
})
}
/// Get current HID runtime state snapshot.
pub async fn snapshot(&self) -> HidRuntimeState {
self.runtime_state.read().await.clone()
}
/// Reload the HID backend with new type
pub async fn reload(&self, new_backend_type: HidBackendType) -> Result<()> {
info!("Reloading HID backend: {:?}", new_backend_type);
self.backend_available.store(false, Ordering::Release);
self.stop_runtime_worker().await;
// Shutdown existing backend first
if let Some(backend) = self.backend.write().await.take() {
if let Err(e) = backend.shutdown().await {
warn!("Error shutting down old HID backend: {}", e);
}
}
// Create and initialize new backend
let new_backend: Option<Arc<dyn HidBackend>> = match new_backend_type {
HidBackendType::Otg => {
info!("Initializing OTG HID backend");
// Get OtgService reference
let otg_service = match self.otg_service.as_ref() {
Some(svc) => svc,
None => {
@@ -398,28 +313,25 @@ impl HidController {
};
match otg_service.hid_device_paths().await {
Some(handles) => {
// Create OtgBackend from handles
match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let backend = Arc::new(backend);
match backend.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(backend)
}
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
None
}
Some(handles) => match otg::OtgBackend::from_handles(handles) {
Ok(backend) => {
let backend = Arc::new(backend);
match backend.init().await {
Ok(_) => {
info!("OTG backend initialized successfully");
Some(backend)
}
Err(e) => {
warn!("Failed to initialize OTG backend: {}", e);
None
}
}
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
None
}
}
}
Err(e) => {
warn!("Failed to create OTG backend: {}", e);
None
}
},
None => {
warn!("OTG HID paths are not available");
None
@@ -470,7 +382,6 @@ impl HidController {
info!("HID backend reloaded successfully: {:?}", new_backend_type);
self.start_event_worker().await;
// Update backend_type on success
*self.backend_type.write().await = new_backend_type.clone();
self.sync_runtime_state_from_backend().await;
@@ -481,7 +392,6 @@ impl HidController {
warn!("HID backend reload resulted in no active backend");
self.backend_available.store(false, Ordering::Release);
// Update backend_type even on failure (to reflect the attempted change)
*self.backend_type.write().await = new_backend_type.clone();
let current = self.runtime_state.read().await.clone();
@@ -541,11 +451,10 @@ impl HidController {
process_hid_event(event, &backend).await;
// After each event, flush latest move if pending
if pending_move_flag.swap(false, Ordering::AcqRel) {
let move_event = { pending_move.lock().take() };
if let Some(move_event) = move_event {
process_hid_event(HidEvent::Mouse(move_event), &backend).await;
process_hid_event(QueuedHidEvent::Mouse(move_event), &backend).await;
}
}
}
@@ -595,7 +504,7 @@ impl HidController {
}
fn enqueue_mouse_move(&self, event: MouseEvent) -> Result<()> {
match self.hid_tx.try_send(HidEvent::Mouse(event.clone())) {
match self.hid_tx.try_send(QueuedHidEvent::Mouse(event.clone())) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(_)) => {
*self.pending_move.lock() = Some(event);
@@ -608,11 +517,10 @@ impl HidController {
}
}
async fn enqueue_event(&self, event: HidEvent) -> Result<()> {
async fn enqueue_event(&self, event: QueuedHidEvent) -> Result<()> {
match self.hid_tx.try_send(event) {
Ok(_) => Ok(()),
Err(mpsc::error::TrySendError::Full(ev)) => {
// For non-move events, wait briefly to avoid dropping critical input
let tx = self.hid_tx.clone();
let send_result = tokio::time::timeout(
Duration::from_millis(HID_EVENT_SEND_TIMEOUT_MS),
@@ -649,7 +557,10 @@ async fn apply_backend_runtime_state(
apply_runtime_state(runtime_state, events, next).await;
}
async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>) {
async fn process_hid_event(
event: QueuedHidEvent,
backend: &Arc<RwLock<Option<Arc<dyn HidBackend>>>>,
) {
let backend_opt = backend.read().await.clone();
let backend = match backend_opt {
Some(b) => b,
@@ -660,10 +571,10 @@ async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn
let result = tokio::task::spawn_blocking(move || {
futures::executor::block_on(async move {
match event {
HidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
HidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
HidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
HidEvent::Reset => backend_for_send.reset().await,
QueuedHidEvent::Keyboard(ev) => backend_for_send.send_keyboard(ev).await,
QueuedHidEvent::Mouse(ev) => backend_for_send.send_mouse(ev).await,
QueuedHidEvent::Consumer(ev) => backend_for_send.send_consumer(ev).await,
QueuedHidEvent::Reset => backend_for_send.reset().await,
}
})
})
@@ -682,12 +593,6 @@ async fn process_hid_event(event: HidEvent, backend: &Arc<RwLock<Option<Arc<dyn
}
}
impl Default for HidController {
fn default() -> Self {
Self::new(HidBackendType::None, None)
}
}
fn device_for_backend_type(backend_type: &HidBackendType) -> Option<String> {
match backend_type {
HidBackendType::Ch9329 { port, .. } => Some(port.clone()),

View File

@@ -1,28 +1,11 @@
//! OTG USB Gadget HID backend
//! Linux gadget HID: `/dev/hidg*` opened from [`crate::otg::OtgService`].
//! Typical nodes: hidg0 keyboard, hidg1 relative mouse, hidg2 absolute, hidg3 consumer control.
//!
//! This backend uses Linux USB Gadget API to emulate USB HID devices.
//! It opens the HID gadget device nodes created by `OtgService`.
//! Depending on the configured OTG profile, this may include:
//! - hidg0: Keyboard
//! - hidg1: Relative Mouse
//! - hidg2: Absolute Mouse
//! - hidg3: Consumer Control Keyboard
//!
//! Requirements:
//! - USB OTG/Device controller (UDC)
//! - ConfigFS with USB gadget support
//! - Root privileges for gadget setup
//!
//! Error Recovery:
//! This module implements automatic device reconnection based on PiKVM's approach.
//! When ESHUTDOWN or EAGAIN errors occur (common during MSD operations), the device
//! file handles are closed and reopened on the next operation.
//! See: https://github.com/raspberrypi/linux/issues/4373
//! Polled timed writes (JetKVM-style). Treat `ESHUTDOWN` (108) by closing handles and reopening; keep fd on `EAGAIN` (11). Host/gadget teardown during MSD resembles PiKVM. <https://github.com/raspberrypi/linux/issues/4373>
use async_trait::async_trait;
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::os::unix::fs::OpenOptionsExt;
@@ -40,9 +23,9 @@ use super::types::{
ConsumerEvent, KeyEventType, KeyboardEvent, KeyboardReport, MouseEvent, MouseEventType,
};
use crate::error::{AppError, Result};
use crate::events::LedState;
use crate::otg::{wait_for_hid_devices, HidDevicePaths};
/// Device type for ensure_device operations
#[derive(Debug, Clone, Copy)]
enum DeviceType {
Keyboard,
@@ -51,23 +34,7 @@ enum DeviceType {
ConsumerControl,
}
/// Keyboard LED state
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct LedState {
/// Num Lock LED
pub num_lock: bool,
/// Caps Lock LED
pub caps_lock: bool,
/// Scroll Lock LED
pub scroll_lock: bool,
/// Compose LED
pub compose: bool,
/// Kana LED
pub kana: bool,
}
impl LedState {
/// Create from raw byte
pub fn from_byte(b: u8) -> Self {
Self {
num_lock: b & 0x01 != 0,
@@ -78,7 +45,6 @@ impl LedState {
}
}
/// Convert to raw byte
pub fn to_byte(&self) -> u8 {
let mut b = 0u8;
if self.num_lock {
@@ -100,76 +66,37 @@ impl LedState {
}
}
/// OTG HID backend with 4 devices
///
/// This backend opens HID device files created by OtgService.
/// It does NOT manage the USB gadget itself - that's handled by OtgService.
///
/// ## Error Recovery
///
/// Based on PiKVM's implementation, this backend automatically handles:
/// - EAGAIN (errno 11): Resource temporarily unavailable - just retry later, don't close device
/// - ESHUTDOWN (errno 108): Transport endpoint shutdown - close and reopen device
///
/// When ESHUTDOWN occurs, the device file handle is closed and will be
/// reopened on the next operation attempt.
/// Opens `/dev/hidg*` nodes provisioned by `OtgService`; gadget lifecycle is not handled here.
pub struct OtgBackend {
/// Keyboard device path (/dev/hidg0)
keyboard_path: Option<PathBuf>,
/// Relative mouse device path (/dev/hidg1)
mouse_rel_path: Option<PathBuf>,
/// Absolute mouse device path (/dev/hidg2)
mouse_abs_path: Option<PathBuf>,
/// Consumer control device path (/dev/hidg3)
consumer_path: Option<PathBuf>,
/// Keyboard device file
keyboard_dev: Mutex<Option<File>>,
/// Relative mouse device file
mouse_rel_dev: Mutex<Option<File>>,
/// Absolute mouse device file
mouse_abs_dev: Mutex<Option<File>>,
/// Consumer control device file
consumer_dev: Mutex<Option<File>>,
/// Whether keyboard LED/status feedback is enabled.
keyboard_leds_enabled: bool,
/// Current keyboard state
keyboard_state: Mutex<KeyboardReport>,
/// Current mouse button state
mouse_buttons: AtomicU8,
/// Last known LED state (using parking_lot::RwLock for sync access)
led_state: Arc<parking_lot::RwLock<LedState>>,
/// Screen resolution for absolute mouse (using parking_lot::RwLock for sync access)
screen_resolution: parking_lot::RwLock<Option<(u32, u32)>>,
/// UDC name for state checking (e.g., "fcc00000.usb")
udc_name: Arc<parking_lot::RwLock<Option<String>>>,
/// Whether the backend has been initialized.
initialized: AtomicBool,
/// Whether the device is currently online (UDC configured and devices accessible)
online: AtomicBool,
/// Last backend error state.
last_error: parking_lot::RwLock<Option<(String, String)>>,
/// Last error log time for throttling (using parking_lot for sync)
last_error_log: parking_lot::Mutex<std::time::Instant>,
/// Error count since last successful operation (for log throttling)
error_count: AtomicU8,
/// Consecutive EAGAIN count (for offline threshold detection)
eagain_count: AtomicU8,
/// Runtime change notifier.
runtime_notify_tx: watch::Sender<()>,
/// Runtime monitor stop flag.
runtime_worker_stop: Arc<AtomicBool>,
/// Runtime monitor thread.
runtime_worker: Mutex<Option<thread::JoinHandle<()>>>,
}
/// Write timeout in milliseconds (same as JetKVM's hidWriteTimeout)
const HID_WRITE_TIMEOUT_MS: i32 = 20;
impl OtgBackend {
/// Create OTG backend from device paths provided by OtgService
///
/// This is the ONLY way to create an OtgBackend - it no longer manages
/// the USB gadget itself. The gadget must already be set up by OtgService.
/// Gadget must already exist; paths come from `OtgService`.
pub fn from_handles(paths: HidDevicePaths) -> Result<Self> {
let (runtime_notify_tx, _runtime_notify_rx) = watch::channel(());
Ok(Self {
@@ -234,7 +161,6 @@ impl OtgBackend {
}
}
/// Log throttled error message (max once per second)
fn log_throttled_error(&self, msg: &str) {
let mut last_log = self.last_error_log.lock();
let now = std::time::Instant::now();
@@ -251,24 +177,17 @@ impl OtgBackend {
}
}
/// Reset error count on successful operation
fn reset_error_count(&self) {
self.error_count.store(0, Ordering::Relaxed);
// Also reset EAGAIN count - successful operation means device is working
self.eagain_count.store(0, Ordering::Relaxed);
}
/// Write data to HID device with timeout (JetKVM style)
///
/// Uses poll() to wait for device to be ready for writing.
/// If timeout expires, silently drops the data (acceptable for mouse movement).
/// Returns Ok(true) if write succeeded, Ok(false) if timed out (silently dropped).
/// Poll-based write with `HID_WRITE_TIMEOUT_MS`; timeout → drop (JetKVM-style).
fn write_with_timeout(&self, file: &mut File, data: &[u8]) -> std::io::Result<bool> {
let mut pollfd = [PollFd::new(file.as_fd(), PollFlags::POLLOUT)];
match poll(&mut pollfd, PollTimeout::from(HID_WRITE_TIMEOUT_MS as u16)) {
Ok(1) => {
// Device ready, check for errors
if let Some(revents) = pollfd[0].revents() {
if revents.contains(PollFlags::POLLERR) || revents.contains(PollFlags::POLLHUP)
{
@@ -278,12 +197,10 @@ impl OtgBackend {
));
}
}
// Write the data
file.write_all(data)?;
Ok(true)
}
Ok(0) => {
// Timeout - silently drop (JetKVM behavior)
trace!("HID write timeout, dropping data");
Ok(false)
}
@@ -292,7 +209,6 @@ impl OtgBackend {
}
}
/// Set the UDC name for state checking
pub fn set_udc_name(&self, udc: &str) {
*self.udc_name.write() = Some(udc.to_string());
}
@@ -324,15 +240,11 @@ impl OtgBackend {
}
}
/// Check if the UDC is in "configured" state
///
/// This is based on PiKVM's `__is_udc_configured()` method.
/// The UDC state file indicates whether the USB host has enumerated and configured the gadget.
/// `true` when `/sys/class/udc/<name>/state` reads `configured` (PiKVM-style).
pub fn is_udc_configured(&self) -> bool {
Self::read_udc_configured(&self.udc_name)
}
/// Find the first available UDC
fn find_udc() -> Option<String> {
let udc_path = PathBuf::from("/sys/class/udc");
if let Ok(entries) = fs::read_dir(&udc_path) {
@@ -345,12 +257,7 @@ impl OtgBackend {
None
}
/// Ensure a device is open and ready for I/O
///
/// This method is based on PiKVM's `__ensure_device()` pattern:
/// 1. Check if device path exists, close handle if not
/// 2. If handle is None but path exists, reopen the device
/// 3. Return whether the device is ready for I/O
/// PiKVM-style: drop handle if node missing; reopen when path reappears.
fn ensure_device(&self, device_type: DeviceType) -> Result<()> {
let (path_opt, dev_mutex) = match device_type {
DeviceType::Keyboard => (&self.keyboard_path, &self.keyboard_dev),
@@ -372,9 +279,7 @@ impl OtgBackend {
}
};
// Check if device path exists
if !path.exists() {
// Close the device if open (device was removed)
let mut dev = dev_mutex.lock();
if dev.is_some() {
debug!(
@@ -392,7 +297,6 @@ impl OtgBackend {
});
}
// If device is not open, try to open it
let mut dev = dev_mutex.lock();
if dev.is_none() {
match Self::open_device(path) {
@@ -415,7 +319,6 @@ impl OtgBackend {
Ok(())
}
/// Open a HID device file with read/write access
fn open_device(path: &PathBuf) -> Result<File> {
OpenOptions::new()
.read(true)
@@ -431,16 +334,15 @@ impl OtgBackend {
})
}
/// Convert I/O error to HidError with appropriate error code
fn io_error_code(e: &std::io::Error) -> &'static str {
match e.raw_os_error() {
Some(32) => "epipe", // EPIPE - broken pipe
Some(108) => "eshutdown", // ESHUTDOWN - transport endpoint shutdown
Some(11) => "eagain", // EAGAIN - resource temporarily unavailable
Some(6) => "enxio", // ENXIO - no such device or address
Some(19) => "enodev", // ENODEV - no such device
Some(5) => "eio", // EIO - I/O error
Some(2) => "enoent", // ENOENT - no such file or directory
Some(32) => "epipe",
Some(108) => "eshutdown",
Some(11) => "eagain",
Some(6) => "enxio",
Some(19) => "enodev",
Some(5) => "eio",
Some(2) => "enoent",
_ => "io_error",
}
}
@@ -455,7 +357,6 @@ impl OtgBackend {
}
}
/// Check if all HID device files exist
pub fn check_devices_exist(&self) -> bool {
self.keyboard_path.as_ref().is_none_or(|p| p.exists())
&& self.mouse_rel_path.as_ref().is_none_or(|p| p.exists())
@@ -463,7 +364,6 @@ impl OtgBackend {
&& self.consumer_path.as_ref().is_none_or(|p| p.exists())
}
/// Get list of missing device paths
pub fn get_missing_devices(&self) -> Vec<String> {
let mut missing = Vec::new();
if let Some(ref path) = self.keyboard_path {
@@ -484,17 +384,11 @@ impl OtgBackend {
missing
}
/// Send keyboard report (8 bytes)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// Uses write_with_timeout to avoid blocking on busy devices.
fn send_keyboard_report(&self, report: &KeyboardReport) -> Result<()> {
if self.keyboard_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::Keyboard)?;
let mut dev = self.keyboard_dev.lock();
@@ -508,7 +402,6 @@ impl OtgBackend {
Ok(())
}
Ok(false) => {
// Timeout - silently dropped (JetKVM behavior)
self.log_throttled_error("HID keyboard write timeout, dropped");
Ok(())
}
@@ -517,7 +410,6 @@ impl OtgBackend {
match error_code {
Some(108) => {
// ESHUTDOWN - endpoint closed, need to reopen device
self.eagain_count.store(0, Ordering::Relaxed);
debug!("Keyboard ESHUTDOWN, closing for recovery");
*dev = None;
@@ -531,7 +423,6 @@ impl OtgBackend {
))
}
Some(11) => {
// EAGAIN after poll - should be rare, silently drop
trace!("Keyboard EAGAIN after poll, dropping");
Ok(())
}
@@ -559,17 +450,11 @@ impl OtgBackend {
}
}
/// Send relative mouse report (4 bytes: buttons, dx, dy, wheel)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// Uses write_with_timeout to avoid blocking on busy devices.
fn send_mouse_report_relative(&self, buttons: u8, dx: i8, dy: i8, wheel: i8) -> Result<()> {
if self.mouse_rel_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::MouseRelative)?;
let mut dev = self.mouse_rel_dev.lock();
@@ -582,10 +467,7 @@ impl OtgBackend {
trace!("Sent relative mouse report: {:02X?}", data);
Ok(())
}
Ok(false) => {
// Timeout - silently dropped (JetKVM behavior)
Ok(())
}
Ok(false) => Ok(()),
Err(e) => {
let error_code = e.raw_os_error();
@@ -603,10 +485,7 @@ impl OtgBackend {
"Failed to write mouse report",
))
}
Some(11) => {
// EAGAIN after poll - should be rare, silently drop
Ok(())
}
Some(11) => Ok(()),
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Relative mouse write error: {}", e);
@@ -631,17 +510,11 @@ impl OtgBackend {
}
}
/// Send absolute mouse report (6 bytes: buttons, x_lo, x_hi, y_lo, y_hi, wheel)
///
/// This method ensures the device is open before writing, and handles
/// ESHUTDOWN errors by closing the device handle for later reconnection.
/// Uses write_with_timeout to avoid blocking on busy devices.
fn send_mouse_report_absolute(&self, buttons: u8, x: u16, y: u16, wheel: i8) -> Result<()> {
if self.mouse_abs_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::MouseAbsolute)?;
let mut dev = self.mouse_abs_dev.lock();
@@ -660,10 +533,7 @@ impl OtgBackend {
self.reset_error_count();
Ok(())
}
Ok(false) => {
// Timeout - silently dropped (JetKVM behavior)
Ok(())
}
Ok(false) => Ok(()),
Err(e) => {
let error_code = e.raw_os_error();
@@ -681,10 +551,7 @@ impl OtgBackend {
"Failed to write mouse report",
))
}
Some(11) => {
// EAGAIN after poll - should be rare, silently drop
Ok(())
}
Some(11) => Ok(()),
_ => {
self.eagain_count.store(0, Ordering::Relaxed);
warn!("Absolute mouse write error: {}", e);
@@ -709,35 +576,27 @@ impl OtgBackend {
}
}
/// Send consumer control report (2 bytes: usage_lo, usage_hi)
///
/// Sends a consumer control usage code and then releases it (sends 0x0000).
/// Press (`usage`) then release (`0x0000`).
fn send_consumer_report(&self, usage: u16) -> Result<()> {
if self.consumer_path.is_none() {
return Ok(());
}
// Ensure device is ready
self.ensure_device(DeviceType::ConsumerControl)?;
let mut dev = self.consumer_dev.lock();
if let Some(ref mut file) = *dev {
// Send the usage code
let data = [(usage & 0xFF) as u8, (usage >> 8) as u8];
match self.write_with_timeout(file, &data) {
Ok(true) => {
trace!("Sent consumer report: {:02X?}", data);
// Send release (0x0000)
let release = [0u8, 0u8];
let _ = self.write_with_timeout(file, &release);
self.mark_online();
self.reset_error_count();
Ok(())
}
Ok(false) => {
// Timeout - silently dropped
Ok(())
}
Ok(false) => Ok(()),
Err(e) => {
let error_code = e.raw_os_error();
match error_code {
@@ -753,10 +612,7 @@ impl OtgBackend {
"Failed to write consumer report",
))
}
Some(11) => {
// EAGAIN after poll - silently drop
Ok(())
}
Some(11) => Ok(()),
_ => {
warn!("Consumer control write error: {}", e);
self.record_error(
@@ -780,12 +636,10 @@ impl OtgBackend {
}
}
/// Send consumer control event
pub fn send_consumer(&self, event: ConsumerEvent) -> Result<()> {
self.send_consumer_report(event.usage)
}
/// Get last known LED state
pub fn led_state(&self) -> LedState {
*self.led_state.read()
}
@@ -973,19 +827,17 @@ impl OtgBackend {
#[async_trait]
impl HidBackend for OtgBackend {
async fn init(&self) -> Result<()> {
info!("Initializing OTG HID backend");
debug!("Initializing OTG HID backend");
// Auto-detect UDC name for state checking only if OtgService did not provide one
if self.udc_name.read().is_none() {
if let Some(udc) = Self::find_udc() {
info!("Auto-detected UDC: {}", udc);
debug!("Auto-detected UDC: {}", udc);
self.set_udc_name(&udc);
}
} else if let Some(udc) = self.udc_name.read().clone() {
info!("Using configured UDC: {}", udc);
debug!("Using configured UDC: {}", udc);
}
// Wait for devices to appear (they should already exist from OtgService)
let mut device_paths = Vec::new();
if let Some(ref path) = self.keyboard_path {
device_paths.push(path.clone());
@@ -1010,51 +862,46 @@ impl HidBackend for OtgBackend {
return Err(AppError::Internal("HID devices did not appear".into()));
}
// Open keyboard device
if let Some(ref path) = self.keyboard_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.keyboard_dev.lock() = Some(file);
info!("Keyboard device opened: {}", path.display());
debug!("Keyboard device opened: {}", path.display());
} else {
warn!("Keyboard device not found: {}", path.display());
}
}
// Open relative mouse device
if let Some(ref path) = self.mouse_rel_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.mouse_rel_dev.lock() = Some(file);
info!("Relative mouse device opened: {}", path.display());
debug!("Relative mouse device opened: {}", path.display());
} else {
warn!("Relative mouse device not found: {}", path.display());
}
}
// Open absolute mouse device
if let Some(ref path) = self.mouse_abs_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.mouse_abs_dev.lock() = Some(file);
info!("Absolute mouse device opened: {}", path.display());
debug!("Absolute mouse device opened: {}", path.display());
} else {
warn!("Absolute mouse device not found: {}", path.display());
}
}
// Open consumer control device (optional, may not exist on older setups)
if let Some(ref path) = self.consumer_path {
if path.exists() {
let file = Self::open_device(path)?;
*self.consumer_dev.lock() = Some(file);
info!("Consumer control device opened: {}", path.display());
debug!("Consumer control device opened: {}", path.display());
} else {
debug!("Consumer control device not found: {}", path.display());
}
}
// Mark as online if all devices opened successfully
self.initialized.store(true, Ordering::Relaxed);
self.notify_runtime_changed();
self.start_runtime_worker();
@@ -1066,7 +913,6 @@ impl HidBackend for OtgBackend {
async fn send_keyboard(&self, event: KeyboardEvent) -> Result<()> {
let usb_key = event.key.to_hid_usage();
// Handle modifier keys separately
if event.key.is_modifier() {
let mut state = self.keyboard_state.lock();
@@ -1084,7 +930,6 @@ impl HidBackend for OtgBackend {
} else {
let mut state = self.keyboard_state.lock();
// Update modifiers from event
state.modifiers = event.modifiers.to_hid_byte();
match event.event_type {
@@ -1110,15 +955,12 @@ impl HidBackend for OtgBackend {
match event.event_type {
MouseEventType::Move => {
// Relative movement - use hidg1
let dx = event.x.clamp(-127, 127) as i8;
let dy = event.y.clamp(-127, 127) as i8;
self.send_mouse_report_relative(buttons, dx, dy, 0)?;
}
MouseEventType::MoveAbs => {
// Absolute movement - use hidg2
// Frontend sends 0-32767 range directly (standard HID absolute mouse range)
// Don't send button state with move - buttons are handled separately on relative device
// Coordinates 032767; buttons are sent only on the relative endpoint.
let x = event.x.clamp(0, 32767) as u16;
let y = event.y.clamp(0, 32767) as u16;
self.send_mouse_report_absolute(0, x, y, 0)?;
@@ -1127,7 +969,6 @@ impl HidBackend for OtgBackend {
if let Some(button) = event.button {
let bit = button.to_hid_bit();
let new_buttons = self.mouse_buttons.fetch_or(bit, Ordering::Relaxed) | bit;
// Send on relative device for button clicks
self.send_mouse_report_relative(new_buttons, 0, 0, 0)?;
}
}
@@ -1147,7 +988,6 @@ impl HidBackend for OtgBackend {
}
async fn reset(&self) -> Result<()> {
// Reset keyboard
{
let mut state = self.keyboard_state.lock();
state.clear();
@@ -1156,7 +996,6 @@ impl HidBackend for OtgBackend {
self.send_keyboard_report(&report)?;
}
// Reset mouse
self.mouse_buttons.store(0, Ordering::Relaxed);
self.send_mouse_report_relative(0, 0, 0, 0)?;
self.send_mouse_report_absolute(0, 0, 0, 0)?;
@@ -1168,16 +1007,13 @@ impl HidBackend for OtgBackend {
async fn shutdown(&self) -> Result<()> {
self.stop_runtime_worker();
// Reset before closing
self.reset().await?;
// Close devices
*self.keyboard_dev.lock() = None;
*self.mouse_rel_dev.lock() = None;
*self.mouse_abs_dev.lock() = None;
*self.consumer_dev.lock() = None;
// Gadget cleanup is handled by OtgService, not here
self.initialized.store(false, Ordering::Relaxed);
self.online.store(false, Ordering::Relaxed);
self.clear_error();
@@ -1199,31 +1035,18 @@ impl HidBackend for OtgBackend {
self.send_consumer_report(event.usage)
}
fn set_screen_resolution(&mut self, width: u32, height: u32) {
fn set_screen_resolution(&self, width: u32, height: u32) {
*self.screen_resolution.write() = Some((width, height));
self.notify_runtime_changed();
}
}
/// Check if OTG HID gadget is available
pub fn is_otg_available() -> bool {
// Check for existing HID devices (they should be created by OtgService)
let kb = PathBuf::from("/dev/hidg0");
let mouse_rel = PathBuf::from("/dev/hidg1");
let mouse_abs = PathBuf::from("/dev/hidg2");
kb.exists() || mouse_rel.exists() || mouse_abs.exists()
}
/// Implement Drop for OtgBackend to close device files
impl Drop for OtgBackend {
fn drop(&mut self) {
self.runtime_worker_stop.store(true, Ordering::Relaxed);
if let Some(handle) = self.runtime_worker.get_mut().take() {
let _ = handle.join();
}
// Close device files
// Note: Gadget cleanup is handled by OtgService, not here
*self.keyboard_dev.lock() = None;
*self.mouse_rel_dev.lock() = None;
*self.mouse_abs_dev.lock() = None;
@@ -1236,12 +1059,6 @@ impl Drop for OtgBackend {
mod tests {
use super::*;
#[test]
fn test_otg_availability_check() {
// This just tests the function runs without panicking
let _available = is_otg_available();
}
#[test]
fn test_led_state() {
let state = LedState::from_byte(0b00000011);
@@ -1254,7 +1071,6 @@ mod tests {
#[test]
fn test_report_sizes() {
// Keyboard report is 8 bytes
let kb_report = KeyboardReport::default();
assert_eq!(kb_report.to_bytes().len(), 8);
}

View File

@@ -1,50 +1,37 @@
//! HID event types for keyboard and mouse
//! Keyboard/mouse/consumer structs (`KeyboardEvent`, `MouseEvent`, …).
use serde::{Deserialize, Serialize};
use super::keyboard::CanonicalKey;
/// Keyboard event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum KeyEventType {
/// Key pressed down
Down,
/// Key released
Up,
}
/// Keyboard modifier flags
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeyboardModifiers {
/// Left Control
#[serde(default)]
pub left_ctrl: bool,
/// Left Shift
#[serde(default)]
pub left_shift: bool,
/// Left Alt
#[serde(default)]
pub left_alt: bool,
/// Left Meta (Windows/Super key)
#[serde(default)]
pub left_meta: bool,
/// Right Control
#[serde(default)]
pub right_ctrl: bool,
/// Right Shift
#[serde(default)]
pub right_shift: bool,
/// Right Alt (AltGr)
#[serde(default)]
pub right_alt: bool,
/// Right Meta
#[serde(default)]
pub right_meta: bool,
}
impl KeyboardModifiers {
/// Convert to USB HID modifier byte
pub fn to_hid_byte(&self) -> u8 {
let mut byte = 0u8;
if self.left_ctrl {
@@ -74,7 +61,6 @@ impl KeyboardModifiers {
byte
}
/// Create from USB HID modifier byte
pub fn from_hid_byte(byte: u8) -> Self {
Self {
left_ctrl: byte & 0x01 != 0,
@@ -88,7 +74,6 @@ impl KeyboardModifiers {
}
}
/// Check if any modifier is active
pub fn any(&self) -> bool {
self.left_ctrl
|| self.left_shift
@@ -101,21 +86,16 @@ impl KeyboardModifiers {
}
}
/// Keyboard event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyboardEvent {
/// Event type (down/up)
#[serde(rename = "type")]
pub event_type: KeyEventType,
/// Canonical keyboard key identifier shared across frontend and backend
pub key: CanonicalKey,
/// Modifier keys state
#[serde(default)]
pub modifiers: KeyboardModifiers,
}
impl KeyboardEvent {
/// Create a key down event
pub fn key_down(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Down,
@@ -124,7 +104,6 @@ impl KeyboardEvent {
}
}
/// Create a key up event
pub fn key_up(key: CanonicalKey, modifiers: KeyboardModifiers) -> Self {
Self {
event_type: KeyEventType::Up,
@@ -134,7 +113,6 @@ impl KeyboardEvent {
}
}
/// Mouse button
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseButton {
@@ -146,7 +124,6 @@ pub enum MouseButton {
}
impl MouseButton {
/// Convert to USB HID button bit
pub fn to_hid_bit(&self) -> u8 {
match self {
MouseButton::Left => 0x01,
@@ -158,44 +135,31 @@ impl MouseButton {
}
}
/// Mouse event type
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MouseEventType {
/// Mouse moved (relative movement)
Move,
/// Mouse moved (absolute position)
MoveAbs,
/// Button pressed
Down,
/// Button released
Up,
/// Mouse wheel scroll
Scroll,
}
/// Mouse event
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MouseEvent {
/// Event type
#[serde(rename = "type")]
pub event_type: MouseEventType,
/// X coordinate or delta
#[serde(default)]
pub x: i32,
/// Y coordinate or delta
#[serde(default)]
pub y: i32,
/// Button (for down/up events)
#[serde(default)]
pub button: Option<MouseButton>,
/// Scroll delta (for scroll events)
#[serde(default)]
pub scroll: i8,
}
impl MouseEvent {
/// Create a relative move event
pub fn move_rel(dx: i32, dy: i32) -> Self {
Self {
event_type: MouseEventType::Move,
@@ -206,7 +170,6 @@ impl MouseEvent {
}
}
/// Create an absolute move event
pub fn move_abs(x: i32, y: i32) -> Self {
Self {
event_type: MouseEventType::MoveAbs,
@@ -217,7 +180,6 @@ impl MouseEvent {
}
}
/// Create a button down event
pub fn button_down(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Down,
@@ -228,7 +190,6 @@ impl MouseEvent {
}
}
/// Create a button up event
pub fn button_up(button: MouseButton) -> Self {
Self {
event_type: MouseEventType::Up,
@@ -239,7 +200,6 @@ impl MouseEvent {
}
}
/// Create a scroll event
pub fn scroll(delta: i8) -> Self {
Self {
event_type: MouseEventType::Scroll,
@@ -251,35 +211,19 @@ impl MouseEvent {
}
}
/// Combined HID event (keyboard or mouse)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "device", rename_all = "lowercase")]
pub enum HidEvent {
Keyboard(KeyboardEvent),
Mouse(MouseEvent),
Consumer(ConsumerEvent),
}
/// Consumer control event (multimedia keys)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsumerEvent {
/// Consumer control usage code (e.g., 0x00CD for Play/Pause)
pub usage: u16,
}
/// USB HID keyboard report (8 bytes)
#[derive(Debug, Clone, Default)]
pub struct KeyboardReport {
/// Modifier byte
pub modifiers: u8,
/// Reserved byte
pub reserved: u8,
/// Key codes (up to 6 simultaneous keys)
pub keys: [u8; 6],
}
impl KeyboardReport {
/// Convert to bytes for USB HID
pub fn to_bytes(&self) -> [u8; 8] {
[
self.modifiers,
@@ -293,7 +237,6 @@ impl KeyboardReport {
]
}
/// Add a key to the report
pub fn add_key(&mut self, key: u8) -> bool {
for slot in &mut self.keys {
if *slot == 0 {
@@ -304,56 +247,21 @@ impl KeyboardReport {
false // All slots full
}
/// Remove a key from the report
pub fn remove_key(&mut self, key: u8) {
for slot in &mut self.keys {
if *slot == key {
*slot = 0;
}
}
// Compact the array
self.keys.sort_by(|a, b| b.cmp(a));
}
/// Clear all keys
pub fn clear(&mut self) {
self.modifiers = 0;
self.keys = [0; 6];
}
}
/// USB HID mouse report
#[derive(Debug, Clone, Default)]
pub struct MouseReport {
/// Button state
pub buttons: u8,
/// X movement (-127 to 127)
pub x: i8,
/// Y movement (-127 to 127)
pub y: i8,
/// Wheel movement (-127 to 127)
pub wheel: i8,
}
impl MouseReport {
/// Convert to bytes for USB HID (relative mouse)
pub fn to_bytes_relative(&self) -> [u8; 4] {
[self.buttons, self.x as u8, self.y as u8, self.wheel as u8]
}
/// Convert to bytes for USB HID (absolute mouse)
pub fn to_bytes_absolute(&self, x: u16, y: u16) -> [u8; 6] {
[
self.buttons,
(x & 0xFF) as u8,
(x >> 8) as u8,
(y & 0xFF) as u8,
(y >> 8) as u8,
self.wheel as u8,
]
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,13 +1,4 @@
//! WebSocket HID channel for HTTP/MJPEG mode
//!
//! This provides an alternative to WebRTC DataChannel for HID input
//! when using MJPEG streaming mode.
//!
//! Uses binary protocol only (same format as DataChannel):
//! - Keyboard: [0x01, event_type, key, modifiers] (4 bytes)
//! - Mouse: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, button/scroll] (7 bytes)
//!
//! See datachannel.rs for detailed protocol specification.
//! MJPEG mode: HID over WebSocket — same binary framing as [`super::datachannel`] (`0x01`/`0x02`/`0x03`; layout detailed there).
use axum::{
extract::{
@@ -24,25 +15,20 @@ use super::datachannel::{parse_hid_message, HidChannelEvent};
use crate::state::AppState;
use crate::utils::LogThrottler;
/// Binary response codes
const RESP_OK: u8 = 0x00;
const RESP_ERR_HID_UNAVAILABLE: u8 = 0x01;
const RESP_ERR_INVALID_MESSAGE: u8 = 0x02;
/// WebSocket HID upgrade handler
pub async fn ws_hid_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
ws.on_upgrade(move |socket| handle_hid_socket(socket, state))
}
/// Handle HID WebSocket connection
async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
// Log throttler for error messages (5 second interval)
let log_throttler = LogThrottler::with_secs(5);
info!("WebSocket HID connection established (binary protocol)");
// Check if HID controller is available and send initial status
let hid_available = state.hid.is_available().await;
let initial_response = if hid_available {
vec![RESP_OK]
@@ -59,17 +45,14 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
return;
}
// Process incoming messages (binary only)
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Check HID availability before processing each message
let hid_available = state.hid.is_available().await;
if !hid_available {
if log_throttler.should_log("hid_unavailable") {
warn!("HID controller not available, ignoring message");
}
// Send error response (optional, for client awareness)
let _ = sender
.send(Message::Binary(vec![RESP_ERR_HID_UNAVAILABLE].into()))
.await;
@@ -77,15 +60,12 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
}
if let Err(e) = handle_binary_message(&data, &state).await {
// Log with throttling to avoid spam
if log_throttler.should_log("binary_hid_error") {
warn!("Binary HID message error: {}", e);
}
// Don't send error response for every failed message to reduce overhead
}
}
Ok(Message::Text(text)) => {
// Text messages are no longer supported
if log_throttler.should_log("text_message_rejected") {
debug!(
"Received text message (not supported): {} bytes",
@@ -111,7 +91,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
}
}
// Reset HID state to release any held keys/buttons
if let Err(e) = state.hid.reset().await {
warn!("Failed to reset HID on WebSocket disconnect: {}", e);
}
@@ -119,7 +98,6 @@ async fn handle_hid_socket(socket: WebSocket, state: Arc<AppState>) {
info!("WebSocket HID connection ended");
}
/// Handle binary HID message (same format as DataChannel)
async fn handle_binary_message(data: &[u8], state: &AppState) -> Result<(), String> {
let event = parse_hid_message(data).ok_or("Invalid binary HID message")?;
@@ -160,12 +138,10 @@ mod tests {
assert_eq!(RESP_OK, 0x00);
assert_eq!(RESP_ERR_HID_UNAVAILABLE, 0x01);
assert_eq!(RESP_ERR_INVALID_MESSAGE, 0x02);
// assert_eq!(RESP_ERR_SEND_FAILED, 0x03); // TODO: fix test
}
#[test]
fn test_keyboard_message_format() {
// Keyboard message: [0x01, event_type, key, modifiers]
let data = [MSG_KEYBOARD, KB_EVENT_DOWN, 0x04, 0x01]; // 'A' key with left ctrl
let event = parse_hid_message(&data);
assert!(event.is_some());
@@ -173,7 +149,6 @@ mod tests {
#[test]
fn test_mouse_message_format() {
// Mouse message: [0x02, event_type, x_lo, x_hi, y_lo, y_hi, extra]
let data = [MSG_MOUSE, MS_EVENT_MOVE, 0x0A, 0x00, 0xF6, 0xFF, 0x00]; // x=10, y=-10
let event = parse_hid_message(&data);
assert!(event.is_some());

View File

@@ -1,30 +1,27 @@
//! One-KVM - Lightweight IP-KVM solution
//!
//! This crate provides the core functionality for One-KVM,
//! a remote KVM (Keyboard, Video, Mouse) solution written in Rust.
//! Core library for One-KVM (IPKVM: capture, HID, OTG, streaming, Web UI glue).
pub mod atx;
pub mod audio;
pub mod auth;
pub mod config;
pub mod db;
pub mod error;
pub mod events;
pub mod extensions;
pub mod hid;
pub mod modules;
pub mod msd;
pub mod otg;
pub mod rtsp;
pub mod rustdesk;
pub mod state;
pub mod stream;
pub mod stream_encoder;
pub mod update;
pub mod utils;
pub mod video;
pub mod web;
pub mod webrtc;
/// Auto-generated secrets module (from secrets.toml at compile time)
pub mod secrets {
include!(concat!(env!("OUT_DIR"), "/secrets_generated.rs"));
}

View File

@@ -1,7 +1,8 @@
use std::collections::HashSet;
use std::future::Future;
use std::io::Write;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
@@ -15,6 +16,7 @@ use one_kvm::atx::AtxController;
use one_kvm::audio::{AudioController, AudioControllerConfig, AudioQuality};
use one_kvm::auth::{SessionStore, UserStore};
use one_kvm::config::{self, AppConfig, ConfigStore};
use one_kvm::db::DatabasePool;
use one_kvm::events::EventBus;
use one_kvm::extensions::ExtensionManager;
use one_kvm::hid::{HidBackendType, HidController};
@@ -33,7 +35,6 @@ use one_kvm::video::{Streamer, VideoStreamManager};
use one_kvm::web;
use one_kvm::webrtc::{WebRtcStreamer, WebRtcStreamerConfig};
/// Log level for the application
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
enum LogLevel {
Error,
@@ -45,7 +46,6 @@ enum LogLevel {
Trace,
}
/// One-KVM command line arguments
#[derive(Parser, Debug)]
#[command(name = "one-kvm")]
#[command(version, about = "A open and lightweight IP-KVM solution", long_about = None)]
@@ -111,37 +111,30 @@ enum UserAction {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Parse command line arguments
let args = CliArgs::parse();
// Initialize logging with CLI arguments
init_logging(args.log_level, args.verbose);
// Install default crypto provider (required by rustls 0.23+)
CryptoProvider::install_default(ring::default_provider())
.expect("Failed to install rustls crypto provider");
tracing::info!("Starting One-KVM v{}", env!("CARGO_PKG_VERSION"));
// Determine data directory (CLI arg takes precedence)
let data_dir = args.data_dir.clone().unwrap_or_else(get_data_dir);
tracing::info!("Data directory: {}", data_dir.display());
// Run one-off CLI command and exit.
if let Some(command) = args.command {
run_cli_command(command, data_dir).await?;
return Ok(());
}
// Ensure data directory exists
tokio::fs::create_dir_all(&data_dir).await?;
// Initialize configuration store
let db_path = data_dir.join("one-kvm.db");
let config_store = ConfigStore::new(&db_path).await?;
let db = open_database_pool(&data_dir).await?;
let config_store = ConfigStore::new(db.clone_pool())?;
config_store.load().await?;
let mut config = (*config_store.get()).clone();
// Normalize MSD directory (absolute path under data dir if empty/relative)
let mut msd_dir_updated = false;
if config.msd.msd_dir.trim().is_empty() {
let msd_dir = data_dir.join("msd");
@@ -159,8 +152,6 @@ async fn main() -> anyhow::Result<()> {
if msd_dir_updated {
config_store.set(config.clone()).await?;
}
// Ensure MSD directories exist (msd/images, msd/ventoy)
let msd_dir = PathBuf::from(&config.msd.msd_dir);
if let Err(e) = tokio::fs::create_dir_all(msd_dir.join("images")).await {
tracing::warn!("Failed to create MSD images directory: {}", e);
@@ -169,7 +160,6 @@ async fn main() -> anyhow::Result<()> {
tracing::warn!("Failed to create MSD ventoy directory: {}", e);
}
// Apply CLI argument overrides to config (only if explicitly specified)
if let Some(addr) = args.address {
config.web.bind_address = addr.clone();
config.web.bind_addresses = vec![addr];
@@ -203,29 +193,20 @@ async fn main() -> anyhow::Result<()> {
config.web.http_port
};
// Log final configuration
for ip in &bind_ips {
let addr = SocketAddr::new(*ip, bind_port);
tracing::info!("Server will listen on: {}://{}", scheme, addr);
}
// Initialize session store
let session_store = SessionStore::new(
config_store.pool().clone(),
config.auth.session_timeout_secs as i64,
);
let session_store = SessionStore::new(config.auth.session_timeout_secs as i64);
// Initialize user store
let user_store = UserStore::new(config_store.pool().clone());
let user_store = UserStore::new(db.clone_pool());
// Create shutdown channel
let (shutdown_tx, _) = broadcast::channel::<()>(1);
// Create event bus for real-time notifications
let events = Arc::new(EventBus::new());
tracing::info!("Event bus initialized");
// Parse video configuration once (avoid duplication)
let (video_format, video_resolution) = parse_video_config(&config);
tracing::debug!(
"Parsed video config: {} @ {}x{}",
@@ -234,7 +215,6 @@ async fn main() -> anyhow::Result<()> {
video_resolution.height
);
// Create video streamer and initialize with config if device is set
let streamer = Streamer::new();
streamer.set_event_bus(events.clone()).await;
if let Some(ref device_path) = config.video.device {
@@ -262,19 +242,19 @@ async fn main() -> anyhow::Result<()> {
}
}
// Create WebRTC streamer
let webrtc_streamer = {
let webrtc_config = WebRtcStreamerConfig {
resolution: video_resolution,
input_format: video_format,
fps: config.video.fps,
bitrate_preset: config.stream.bitrate_preset,
encoder_backend: config.stream.encoder.to_backend(),
encoder_backend: one_kvm::stream_encoder::encoder_type_to_backend(
config.stream.encoder.clone(),
),
webrtc: {
let mut stun_servers = vec![];
let mut turn_servers = vec![];
// Check if user configured custom servers
let has_custom_stun = config
.stream
.stun_server
@@ -288,25 +268,12 @@ async fn main() -> anyhow::Result<()> {
.map(|s| !s.is_empty())
.unwrap_or(false);
// If no custom servers, use public ICE servers (like RustDesk)
if !has_custom_stun && !has_custom_turn {
use one_kvm::webrtc::config::public_ice;
if public_ice::is_configured() {
if let Some(stun) = public_ice::stun_server() {
stun_servers.push(stun.clone());
tracing::info!("Using public STUN server: {}", stun);
}
for turn in public_ice::turn_servers() {
tracing::info!("Using public TURN server: {:?}", turn.urls);
turn_servers.push(turn);
}
} else {
tracing::info!(
"No public ICE servers configured, using host candidates only"
);
}
let stun = public_ice::stun_server().to_string();
tracing::info!("Using public STUN server: {}", stun);
stun_servers.push(stun);
} else {
// Use custom servers
if let Some(ref stun) = config.stream.stun_server {
if !stun.is_empty() {
stun_servers.push(stun.clone());
@@ -342,18 +309,15 @@ async fn main() -> anyhow::Result<()> {
};
WebRtcStreamer::with_config(webrtc_config)
};
tracing::info!("WebRTC streamer created (supports H264, extensible to VP8/VP9/H265)");
tracing::info!("WebRTC streamer created");
// Create OTG Service (single instance for centralized USB gadget management)
let otg_service = Arc::new(OtgService::new());
tracing::info!("OTG Service created");
// Reconcile OTG once from the persisted config so controllers only consume its result.
if let Err(e) = otg_service.apply_config(&config.hid, &config.msd).await {
tracing::warn!("Failed to apply OTG config: {}", e);
}
// Create HID controller based on config
let hid_backend = match config.hid.backend {
config::HidBackend::Otg => HidBackendType::Otg,
config::HidBackend::Ch9329 => HidBackendType::Ch9329 {
@@ -362,27 +326,17 @@ async fn main() -> anyhow::Result<()> {
},
config::HidBackend::None => HidBackendType::None,
};
let hid = Arc::new(HidController::new(
hid_backend,
Some(otg_service.clone()), // Always pass OtgService to support hot-reload to OTG
));
let hid = Arc::new(HidController::new(hid_backend, Some(otg_service.clone())));
hid.set_event_bus(events.clone()).await;
if let Err(e) = hid.init().await {
tracing::warn!("Failed to initialize HID backend: {}", e);
}
// Create MSD controller (optional, based on config)
let msd = if config.msd.enabled {
// Initialize Ventoy resources from data directory
let ventoy_resource_dir = ventoy_img::get_resource_dir(&data_dir);
let ventoy_resource_dir = data_dir.join("ventoy");
if ventoy_resource_dir.exists() {
if let Err(e) = ventoy_img::init_resources(&ventoy_resource_dir) {
tracing::warn!("Failed to initialize Ventoy resources: {}", e);
tracing::info!(
"Ventoy resource files should be placed in: {}",
ventoy_resource_dir.display()
);
tracing::info!("Required files: {:?}", ventoy_img::required_files());
} else {
tracing::info!(
"Ventoy resources initialized from {}",
@@ -394,10 +348,6 @@ async fn main() -> anyhow::Result<()> {
"Ventoy resource directory not found: {}",
ventoy_resource_dir.display()
);
tracing::info!(
"Create the directory and place the following files: {:?}",
ventoy_img::required_files()
);
}
let controller = MsdController::new(otg_service.clone(), config.msd.msd_dir_path());
@@ -413,7 +363,6 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create ATX controller (optional, based on config)
let atx = if config.atx.enabled {
let controller_config = config.atx.to_controller_config();
let controller = AtxController::new(controller_config);
@@ -429,12 +378,21 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create Audio controller
let audio = {
let audio_config = AudioControllerConfig {
enabled: config.audio.enabled,
device: config.audio.device.clone(),
quality: AudioQuality::from_str(&config.audio.quality),
quality: match config.audio.quality.parse::<AudioQuality>() {
Ok(q) => q,
Err(e) => {
tracing::warn!(
"Invalid audio quality in config (value={:?}): {}, using balanced",
config.audio.quality,
e
);
AudioQuality::Balanced
}
},
};
let controller = AudioController::new(audio_config);
@@ -446,7 +404,6 @@ async fn main() -> anyhow::Result<()> {
config.audio.device,
config.audio.quality
);
// Start audio streaming so WebRTC can subscribe to Opus frames
if let Err(e) = controller.start_streaming().await {
tracing::warn!("Failed to start audio streaming: {}", e);
}
@@ -457,29 +414,23 @@ async fn main() -> anyhow::Result<()> {
Arc::new(controller)
};
// Create Extension manager (ttyd, gostc, easytier)
let extensions = Arc::new(ExtensionManager::new());
tracing::info!("Extension manager initialized");
// Wire up WebRTC streamer with HID controller
// This enables WebRTC DataChannel to process HID events
webrtc_streamer.set_hid_controller(hid.clone()).await;
// Wire up WebRTC streamer with Audio controller
// This enables WebRTC audio track to receive Opus frames
webrtc_streamer.set_audio_controller(audio.clone()).await;
if config.audio.enabled {
if let Err(e) = webrtc_streamer.set_audio_enabled(true).await {
tracing::warn!("Failed to enable WebRTC audio: {}", e);
} else {
tracing::info!("WebRTC audio enabled");
tracing::debug!("WebRTC audio enabled");
}
}
// Configure direct capture for WebRTC encoder pipeline
let (device_path, actual_resolution, actual_format, actual_fps, jpeg_quality) =
streamer.current_capture_config().await;
tracing::info!(
tracing::debug!(
"Initial video config: {}x{} {:?} @ {}fps",
actual_resolution.width,
actual_resolution.height,
@@ -490,22 +441,50 @@ async fn main() -> anyhow::Result<()> {
.update_video_config(actual_resolution, actual_format, actual_fps)
.await;
if let Some(device_path) = device_path {
let (subdev_path, bridge_kind, v4l2_driver) = streamer
.current_device()
.await
.map(|d| {
(
d.subdev_path.clone(),
d.bridge_kind.clone(),
Some(d.driver.clone()),
)
})
.unwrap_or((None, None, None));
webrtc_streamer
.set_capture_device(device_path, jpeg_quality)
.set_capture_device(
device_path,
jpeg_quality,
subdev_path,
bridge_kind,
v4l2_driver,
)
.await;
tracing::info!("WebRTC streamer configured for direct capture");
tracing::debug!("WebRTC streamer configured for direct capture");
} else {
tracing::warn!("No capture device configured for WebRTC");
}
// Create video stream manager (unified MJPEG/WebRTC management)
// Use with_webrtc_streamer to ensure we use the same WebRtcStreamer instance
let stream_manager =
VideoStreamManager::with_webrtc_streamer(streamer.clone(), webrtc_streamer.clone());
let stream_manager = VideoStreamManager::with_webrtc_streamer(
streamer.clone(),
webrtc_streamer.clone() as std::sync::Arc<dyn one_kvm::video::traits::VideoOutput>,
);
stream_manager.set_event_bus(events.clone()).await;
stream_manager.set_config_store(config_store.clone()).await;
{
let stream_manager_weak = Arc::downgrade(&stream_manager);
audio
.set_recovered_callback(Arc::new(move || {
if let Some(stream_manager) = stream_manager_weak.upgrade() {
tokio::spawn(async move {
stream_manager.reconnect_webrtc_audio_sources().await;
});
}
}))
.await;
}
// Initialize stream manager with configured mode
let initial_mode = config.stream.mode.clone();
if let Err(e) = stream_manager.init_with_mode(initial_mode.clone()).await {
tracing::warn!(
@@ -520,7 +499,6 @@ async fn main() -> anyhow::Result<()> {
);
}
// Create RustDesk service (optional, based on config)
let rustdesk = if config.rustdesk.is_valid() {
tracing::info!(
"Initializing RustDesk service: ID={} -> {}",
@@ -545,7 +523,6 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create RTSP service (optional, based on config)
let rtsp = if config.rtsp.enabled {
tracing::info!(
"Initializing RTSP service: rtsp://{}:{}/{}",
@@ -560,15 +537,16 @@ async fn main() -> anyhow::Result<()> {
None
};
// Create application state
let update_service = Arc::new(UpdateService::new(data_dir.join("updates")));
let state = AppState::new(
db.clone(),
config_store.clone(),
session_store,
user_store,
otg_service,
stream_manager,
webrtc_streamer.clone(),
hid,
msd,
atx,
@@ -584,12 +562,10 @@ async fn main() -> anyhow::Result<()> {
extensions.set_event_bus(events.clone()).await;
// Start RustDesk service if enabled
if let Some(ref service) = rustdesk {
if let Err(e) = service.start().await {
tracing::error!("Failed to start RustDesk service: {}", e);
} else {
// Save generated keypair and UUID to config
if let Some(updated_config) = service.save_credentials() {
if let Err(e) = config_store
.update(|cfg| {
@@ -609,7 +585,6 @@ async fn main() -> anyhow::Result<()> {
}
}
// Start RTSP service if enabled
if let Some(ref service) = rtsp {
if let Err(e) = service.start().await {
tracing::error!("Failed to start RTSP service: {}", e);
@@ -618,7 +593,6 @@ async fn main() -> anyhow::Result<()> {
}
}
// Enforce startup codec constraints (e.g. RTSP/RustDesk locks)
{
let runtime_config = state.config.get();
let constraints = StreamCodecConstraints::from_config(&runtime_config);
@@ -633,13 +607,11 @@ async fn main() -> anyhow::Result<()> {
}
}
// Start enabled extensions
{
let ext_config = config_store.get();
extensions.start_enabled(&ext_config.extensions).await;
}
// Start extension health check task (every 30 seconds)
{
let extensions_clone = extensions.clone();
let config_store_clone = config_store.clone();
@@ -656,17 +628,12 @@ async fn main() -> anyhow::Result<()> {
state.publish_device_info().await;
// Start device info broadcast task
// This monitors state change events and broadcasts DeviceInfo to all clients
spawn_device_info_broadcaster(state.clone(), events);
// Create router
let app = web::create_router(state.clone());
// Bind sockets for configured addresses
let listeners = bind_tcp_listeners(&bind_ips, bind_port)?;
// Setup graceful shutdown
let shutdown_signal = async move {
tokio::signal::ctrl_c()
.await
@@ -675,9 +642,7 @@ async fn main() -> anyhow::Result<()> {
let _ = shutdown_tx.send(());
};
// Start server
if config.web.https_enabled {
// Generate self-signed certificate if no custom cert provided
let tls_config = if let (Some(cert_path), Some(key_path)) =
(&config.web.ssl_cert_path, &config.web.ssl_key_path)
{
@@ -687,7 +652,6 @@ async fn main() -> anyhow::Result<()> {
let cert_path = cert_dir.join("server.crt");
let key_path = cert_dir.join("server.key");
// Check if certificate already exists, only generate if missing
if !cert_path.exists() || !key_path.exists() {
tracing::info!("Generating new self-signed TLS certificate");
let cert = generate_self_signed_cert()?;
@@ -701,7 +665,7 @@ async fn main() -> anyhow::Result<()> {
RustlsConfig::from_pem_file(&cert_path, &key_path).await?
};
let mut servers = FuturesUnordered::new();
let servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTPS server on {}", local_addr);
@@ -711,19 +675,9 @@ async fn main() -> anyhow::Result<()> {
servers.push(server);
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTPS server error: {}", e);
}
cleanup(&state).await;
}
}
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTPS").await;
} else {
let mut servers = FuturesUnordered::new();
let servers = FuturesUnordered::new();
for listener in listeners {
let local_addr = listener.local_addr()?;
tracing::info!("Starting HTTP server on {}", local_addr);
@@ -733,26 +687,14 @@ async fn main() -> anyhow::Result<()> {
servers.push(async move { server.await });
}
tokio::select! {
_ = shutdown_signal => {
cleanup(&state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("HTTP server error: {}", e);
}
cleanup(&state).await;
}
}
run_servers_until_shutdown(servers, shutdown_signal, &state, "HTTP").await;
}
tracing::info!("Server shutdown complete");
Ok(())
}
/// Initialize logging with tracing
fn init_logging(level: LogLevel, verbose_count: u8) {
// Verbose count overrides log level
let effective_level = match verbose_count {
0 => level,
1 => LogLevel::Verbose,
@@ -760,7 +702,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
_ => LogLevel::Trace,
};
// Build filter string based on effective level
let filter = match effective_level {
LogLevel::Error => "one_kvm=error,tower_http=error",
LogLevel::Warn => "one_kvm=warn,tower_http=warn",
@@ -770,7 +711,6 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
LogLevel::Trace => "one_kvm=trace,tower_http=debug",
};
// Environment variable takes highest priority
let env_filter =
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| filter.into());
@@ -783,23 +723,48 @@ fn init_logging(level: LogLevel, verbose_count: u8) {
}
}
/// Get the application data directory
fn get_data_dir() -> PathBuf {
// Check environment variable first
if let Ok(path) = std::env::var("ONE_KVM_DATA_DIR") {
return PathBuf::from(path);
}
// Default to system configuration directory
PathBuf::from("/etc/one-kvm")
}
async fn open_database_pool(data_dir: &Path) -> anyhow::Result<DatabasePool> {
let db_path = data_dir.join("one-kvm.db");
let db = DatabasePool::new(&db_path).await?;
db.init_schema().await?;
Ok(db)
}
async fn run_servers_until_shutdown<F, E>(
mut servers: FuturesUnordered<F>,
shutdown_signal: impl Future<Output = ()>,
state: &Arc<AppState>,
protocol: &'static str,
) where
F: Future<Output = Result<(), E>> + Send,
E: std::fmt::Display,
{
tokio::select! {
_ = shutdown_signal => {
cleanup(state).await;
}
result = servers.next() => {
if let Some(Err(e)) = result {
tracing::error!("{} server error: {}", protocol, e);
}
cleanup(state).await;
}
}
}
async fn run_cli_command(command: CliCommand, data_dir: PathBuf) -> anyhow::Result<()> {
tokio::fs::create_dir_all(&data_dir).await?;
let db_path = data_dir.join("one-kvm.db");
let config_store = ConfigStore::new(&db_path).await?;
let users = UserStore::new(config_store.pool().clone());
let sessions = SessionStore::new(config_store.pool().clone(), 0);
let db = open_database_pool(&data_dir).await?;
let users = UserStore::new(db.clone_pool());
let sessions = SessionStore::new(0);
match command {
CliCommand::User(user) => run_user_action(user.action, &users, &sessions).await,
@@ -817,15 +782,9 @@ async fn run_user_action(
}
async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow::Result<()> {
let all = users.list().await?;
let user = match all.len() {
0 => anyhow::bail!("No local user exists yet; complete setup in the web UI first."),
1 => &all[0],
_ => anyhow::bail!(
"Expected exactly one local user (single-user design), found {}. Remove extra users from the database or contact support.",
all.len()
),
};
let user = users.single_user().await?.ok_or_else(|| {
anyhow::anyhow!("No local user exists yet; complete setup in the web UI first.")
})?;
let new_password = read_new_password_interactive()?;
if new_password.len() < 4 {
@@ -833,7 +792,7 @@ async fn set_user_password(users: &UserStore, sessions: &SessionStore) -> anyhow
}
users.update_password(&user.id, &new_password).await?;
let revoked = sessions.delete_by_user_id(&user.id).await?;
let revoked = sessions.delete_all().await?;
tracing::info!(
"Password updated for user '{}' and {} sessions revoked",
@@ -869,7 +828,6 @@ fn read_new_password_interactive() -> anyhow::Result<String> {
Ok(a)
}
/// Resolve bind IPs from config, preferring bind_addresses when set.
fn resolve_bind_addresses(web: &config::WebConfig) -> anyhow::Result<Vec<IpAddr>> {
let raw_addrs = if !web.bind_addresses.is_empty() {
web.bind_addresses.as_slice()
@@ -910,7 +868,6 @@ fn bind_tcp_listeners(addrs: &[IpAddr], port: u16) -> anyhow::Result<Vec<std::ne
Ok(listeners)
}
/// Parse video format and resolution from config (avoids code duplication)
fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
let format = config
.video
@@ -922,7 +879,6 @@ fn parse_video_config(config: &AppConfig) -> (PixelFormat, Resolution) {
(format, resolution)
}
/// Generate a self-signed TLS certificate
fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyPair>> {
use rcgen::generate_simple_self_signed;
@@ -936,8 +892,6 @@ fn generate_self_signed_cert() -> anyhow::Result<rcgen::CertifiedKey<rcgen::KeyP
Ok(certified_key)
}
/// Spawn a background task that monitors state change events
/// and broadcasts DeviceInfo to all WebSocket clients with debouncing
fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
use std::time::{Duration, Instant};
@@ -1024,7 +978,6 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
let mut pending_broadcast = false;
loop {
// Use timeout to handle pending broadcasts
let recv_result = if pending_broadcast {
let remaining =
DEBOUNCE_MS.saturating_sub(last_broadcast.elapsed().as_millis() as u64);
@@ -1049,12 +1002,9 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
tracing::info!("Event bus closed, stopping DeviceInfo broadcaster");
break;
}
Err(_timeout) => {
// Debounce timeout reached, broadcast now
}
Err(_timeout) => {}
}
// Broadcast if pending and debounce time has passed
if pending_broadcast && last_broadcast.elapsed() >= Duration::from_millis(DEBOUNCE_MS) {
state.publish_device_info().await;
tracing::trace!("Broadcasted DeviceInfo (debounced)");
@@ -1070,13 +1020,10 @@ fn spawn_device_info_broadcaster(state: Arc<AppState>, events: Arc<EventBus>) {
);
}
/// Clean up subsystems on shutdown
async fn cleanup(state: &Arc<AppState>) {
// Stop all extensions
state.extensions.stop_all().await;
tracing::info!("Extensions stopped");
// Stop RustDesk service
if let Some(ref service) = *state.rustdesk.read().await {
if let Err(e) = service.stop().await {
tracing::warn!("Failed to stop RustDesk service: {}", e);
@@ -1085,7 +1032,6 @@ async fn cleanup(state: &Arc<AppState>) {
}
}
// Stop RTSP service
if let Some(ref service) = *state.rtsp.read().await {
if let Err(e) = service.stop().await {
tracing::warn!("Failed to stop RTSP service: {}", e);
@@ -1094,31 +1040,26 @@ async fn cleanup(state: &Arc<AppState>) {
}
}
// Stop video
if let Err(e) = state.stream_manager.stop().await {
tracing::warn!("Failed to stop streamer: {}", e);
}
// Shutdown HID
if let Err(e) = state.hid.shutdown().await {
tracing::warn!("Failed to shutdown HID: {}", e);
}
// Shutdown MSD
if let Some(msd) = state.msd.write().await.as_mut() {
if let Err(e) = msd.shutdown().await {
tracing::warn!("Failed to shutdown MSD: {}", e);
}
}
// Shutdown ATX
if let Some(atx) = state.atx.write().await.as_mut() {
if let Err(e) = atx.shutdown().await {
tracing::warn!("Failed to shutdown ATX: {}", e);
}
}
// Shutdown Audio
if let Err(e) = state.audio.shutdown().await {
tracing::warn!("Failed to shutdown audio: {}", e);
}

View File

@@ -1,49 +0,0 @@
//! Module management for One-KVM
//!
//! This module provides infrastructure for managing feature modules
//! (video streaming, HID control, MSD, ATX) as independent async tasks.
use std::future::Future;
use std::pin::Pin;
use tokio::sync::broadcast;
/// Module status
#[derive(Debug, Clone, PartialEq)]
pub enum ModuleStatus {
Stopped,
Starting,
Running,
Stopping,
Error(String),
}
/// Trait for feature modules
pub trait Module: Send + Sync {
/// Module name
fn name(&self) -> &'static str;
/// Current status
fn status(&self) -> ModuleStatus;
/// Start the module
fn start(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
/// Stop the module
fn stop(&mut self) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send + '_>>;
}
/// Module manager for coordinating feature modules
pub struct ModuleManager {
shutdown_rx: broadcast::Receiver<()>,
}
impl ModuleManager {
pub fn new(shutdown_rx: broadcast::Receiver<()>) -> Self {
Self { shutdown_rx }
}
/// Wait for shutdown signal
pub async fn wait_for_shutdown(&mut self) {
let _ = self.shutdown_rx.recv().await;
}
}

View File

@@ -1,11 +1,3 @@
//! MSD Controller
//!
//! Manages the mass storage device lifecycle including:
//! - Image mounting and unmounting
//! - Virtual drive management
//! - State tracking
//! - Image downloads from URL
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
@@ -14,41 +6,25 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use super::image::ImageManager;
use super::monitor::{MsdHealthMonitor, MsdHealthStatus};
use super::monitor::MsdHealthMonitor;
use super::types::{DownloadProgress, DownloadStatus, DriveInfo, ImageInfo, MsdMode, MsdState};
use crate::error::{AppError, Result};
use crate::otg::{MsdFunction, MsdLunConfig, OtgService};
/// MSD Controller
pub struct MsdController {
/// OTG Service reference
otg_service: Arc<OtgService>,
/// MSD function manager (provided by OtgService)
msd_function: RwLock<Option<MsdFunction>>,
/// Current state
state: RwLock<MsdState>,
/// Images storage path
images_path: PathBuf,
/// Ventoy directory path
ventoy_dir: PathBuf,
/// Virtual drive path
drive_path: PathBuf,
/// Event bus for broadcasting state changes (optional)
events: tokio::sync::RwLock<Option<Arc<crate::events::EventBus>>>,
/// Active downloads (download_id -> CancellationToken)
downloads: Arc<RwLock<HashMap<String, CancellationToken>>>,
/// Operation mutex lock (prevents concurrent operations)
operation_lock: Arc<RwLock<()>>,
/// Health monitor for error tracking and recovery
monitor: Arc<MsdHealthMonitor>,
}
impl MsdController {
/// Create new MSD controller
///
/// # Parameters
/// * `otg_service` - OTG service for gadget management
/// * `msd_dir` - Base directory for MSD storage
pub fn new(otg_service: Arc<OtgService>, msd_dir: impl Into<PathBuf>) -> Self {
let msd_dir = msd_dir.into();
let images_path = msd_dir.join("images");
@@ -68,11 +44,9 @@ impl MsdController {
}
}
/// Initialize the MSD controller
pub async fn init(&self) -> Result<()> {
info!("Initializing MSD controller");
// 1. Ensure images directory exists
if let Err(e) = std::fs::create_dir_all(&self.images_path) {
warn!("Failed to create images directory: {}", e);
}
@@ -80,20 +54,16 @@ impl MsdController {
warn!("Failed to create ventoy directory: {}", e);
}
// 2. Get active MSD function from OtgService
info!("Fetching MSD function from OtgService");
let msd_func = self.otg_service.msd_function().await.ok_or_else(|| {
AppError::Internal("MSD function is not active in OtgService".to_string())
})?;
// 3. Store function handle
*self.msd_function.write().await = Some(msd_func);
// 4. Update state
let mut state = self.state.write().await;
state.available = true;
// 5. Check for existing virtual drive
if self.drive_path.exists() {
if let Ok(metadata) = std::fs::metadata(&self.drive_path) {
state.drive_info = Some(DriveInfo {
@@ -114,17 +84,14 @@ impl MsdController {
Ok(())
}
/// Get current MSD state
pub async fn state(&self) -> MsdState {
self.state.read().await.clone()
}
/// Set event bus for broadcasting state changes
pub async fn set_event_bus(&self, events: std::sync::Arc<crate::events::EventBus>) {
*self.events.write().await = Some(events);
}
/// Publish an event to the event bus
async fn publish_event(&self, event: crate::events::SystemEvent) {
if let Some(ref bus) = *self.events.read().await {
bus.publish(event);
@@ -137,43 +104,21 @@ impl MsdController {
}
}
/// Check if MSD is available
pub async fn is_available(&self) -> bool {
self.state.read().await.available
}
/// Connect an image file
///
/// # Parameters
/// * `image` - Image info to mount
/// * `cdrom` - Mount as CD-ROM (read-only, removable)
/// * `read_only` - Mount as read-only
pub async fn connect_image(
&self,
image: &ImageInfo,
cdrom: bool,
read_only: bool,
) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(err);
}
self.assert_can_connect(&state).await?;
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Verify image exists
if !image.path.exists() {
let error_msg = format!("Image file not found: {}", image.path.display());
self.monitor
@@ -182,29 +127,12 @@ impl MsdController {
return Err(AppError::Internal(error_msg));
}
// Configure LUN
let config = if cdrom {
MsdLunConfig::cdrom(image.path.clone())
} else {
MsdLunConfig::disk(image.path.clone(), read_only)
};
let gadget_path = self.active_gadget_path().await?;
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(err);
}
self.configure_lun_now(&config).await?;
state.connected = true;
state.mode = MsdMode::Image;
@@ -215,42 +143,19 @@ impl MsdController {
image.name, cdrom, read_only
);
// Release the lock before publishing events
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
self.finish_connect_success().await;
Ok(())
}
/// Connect the virtual drive
pub async fn connect_drive(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
if !state.available {
let err = AppError::Internal("MSD not available".to_string());
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(err);
}
self.assert_can_connect(&state).await?;
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
// Check drive exists
if !self.drive_path.exists() {
let err =
AppError::Internal("Virtual drive not initialized. Call init first.".to_string());
@@ -260,25 +165,8 @@ impl MsdController {
return Err(err);
}
// Configure LUN as read-write disk
let config = MsdLunConfig::disk(self.drive_path.clone(), false);
let gadget_path = self.active_gadget_path().await?;
if let Some(ref msd) = *self.msd_function.read().await {
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, &config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
} else {
let err = AppError::Internal("MSD function not initialized".to_string());
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(err);
}
self.configure_lun_now(&config).await?;
state.connected = true;
state.mode = MsdMode::Drive;
@@ -286,23 +174,57 @@ impl MsdController {
info!("Connected virtual drive: {}", self.drive_path.display());
// Release the lock before publishing event
drop(state);
drop(_op_guard);
// Report recovery if we were in an error state
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
self.finish_connect_success().await;
Ok(())
}
/// Disconnect current storage
async fn assert_can_connect(&self, state: &MsdState) -> Result<()> {
if !state.available {
self.monitor
.report_error("MSD not available", "not_available")
.await;
return Err(AppError::Internal("MSD not available".to_string()));
}
if state.connected {
return Err(AppError::Internal(
"Already connected. Disconnect first.".to_string(),
));
}
Ok(())
}
async fn configure_lun_now(&self, config: &MsdLunConfig) -> Result<()> {
let gadget_path = self.active_gadget_path().await?;
let msd_hold = self.msd_function.read().await;
let Some(ref msd) = *msd_hold else {
self.monitor
.report_error("MSD function not initialized", "not_initialized")
.await;
return Err(AppError::Internal(
"MSD function not initialized".to_string(),
));
};
if let Err(e) = msd.configure_lun_async(&gadget_path, 0, config).await {
let error_msg = format!("Failed to configure LUN: {}", e);
self.monitor
.report_error(&error_msg, "configfs_error")
.await;
return Err(e);
}
Ok(())
}
async fn finish_connect_success(&self) {
if self.monitor.is_error().await {
self.monitor.report_recovered().await;
}
self.mark_device_info_dirty().await;
}
pub async fn disconnect(&self) -> Result<()> {
// Acquire operation lock to prevent concurrent operations
let _op_guard = self.operation_lock.write().await;
let mut state = self.state.write().await;
@@ -323,7 +245,6 @@ impl MsdController {
info!("Disconnected storage");
// Release the lock before publishing events
drop(state);
drop(_op_guard);
@@ -332,41 +253,31 @@ impl MsdController {
Ok(())
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
/// Get ventoy directory path
pub fn ventoy_dir(&self) -> &PathBuf {
&self.ventoy_dir
}
/// Get virtual drive path
pub fn drive_path(&self) -> &PathBuf {
&self.drive_path
}
/// Check if currently connected
pub async fn is_connected(&self) -> bool {
self.state.read().await.connected
}
/// Get current mode
pub async fn mode(&self) -> MsdMode {
self.state.read().await.mode.clone()
}
/// Update drive info
pub async fn update_drive_info(&self, info: DriveInfo) {
let mut state = self.state.write().await;
state.drive_info = Some(info);
}
/// Start downloading an image from URL
///
/// Returns the download_id that can be used to track or cancel the download.
/// Progress is reported via MsdDownloadProgress events.
pub async fn download_image(
&self,
url: String,
@@ -375,18 +286,15 @@ impl MsdController {
let download_id = uuid::Uuid::new_v4().to_string();
let cancel_token = CancellationToken::new();
// Register download
{
let mut downloads = self.downloads.write().await;
downloads.insert(download_id.clone(), cancel_token.clone());
}
// Extract filename for initial response
let display_filename = filename
.clone()
.unwrap_or_else(|| url.rsplit('/').next().unwrap_or("download").to_string());
// Create initial progress
let initial_progress = DownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
@@ -398,7 +306,6 @@ impl MsdController {
error: None,
};
// Publish started event
self.publish_event(crate::events::SystemEvent::MsdDownloadProgress {
download_id: download_id.clone(),
url: url.clone(),
@@ -410,18 +317,15 @@ impl MsdController {
})
.await;
// Clone what we need for the spawned task
let images_path = self.images_path.clone();
let events = self.events.read().await.clone();
let downloads = self.downloads.clone();
let download_id_clone = download_id.clone();
let url_clone = url.clone();
// Spawn download task
tokio::spawn(async move {
let manager = ImageManager::new(images_path);
// Create progress callback
let events_for_callback = events.clone();
let download_id_for_callback = download_id_clone.clone();
let url_for_callback = url_clone.clone();
@@ -443,18 +347,15 @@ impl MsdController {
}
};
// Run download
let result = manager
.download_from_url(&url_clone, filename, progress_callback)
.await;
// Remove from active downloads
{
let mut downloads_guard = downloads.write().await;
downloads_guard.remove(&download_id_clone);
}
// Publish completion event
match result {
Ok(image_info) => {
if let Some(ref bus) = events {
@@ -489,7 +390,6 @@ impl MsdController {
Ok(initial_progress)
}
/// Cancel an active download
pub async fn cancel_download(&self, download_id: &str) -> Result<()> {
let mut downloads = self.downloads.write().await;
@@ -505,12 +405,6 @@ impl MsdController {
}
}
/// Get list of active download IDs
pub async fn active_downloads(&self) -> Vec<String> {
let downloads = self.downloads.read().await;
downloads.keys().cloned().collect()
}
async fn active_gadget_path(&self) -> Result<PathBuf> {
self.otg_service
.gadget_path()
@@ -518,16 +412,13 @@ impl MsdController {
.ok_or_else(|| AppError::Internal("OTG gadget path is not available".to_string()))
}
/// Shutdown the controller
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down MSD controller");
// 1. Disconnect if connected
if let Err(e) = self.disconnect().await {
warn!("Error disconnecting during shutdown: {}", e);
}
// 2. Clear local state
*self.msd_function.write().await = None;
let mut state = self.state.write().await;
@@ -537,27 +428,9 @@ impl MsdController {
Ok(())
}
/// Get the health monitor reference
pub fn monitor(&self) -> &Arc<MsdHealthMonitor> {
&self.monitor
}
/// Get current health status
pub async fn health_status(&self) -> MsdHealthStatus {
self.monitor.status().await
}
/// Check if the MSD is healthy
pub async fn is_healthy(&self) -> bool {
self.monitor.is_healthy().await
}
}
impl Drop for MsdController {
fn drop(&mut self) {
// Cleanup is handled by OtgGadgetManager when the gadget is torn down
// Individual controllers don't need to cleanup the ConfigFS
}
}
#[cfg(test)]
@@ -573,7 +446,6 @@ mod tests {
let controller = MsdController::new(otg_service, &msd_dir);
// Check that MSD is not initialized (msd_function is None)
let state = controller.state().await;
assert!(!state.available);
assert!(controller.images_path.ends_with("images"));

View File

@@ -1,53 +1,37 @@
//! Image file manager
//!
//! Handles ISO/IMG image file operations:
//! - List available images
//! - Upload new images
//! - Delete images
//! - Metadata management
//! - Download from URL
use chrono::Utc;
use futures::StreamExt;
use std::fs::{self, File};
use std::io::{self, Read, Write};
use std::fs;
#[cfg(test)]
use std::io::Write;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use time::OffsetDateTime;
use tokio::io::AsyncWriteExt;
use tracing::info;
use super::types::ImageInfo;
use crate::error::{AppError, Result};
/// Maximum image size (32 GB)
const MAX_IMAGE_SIZE: u64 = 32 * 1024 * 1024 * 1024;
/// Progress report throttle interval (milliseconds)
const PROGRESS_THROTTLE_MS: u64 = 200;
/// Progress report throttle bytes threshold (512 KB)
const PROGRESS_THROTTLE_BYTES: u64 = 512 * 1024;
/// Image Manager
pub struct ImageManager {
/// Images storage directory
images_path: PathBuf,
}
impl ImageManager {
/// Create a new image manager
pub fn new(images_path: PathBuf) -> Self {
Self { images_path }
}
/// Ensure images directory exists
pub fn ensure_dir(&self) -> Result<()> {
fs::create_dir_all(&self.images_path)
.map_err(|e| AppError::Internal(format!("Failed to create images directory: {}", e)))?;
Ok(())
}
/// List all available images
pub fn list(&self) -> Result<Vec<ImageInfo>> {
self.ensure_dir()?;
@@ -68,28 +52,26 @@ impl ImageManager {
}
}
// Sort by creation time (newest first)
images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(images)
}
/// Get image info from path
fn get_image_info(&self, path: &Path) -> Option<ImageInfo> {
let metadata = fs::metadata(path).ok()?;
let name = path.file_name()?.to_string_lossy().to_string();
// Use filename hash as ID (stable across restarts)
let id = format!("{:x}", md5_hash(&name));
let id = stable_image_id_from_filename(&name);
let created_at = metadata
.created()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| {
chrono::DateTime::from_timestamp(d.as_secs() as i64, 0).unwrap_or_else(Utc::now)
OffsetDateTime::from_unix_timestamp(d.as_secs() as i64)
.unwrap_or_else(|_| OffsetDateTime::now_utc())
})
.unwrap_or_else(Utc::now);
.unwrap_or_else(OffsetDateTime::now_utc);
Some(ImageInfo {
id,
@@ -100,7 +82,6 @@ impl ImageManager {
})
}
/// Get image by ID
pub fn get(&self, id: &str) -> Result<ImageInfo> {
for image in self.list()? {
if image.id == id {
@@ -110,24 +91,21 @@ impl ImageManager {
Err(AppError::NotFound(format!("Image not found: {}", id)))
}
/// Get image by name
pub fn get_by_name(&self, name: &str) -> Result<ImageInfo> {
let path = self.images_path.join(name);
self.get_image_info(&path)
.ok_or_else(|| AppError::NotFound(format!("Image not found: {}", name)))
}
/// Create a new image from bytes
pub fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
#[cfg(test)]
fn create(&self, name: &str, data: &[u8]) -> Result<ImageInfo> {
self.ensure_dir()?;
// Validate name
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Check size
if data.len() as u64 > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
@@ -135,7 +113,6 @@ impl ImageManager {
)));
}
// Write file
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
@@ -144,11 +121,10 @@ impl ImageManager {
)));
}
let mut file = File::create(&path)
let mut file = fs::File::create(&path)
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
file.write_all(data).map_err(|e| {
// Try to clean up on error
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
@@ -158,55 +134,6 @@ impl ImageManager {
self.get_by_name(&name)
}
/// Create a new image from a file stream (for chunked uploads)
pub fn create_from_stream<R: Read>(
&self,
name: &str,
reader: &mut R,
expected_size: Option<u64>,
) -> Result<ImageInfo> {
self.ensure_dir()?;
let name = sanitize_filename(name);
if name.is_empty() {
return Err(AppError::Internal("Invalid filename".to_string()));
}
if let Some(size) = expected_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::Internal(format!(
"Image too large. Maximum size: {} GB",
MAX_IMAGE_SIZE / 1024 / 1024 / 1024
)));
}
}
let path = self.images_path.join(&name);
if path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
name
)));
}
// Create file and copy data
let mut file = File::create(&path)
.map_err(|e| AppError::Internal(format!("Failed to create image file: {}", e)))?;
let bytes_written = io::copy(reader, &mut file).map_err(|e| {
let _ = fs::remove_file(&path);
AppError::Internal(format!("Failed to write image data: {}", e))
})?;
info!("Created image: {} ({} bytes)", name, bytes_written);
self.get_by_name(&name)
}
/// Create a new image from an async multipart field (streaming, memory-efficient)
///
/// This method streams data directly to disk without buffering the entire file in memory,
/// making it suitable for large files (multi-GB ISOs).
pub async fn create_from_multipart_field(
&self,
name: &str,
@@ -219,12 +146,10 @@ impl ImageManager {
return Err(AppError::Internal("Invalid filename".to_string()));
}
// Use a temporary file during upload
let temp_name = format!(".upload_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_name);
let final_path = self.images_path.join(&name);
// Check if final file already exists
if final_path.exists() {
return Err(AppError::Internal(format!(
"Image already exists: {}",
@@ -232,23 +157,19 @@ impl ImageManager {
)));
}
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
let mut bytes_written: u64 = 0;
// Stream chunks directly to disk
while let Some(chunk) = field
.chunk()
.await
.map_err(|e| AppError::Internal(format!("Failed to read upload chunk: {}", e)))?
{
// Check size limit
bytes_written += chunk.len() as u64;
if bytes_written > MAX_IMAGE_SIZE {
// Cleanup and return error
drop(file);
let _ = tokio::fs::remove_file(&temp_path).await;
return Err(AppError::Internal(format!(
@@ -257,19 +178,16 @@ impl ImageManager {
)));
}
// Write chunk to file
file.write_all(&chunk)
.await
.map_err(|e| AppError::Internal(format!("Failed to write chunk: {}", e)))?;
}
// Flush and close file
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
@@ -285,7 +203,6 @@ impl ImageManager {
self.get_by_name(&name)
}
/// Delete an image by ID
pub fn delete(&self, id: &str) -> Result<()> {
let image = self.get(id)?;
@@ -296,45 +213,6 @@ impl ImageManager {
Ok(())
}
/// Delete an image by name
pub fn delete_by_name(&self, name: &str) -> Result<()> {
let path = self.images_path.join(name);
if !path.exists() {
return Err(AppError::NotFound(format!("Image not found: {}", name)));
}
fs::remove_file(&path)
.map_err(|e| AppError::Internal(format!("Failed to delete image: {}", e)))?;
info!("Deleted image: {}", name);
Ok(())
}
/// Get total storage used
pub fn used_space(&self) -> u64 {
self.list()
.map(|images| images.iter().map(|i| i.size).sum())
.unwrap_or(0)
}
/// Check if storage has space for new image
pub fn has_space(&self, size: u64) -> bool {
// For now, just check against max size
// In the future, could check disk space
size <= MAX_IMAGE_SIZE
}
/// Download image from URL with progress callback
///
/// # Arguments
/// * `url` - The URL to download from
/// * `filename` - Optional custom filename (extracted from URL or Content-Disposition if not provided)
/// * `progress_callback` - Callback function called with (bytes_downloaded, total_bytes)
///
/// # Returns
/// * `Ok(ImageInfo)` - The downloaded image info
/// * `Err(AppError)` - If download fails
pub async fn download_from_url<F>(
&self,
url: &str,
@@ -346,20 +224,17 @@ impl ImageManager {
{
self.ensure_dir()?;
// Validate URL
let parsed_url = reqwest::Url::parse(url)
.map_err(|e| AppError::BadRequest(format!("Invalid URL: {}", e)))?;
info!("Starting download from: {}", url);
// Create HTTP client with timeout
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(3600)) // 1 hour timeout for large files
.timeout(std::time::Duration::from_secs(3600))
.connect_timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| AppError::Internal(format!("Failed to create HTTP client: {}", e)))?;
// Send HEAD request first to get content info
let head_response = client
.head(url)
.send()
@@ -379,7 +254,6 @@ impl ImageManager {
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
// Check file size
if let Some(size) = total_size {
if size > MAX_IMAGE_SIZE {
return Err(AppError::BadRequest(format!(
@@ -390,11 +264,9 @@ impl ImageManager {
}
}
// Determine filename
let final_filename = if let Some(name) = filename {
sanitize_filename(&name)
} else {
// Try Content-Disposition header first
let from_header = head_response
.headers()
.get(reqwest::header::CONTENT_DISPOSITION)
@@ -404,7 +276,6 @@ impl ImageManager {
if let Some(name) = from_header {
sanitize_filename(&name)
} else {
// Fall back to URL path
let path = parsed_url.path();
let name = path.rsplit('/').next().unwrap_or("download");
let name = urlencoding::decode(name).unwrap_or_else(|_| name.into());
@@ -418,7 +289,6 @@ impl ImageManager {
));
}
// Check if file already exists
let final_path = self.images_path.join(&final_filename);
if final_path.exists() {
return Err(AppError::BadRequest(format!(
@@ -427,11 +297,9 @@ impl ImageManager {
)));
}
// Create temporary file for download
let temp_filename = format!(".download_{}", uuid::Uuid::new_v4());
let temp_path = self.images_path.join(&temp_filename);
// Start actual download
let response = client
.get(url)
.send()
@@ -445,7 +313,6 @@ impl ImageManager {
)));
}
// Get actual content length from response (may differ from HEAD)
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
@@ -453,19 +320,16 @@ impl ImageManager {
.and_then(|s| s.parse::<u64>().ok())
.or(total_size);
// Create temp file
let mut file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
// Stream download with progress (throttled)
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
let mut last_report_time = Instant::now();
let mut last_reported_bytes: u64 = 0;
let throttle_interval = Duration::from_millis(PROGRESS_THROTTLE_MS);
// Report initial progress
progress_callback(0, content_length);
while let Some(chunk_result) = stream.next().await {
@@ -473,14 +337,12 @@ impl ImageManager {
chunk_result.map_err(|e| AppError::Internal(format!("Download error: {}", e)))?;
file.write_all(&chunk).await.map_err(|e| {
// Cleanup on error
let _ = std::fs::remove_file(&temp_path);
AppError::Internal(format!("Failed to write data: {}", e))
})?;
downloaded += chunk.len() as u64;
// Throttled progress reporting: report if enough time or bytes have passed
let now = Instant::now();
let time_elapsed = now.duration_since(last_report_time) >= throttle_interval;
let bytes_elapsed = downloaded - last_reported_bytes >= PROGRESS_THROTTLE_BYTES;
@@ -492,18 +354,15 @@ impl ImageManager {
}
}
// Always report final progress
if downloaded != last_reported_bytes {
progress_callback(downloaded, content_length);
}
// Ensure all data is flushed
file.flush()
.await
.map_err(|e| AppError::Internal(format!("Failed to flush file: {}", e)))?;
drop(file);
// Verify downloaded size
let metadata = tokio::fs::metadata(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to read file metadata: {}", e)))?;
@@ -519,7 +378,6 @@ impl ImageManager {
}
}
// Move temp file to final location
tokio::fs::rename(&temp_path, &final_path)
.await
.map_err(|e| {
@@ -533,35 +391,29 @@ impl ImageManager {
metadata.len()
);
// Return image info
self.get_by_name(&final_filename)
}
/// Get images storage path
pub fn images_path(&self) -> &PathBuf {
&self.images_path
}
}
/// Simple hash function for generating stable IDs
fn md5_hash(s: &str) -> u64 {
fn stable_image_id_from_filename(name: &str) -> String {
let mut hash: u64 = 0;
for (i, byte) in s.bytes().enumerate() {
for (i, byte) in name.bytes().enumerate() {
hash = hash.wrapping_add((byte as u64).wrapping_mul((i as u64).wrapping_add(1)));
hash = hash.wrapping_mul(31);
}
hash
format!("{:x}", hash)
}
/// Sanitize filename to prevent path traversal
fn sanitize_filename(name: &str) -> String {
let name = name.trim();
let name = name.replace(['/', '\\', '\0', ':', '*', '?', '"', '<', '>', '|'], "_");
// Remove leading dots (hidden files)
let name = name.trim_start_matches('.');
// Limit length
if name.len() > 255 {
name[..255].to_string()
} else {
@@ -569,17 +421,10 @@ fn sanitize_filename(name: &str) -> String {
}
}
/// Extract filename from Content-Disposition header
fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
// Handle both:
// Content-Disposition: attachment; filename="example.iso"
// Content-Disposition: attachment; filename*=UTF-8''example.iso
// Try filename* first (RFC 5987)
if let Some(pos) = header.find("filename*=") {
let start = pos + 10;
let value = &header[start..];
// Format: charset'language'value
if let Some(quote_start) = value.find("''") {
let encoded = value[quote_start + 2..].split(';').next()?;
let decoded = urlencoding::decode(encoded.trim()).ok()?;
@@ -590,7 +435,6 @@ fn extract_filename_from_content_disposition(header: &str) -> Option<String> {
}
}
// Try filename next
if let Some(pos) = header.find("filename=") {
let start = pos + 9;
let value = &header[start..];
@@ -612,7 +456,7 @@ mod tests {
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("test.iso"), "test.iso");
assert_eq!(sanitize_filename("../test.iso"), "_test.iso"); // .. becomes empty after trim_start_matches('.')
assert_eq!(sanitize_filename("../test.iso"), "_test.iso");
assert_eq!(sanitize_filename("test/file.iso"), "test_file.iso");
assert_eq!(sanitize_filename(".hidden.iso"), "hidden.iso");
}

View File

@@ -1,19 +1,3 @@
//! MSD (Mass Storage Device) module
//!
//! Provides virtual USB storage functionality with two modes:
//! - Image mounting: Mount ISO/IMG files for system installation
//! - Ventoy drive: Bootable exFAT drive for multiple ISO files
//!
//! Architecture:
//! ```text
//! Web API --> MSD Controller --> ConfigFS Mass Storage --> Target PC
//! |
//! ┌──────┴──────┐
//! │ │
//! Image Manager Ventoy Drive
//! (ISO/IMG) (Bootable exFAT)
//! ```
pub mod controller;
pub mod image;
pub mod monitor;
@@ -22,12 +6,11 @@ pub mod ventoy_drive;
pub use controller::MsdController;
pub use image::ImageManager;
pub use monitor::{MsdHealthMonitor, MsdHealthStatus, MsdMonitorConfig};
pub use monitor::MsdHealthMonitor;
pub use types::{
DownloadProgress, DownloadStatus, DriveFile, DriveInfo, DriveInitRequest, ImageDownloadRequest,
ImageInfo, MsdConnectRequest, MsdMode, MsdState,
};
pub use ventoy_drive::VentoyDrive;
// Re-export from otg module for backward compatibility
pub use crate::otg::{MsdFunction, MsdLunConfig};

View File

@@ -1,99 +1,46 @@
//! MSD (Mass Storage Device) health monitoring
//!
//! This module provides health monitoring for MSD operations, including:
//! - ConfigFS operation error tracking
//! - Image mount/unmount error tracking
//! - Error state tracking
//! - Log throttling to prevent log flooding
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::utils::LogThrottler;
/// MSD health status
const LOG_THROTTLE_SECS: u64 = 5;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum MsdHealthStatus {
/// Device is healthy and operational
pub(crate) enum MsdHealthStatus {
#[default]
Healthy,
/// Device has an error
Error {
/// Human-readable error reason
reason: String,
/// Error code for programmatic handling
error_code: String,
},
}
/// MSD health monitor configuration
#[derive(Debug, Clone)]
pub struct MsdMonitorConfig {
/// Log throttle interval in seconds
pub log_throttle_secs: u64,
}
impl Default for MsdMonitorConfig {
fn default() -> Self {
Self {
log_throttle_secs: 5,
}
}
}
/// MSD health monitor
///
/// Monitors MSD operation health and manages error state.
pub struct MsdHealthMonitor {
/// Current health status
status: RwLock<MsdHealthStatus>,
/// Log throttler to prevent log flooding
throttler: LogThrottler,
/// Error count (for tracking)
error_count: AtomicU32,
/// Last error code (for change detection)
last_error_code: RwLock<Option<String>>,
}
impl MsdHealthMonitor {
/// Create a new MSD health monitor with the specified configuration
pub fn new(config: MsdMonitorConfig) -> Self {
let throttle_secs = config.log_throttle_secs;
pub fn with_defaults() -> Self {
Self {
status: RwLock::new(MsdHealthStatus::Healthy),
throttler: LogThrottler::with_secs(throttle_secs),
throttler: LogThrottler::with_secs(LOG_THROTTLE_SECS),
error_count: AtomicU32::new(0),
last_error_code: RwLock::new(None),
}
}
/// Create a new MSD health monitor with default configuration
pub fn with_defaults() -> Self {
Self::new(MsdMonitorConfig::default())
}
/// Report an error from MSD operations
///
/// This method is called when an MSD operation fails. It:
/// 1. Updates the health status
/// 2. Logs the error (with throttling)
/// 3. Updates in-memory error state
///
/// # Arguments
///
/// * `reason` - Human-readable error description
/// * `error_code` - Error code for programmatic handling
pub async fn report_error(&self, reason: &str, error_code: &str) {
let count = self.error_count.fetch_add(1, Ordering::Relaxed) + 1;
// Check if error code changed
let error_changed = {
let last = self.last_error_code.read().await;
last.as_ref().map(|s| s.as_str()) != Some(error_code)
};
// Log with throttling (always log if error type changed)
let throttle_key = format!("msd_{}", error_code);
if error_changed || self.throttler.should_log(&throttle_key) {
warn!(
@@ -102,29 +49,21 @@ impl MsdHealthMonitor {
);
}
// Update last error code
*self.last_error_code.write().await = Some(error_code.to_string());
// Update status
*self.status.write().await = MsdHealthStatus::Error {
reason: reason.to_string(),
error_code: error_code.to_string(),
};
}
/// Report that the MSD has recovered from error
///
/// This method is called when an MSD operation succeeds after errors.
/// It resets the error state.
pub async fn report_recovered(&self) {
let prev_status = self.status.read().await.clone();
// Only report recovery if we were in an error state
if prev_status != MsdHealthStatus::Healthy {
let error_count = self.error_count.load(Ordering::Relaxed);
info!("MSD recovered after {} errors", error_count);
// Reset state
self.error_count.store(0, Ordering::Relaxed);
self.throttler.clear_all();
*self.last_error_code.write().await = None;
@@ -132,29 +71,25 @@ impl MsdHealthMonitor {
}
}
/// Get the current health status
pub async fn status(&self) -> MsdHealthStatus {
#[cfg(test)]
pub(crate) async fn status(&self) -> MsdHealthStatus {
self.status.read().await.clone()
}
/// Get the current error count
pub fn error_count(&self) -> u32 {
#[cfg(test)]
pub(crate) fn error_count(&self) -> u32 {
self.error_count.load(Ordering::Relaxed)
}
/// Check if the monitor is in an error state
pub async fn is_error(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Error { .. })
}
/// Check if the monitor is healthy
pub async fn is_healthy(&self) -> bool {
#[cfg(test)]
pub(crate) async fn is_healthy(&self) -> bool {
matches!(*self.status.read().await, MsdHealthStatus::Healthy)
}
/// Reset the monitor to healthy state without publishing events
///
/// This is useful during initialization.
pub async fn reset(&self) {
self.error_count.store(0, Ordering::Relaxed);
*self.last_error_code.write().await = None;
@@ -162,7 +97,6 @@ impl MsdHealthMonitor {
self.throttler.clear_all();
}
/// Get the current error message if in error state
pub async fn error_message(&self) -> Option<String> {
match &*self.status.read().await {
MsdHealthStatus::Error { reason, .. } => Some(reason.clone()),
@@ -212,13 +146,11 @@ mod tests {
async fn test_report_recovered() {
let monitor = MsdHealthMonitor::with_defaults();
// First report an error
monitor
.report_error("Image not found", "image_not_found")
.await;
assert!(monitor.is_error().await);
// Then report recovery
monitor.report_recovered().await;
assert!(monitor.is_healthy().await);
assert_eq!(monitor.error_count(), 0);

View File

@@ -1,52 +1,38 @@
//! MSD data types and structures
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use time::OffsetDateTime;
/// MSD operating mode
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum MsdMode {
/// No storage connected
#[default]
None,
/// Image file mounted (ISO/IMG)
Image,
/// Virtual drive (FAT32) connected
Drive,
}
/// Image file metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageInfo {
/// Unique image ID
pub id: String,
/// Display name
pub name: String,
/// File path on disk
#[serde(skip_serializing)]
pub path: PathBuf,
/// File size in bytes
pub size: u64,
/// Creation timestamp
pub created_at: DateTime<Utc>,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
}
impl ImageInfo {
/// Create new image info
pub fn new(id: String, name: String, path: PathBuf, size: u64) -> Self {
Self {
id,
name,
path,
size,
created_at: Utc::now(),
created_at: OffsetDateTime::now_utc(),
}
}
/// Format size for display
pub fn size_display(&self) -> String {
const KB: u64 = 1024;
const MB: u64 = KB * 1024;
@@ -64,18 +50,12 @@ impl ImageInfo {
}
}
/// MSD state information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MsdState {
/// Whether MSD feature is available
pub available: bool,
/// Current mode
pub mode: MsdMode,
/// Whether storage is connected to target
pub connected: bool,
/// Currently mounted image (if mode is Image)
pub current_image: Option<ImageInfo>,
/// Virtual drive info (if mode is Drive)
pub drive_info: Option<DriveInfo>,
}
@@ -91,24 +71,17 @@ impl Default for MsdState {
}
}
/// Virtual drive information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveInfo {
/// Drive size in bytes
pub size: u64,
/// Used space in bytes
pub used: u64,
/// Free space in bytes
pub free: u64,
/// Whether drive is initialized
pub initialized: bool,
/// Drive file path
#[serde(skip_serializing)]
pub path: PathBuf,
}
impl DriveInfo {
/// Create new drive info
pub fn new(path: PathBuf, size: u64) -> Self {
Self {
size,
@@ -120,91 +93,60 @@ impl DriveInfo {
}
}
/// File entry in virtual drive
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriveFile {
/// File name
pub name: String,
/// Relative path from drive root
pub path: String,
/// File size in bytes (0 for directories)
pub size: u64,
/// Whether this is a directory
pub is_dir: bool,
/// Last modified timestamp
pub modified: Option<DateTime<Utc>>,
#[serde(with = "time::serde::rfc3339::option")]
pub modified: Option<OffsetDateTime>,
}
/// MSD connect request
#[derive(Debug, Clone, Deserialize)]
pub struct MsdConnectRequest {
/// Connection mode: "image" or "drive"
pub mode: MsdMode,
/// Image ID to mount (required for image mode)
pub image_id: Option<String>,
/// Mount as CD-ROM (optional, defaults based on image type)
#[serde(default)]
pub cdrom: Option<bool>,
/// Mount as read-only
#[serde(default)]
pub read_only: Option<bool>,
}
/// Virtual drive init request
#[derive(Debug, Clone, Deserialize)]
pub struct DriveInitRequest {
/// Drive size in megabytes (defaults to 16GB)
#[serde(default = "default_drive_size")]
pub size_mb: u32,
/// Optional custom path for Ventoy installation
pub ventoy_path: Option<String>,
}
fn default_drive_size() -> u32 {
16 * 1024 // 16GB
16 * 1024
}
/// Image download request
#[derive(Debug, Clone, Deserialize)]
pub struct ImageDownloadRequest {
/// URL to download from
pub url: String,
/// Optional custom filename
pub filename: Option<String>,
}
/// Download status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DownloadStatus {
/// Download has started
Started,
/// Download is in progress
InProgress,
/// Download completed successfully
Completed,
/// Download failed
Failed,
}
/// Download progress information
#[derive(Debug, Clone, Serialize)]
pub struct DownloadProgress {
/// Unique download ID
pub download_id: String,
/// Source URL
pub url: String,
/// Target filename
pub filename: String,
/// Bytes downloaded so far
pub bytes_downloaded: u64,
/// Total file size (None if unknown)
pub total_bytes: Option<u64>,
/// Progress percentage (0.0 - 100.0, None if total unknown)
pub progress_pct: Option<f32>,
/// Download status
pub status: DownloadStatus,
/// Error message if failed
pub error: Option<String>,
}
@@ -218,7 +160,7 @@ mod tests {
"test".into(),
"test.iso".into(),
PathBuf::from("/tmp/test.iso"),
1024 * 1024 * 1024 * 2, // 2 GB
1024 * 1024 * 1024 * 2,
);
assert!(info.size_display().contains("GB"));
}

View File

@@ -1,8 +1,3 @@
//! Ventoy Virtual Drive
//!
//! Replaces FAT32 VirtualDrive with a Ventoy bootable image.
//! Provides a bootable USB with exFAT data partition for ISO files.
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
@@ -13,33 +8,20 @@ use ventoy_img::{FileInfo as VentoyFileInfo, VentoyError, VentoyImage};
use super::types::{DriveFile, DriveInfo};
use crate::error::{AppError, Result};
/// Chunk size for streaming reads (64 KB)
const STREAM_CHUNK_SIZE: usize = 64 * 1024;
/// Minimum drive size (1 GB) - Ventoy requires space for boot partition
const MIN_DRIVE_SIZE_MB: u32 = 1024;
/// Maximum drive size (128 GB)
const MAX_DRIVE_SIZE_MB: u32 = 128 * 1024;
/// Default drive label
const DEFAULT_LABEL: &str = "ONE-KVM";
/// Ventoy Drive Manager
///
/// Thread-safe wrapper around VentoyImage providing async file operations.
/// Uses spawn_blocking for all ventoy-img-rs operations since they are synchronous.
/// Uses RwLock to allow concurrent read operations while serializing writes.
pub struct VentoyDrive {
/// Drive image path
path: PathBuf,
/// RwLock for concurrent reads, exclusive writes
/// (ventoy-img-rs operations are synchronous and not thread-safe)
lock: Arc<RwLock<()>>,
}
impl VentoyDrive {
/// Create new Ventoy drive manager
pub fn new(path: PathBuf) -> Self {
Self {
path,
@@ -47,40 +29,32 @@ impl VentoyDrive {
}
}
/// Check if drive image exists
pub fn exists(&self) -> bool {
self.path.exists()
}
/// Get drive path
pub fn path(&self) -> &PathBuf {
&self.path
}
/// Initialize a new Ventoy drive image
///
/// Creates a bootable Ventoy image with the specified size.
/// The image includes boot partitions and an exFAT data partition.
pub async fn init(&self, size_mb: u32) -> Result<DriveInfo> {
let size_mb = size_mb.clamp(MIN_DRIVE_SIZE_MB, MAX_DRIVE_SIZE_MB);
let size_str = format!("{}M", size_mb);
let path = self.path.clone();
let _lock = self.lock.write().await; // Write lock for initialization
let _lock = self.lock.write().await;
info!("Creating {} MB Ventoy drive at {}", size_mb, path.display());
// Run Ventoy creation in blocking task
let info = tokio::task::spawn_blocking(move || {
VentoyImage::create(&path, &size_str, DEFAULT_LABEL).map_err(ventoy_to_app_error)?;
// Get file metadata for DriveInfo
let metadata = std::fs::metadata(&path)
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
Ok::<DriveInfo, AppError>(DriveInfo {
size: metadata.len(),
used: 0,
free: metadata.len(), // Approximate - exFAT overhead not calculated
free: metadata.len(),
initialized: true,
path,
})
@@ -92,20 +66,18 @@ impl VentoyDrive {
Ok(info)
}
/// Get drive information
pub async fn info(&self) -> Result<DriveInfo> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
let path = self.path.clone();
let _lock = self.lock.read().await; // Read lock for info query
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let metadata = std::fs::metadata(&path)
.map_err(|e| AppError::Internal(format!("Failed to read drive metadata: {}", e)))?;
// Open image to get file list and calculate used space
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
let files = image.list_files_recursive().map_err(ventoy_to_app_error)?;
@@ -116,7 +88,6 @@ impl VentoyDrive {
.map(|f| f.size)
.sum();
// Note: This is approximate since we don't have exact exFAT overhead
let size = metadata.len();
let free = size.saturating_sub(used);
@@ -132,7 +103,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// List files at a given path (or root if empty/"/")
pub async fn list_files(&self, dir_path: &str) -> Result<Vec<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -140,7 +110,7 @@ impl VentoyDrive {
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.read().await; // Read lock for listing
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -161,9 +131,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Write a file to the drive from multipart upload (streaming)
///
/// Streams the file directly into the Ventoy image's exFAT partition.
pub async fn write_file_from_multipart_field(
&self,
file_path: &str,
@@ -173,12 +140,10 @@ impl VentoyDrive {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, stream to a temporary file (to get the size)
let temp_dir = self.path.parent().unwrap_or(Path::new("/tmp"));
let temp_name = format!(".upload_ventoy_{}", uuid::Uuid::new_v4());
let temp_path = temp_dir.join(&temp_name);
// Stream upload to temp file
let mut temp_file = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| AppError::Internal(format!("Failed to create temp file: {}", e)))?;
@@ -201,23 +166,16 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Failed to flush temp file: {}", e)))?;
drop(temp_file);
// Now copy from temp file to Ventoy image
let path = self.path.clone();
let file_path = file_path.to_string();
let temp_path_clone = temp_path.clone();
let _lock = self.lock.write().await; // Write lock for file write
let _lock = self.lock.write().await;
let result = tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use add_file_to_path which handles streaming internally
image
.add_file_to_path(
&temp_path_clone,
&file_path,
true, // create_parents
true, // overwrite
)
.add_file_to_path(&temp_path_clone, &file_path, true, true)
.map_err(ventoy_to_app_error)?;
Ok::<(), AppError>(())
@@ -225,14 +183,13 @@ impl VentoyDrive {
.await
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?;
// Cleanup temp file
let _ = tokio::fs::remove_file(&temp_path).await;
result?;
Ok(bytes_written)
}
/// Read a file from the drive (for download)
#[cfg(test)]
pub async fn read_file(&self, file_path: &str) -> Result<Vec<u8>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -240,7 +197,7 @@ impl VentoyDrive {
let path = self.path.clone();
let file_path = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file read
let _lock = self.lock.read().await;
tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -251,10 +208,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Get file information without reading content
///
/// Returns file size, name, and other metadata.
/// Returns None if the file doesn't exist.
pub async fn get_file_info(&self, file_path: &str) -> Result<Option<DriveFile>> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -262,7 +215,7 @@ impl VentoyDrive {
let path = self.path.clone();
let file_path_owned = file_path.to_string();
let _lock = self.lock.read().await; // Read lock for file info
let _lock = self.lock.read().await;
let info = tokio::task::spawn_blocking(move || {
let image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -282,10 +235,6 @@ impl VentoyDrive {
}))
}
/// Read a file from the drive as a stream (for large file downloads)
///
/// Returns an async channel receiver that yields chunks of file data.
/// This avoids loading the entire file into memory.
pub async fn read_file_stream(
&self,
file_path: &str,
@@ -297,7 +246,6 @@ impl VentoyDrive {
return Err(AppError::Internal("Drive not initialized".to_string()));
}
// First, get the file size
let file_info = self
.get_file_info(file_path)
.await?
@@ -315,15 +263,12 @@ impl VentoyDrive {
let file_path_owned = file_path.to_string();
let lock = self.lock.clone();
// Create a channel for streaming data
let (tx, rx) =
tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, std::io::Error>>(8);
// Spawn blocking task to read and send chunks
tokio::task::spawn_blocking(move || {
// Hold read lock for the entire read operation
let rt = tokio::runtime::Handle::current();
let _lock = rt.block_on(lock.read()); // Read lock for streaming
let _lock = rt.block_on(lock.read());
let image = match VentoyImage::open(&path) {
Ok(img) => img,
@@ -333,10 +278,8 @@ impl VentoyDrive {
}
};
// Create a channel writer that sends chunks
let mut chunk_writer = ChannelWriter::new(tx.clone(), rt.clone());
// Stream the file through the writer
if let Err(e) = image.read_file_to_writer(&file_path_owned, &mut chunk_writer) {
let _ = rt.block_on(tx.send(Err(std::io::Error::other(e.to_string()))));
}
@@ -345,7 +288,6 @@ impl VentoyDrive {
Ok((file_size, rx))
}
/// Create a directory
pub async fn mkdir(&self, dir_path: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -353,7 +295,7 @@ impl VentoyDrive {
let path = self.path.clone();
let dir_path = dir_path.to_string();
let _lock = self.lock.write().await; // Write lock for mkdir
let _lock = self.lock.write().await;
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
@@ -366,7 +308,6 @@ impl VentoyDrive {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Delete a file or directory
pub async fn delete(&self, path_to_delete: &str) -> Result<()> {
if !self.exists() {
return Err(AppError::Internal("Drive not initialized".to_string()));
@@ -374,12 +315,11 @@ impl VentoyDrive {
let path = self.path.clone();
let path_to_delete = path_to_delete.to_string();
let _lock = self.lock.write().await; // Write lock for delete
let _lock = self.lock.write().await;
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).map_err(ventoy_to_app_error)?;
// Use recursive delete to handle directories
image
.remove_recursive(&path_to_delete)
.map_err(ventoy_to_app_error)
@@ -389,7 +329,6 @@ impl VentoyDrive {
}
}
/// Convert VentoyError to AppError
fn ventoy_to_app_error(err: VentoyError) -> AppError {
match err {
VentoyError::Io(e) => AppError::Io(e),
@@ -405,7 +344,6 @@ fn ventoy_to_app_error(err: VentoyError) -> AppError {
}
}
/// Convert VentoyFileInfo to DriveFile
fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFile {
let full_path = if parent_path.is_empty() || parent_path == "/" {
format!("/{}", info.name)
@@ -418,13 +356,10 @@ fn ventoy_file_to_drive_file(info: VentoyFileInfo, parent_path: &str) -> DriveFi
path: full_path,
size: info.size,
is_dir: info.is_directory,
modified: None, // Ventoy FileInfo doesn't include timestamps
modified: None,
}
}
/// A writer that sends chunks to an async channel
///
/// This bridges the sync Write trait with async channels for streaming.
struct ChannelWriter {
tx: tokio::sync::mpsc::Sender<std::result::Result<bytes::Bytes, std::io::Error>>,
rt: tokio::runtime::Handle,
@@ -484,7 +419,6 @@ impl std::io::Write for ChannelWriter {
impl Drop for ChannelWriter {
fn drop(&mut self) {
// Flush any remaining data when the writer is dropped
let _ = self.flush_buffer();
}
}
@@ -496,16 +430,13 @@ mod tests {
use std::sync::OnceLock;
use tempfile::TempDir;
/// Path to ventoy resources directory
static RESOURCE_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../ventoy-img-rs/resources");
/// Initialize ventoy resources once
fn init_ventoy_resources() -> bool {
static INIT: OnceLock<bool> = OnceLock::new();
*INIT.get_or_init(|| {
let resource_path = std::path::Path::new(RESOURCE_DIR);
// Decompress xz files if needed
let core_xz = resource_path.join("core.img.xz");
let core_img = resource_path.join("core.img");
if core_xz.exists() && !core_img.exists() {
@@ -524,7 +455,6 @@ mod tests {
}
}
// Initialize resources
if let Err(e) = ventoy_img::resources::init_resources(resource_path) {
eprintln!("Failed to init ventoy resources: {}", e);
return false;
@@ -534,7 +464,6 @@ mod tests {
})
}
/// Decompress xz file using system command
fn decompress_xz(src: &std::path::Path, dst: &std::path::Path) -> std::io::Result<()> {
let output = Command::new("xz")
.args(["-d", "-k", "-c", src.to_str().unwrap()])
@@ -551,7 +480,6 @@ mod tests {
Ok(())
}
/// Ensure resources are initialized, skip test if failed
fn ensure_resources() -> bool {
if !init_ventoy_resources() {
eprintln!("Skipping test: ventoy resources not available");
@@ -602,15 +530,12 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Write a test file
let test_content = b"Hello, Ventoy!";
let test_file_path = temp_dir.path().join("test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive using ventoy-img directly
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -619,7 +544,6 @@ mod tests {
.await
.unwrap();
// Read file from drive
let read_data = drive.read_file("/test.txt").await.unwrap();
assert_eq!(read_data, test_content);
}
@@ -633,18 +557,14 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a directory
drive.mkdir("/mydir").await.unwrap();
// Write a test file
let test_content = b"Test file content for info check";
let test_file_path = temp_dir.path().join("info_test.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -653,7 +573,6 @@ mod tests {
.await
.unwrap();
// Test get_file_info for file
let file_info = drive.get_file_info("/info_test.txt").await.unwrap();
assert!(file_info.is_some());
let file_info = file_info.unwrap();
@@ -661,14 +580,12 @@ mod tests {
assert_eq!(file_info.size, test_content.len() as u64);
assert!(!file_info.is_dir);
// Test get_file_info for directory
let dir_info = drive.get_file_info("/mydir").await.unwrap();
assert!(dir_info.is_some());
let dir_info = dir_info.unwrap();
assert_eq!(dir_info.name, "mydir");
assert!(dir_info.is_dir);
// Test get_file_info for non-existent file
let not_found = drive.get_file_info("/nonexistent.txt").await.unwrap();
assert!(not_found.is_none());
}
@@ -682,16 +599,13 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create test data that spans multiple chunks (>64KB)
let test_size = 200 * 1024; // 200 KB
let test_size = 200 * 1024;
let test_content: Vec<u8> = (0..test_size).map(|i| (i % 256) as u8).collect();
let test_file_path = temp_dir.path().join("large_file.bin");
std::fs::write(&test_file_path, &test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
let file_path_clone = test_file_path.clone();
tokio::task::spawn_blocking(move || {
@@ -701,18 +615,15 @@ mod tests {
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/large_file.bin").await.unwrap();
assert_eq!(file_size, test_size as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.len(), test_content.len());
assert_eq!(received_data, test_content);
}
@@ -726,15 +637,12 @@ mod tests {
let drive_path = temp_dir.path().join("test_ventoy.img");
let drive = VentoyDrive::new(drive_path.clone());
// Initialize drive
drive.init(MIN_DRIVE_SIZE_MB).await.unwrap();
// Create a small test file
let test_content = b"Small file for streaming test";
let test_file_path = temp_dir.path().join("small.txt");
std::fs::write(&test_file_path, test_content).unwrap();
// Add file to drive
let path = drive.path().clone();
tokio::task::spawn_blocking(move || {
let mut image = VentoyImage::open(&path).unwrap();
@@ -743,18 +651,15 @@ mod tests {
.await
.unwrap();
// Stream read the file
let (file_size, mut rx) = drive.read_file_stream("/small.txt").await.unwrap();
assert_eq!(file_size, test_content.len() as u64);
// Collect all chunks
let mut received_data = Vec::new();
while let Some(chunk_result) = rx.recv().await {
let chunk = chunk_result.expect("Chunk should not be an error");
received_data.extend_from_slice(&chunk);
}
// Verify data matches
assert_eq!(received_data.as_slice(), test_content);
}
}

View File

@@ -1,5 +1,3 @@
//! ConfigFS file operations for USB Gadget
use std::fs::{self, File, OpenOptions};
use std::io::Write;
use std::path::Path;
@@ -7,34 +5,18 @@ use std::process::Command;
use crate::error::{AppError, Result};
/// ConfigFS base path for USB gadgets
pub const CONFIGFS_PATH: &str = "/sys/kernel/config/usb_gadget";
/// Default gadget name
pub const DEFAULT_GADGET_NAME: &str = "one-kvm";
/// USB Vendor ID (Linux Foundation) - default value
pub const DEFAULT_USB_VENDOR_ID: u16 = 0x1d6b;
/// USB Product ID (Multifunction Composite Gadget) - default value
pub const DEFAULT_USB_PRODUCT_ID: u16 = 0x0104;
/// USB device version - default value
pub const DEFAULT_USB_BCD_DEVICE: u16 = 0x0100;
/// USB spec version (USB 2.0)
pub const USB_BCD_USB: u16 = 0x0200;
/// Check if ConfigFS is available
pub fn is_configfs_available() -> bool {
Path::new(CONFIGFS_PATH).exists()
}
/// Ensure libcomposite support is available for USB gadget operations.
///
/// This is a best-effort runtime fallback for systems where `libcomposite`
/// is built as a module and not loaded yet. It does not try to mount configfs;
/// mounting remains an explicit system responsibility.
/// Loads `libcomposite` if needed; does not mount configfs.
pub fn ensure_libcomposite_loaded() -> Result<()> {
if is_configfs_available() {
return Ok(());
@@ -66,7 +48,6 @@ pub fn ensure_libcomposite_loaded() -> Result<()> {
}
}
/// Find available UDC (USB Device Controller)
pub fn find_udc() -> Option<String> {
let udc_path = Path::new("/sys/class/udc");
if !udc_path.exists() {
@@ -80,40 +61,17 @@ pub fn find_udc() -> Option<String> {
.next()
}
/// Check if UDC is known to have low endpoint resources
pub fn is_low_endpoint_udc(name: &str) -> bool {
let name = name.to_ascii_lowercase();
name.contains("musb") || name.contains("musb-hdrc")
}
/// Resolve preferred UDC name if available, otherwise auto-detect
pub fn resolve_udc_name(preferred: Option<&str>) -> Option<String> {
if let Some(name) = preferred {
let path = Path::new("/sys/class/udc").join(name);
if path.exists() {
return Some(name.to_string());
}
}
find_udc()
}
/// Write string content to a file
///
/// For sysfs files, this function appends a newline and flushes
/// to ensure the kernel processes the write immediately.
///
/// IMPORTANT: sysfs attributes require a single atomic write() syscall.
/// The kernel processes the value on the first write(), so we must
/// build the complete buffer (including newline) before writing.
/// Sysfs/configfs: one write syscall with final buffer (incl. newline when needed).
pub fn write_file(path: &Path, content: &str) -> Result<()> {
// For sysfs files (especially write-only ones like forced_eject),
// we need to use simple O_WRONLY without O_TRUNC
// O_TRUNC may fail on special files or require read permission
let mut file = OpenOptions::new()
.write(true)
.open(path)
.or_else(|e| {
// If open fails, try create (for regular files)
if path.exists() {
Err(e)
} else {
@@ -122,9 +80,6 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
})
.map_err(|e| AppError::Internal(format!("Failed to open {}: {}", path.display(), e)))?;
// Build complete buffer with newline, then write in single syscall.
// This is critical for sysfs - multiple write() calls may cause
// the kernel to only process partial data or return EINVAL.
let data: std::borrow::Cow<[u8]> = if content.ends_with('\n') {
content.as_bytes().into()
} else {
@@ -136,14 +91,12 @@ pub fn write_file(path: &Path, content: &str) -> Result<()> {
file.write_all(&data)
.map_err(|e| AppError::Internal(format!("Failed to write to {}: {}", path.display(), e)))?;
// Explicitly flush to ensure sysfs processes the write
file.flush()
.map_err(|e| AppError::Internal(format!("Failed to flush {}: {}", path.display(), e)))?;
Ok(())
}
/// Write binary content to a file
pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
let mut file = File::create(path)
.map_err(|e| AppError::Internal(format!("Failed to create {}: {}", path.display(), e)))?;
@@ -154,14 +107,6 @@ pub fn write_bytes(path: &Path, data: &[u8]) -> Result<()> {
Ok(())
}
/// Read string content from a file
pub fn read_file(path: &Path) -> Result<String> {
fs::read_to_string(path)
.map(|s| s.trim().to_string())
.map_err(|e| AppError::Internal(format!("Failed to read {}: {}", path.display(), e)))
}
/// Create directory if not exists
pub fn create_dir(path: &Path) -> Result<()> {
fs::create_dir_all(path).map_err(|e| {
AppError::Internal(format!(
@@ -172,7 +117,6 @@ pub fn create_dir(path: &Path) -> Result<()> {
})
}
/// Remove directory
pub fn remove_dir(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_dir(path).map_err(|e| {
@@ -186,7 +130,6 @@ pub fn remove_dir(path: &Path) -> Result<()> {
Ok(())
}
/// Remove file
pub fn remove_file(path: &Path) -> Result<()> {
if path.exists() {
fs::remove_file(path).map_err(|e| {
@@ -196,7 +139,6 @@ pub fn remove_file(path: &Path) -> Result<()> {
Ok(())
}
/// Create symlink
pub fn create_symlink(src: &Path, dest: &Path) -> Result<()> {
std::os::unix::fs::symlink(src, dest).map_err(|e| {
AppError::Internal(format!(

View File

@@ -1,11 +1,7 @@
//! USB Endpoint allocation management
use crate::error::{AppError, Result};
/// Default maximum endpoints for typical UDC
pub const DEFAULT_MAX_ENDPOINTS: u8 = 16;
/// Endpoint allocator - manages UDC endpoint resources
#[derive(Debug, Clone)]
pub struct EndpointAllocator {
max_endpoints: u8,
@@ -13,7 +9,6 @@ pub struct EndpointAllocator {
}
impl EndpointAllocator {
/// Create a new endpoint allocator
pub fn new(max_endpoints: u8) -> Self {
Self {
max_endpoints,
@@ -21,7 +16,6 @@ impl EndpointAllocator {
}
}
/// Allocate endpoints for a function
pub fn allocate(&mut self, count: u8) -> Result<()> {
if self.used_endpoints + count > self.max_endpoints {
return Err(AppError::Internal(format!(
@@ -34,27 +28,22 @@ impl EndpointAllocator {
Ok(())
}
/// Release endpoints
pub fn release(&mut self, count: u8) {
self.used_endpoints = self.used_endpoints.saturating_sub(count);
}
/// Get available endpoint count
pub fn available(&self) -> u8 {
self.max_endpoints.saturating_sub(self.used_endpoints)
}
/// Get used endpoint count
pub fn used(&self) -> u8 {
self.used_endpoints
}
/// Get maximum endpoint count
pub fn max(&self) -> u8 {
self.max_endpoints
}
/// Check if can allocate
pub fn can_allocate(&self, count: u8) -> bool {
self.available() >= count
}
@@ -82,7 +71,6 @@ mod tests {
alloc.allocate(4).unwrap();
assert_eq!(alloc.available(), 2);
// Should fail - not enough endpoints
assert!(alloc.allocate(3).is_err());
alloc.release(2);

View File

@@ -1,42 +1,17 @@
//! USB Gadget Function trait definition
use std::path::Path;
use crate::error::Result;
/// Function metadata
#[derive(Debug, Clone)]
pub struct FunctionMeta {
/// Function name (e.g., "hid.usb0")
pub name: String,
/// Human-readable description
pub description: String,
/// Number of endpoints used
pub endpoints: u8,
/// Whether the function is enabled
pub enabled: bool,
}
/// USB Gadget Function trait
pub trait GadgetFunction: Send + Sync {
/// Get function name (e.g., "hid.usb0", "mass_storage.usb0")
fn name(&self) -> &str;
/// Get number of endpoints required
fn endpoints_required(&self) -> u8;
/// Get function metadata
fn meta(&self) -> FunctionMeta;
/// Create function directory and configuration in ConfigFS
fn create(&self, gadget_path: &Path) -> Result<()>;
/// Link function to configuration
fn link(&self, config_path: &Path, gadget_path: &Path) -> Result<()>;
/// Unlink function from configuration
fn unlink(&self, config_path: &Path) -> Result<()>;
/// Cleanup function directory
fn cleanup(&self, gadget_path: &Path) -> Result<()>;
}

View File

@@ -1,35 +1,24 @@
//! HID Function implementation for USB Gadget
use std::path::{Path, PathBuf};
use tracing::debug;
use super::configfs::{
create_dir, create_symlink, remove_dir, remove_file, write_bytes, write_file,
};
use super::function::{FunctionMeta, GadgetFunction};
use super::function::GadgetFunction;
use super::report_desc::{
CONSUMER_CONTROL, KEYBOARD, KEYBOARD_WITH_LED, MOUSE_ABSOLUTE, MOUSE_RELATIVE,
};
use crate::error::Result;
/// HID function type
#[derive(Debug, Clone)]
pub enum HidFunctionType {
/// Keyboard
Keyboard,
/// Relative mouse (traditional mouse movement)
/// Uses 1 endpoint: IN
MouseRelative,
/// Absolute mouse (touchscreen-like positioning)
/// Uses 1 endpoint: IN
MouseAbsolute,
/// Consumer control (multimedia keys)
/// Uses 1 endpoint: IN
ConsumerControl,
}
impl HidFunctionType {
/// Get the base endpoint cost for this function type.
pub fn endpoints(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1,
@@ -39,27 +28,24 @@ impl HidFunctionType {
}
}
/// Get HID protocol
pub fn protocol(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1, // Keyboard
HidFunctionType::MouseRelative => 2, // Mouse
HidFunctionType::MouseAbsolute => 2, // Mouse
HidFunctionType::ConsumerControl => 0, // None
HidFunctionType::Keyboard => 1,
HidFunctionType::MouseRelative => 2,
HidFunctionType::MouseAbsolute => 2,
HidFunctionType::ConsumerControl => 0,
}
}
/// Get HID subclass
pub fn subclass(&self) -> u8 {
match self {
HidFunctionType::Keyboard => 1, // Boot interface
HidFunctionType::MouseRelative => 1, // Boot interface
HidFunctionType::MouseAbsolute => 0, // No boot interface
HidFunctionType::ConsumerControl => 0, // No boot interface
HidFunctionType::Keyboard => 1,
HidFunctionType::MouseRelative => 1,
HidFunctionType::MouseAbsolute => 0,
HidFunctionType::ConsumerControl => 0,
}
}
/// Get report length in bytes
pub fn report_length(&self, _keyboard_leds: bool) -> u8 {
match self {
HidFunctionType::Keyboard => 8,
@@ -69,7 +55,6 @@ impl HidFunctionType {
}
}
/// Get report descriptor
pub fn report_desc(&self, keyboard_leds: bool) -> &'static [u8] {
match self {
HidFunctionType::Keyboard => {
@@ -84,33 +69,17 @@ impl HidFunctionType {
HidFunctionType::ConsumerControl => CONSUMER_CONTROL,
}
}
/// Get description
pub fn description(&self) -> &'static str {
match self {
HidFunctionType::Keyboard => "Keyboard",
HidFunctionType::MouseRelative => "Relative Mouse",
HidFunctionType::MouseAbsolute => "Absolute Mouse",
HidFunctionType::ConsumerControl => "Consumer Control",
}
}
}
/// HID Function for USB Gadget
#[derive(Debug, Clone)]
pub struct HidFunction {
/// Instance number (usb0, usb1, ...)
instance: u8,
/// Function type
func_type: HidFunctionType,
/// Cached function name (avoids repeated allocation)
name: String,
/// Whether keyboard LED/status feedback is enabled.
keyboard_leds: bool,
}
impl HidFunction {
/// Create a keyboard function
pub fn keyboard(instance: u8, keyboard_leds: bool) -> Self {
Self {
instance,
@@ -120,7 +89,6 @@ impl HidFunction {
}
}
/// Create a relative mouse function
pub fn mouse_relative(instance: u8) -> Self {
Self {
instance,
@@ -130,7 +98,6 @@ impl HidFunction {
}
}
/// Create an absolute mouse function
pub fn mouse_absolute(instance: u8) -> Self {
Self {
instance,
@@ -140,7 +107,6 @@ impl HidFunction {
}
}
/// Create a consumer control function
pub fn consumer_control(instance: u8) -> Self {
Self {
instance,
@@ -150,12 +116,10 @@ impl HidFunction {
}
}
/// Get function path in gadget
fn function_path(&self, gadget_path: &Path) -> PathBuf {
gadget_path.join("functions").join(self.name())
}
/// Get expected device path (e.g., /dev/hidg0)
pub fn device_path(&self) -> PathBuf {
PathBuf::from(format!("/dev/hidg{}", self.instance))
}
@@ -170,20 +134,10 @@ impl GadgetFunction for HidFunction {
self.func_type.endpoints()
}
fn meta(&self) -> FunctionMeta {
FunctionMeta {
name: self.name().to_string(),
description: self.func_type.description().to_string(),
endpoints: self.endpoints_required(),
enabled: true,
}
}
fn create(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
create_dir(&func_path)?;
// Set HID parameters
write_file(
&func_path.join("protocol"),
&self.func_type.protocol().to_string(),
@@ -197,7 +151,6 @@ impl GadgetFunction for HidFunction {
&self.func_type.report_length(self.keyboard_leds).to_string(),
)?;
// Write report descriptor
write_bytes(
&func_path.join("report_desc"),
self.func_type.report_desc(self.keyboard_leds),

View File

@@ -1,6 +1,3 @@
//! OTG Gadget Manager - unified management for USB Gadget functions
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tracing::{debug, error, info, warn};
@@ -11,14 +8,13 @@ use super::configfs::{
DEFAULT_USB_VENDOR_ID, USB_BCD_USB,
};
use super::endpoint::{EndpointAllocator, DEFAULT_MAX_ENDPOINTS};
use super::function::{FunctionMeta, GadgetFunction};
use super::function::GadgetFunction;
use super::hid::HidFunction;
use super::msd::MsdFunction;
use crate::error::{AppError, Result};
const REBIND_DELAY_MS: u64 = 300;
/// USB Gadget device descriptor configuration
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GadgetDescriptor {
pub vendor_id: u16,
@@ -42,44 +38,28 @@ impl Default for GadgetDescriptor {
}
}
/// OTG Gadget Manager - unified management for HID and MSD
pub struct OtgGadgetManager {
/// Gadget name
gadget_name: String,
/// Gadget path in ConfigFS
gadget_path: PathBuf,
/// Configuration path
config_path: PathBuf,
/// Device descriptor
descriptor: GadgetDescriptor,
/// Endpoint allocator
endpoint_allocator: EndpointAllocator,
/// HID instance counter
hid_instance: u8,
/// MSD instance counter
msd_instance: u8,
/// Registered functions
functions: Vec<Box<dyn GadgetFunction>>,
/// Function metadata
meta: HashMap<String, FunctionMeta>,
/// Bound UDC name
bound_udc: Option<String>,
/// Whether gadget was created by us
created_by_us: bool,
}
impl OtgGadgetManager {
/// Create a new gadget manager with default settings
pub fn new() -> Self {
Self::with_config(DEFAULT_GADGET_NAME, DEFAULT_MAX_ENDPOINTS)
}
/// Create a new gadget manager with custom configuration
pub fn with_config(gadget_name: &str, max_endpoints: u8) -> Self {
Self::with_descriptor(gadget_name, max_endpoints, GadgetDescriptor::default())
}
/// Create a new gadget manager with custom descriptor
pub fn with_descriptor(
gadget_name: &str,
max_endpoints: u8,
@@ -96,30 +76,24 @@ impl OtgGadgetManager {
endpoint_allocator: EndpointAllocator::new(max_endpoints),
hid_instance: 0,
msd_instance: 0,
// Pre-allocate for typical use: 3 HID (keyboard, rel mouse, abs mouse) + 1 MSD
functions: Vec::with_capacity(4),
meta: HashMap::with_capacity(4),
bound_udc: None,
created_by_us: false,
}
}
/// Check if ConfigFS is available
pub fn is_available() -> bool {
is_configfs_available()
}
/// Find available UDC
pub fn find_udc() -> Option<String> {
find_udc()
}
/// Check if gadget exists
pub fn gadget_exists(&self) -> bool {
self.gadget_path.exists()
}
/// Check if gadget is bound to UDC
pub fn is_bound(&self) -> bool {
let udc_file = self.gadget_path.join("UDC");
if let Ok(content) = fs::read_to_string(&udc_file) {
@@ -129,8 +103,6 @@ impl OtgGadgetManager {
}
}
/// Add keyboard function
/// Returns the expected device path (e.g., /dev/hidg0)
pub fn add_keyboard(&mut self, keyboard_leds: bool) -> Result<PathBuf> {
let func = HidFunction::keyboard(self.hid_instance, keyboard_leds);
let device_path = func.device_path();
@@ -139,7 +111,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add relative mouse function
pub fn add_mouse_relative(&mut self) -> Result<PathBuf> {
let func = HidFunction::mouse_relative(self.hid_instance);
let device_path = func.device_path();
@@ -148,7 +119,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add absolute mouse function
pub fn add_mouse_absolute(&mut self) -> Result<PathBuf> {
let func = HidFunction::mouse_absolute(self.hid_instance);
let device_path = func.device_path();
@@ -157,7 +127,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add consumer control function (multimedia keys)
pub fn add_consumer_control(&mut self) -> Result<PathBuf> {
let func = HidFunction::consumer_control(self.hid_instance);
let device_path = func.device_path();
@@ -166,7 +135,6 @@ impl OtgGadgetManager {
Ok(device_path)
}
/// Add MSD function (returns MsdFunction handle for LUN configuration)
pub fn add_msd(&mut self) -> Result<MsdFunction> {
let func = MsdFunction::new(self.msd_instance);
let func_clone = func.clone();
@@ -175,11 +143,9 @@ impl OtgGadgetManager {
Ok(func_clone)
}
/// Add a generic function
fn add_function(&mut self, func: Box<dyn GadgetFunction>) -> Result<()> {
let endpoints = func.endpoints_required();
// Check endpoint availability
if !self.endpoint_allocator.can_allocate(endpoints) {
return Err(AppError::Internal(format!(
"Not enough endpoints for function {}: need {}, available {}",
@@ -189,70 +155,55 @@ impl OtgGadgetManager {
)));
}
// Allocate endpoints
self.endpoint_allocator.allocate(endpoints)?;
// Store metadata
self.meta.insert(func.name().to_string(), func.meta());
// Store function
self.functions.push(func);
Ok(())
}
/// Setup the gadget (create directories and configure)
pub fn setup(&mut self) -> Result<()> {
info!("Setting up OTG USB Gadget: {}", self.gadget_name);
debug!("Setting up OTG USB Gadget: {}", self.gadget_name);
// Check ConfigFS availability
if !Self::is_available() {
return Err(AppError::Internal(
"ConfigFS not available. Is it mounted at /sys/kernel/config?".to_string(),
));
}
// Check if gadget already exists and is bound
if self.gadget_exists() {
if self.is_bound() {
info!("Gadget already exists and is bound, skipping setup");
debug!("Gadget already exists and is bound, skipping setup");
return Ok(());
}
warn!("Gadget exists but not bound, will reconfigure");
self.cleanup()?;
}
// Create gadget directory
create_dir(&self.gadget_path)?;
self.created_by_us = true;
// Set device descriptors
self.set_device_descriptors()?;
// Create strings
self.create_strings()?;
// Create configuration
self.create_configuration()?;
// Create and link all functions
for func in &self.functions {
func.create(&self.gadget_path)?;
func.link(&self.config_path, &self.gadget_path)?;
}
info!("OTG USB Gadget setup complete");
debug!("OTG USB Gadget setup complete");
Ok(())
}
/// Bind gadget to a specific UDC
pub fn bind(&mut self, udc: &str) -> Result<()> {
// Recreate config symlinks before binding to avoid kernel gadget issues after rebind
if let Err(e) = self.recreate_config_links() {
warn!("Failed to recreate gadget config links before bind: {}", e);
}
info!("Binding gadget to UDC: {}", udc);
debug!("Binding gadget to UDC: {}", udc);
write_file(&self.gadget_path.join("UDC"), &udc)?;
self.bound_udc = Some(udc.to_string());
std::thread::sleep(std::time::Duration::from_millis(REBIND_DELAY_MS));
@@ -260,7 +211,6 @@ impl OtgGadgetManager {
Ok(())
}
/// Unbind gadget from UDC
pub fn unbind(&mut self) -> Result<()> {
if self.is_bound() {
write_file(&self.gadget_path.join("UDC"), "")?;
@@ -271,7 +221,6 @@ impl OtgGadgetManager {
Ok(())
}
/// Cleanup all resources
pub fn cleanup(&mut self) -> Result<()> {
if !self.gadget_exists() {
return Ok(());
@@ -279,29 +228,23 @@ impl OtgGadgetManager {
info!("Cleaning up OTG USB Gadget: {}", self.gadget_name);
// Unbind from UDC first
let _ = self.unbind();
// Unlink and cleanup functions
for func in self.functions.iter().rev() {
let _ = func.unlink(&self.config_path);
}
// Remove config strings
let config_strings = self.config_path.join("strings/0x409");
let _ = remove_dir(&config_strings);
let _ = remove_dir(&self.config_path);
// Cleanup functions
for func in self.functions.iter().rev() {
let _ = func.cleanup(&self.gadget_path);
}
// Remove gadget strings
let gadget_strings = self.gadget_path.join("strings/0x409");
let _ = remove_dir(&gadget_strings);
// Remove gadget directory
if let Err(e) = remove_dir(&self.gadget_path) {
warn!("Could not remove gadget directory: {}", e);
}
@@ -311,7 +254,6 @@ impl OtgGadgetManager {
Ok(())
}
/// Set USB device descriptors
fn set_device_descriptors(&self) -> Result<()> {
write_file(
&self.gadget_path.join("idVendor"),
@@ -329,14 +271,13 @@ impl OtgGadgetManager {
&self.gadget_path.join("bcdUSB"),
&format!("0x{:04x}", USB_BCD_USB),
)?;
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?; // Composite device
write_file(&self.gadget_path.join("bDeviceClass"), "0x00")?;
write_file(&self.gadget_path.join("bDeviceSubClass"), "0x00")?;
write_file(&self.gadget_path.join("bDeviceProtocol"), "0x00")?;
debug!("Set device descriptors");
Ok(())
}
/// Create USB strings
fn create_strings(&self) -> Result<()> {
let strings_path = self.gadget_path.join("strings/0x409");
create_dir(&strings_path)?;
@@ -354,41 +295,23 @@ impl OtgGadgetManager {
Ok(())
}
/// Create configuration
fn create_configuration(&self) -> Result<()> {
create_dir(&self.config_path)?;
// Create config strings
let strings_path = self.config_path.join("strings/0x409");
create_dir(&strings_path)?;
write_file(&strings_path.join("configuration"), "Config 1: HID + MSD")?;
// Set max power (500mA)
write_file(&self.config_path.join("MaxPower"), "500")?;
debug!("Created configuration c.1");
Ok(())
}
/// Get function metadata
pub fn get_meta(&self) -> &HashMap<String, FunctionMeta> {
&self.meta
}
/// Get endpoint usage info
pub fn endpoint_info(&self) -> (u8, u8) {
(
self.endpoint_allocator.used(),
self.endpoint_allocator.max(),
)
}
/// Get gadget path
pub fn gadget_path(&self) -> &PathBuf {
&self.gadget_path
}
/// Recreate config symlinks from functions directory
fn recreate_config_links(&self) -> Result<()> {
let functions_path = self.gadget_path.join("functions");
if !functions_path.exists() || !self.config_path.exists() {
@@ -450,15 +373,10 @@ impl Drop for OtgGadgetManager {
}
}
/// Wait for HID devices to become available
///
/// Uses exponential backoff starting from 10ms, capped at 100ms,
/// to reduce CPU usage while still providing fast response.
pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) -> bool {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_millis(timeout_ms);
// Exponential backoff: start at 10ms, double each time, cap at 100ms
let mut delay_ms = 10u64;
const MAX_DELAY_MS: u64 = 100;
@@ -467,7 +385,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) ->
return true;
}
// Calculate remaining time to avoid overshooting timeout
let remaining = timeout.saturating_sub(start.elapsed());
let sleep_duration = std::time::Duration::from_millis(delay_ms).min(remaining);
@@ -477,7 +394,6 @@ pub async fn wait_for_hid_devices(device_paths: &[PathBuf], timeout_ms: u64) ->
tokio::time::sleep(sleep_duration).await;
// Exponential backoff with cap
delay_ms = (delay_ms * 2).min(MAX_DELAY_MS);
}
@@ -492,18 +408,16 @@ mod tests {
fn test_manager_creation() {
let manager = OtgGadgetManager::new();
assert_eq!(manager.gadget_name, DEFAULT_GADGET_NAME);
assert!(!manager.gadget_exists()); // Won't exist in test environment
assert!(!manager.gadget_exists());
}
#[test]
fn test_endpoint_tracking() {
let mut manager = OtgGadgetManager::with_config("test", 8);
// Keyboard uses 1 endpoint
let _ = manager.add_keyboard(false);
assert_eq!(manager.endpoint_allocator.used(), 1);
// Mouse uses 1 endpoint each
let _ = manager.add_mouse_relative();
let _ = manager.add_mouse_absolute();
assert_eq!(manager.endpoint_allocator.used(), 3);

View File

@@ -1,21 +1,4 @@
//! OTG USB Gadget unified management module
//!
//! This module provides unified management for USB Gadget functions:
//! - HID (Keyboard, Mouse)
//! - MSD (Mass Storage Device)
//!
//! Architecture:
//! ```text
//! OtgService (high-level coordination)
//! └── OtgGadgetManager (gadget lifecycle)
//! ├── EndpointAllocator (manages UDC endpoints)
//! ├── HidFunction (keyboard, mouse_rel, mouse_abs)
//! └── MsdFunction (mass storage)
//! ```
//!
//! The recommended way to use this module is through `OtgService`, which provides
//! a high-level interface for enabling/disabling HID and MSD functions independently.
//! Both `HidController` and `MsdController` should share the same `OtgService` instance.
//! USB OTG composite gadget (HID + MSD).
pub mod configfs;
pub mod endpoint;
@@ -26,10 +9,6 @@ pub mod msd;
pub mod report_desc;
pub mod service;
pub use endpoint::EndpointAllocator;
pub use function::{FunctionMeta, GadgetFunction};
pub use hid::{HidFunction, HidFunctionType};
pub use manager::{wait_for_hid_devices, OtgGadgetManager};
pub use msd::{MsdFunction, MsdLunConfig};
pub use report_desc::{KEYBOARD, MOUSE_ABSOLUTE, MOUSE_RELATIVE};
pub use service::{HidDevicePaths, OtgDesiredState, OtgService, OtgServiceState};
pub use service::{HidDevicePaths, OtgService};

View File

@@ -1,25 +1,17 @@
//! MSD (Mass Storage Device) Function implementation for USB Gadget
use std::fs;
use std::path::{Path, PathBuf};
use tracing::{debug, info, warn};
use super::configfs::{create_dir, create_symlink, remove_dir, remove_file, write_file};
use super::function::{FunctionMeta, GadgetFunction};
use super::function::GadgetFunction;
use crate::error::{AppError, Result};
/// MSD LUN configuration
#[derive(Debug, Clone)]
pub struct MsdLunConfig {
/// File/image path to expose
pub file: PathBuf,
/// Mount as CD-ROM
pub cdrom: bool,
/// Read-only mode
pub ro: bool,
/// Removable media
pub removable: bool,
/// Disable Force Unit Access
pub nofua: bool,
}
@@ -36,7 +28,6 @@ impl Default for MsdLunConfig {
}
impl MsdLunConfig {
/// Create CD-ROM configuration
pub fn cdrom(file: PathBuf) -> Self {
Self {
file,
@@ -47,7 +38,6 @@ impl MsdLunConfig {
}
}
/// Create disk configuration
pub fn disk(file: PathBuf, read_only: bool) -> Self {
Self {
file,
@@ -59,38 +49,26 @@ impl MsdLunConfig {
}
}
/// MSD Function for USB Gadget
#[derive(Debug, Clone)]
pub struct MsdFunction {
/// Instance number (usb0, usb1, ...)
instance: u8,
/// Cached function name (avoids repeated allocation)
name: String,
}
impl MsdFunction {
/// Create a new MSD function
pub fn new(instance: u8) -> Self {
Self {
instance,
name: format!("mass_storage.usb{}", instance),
}
}
/// Get function path in gadget
fn function_path(&self, gadget_path: &Path) -> PathBuf {
gadget_path.join("functions").join(self.name())
}
/// Get LUN path
fn lun_path(&self, gadget_path: &Path, lun: u8) -> PathBuf {
self.function_path(gadget_path).join(format!("lun.{}", lun))
}
/// Configure a LUN with specified settings (async version)
///
/// This is the preferred method for async contexts. It runs the blocking
/// file I/O and USB timing operations in a separate thread pool.
pub async fn configure_lun_async(
&self,
gadget_path: &Path,
@@ -106,17 +84,6 @@ impl MsdFunction {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Configure a LUN with specified settings
/// Note: This should be called after the gadget is set up
///
/// This implementation is based on PiKVM's MSD drive configuration.
/// Key improvements:
/// - Uses forced_eject when available (safer than clearing file directly)
/// - Reduced sleep times to minimize HID interference
/// - Better retry logic for EBUSY errors
///
/// **Note**: This is a blocking function. In async contexts, prefer
/// `configure_lun_async` to avoid blocking the runtime.
pub fn configure_lun(&self, gadget_path: &Path, lun: u8, config: &MsdLunConfig) -> Result<()> {
let lun_path = self.lun_path(gadget_path, lun);
@@ -124,7 +91,6 @@ impl MsdFunction {
create_dir(&lun_path)?;
}
// Batch read all current values to minimize syscalls
let read_attr = |attr: &str| -> String {
fs::read_to_string(lun_path.join(attr))
.unwrap_or_default()
@@ -137,28 +103,21 @@ impl MsdFunction {
let current_removable = read_attr("removable");
let current_nofua = read_attr("nofua");
// Prepare new values
let new_cdrom = if config.cdrom { "1" } else { "0" };
let new_ro = if config.ro { "1" } else { "0" };
let new_removable = if config.removable { "1" } else { "0" };
let new_nofua = if config.nofua { "1" } else { "0" };
// Disconnect current file first using forced_eject if available (PiKVM approach)
let forced_eject_path = lun_path.join("forced_eject");
if forced_eject_path.exists() {
// forced_eject is safer - it forcibly detaches regardless of host state
debug!("Using forced_eject to clear LUN {}", lun);
let _ = write_file(&forced_eject_path, "1");
} else {
// Fallback to clearing file directly
let _ = write_file(&lun_path.join("file"), "");
}
// Brief yield to allow USB stack to process the disconnect
// Reduced from 200ms to 50ms - let USB protocol handle timing
std::thread::sleep(std::time::Duration::from_millis(50));
// Write only changed attributes
let cdrom_changed = current_cdrom != new_cdrom;
if cdrom_changed {
debug!(
@@ -186,13 +145,11 @@ impl MsdFunction {
write_file(&lun_path.join("nofua"), new_nofua)?;
}
// If cdrom mode changed, brief yield for USB host
if cdrom_changed {
debug!("CDROM mode changed, brief yield for USB host");
std::thread::sleep(std::time::Duration::from_millis(50));
}
// Set file path (this triggers the actual mount) - with retry on EBUSY
if config.file.exists() {
let file_path = config.file.to_string_lossy();
let mut last_error = None;
@@ -210,7 +167,6 @@ impl MsdFunction {
return Ok(());
}
Err(e) => {
// Check if it's EBUSY (error code 16)
let is_busy = e.to_string().contains("Device or resource busy")
|| e.to_string().contains("os error 16");
@@ -220,7 +176,6 @@ impl MsdFunction {
lun,
attempt + 1
);
// Exponential backoff: 50, 100, 200, 400ms
std::thread::sleep(std::time::Duration::from_millis(50 << attempt));
last_error = Some(e);
continue;
@@ -231,7 +186,6 @@ impl MsdFunction {
}
}
// If we get here, all retries failed
if let Some(e) = last_error {
return Err(e);
}
@@ -242,9 +196,6 @@ impl MsdFunction {
Ok(())
}
/// Disconnect LUN (async version)
///
/// Preferred for async contexts.
pub async fn disconnect_lun_async(&self, gadget_path: &Path, lun: u8) -> Result<()> {
let gadget_path = gadget_path.to_path_buf();
let this = self.clone();
@@ -254,17 +205,10 @@ impl MsdFunction {
.map_err(|e| AppError::Internal(format!("Task join error: {}", e)))?
}
/// Disconnect LUN (clear file)
///
/// This method uses forced_eject when available, which is safer than
/// directly clearing the file path. Based on PiKVM's implementation.
/// See: https://docs.kernel.org/usb/mass-storage.html
pub fn disconnect_lun(&self, gadget_path: &Path, lun: u8) -> Result<()> {
let lun_path = self.lun_path(gadget_path, lun);
if lun_path.exists() {
// Prefer forced_eject if available (PiKVM approach)
// forced_eject forcibly detaches the backing file regardless of host state
let forced_eject_path = lun_path.join("forced_eject");
if forced_eject_path.exists() {
debug!(
@@ -282,7 +226,6 @@ impl MsdFunction {
}
}
} else {
// Fallback to clearing file directly
write_file(&lun_path.join("file"), "")?;
}
info!("LUN {} disconnected", lun);
@@ -291,7 +234,6 @@ impl MsdFunction {
Ok(())
}
/// Get current LUN file path
pub fn get_lun_file(&self, gadget_path: &Path, lun: u8) -> Option<PathBuf> {
let lun_path = self.lun_path(gadget_path, lun);
let file_path = lun_path.join("file");
@@ -306,7 +248,6 @@ impl MsdFunction {
None
}
/// Check if LUN is connected
pub fn is_lun_connected(&self, gadget_path: &Path, lun: u8) -> bool {
self.get_lun_file(gadget_path, lun).is_some()
}
@@ -318,39 +259,23 @@ impl GadgetFunction for MsdFunction {
}
fn endpoints_required(&self) -> u8 {
2 // IN + OUT for bulk transfers
}
fn meta(&self) -> FunctionMeta {
FunctionMeta {
name: self.name().to_string(),
description: if self.instance == 0 {
"Mass Storage Drive".to_string()
} else {
format!("Extra Drive #{}", self.instance)
},
endpoints: self.endpoints_required(),
enabled: true,
}
2
}
fn create(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
create_dir(&func_path)?;
// Set stall to 0 (workaround for some hosts)
let stall_path = func_path.join("stall");
if stall_path.exists() {
let _ = write_file(&stall_path, "0");
}
// LUN 0 is created automatically, but ensure it exists
let lun0_path = func_path.join("lun.0");
if !lun0_path.exists() {
create_dir(&lun0_path)?;
}
// Set default LUN 0 parameters
let _ = write_file(&lun0_path.join("cdrom"), "0");
let _ = write_file(&lun0_path.join("ro"), "0");
let _ = write_file(&lun0_path.join("removable"), "1");
@@ -382,12 +307,10 @@ impl GadgetFunction for MsdFunction {
fn cleanup(&self, gadget_path: &Path) -> Result<()> {
let func_path = self.function_path(gadget_path);
// Disconnect all LUNs first
for lun in 0..8 {
let _ = self.disconnect_lun(gadget_path, lun);
}
// Remove function directory
if let Err(e) = remove_dir(&func_path) {
warn!("Could not remove MSD function directory: {}", e);
}

View File

@@ -1,10 +1,3 @@
//! HID Report Descriptors
/// Keyboard HID Report Descriptor (no LED output)
/// Report format (8 bytes input):
/// [0] Modifier keys (8 bits)
/// [1] Reserved
/// [2-7] Key codes (6 keys)
pub const KEYBOARD: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x06, // Usage (Keyboard)
@@ -34,13 +27,6 @@ pub const KEYBOARD: &[u8] = &[
0xC0, // End Collection
];
/// Keyboard HID Report Descriptor with LED output support.
/// Input report format (8 bytes):
/// [0] Modifier keys (8 bits)
/// [1] Reserved
/// [2-7] Key codes (6 keys)
/// Output report format (1 byte):
/// [0] Num Lock / Caps Lock / Scroll Lock / Compose / Kana
pub const KEYBOARD_WITH_LED: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x06, // Usage (Keyboard)
@@ -81,12 +67,6 @@ pub const KEYBOARD_WITH_LED: &[u8] = &[
0xC0, // End Collection
];
/// Relative Mouse HID Report Descriptor (4 bytes report)
/// Report format:
/// [0] Buttons (5 bits) + padding (3 bits)
/// [1] X movement (signed 8-bit)
/// [2] Y movement (signed 8-bit)
/// [3] Wheel (signed 8-bit)
pub const MOUSE_RELATIVE: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
@@ -126,12 +106,6 @@ pub const MOUSE_RELATIVE: &[u8] = &[
0xC0, // End Collection
];
/// Absolute Mouse HID Report Descriptor (6 bytes report)
/// Report format:
/// [0] Buttons (5 bits) + padding (3 bits)
/// [1-2] X position (16-bit, 0-32767)
/// [3-4] Y position (16-bit, 0-32767)
/// [5] Wheel (signed 8-bit)
pub const MOUSE_ABSOLUTE: &[u8] = &[
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
@@ -177,10 +151,6 @@ pub const MOUSE_ABSOLUTE: &[u8] = &[
0xC0, // End Collection
];
/// Consumer Control HID Report Descriptor (2 bytes report)
/// Report format:
/// [0-1] Consumer Control Usage (16-bit little-endian)
/// Supports: Play/Pause, Stop, Next/Prev Track, Mute, Volume Up/Down, etc.
pub const CONSUMER_CONTROL: &[u8] = &[
0x05, 0x0C, // Usage Page (Consumer)
0x09, 0x01, // Usage (Consumer Control)

View File

@@ -1,9 +1,3 @@
//! OTG Service - unified gadget lifecycle management
//!
//! This module provides centralized management for USB OTG gadget functions.
//! It is the single owner of the USB gadget desired state and reconciles
//! ConfigFS to match that state.
use std::path::PathBuf;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, info, warn};
@@ -13,7 +7,6 @@ use super::msd::MsdFunction;
use crate::config::{HidBackend, HidConfig, MsdConfig, OtgDescriptorConfig, OtgHidFunctions};
use crate::error::{AppError, Result};
/// HID device paths
#[derive(Debug, Clone, Default)]
pub struct HidDevicePaths {
pub keyboard: Option<PathBuf>,
@@ -26,26 +19,20 @@ pub struct HidDevicePaths {
impl HidDevicePaths {
pub fn existing_paths(&self) -> Vec<PathBuf> {
let mut paths = Vec::new();
if let Some(ref p) = self.keyboard {
paths.push(p.clone());
}
if let Some(ref p) = self.mouse_relative {
paths.push(p.clone());
}
if let Some(ref p) = self.mouse_absolute {
paths.push(p.clone());
}
if let Some(ref p) = self.consumer {
paths.push(p.clone());
}
paths
[
&self.keyboard,
&self.mouse_relative,
&self.mouse_absolute,
&self.consumer,
]
.into_iter()
.filter_map(|p| p.as_ref().cloned())
.collect()
}
}
/// Desired OTG gadget state derived from configuration.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OtgDesiredState {
pub(crate) struct OtgDesiredState {
pub udc: Option<String>,
pub descriptor: GadgetDescriptor,
pub hid_functions: Option<OtgHidFunctions>,
@@ -68,7 +55,7 @@ impl Default for OtgDesiredState {
}
impl OtgDesiredState {
pub fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result<Self> {
pub(crate) fn from_config(hid: &HidConfig, msd: &MsdConfig) -> Result<Self> {
let hid_functions = if hid.backend == HidBackend::Otg {
let functions = hid.constrained_otg_functions();
Some(functions)
@@ -96,45 +83,28 @@ impl OtgDesiredState {
}
}
/// OTG Service state
#[derive(Debug, Clone, Default)]
pub struct OtgServiceState {
/// Whether the gadget is created and bound
struct OtgServiceState {
pub gadget_active: bool,
/// Whether HID functions are enabled
pub hid_enabled: bool,
/// Whether MSD function is enabled
pub msd_enabled: bool,
/// Bound UDC name
pub configured_udc: Option<String>,
/// HID device paths (set after gadget setup)
pub hid_paths: Option<HidDevicePaths>,
/// HID function selection (set after gadget setup)
pub hid_functions: Option<OtgHidFunctions>,
/// Whether keyboard LED/status feedback is enabled.
pub keyboard_leds_enabled: bool,
/// Applied endpoint budget.
pub max_endpoints: u8,
/// Applied descriptor configuration
pub descriptor: Option<GadgetDescriptor>,
/// Error message if setup failed
pub error: Option<String>,
}
/// OTG Service - unified gadget lifecycle management
pub struct OtgService {
/// The underlying gadget manager
manager: Mutex<Option<OtgGadgetManager>>,
/// Current state
state: RwLock<OtgServiceState>,
/// MSD function handle (for runtime LUN configuration)
msd_function: RwLock<Option<MsdFunction>>,
/// Desired OTG state
desired: RwLock<OtgDesiredState>,
}
impl OtgService {
/// Create a new OTG service
pub fn new() -> Self {
Self {
manager: Mutex::new(None),
@@ -144,55 +114,29 @@ impl OtgService {
}
}
/// Check if OTG is available on this system
pub fn is_available() -> bool {
OtgGadgetManager::is_available() && OtgGadgetManager::find_udc().is_some()
}
/// Get current service state
pub async fn state(&self) -> OtgServiceState {
self.state.read().await.clone()
}
/// Check if gadget is active
pub async fn is_gadget_active(&self) -> bool {
self.state.read().await.gadget_active
}
/// Check if HID is enabled
pub async fn is_hid_enabled(&self) -> bool {
self.state.read().await.hid_enabled
}
/// Check if MSD is enabled
pub async fn is_msd_enabled(&self) -> bool {
self.state.read().await.msd_enabled
}
/// Get gadget path (for MSD LUN configuration)
pub async fn gadget_path(&self) -> Option<PathBuf> {
let manager = self.manager.lock().await;
manager.as_ref().map(|m| m.gadget_path().clone())
}
/// Get HID device paths
pub async fn hid_device_paths(&self) -> Option<HidDevicePaths> {
self.state.read().await.hid_paths.clone()
}
/// Get MSD function handle (for LUN configuration)
pub async fn msd_function(&self) -> Option<MsdFunction> {
self.msd_function.read().await.clone()
}
/// Apply desired OTG state derived from the current application config.
pub async fn apply_config(&self, hid: &HidConfig, msd: &MsdConfig) -> Result<()> {
let desired = OtgDesiredState::from_config(hid, msd)?;
self.apply_desired_state(desired).await
}
/// Apply a fully materialized desired OTG state.
pub async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> {
pub(crate) async fn apply_desired_state(&self, desired: OtgDesiredState) -> Result<()> {
{
let mut current = self.desired.write().await;
*current = desired;
@@ -204,7 +148,7 @@ impl OtgService {
async fn reconcile_gadget(&self) -> Result<()> {
let desired = self.desired.read().await.clone();
info!(
debug!(
"Reconciling OTG gadget: HID={}, MSD={}, UDC={:?}",
desired.hid_enabled(),
desired.msd_enabled,
@@ -222,7 +166,7 @@ impl OtgService {
&& state.max_endpoints == desired.max_endpoints
&& state.descriptor.as_ref() == Some(&desired.descriptor)
{
info!("OTG gadget already matches desired state");
debug!("OTG gadget already matches desired state");
return Ok(());
}
}
@@ -230,7 +174,7 @@ impl OtgService {
{
let mut manager = self.manager.lock().await;
if let Some(mut m) = manager.take() {
info!("Cleaning up existing gadget before OTG reconcile");
debug!("Cleaning up existing gadget before OTG reconcile");
if let Err(e) = m.cleanup() {
warn!("Error cleaning up existing gadget: {}", e);
}
@@ -392,7 +336,6 @@ impl OtgService {
Ok(())
}
/// Shutdown the OTG service and cleanup all resources
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down OTG service");
@@ -425,12 +368,6 @@ impl Default for OtgService {
}
}
impl Drop for OtgService {
fn drop(&mut self) {
debug!("OtgService dropping");
}
}
impl From<&OtgDescriptorConfig> for GadgetDescriptor {
fn from(config: &OtgDescriptorConfig) -> Self {
Self {
@@ -452,17 +389,8 @@ mod tests {
use super::*;
#[test]
fn test_service_creation() {
fn service_new_and_availability_probe() {
let _service = OtgService::new();
let _ = OtgService::is_available();
}
#[tokio::test]
async fn test_initial_state() {
let service = OtgService::new();
let state = service.state().await;
assert!(!state.gadget_active);
assert!(!state.hid_enabled);
assert!(!state.msd_enabled);
}
}

73
src/rtsp/auth.rs Normal file
View File

@@ -0,0 +1,73 @@
use base64::Engine;
use crate::config::RtspConfig;
use super::types::RtspRequest;
pub(crate) fn extract_basic_auth(req: &RtspRequest) -> Option<(String, String)> {
let value = req.headers.get("authorization")?;
let mut parts = value.split_whitespace();
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("basic") {
return None;
}
let b64 = parts.next()?;
let decoded = base64::engine::general_purpose::STANDARD.decode(b64).ok()?;
let raw = String::from_utf8(decoded).ok()?;
let (user, pass) = raw.split_once(':')?;
Some((user.to_string(), pass.to_string()))
}
pub(crate) fn rtsp_auth_credentials(config: &RtspConfig) -> Option<(String, String)> {
let username = config.username.as_ref()?.trim();
if username.is_empty() {
return None;
}
Some((
username.to_string(),
config.password.clone().unwrap_or_default(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use rtsp_types as rtsp;
use std::collections::HashMap;
#[test]
fn rtsp_auth_requires_non_empty_username() {
let mut config = RtspConfig::default();
config.password = Some("secret".to_string());
assert!(rtsp_auth_credentials(&config).is_none());
config.username = Some("".to_string());
assert!(rtsp_auth_credentials(&config).is_none());
config.username = Some("user".to_string());
let credentials = rtsp_auth_credentials(&config).expect("expected credentials");
assert_eq!(credentials, ("user".to_string(), "secret".to_string()));
config.password = None;
let credentials = rtsp_auth_credentials(&config).expect("expected credentials");
assert_eq!(credentials, ("user".to_string(), "".to_string()));
}
#[test]
fn extract_basic_auth_roundtrip() {
let encoded = base64::engine::general_purpose::STANDARD.encode(b"alice:pwd");
let mut headers = HashMap::new();
headers.insert("authorization".to_string(), format!("Basic {}", encoded));
let req = RtspRequest {
method: rtsp::Method::Options,
uri: "*".to_string(),
version: rtsp::Version::V1_0,
headers,
};
assert_eq!(
extract_basic_auth(&req),
Some(("alice".to_string(), "pwd".to_string()))
);
}
}

96
src/rtsp/bitstream.rs Normal file
View File

@@ -0,0 +1,96 @@
use bytes::Bytes;
use crate::video::encoder::registry::VideoEncoderType;
use crate::video::shared_video_pipeline::EncodedVideoFrame;
use super::state::ParameterSets;
pub(crate) fn update_parameter_sets(params: &mut ParameterSets, frame: &EncodedVideoFrame) {
let nal_units = split_annexb_nal_units(frame.data.as_ref());
match frame.codec {
VideoEncoderType::H264 => {
for nal in nal_units {
match h264_nal_type(nal) {
Some(7) => params.h264_sps = Some(Bytes::copy_from_slice(nal)),
Some(8) => params.h264_pps = Some(Bytes::copy_from_slice(nal)),
_ => {}
}
}
}
VideoEncoderType::H265 => {
for nal in nal_units {
match h265_nal_type(nal) {
Some(32) => params.h265_vps = Some(Bytes::copy_from_slice(nal)),
Some(33) => params.h265_sps = Some(Bytes::copy_from_slice(nal)),
Some(34) => params.h265_pps = Some(Bytes::copy_from_slice(nal)),
_ => {}
}
}
}
_ => {}
}
}
fn split_annexb_nal_units(data: &[u8]) -> Vec<&[u8]> {
let mut nal_units = Vec::new();
let mut cursor = 0usize;
while let Some((start, start_code_len)) = find_annexb_start_code(data, cursor) {
let nal_start = start + start_code_len;
if nal_start >= data.len() {
break;
}
let next_start = find_annexb_start_code(data, nal_start)
.map(|(idx, _)| idx)
.unwrap_or(data.len());
let mut nal_end = next_start;
while nal_end > nal_start && data[nal_end - 1] == 0 {
nal_end -= 1;
}
if nal_end > nal_start {
nal_units.push(&data[nal_start..nal_end]);
}
cursor = next_start;
}
nal_units
}
fn find_annexb_start_code(data: &[u8], from: usize) -> Option<(usize, usize)> {
if from >= data.len() {
return None;
}
let mut i = from;
while i + 3 <= data.len() {
if i + 4 <= data.len()
&& data[i] == 0
&& data[i + 1] == 0
&& data[i + 2] == 0
&& data[i + 3] == 1
{
return Some((i, 4));
}
if data[i] == 0 && data[i + 1] == 0 && data[i + 2] == 1 {
return Some((i, 3));
}
i += 1;
}
None
}
fn h264_nal_type(nal: &[u8]) -> Option<u8> {
nal.first().map(|value| value & 0x1f)
}
fn h265_nal_type(nal: &[u8]) -> Option<u8> {
nal.first().map(|value| (value >> 1) & 0x3f)
}

9
src/rtsp/codec.rs Normal file
View File

@@ -0,0 +1,9 @@
use crate::config::RtspCodec;
use crate::video::encoder::VideoCodecType;
pub(crate) fn rtsp_codec_to_video(codec: RtspCodec) -> VideoCodecType {
match codec {
RtspCodec::H264 => VideoCodecType::H264,
RtspCodec::H265 => VideoCodecType::H265,
}
}

View File

@@ -1,3 +1,14 @@
pub mod service;
//! RTSP TCP server exposing H.264/H.265 video from [`VideoStreamManager`](crate::video::VideoStreamManager).
mod auth;
mod bitstream;
mod codec;
mod protocol;
mod response;
mod sdp;
mod service;
mod state;
mod streaming;
mod types;
pub use service::{RtspService, RtspServiceStatus};

193
src/rtsp/protocol.rs Normal file
View File

@@ -0,0 +1,193 @@
use rtsp_types as rtsp;
use std::collections::HashMap;
use super::types::RtspRequest;
pub(crate) const OPTIONS_PUBLIC_CAPABILITIES: &str =
"OPTIONS, DESCRIBE, SETUP, PLAY, GET_PARAMETER, SET_PARAMETER, TEARDOWN";
pub(crate) fn strip_interleaved_frames_prefix(buffer: &mut Vec<u8>) -> bool {
if buffer.len() < 4 || buffer[0] != b'$' {
return false;
}
let payload_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize;
let frame_len = 4 + payload_len;
if buffer.len() < frame_len {
return false;
}
buffer.drain(0..frame_len);
true
}
pub(crate) fn take_rtsp_request_from_buffer(buffer: &mut Vec<u8>) -> Option<String> {
let delimiter = b"\r\n\r\n";
let pos = find_bytes(buffer, delimiter)?;
let req_end = pos + delimiter.len();
let req_bytes: Vec<u8> = buffer.drain(0..req_end).collect();
Some(String::from_utf8_lossy(&req_bytes).to_string())
}
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
pub(crate) fn parse_rtsp_request(raw: &str) -> Option<RtspRequest> {
let (message, consumed): (rtsp::Message<Vec<u8>>, usize) =
rtsp::Message::parse(raw.as_bytes()).ok()?;
if consumed != raw.len() {
return None;
}
let request = match message {
rtsp::Message::Request(req) => req,
_ => return None,
};
let uri = request
.request_uri()
.map(|value| value.as_str().to_string())
.unwrap_or_default();
let mut headers = HashMap::new();
for (name, value) in request.headers() {
headers.insert(name.to_string().to_ascii_lowercase(), value.to_string());
}
Some(RtspRequest {
method: request.method().clone(),
uri,
version: request.version(),
headers,
})
}
pub(crate) fn parse_interleaved_channel(transport: &str) -> Option<u8> {
let lower = transport.to_ascii_lowercase();
if let Some((_, v)) = lower.split_once("interleaved=") {
let head = v.split(';').next().unwrap_or(v);
let first = head.split('-').next().unwrap_or(head).trim();
return first.parse::<u8>().ok();
}
None
}
pub(crate) fn is_tcp_transport_request(transport: &str) -> bool {
transport
.split(',')
.map(str::trim)
.map(str::to_ascii_lowercase)
.any(|item| item.contains("rtp/avp/tcp") || item.contains("interleaved="))
}
pub(crate) fn is_valid_rtsp_path(method: &rtsp::Method, uri: &str, configured_path: &str) -> bool {
if matches!(method, rtsp::Method::Options) && uri.trim() == "*" {
return true;
}
let normalized_cfg = configured_path.trim_matches('/');
if normalized_cfg.is_empty() {
return false;
}
let request_path = extract_rtsp_path(uri);
if request_path == normalized_cfg {
return true;
}
if !matches!(method, rtsp::Method::Setup | rtsp::Method::Teardown) {
return false;
}
let control_track_path = format!("{}/trackID=0", normalized_cfg);
request_path == "trackID=0" || request_path == control_track_path
}
fn extract_rtsp_path(uri: &str) -> String {
let raw_path = if let Some((_, remainder)) = uri.split_once("://") {
match remainder.find('/') {
Some(idx) => &remainder[idx..],
None => "/",
}
} else {
uri
};
raw_path
.split('?')
.next()
.unwrap_or(raw_path)
.split('#')
.next()
.unwrap_or(raw_path)
.trim_matches('/')
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rtsp_path_matching_follows_sdp_control_rules() {
assert!(is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live",
"live"
));
assert!(is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live/?token=1",
"/live/"
));
assert!(!is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live2",
"live"
));
assert!(!is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/",
"/"
));
assert!(is_valid_rtsp_path(
&rtsp::Method::Setup,
"rtsp://127.0.0.1/live/trackID=0",
"live"
));
assert!(is_valid_rtsp_path(
&rtsp::Method::Setup,
"rtsp://127.0.0.1/trackID=0",
"live"
));
assert!(!is_valid_rtsp_path(
&rtsp::Method::Describe,
"rtsp://127.0.0.1/live/trackID=0",
"live"
));
assert!(is_valid_rtsp_path(&rtsp::Method::Options, "*", "live"));
}
#[test]
fn transport_parsing_detects_tcp_interleaved_requests() {
assert!(is_tcp_transport_request(
"RTP/AVP/TCP;unicast;interleaved=0-1"
));
assert!(is_tcp_transport_request("RTP/AVP;unicast;interleaved=2-3"));
assert!(!is_tcp_transport_request(
"RTP/AVP;unicast;client_port=8000-8001"
));
}
#[test]
fn options_public_includes_standard_methods() {
assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("GET_PARAMETER"));
assert!(OPTIONS_PUBLIC_CAPABILITIES.contains("TEARDOWN"));
}
}

81
src/rtsp/response.rs Normal file
View File

@@ -0,0 +1,81 @@
use rtsp_types as rtsp;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::error::{AppError, Result};
use super::types::RtspRequest;
async fn serialize_and_write<W: AsyncWrite + Unpin>(
stream: &mut W,
response: rtsp::Response<Vec<u8>>,
) -> Result<()> {
let mut data = Vec::new();
response
.write(&mut data)
.map_err(|e| AppError::BadRequest(format!("failed to serialize RTSP response: {}", e)))?;
stream.write_all(&data).await?;
Ok(())
}
pub(crate) async fn send_simple_response<W: AsyncWrite + Unpin>(
stream: &mut W,
code: u16,
_reason: &str,
cseq: Option<&str>,
body: &str,
) -> Result<()> {
let mut builder = rtsp::Response::builder(rtsp::Version::V1_0, status_code_from_u16(code));
if let Some(cseq) = cseq {
builder = builder.header(rtsp::headers::CSEQ, cseq);
}
let response = builder.build(body.as_bytes().to_vec());
serialize_and_write(stream, response).await
}
pub(crate) async fn send_response<W: AsyncWrite + Unpin>(
stream: &mut W,
req: &RtspRequest,
code: u16,
_reason: &str,
extra_headers: Vec<(String, String)>,
body: &str,
session_id: &str,
) -> Result<()> {
let cseq = req
.headers
.get("cseq")
.cloned()
.unwrap_or_else(|| "1".to_string());
let mut builder = rtsp::Response::builder(req.version, status_code_from_u16(code))
.header(rtsp::headers::CSEQ, cseq.as_str());
if !session_id.is_empty() {
builder = builder.header(rtsp::headers::SESSION, session_id);
}
for (name, value) in extra_headers {
let header_name = rtsp::HeaderName::try_from(name.as_str()).map_err(|e| {
AppError::BadRequest(format!("invalid RTSP header name {}: {}", name, e))
})?;
builder = builder.header(header_name, value);
}
let response = builder.build(body.as_bytes().to_vec());
serialize_and_write(stream, response).await
}
pub(crate) fn status_code_from_u16(code: u16) -> rtsp::StatusCode {
match code {
200 => rtsp::StatusCode::Ok,
400 => rtsp::StatusCode::BadRequest,
401 => rtsp::StatusCode::Unauthorized,
404 => rtsp::StatusCode::NotFound,
405 => rtsp::StatusCode::MethodNotAllowed,
453 => rtsp::StatusCode::NotEnoughBandwidth,
455 => rtsp::StatusCode::MethodNotValidInThisState,
461 => rtsp::StatusCode::UnsupportedTransport,
_ => rtsp::StatusCode::InternalServerError,
}
}

224
src/rtsp/sdp.rs Normal file
View File

@@ -0,0 +1,224 @@
use base64::Engine;
use sdp_types as sdp;
use crate::config::RtspConfig;
use crate::video::encoder::VideoCodecType;
use crate::webrtc::rtp::parse_profile_level_id_from_sps;
use super::state::ParameterSets;
pub(crate) fn build_h264_fmtp(payload_type: u8, params: &ParameterSets) -> String {
let mut attrs = vec!["packetization-mode=1".to_string()];
if let Some(sps) = params.h264_sps.as_ref() {
if let Some(profile_level_id) = parse_profile_level_id_from_sps(sps) {
attrs.push(format!("profile-level-id={}", profile_level_id));
}
} else {
attrs.push("profile-level-id=42e01f".to_string());
}
if let (Some(sps), Some(pps)) = (params.h264_sps.as_ref(), params.h264_pps.as_ref()) {
let sps_b64 = base64::engine::general_purpose::STANDARD.encode(sps.as_ref());
let pps_b64 = base64::engine::general_purpose::STANDARD.encode(pps.as_ref());
attrs.push(format!("sprop-parameter-sets={},{}", sps_b64, pps_b64));
}
format!("{} {}", payload_type, attrs.join(";"))
}
pub(crate) fn build_h265_fmtp(payload_type: u8, params: &ParameterSets) -> String {
let mut attrs = Vec::new();
if let Some(vps) = params.h265_vps.as_ref() {
attrs.push(format!(
"sprop-vps={}",
base64::engine::general_purpose::STANDARD.encode(vps.as_ref())
));
}
if let Some(sps) = params.h265_sps.as_ref() {
attrs.push(format!(
"sprop-sps={}",
base64::engine::general_purpose::STANDARD.encode(sps.as_ref())
));
}
if let Some(pps) = params.h265_pps.as_ref() {
attrs.push(format!(
"sprop-pps={}",
base64::engine::general_purpose::STANDARD.encode(pps.as_ref())
));
}
if attrs.is_empty() {
format!("{} profile-id=1", payload_type)
} else {
format!("{} {}", payload_type, attrs.join(";"))
}
}
pub(crate) fn build_sdp(
config: &RtspConfig,
codec: VideoCodecType,
params: &ParameterSets,
) -> String {
let (payload_type, codec_name, fmtp_value) = match codec {
VideoCodecType::H264 => (96u8, "H264", build_h264_fmtp(96, params)),
VideoCodecType::H265 => (99u8, "H265", build_h265_fmtp(99, params)),
_ => {
tracing::warn!("RTSP SDP: unexpected VideoCodecType, falling back to H264");
(96u8, "H264", build_h264_fmtp(96, params))
}
};
let session = sdp::Session {
origin: sdp::Origin {
username: Some("-".to_string()),
sess_id: "0".to_string(),
sess_version: 0,
nettype: "IN".to_string(),
addrtype: "IP4".to_string(),
unicast_address: config.bind.clone(),
},
session_name: "One-KVM RTSP Stream".to_string(),
session_description: None,
uri: None,
emails: Vec::new(),
phones: Vec::new(),
connection: Some(sdp::Connection {
nettype: "IN".to_string(),
addrtype: "IP4".to_string(),
connection_address: "0.0.0.0".to_string(),
}),
bandwidths: Vec::new(),
times: vec![sdp::Time {
start_time: 0,
stop_time: 0,
repeats: Vec::new(),
}],
time_zones: Vec::new(),
key: None,
attributes: vec![sdp::Attribute {
attribute: "control".to_string(),
value: Some("*".to_string()),
}],
medias: vec![sdp::Media {
media: "video".to_string(),
port: 0,
num_ports: None,
proto: "RTP/AVP".to_string(),
fmt: payload_type.to_string(),
media_title: None,
connections: Vec::new(),
bandwidths: Vec::new(),
key: None,
attributes: vec![
sdp::Attribute {
attribute: "rtpmap".to_string(),
value: Some(format!("{} {}/90000", payload_type, codec_name)),
},
sdp::Attribute {
attribute: "fmtp".to_string(),
value: Some(fmtp_value),
},
sdp::Attribute {
attribute: "control".to_string(),
value: Some("trackID=0".to_string()),
},
],
}],
};
let mut output = Vec::new();
if let Err(err) = session.write(&mut output) {
tracing::warn!("Failed to serialize SDP with sdp-types: {}", err);
return String::new();
}
match String::from_utf8(output) {
Ok(sdp_text) => sdp_text,
Err(err) => {
tracing::warn!("Failed to convert SDP bytes to UTF-8: {}", err);
String::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RtspConfig;
use bytes::Bytes;
#[test]
fn build_sdp_h264_is_parseable_with_expected_video_attributes() {
let config = RtspConfig::default();
let mut params = ParameterSets::default();
params.h264_sps = Some(Bytes::from_static(&[0x67, 0x42, 0xe0, 0x1f, 0x96, 0x54]));
params.h264_pps = Some(Bytes::from_static(&[0x68, 0xce, 0x06, 0xe2]));
let sdp_text = build_sdp(&config, VideoCodecType::H264, &params);
assert!(!sdp_text.is_empty());
let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed");
assert_eq!(session.session_name, "One-KVM RTSP Stream");
assert_eq!(session.medias.len(), 1);
let media = &session.medias[0];
assert_eq!(media.media, "video");
assert_eq!(media.proto, "RTP/AVP");
assert_eq!(media.fmt, "96");
let has_rtpmap = media.attributes.iter().any(|attr| {
attr.attribute == "rtpmap" && attr.value.as_deref() == Some("96 H264/90000")
});
assert!(has_rtpmap);
let fmtp_value = media
.attributes
.iter()
.find(|attr| attr.attribute == "fmtp")
.and_then(|attr| attr.value.as_deref())
.expect("missing fmtp value");
assert!(fmtp_value.starts_with("96 "));
assert!(fmtp_value.contains("packetization-mode=1"));
assert!(fmtp_value.contains("sprop-parameter-sets="));
}
#[test]
fn build_sdp_h265_is_parseable_with_expected_video_attributes() {
let config = RtspConfig::default();
let mut params = ParameterSets::default();
params.h265_vps = Some(Bytes::from_static(&[0x40, 0x01, 0x0c, 0x01]));
params.h265_sps = Some(Bytes::from_static(&[0x42, 0x01, 0x01, 0x60]));
params.h265_pps = Some(Bytes::from_static(&[0x44, 0x01, 0xc0, 0x73]));
let sdp_text = build_sdp(&config, VideoCodecType::H265, &params);
assert!(!sdp_text.is_empty());
let session = sdp::Session::parse(sdp_text.as_bytes()).expect("sdp parse failed");
assert_eq!(session.medias.len(), 1);
let media = &session.medias[0];
assert_eq!(media.media, "video");
assert_eq!(media.proto, "RTP/AVP");
assert_eq!(media.fmt, "99");
let has_rtpmap = media.attributes.iter().any(|attr| {
attr.attribute == "rtpmap" && attr.value.as_deref() == Some("99 H265/90000")
});
assert!(has_rtpmap);
let fmtp_value = media
.attributes
.iter()
.find(|attr| attr.attribute == "fmtp")
.and_then(|attr| attr.value.as_deref())
.expect("missing fmtp value");
assert!(fmtp_value.starts_with("99 "));
assert!(fmtp_value.contains("sprop-vps="));
assert!(fmtp_value.contains("sprop-sps="));
assert!(fmtp_value.contains("sprop-pps="));
}
}

File diff suppressed because it is too large Load Diff

28
src/rtsp/state.rs Normal file
View File

@@ -0,0 +1,28 @@
use bytes::Bytes;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Default, Clone)]
pub(crate) struct ParameterSets {
pub h264_sps: Option<Bytes>,
pub h264_pps: Option<Bytes>,
pub h265_vps: Option<Bytes>,
pub h265_sps: Option<Bytes>,
pub h265_pps: Option<Bytes>,
}
#[derive(Clone)]
pub(crate) struct SharedRtspState {
pub active_client: Arc<Mutex<Option<SocketAddr>>>,
pub parameter_sets: Arc<RwLock<ParameterSets>>,
}
impl SharedRtspState {
pub fn new() -> Self {
Self {
active_client: Arc::new(Mutex::new(None)),
parameter_sets: Arc::new(RwLock::new(ParameterSets::default())),
}
}
}

367
src/rtsp/streaming.rs Normal file
View File

@@ -0,0 +1,367 @@
use bytes::Bytes;
use rand::Rng;
use rtp::packet::Packet;
use rtp::packetizer::Payloader;
use rtsp_types as rtsp;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::{sleep, Duration};
use webrtc::util::{Marshal, MarshalSize};
use crate::config::RtspCodec;
use crate::error::{AppError, Result};
use crate::video::encoder::registry::VideoEncoderType;
use crate::video::shared_video_pipeline::EncodedVideoFrame;
use crate::video::VideoStreamManager;
use crate::webrtc::h265_payloader::H265Payloader;
use super::bitstream::update_parameter_sets;
use super::protocol::{
parse_rtsp_request, strip_interleaved_frames_prefix, take_rtsp_request_from_buffer,
};
use super::response::send_response;
use super::state::SharedRtspState;
use super::types::RtspRequest;
pub(crate) const RTP_CLOCK_RATE: u32 = 90_000;
pub(crate) const RTP_MTU: usize = 1200;
pub(crate) const RTSP_BUF_SIZE: usize = 8192;
const RTSP_RESUBSCRIBE_DELAY_MS: u64 = 300;
pub(crate) async fn stream_video_interleaved(
stream: TcpStream,
video_manager: &Arc<VideoStreamManager>,
rtsp_codec: RtspCodec,
channel: u8,
shared: SharedRtspState,
session_id: String,
) -> Result<()> {
let (mut reader, mut writer) = stream.into_split();
let mut rx = video_manager
.subscribe_encoded_frames()
.await
.ok_or_else(|| {
AppError::VideoError("RTSP failed to subscribe encoded frames".to_string())
})?;
video_manager.request_keyframe().await.ok();
let payload_type = match rtsp_codec {
RtspCodec::H264 => 96,
RtspCodec::H265 => 99,
};
let mut sequence_number: u16 = rand::rng().random();
let ssrc: u32 = rand::rng().random();
let mut h264_payloader = rtp::codecs::h264::H264Payloader::default();
let mut h265_payloader = H265Payloader::new();
let mut ctrl_read_buf = [0u8; RTSP_BUF_SIZE];
let mut ctrl_buffer = Vec::with_capacity(RTSP_BUF_SIZE);
// 4-byte interleaved prefix + RTP header + payload shard (≤ RTP_MTU)
let mut interleaved_rtp_buf = Vec::with_capacity(4 + RTP_MTU + 96);
let mut last_rtp_timestamp: u32 = 0;
loop {
tokio::select! {
maybe_frame = rx.recv() => {
let Some(frame) = maybe_frame else {
tracing::warn!("RTSP encoded frame subscription ended, attempting to restart pipeline");
if let Some(new_rx) = video_manager.subscribe_encoded_frames().await {
rx = new_rx;
let _ = video_manager.request_keyframe().await;
tracing::info!("RTSP frame subscription recovered");
} else {
tracing::warn!(
"RTSP failed to resubscribe encoded frames, retrying in {}ms",
RTSP_RESUBSCRIBE_DELAY_MS
);
sleep(Duration::from_millis(RTSP_RESUBSCRIBE_DELAY_MS)).await;
}
continue;
};
if !is_frame_codec_match(&frame, &rtsp_codec) {
continue;
}
{
let mut params = shared.parameter_sets.write().await;
update_parameter_sets(&mut params, &frame);
}
let rtp_timestamp = monotonic_rtp_timestamp(
frame.pts_ms,
&mut last_rtp_timestamp,
frame.duration,
);
let payloads: Vec<Bytes> = match rtsp_codec {
RtspCodec::H264 => h264_payloader
.payload(RTP_MTU, &frame.data)
.map_err(|e| AppError::VideoError(format!("H264 payload failed: {}", e)))?,
RtspCodec::H265 => h265_payloader.payload(RTP_MTU, &frame.data),
};
if payloads.is_empty() {
continue;
}
let total_payloads = payloads.len();
for (idx, payload) in payloads.into_iter().enumerate() {
let marker = idx == total_payloads.saturating_sub(1);
let packet = Packet {
header: rtp::header::Header {
version: 2,
padding: false,
extension: false,
marker,
payload_type,
sequence_number,
timestamp: rtp_timestamp,
ssrc,
..Default::default()
},
payload,
};
sequence_number = sequence_number.wrapping_add(1);
send_interleaved_rtp(&mut writer, channel, &packet, &mut interleaved_rtp_buf)
.await?;
}
if frame.is_keyframe {
tracing::debug!("RTSP keyframe sent");
}
}
read_res = reader.read(&mut ctrl_read_buf) => {
let n = read_res?;
if n == 0 {
break;
}
ctrl_buffer.extend_from_slice(&ctrl_read_buf[..n]);
while strip_interleaved_frames_prefix(&mut ctrl_buffer) {}
while let Some(raw_req) = take_rtsp_request_from_buffer(&mut ctrl_buffer) {
let Some(req) = parse_rtsp_request(&raw_req) else {
continue;
};
if handle_play_control_request(&mut writer, &req, &session_id).await? {
return Ok(());
}
while strip_interleaved_frames_prefix(&mut ctrl_buffer) {}
}
}
}
}
Ok(())
}
pub(crate) async fn send_interleaved_rtp<W: AsyncWrite + Unpin>(
stream: &mut W,
channel: u8,
packet: &Packet,
marshal_buf: &mut Vec<u8>,
) -> Result<()> {
let rtp_len = packet.marshal_size();
let rtp_len_u16 = u16::try_from(rtp_len).map_err(|_| {
AppError::VideoError(format!(
"RTP packet too large for interleaved framing: {} bytes",
rtp_len
))
})?;
marshal_buf.clear();
marshal_buf.reserve(4 + rtp_len);
marshal_buf.extend_from_slice(&[b'$', channel, (rtp_len_u16 >> 8) as u8, rtp_len_u16 as u8]);
let body_off = marshal_buf.len();
marshal_buf.resize(body_off + rtp_len, 0);
let written = packet
.marshal_to(&mut marshal_buf[body_off..])
.map_err(|e| AppError::VideoError(format!("RTP marshal failed: {}", e)))?;
if written != rtp_len {
return Err(AppError::VideoError(format!(
"RTP marshal size mismatch: wrote {written}, expected {rtp_len}"
)));
}
stream.write_all(marshal_buf).await?;
Ok(())
}
pub(crate) async fn handle_play_control_request<W: AsyncWrite + Unpin>(
stream: &mut W,
req: &RtspRequest,
session_id: &str,
) -> Result<bool> {
use super::protocol::OPTIONS_PUBLIC_CAPABILITIES;
match &req.method {
rtsp::Method::Teardown => {
send_response(stream, req, 200, "OK", vec![], "", session_id).await?;
Ok(true)
}
rtsp::Method::Options => {
send_response(
stream,
req,
200,
"OK",
vec![(
"Public".to_string(),
OPTIONS_PUBLIC_CAPABILITIES.to_string(),
)],
"",
session_id,
)
.await?;
Ok(false)
}
rtsp::Method::GetParameter | rtsp::Method::SetParameter => {
send_response(stream, req, 200, "OK", vec![], "", session_id).await?;
Ok(false)
}
_ => {
send_response(
stream,
req,
405,
"Method Not Allowed",
vec![],
"",
session_id,
)
.await?;
Ok(false)
}
}
}
fn pts_to_rtp_timestamp(pts_ms: i64) -> u32 {
if pts_ms <= 0 {
return 0;
}
((pts_ms as u64 * RTP_CLOCK_RATE as u64) / 1000) as u32
}
fn rtp_timestamp_increment(frame_duration: Duration) -> u32 {
let inc = (frame_duration.as_secs_f64() * f64::from(RTP_CLOCK_RATE)).round() as u32;
inc.max(1)
}
fn monotonic_rtp_timestamp(pts_ms: i64, last: &mut u32, frame_duration: Duration) -> u32 {
let from_pts = pts_to_rtp_timestamp(pts_ms);
let inc = rtp_timestamp_increment(frame_duration);
let ts = if from_pts > *last {
from_pts
} else {
last.wrapping_add(inc)
};
*last = ts;
ts
}
fn is_frame_codec_match(frame: &EncodedVideoFrame, codec: &RtspCodec) -> bool {
matches!(
(frame.codec, codec),
(VideoEncoderType::H264, RtspCodec::H264) | (VideoEncoderType::H265, RtspCodec::H265)
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tokio::io::{duplex, AsyncReadExt};
fn make_test_request(method: rtsp::Method) -> RtspRequest {
let mut headers = HashMap::new();
headers.insert("cseq".to_string(), "7".to_string());
RtspRequest {
method,
uri: "rtsp://127.0.0.1/live".to_string(),
version: rtsp::Version::V1_0,
headers,
}
}
async fn read_response_from_duplex(
mut client: tokio::io::DuplexStream,
) -> rtsp::Response<Vec<u8>> {
let mut buf = vec![0u8; 4096];
let n = client
.read(&mut buf)
.await
.expect("failed to read rtsp response");
assert!(n > 0);
let (message, consumed): (rtsp::Message<Vec<u8>>, usize) =
rtsp::Message::parse(&buf[..n]).expect("failed to parse rtsp response");
assert_eq!(consumed, n);
match message {
rtsp::Message::Response(response) => response,
_ => panic!("expected RTSP response"),
}
}
#[tokio::test]
async fn play_control_teardown_returns_ok_and_stop() {
let req = make_test_request(rtsp::Method::Teardown);
let (client, mut server) = duplex(4096);
let should_stop = handle_play_control_request(&mut server, &req, "session-1")
.await
.expect("control handling failed");
assert!(should_stop);
drop(server);
let response = read_response_from_duplex(client).await;
assert_eq!(response.status(), rtsp::StatusCode::Ok);
}
#[tokio::test]
async fn play_control_pause_returns_method_not_allowed() {
let req = make_test_request(rtsp::Method::Pause);
let (client, mut server) = duplex(4096);
let should_stop = handle_play_control_request(&mut server, &req, "session-1")
.await
.expect("control handling failed");
assert!(!should_stop);
drop(server);
let response = read_response_from_duplex(client).await;
assert_eq!(response.status(), rtsp::StatusCode::MethodNotAllowed);
}
#[test]
fn monotonic_rtp_timestamp_steps_when_pts_stays_zero() {
let d = Duration::from_millis(33);
let mut last = 0u32;
let a = monotonic_rtp_timestamp(0, &mut last, d);
let b = monotonic_rtp_timestamp(0, &mut last, d);
let c = monotonic_rtp_timestamp(0, &mut last, d);
assert!(a > 0);
assert!(b > a);
assert!(c > b);
}
#[test]
fn monotonic_rtp_timestamp_uses_pts_when_it_advances() {
let d = Duration::from_millis(33);
let mut last = 0u32;
let a = monotonic_rtp_timestamp(1000, &mut last, d);
assert_eq!(a, 90_000);
let b = monotonic_rtp_timestamp(2000, &mut last, d);
assert_eq!(b, 180_000);
}
}

53
src/rtsp/types.rs Normal file
View File

@@ -0,0 +1,53 @@
use rand::Rng;
use rtsp_types as rtsp;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum RtspServiceStatus {
Stopped,
Starting,
Running,
Error(String),
}
impl fmt::Display for RtspServiceStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Stopped => write!(f, "stopped"),
Self::Starting => write!(f, "starting"),
Self::Running => write!(f, "running"),
Self::Error(err) => write!(f, "error: {}", err),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct RtspRequest {
pub method: rtsp::Method,
pub uri: String,
pub version: rtsp::Version,
pub headers: HashMap<String, String>,
}
pub(crate) struct RtspConnectionState {
pub session_id: String,
pub setup_done: bool,
pub interleaved_channel: u8,
}
impl RtspConnectionState {
pub fn new() -> Self {
Self {
session_id: generate_session_id(),
setup_done: false,
interleaved_channel: 0,
}
}
}
pub(crate) fn generate_session_id() -> String {
let mut rng = rand::rng();
let value: u64 = rng.random();
format!("{:016x}", value)
}

View File

@@ -1,21 +1,11 @@
//! RustDesk BytesCodec - Variable-length framing for TCP messages
//!
//! RustDesk uses a custom variable-length encoding for message framing:
//! - Length <= 0x3F (63): 1-byte header, format `(len << 2)`
//! - Length <= 0x3FFF (16383): 2-byte LE header, format `(len << 2) | 0x1`
//! - Length <= 0x3FFFFF (4194303): 3-byte LE header, format `(len << 2) | 0x2`
//! - Length <= 0x3FFFFFFF (1073741823): 4-byte LE header, format `(len << 2) | 0x3`
//!
//! The low 2 bits of the first byte indicate the header length (+1).
//! Variable-length TCP framing (RustDesk wire format).
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
/// Maximum packet length (1GB)
const MAX_PACKET_LENGTH: usize = 0x3FFFFFFF;
/// Encode a message with RustDesk's variable-length framing
pub fn encode_frame(data: &[u8]) -> io::Result<Vec<u8>> {
let len = data.len();
let mut buf = Vec::with_capacity(len + 4);
@@ -44,8 +34,6 @@ pub fn encode_frame(data: &[u8]) -> io::Result<Vec<u8>> {
Ok(buf)
}
/// Decode the header to get message length
/// Returns (header_length, message_length)
fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) {
let head_len = ((first_byte & 0x3) + 1) as usize;
@@ -64,21 +52,17 @@ fn decode_header(first_byte: u8, header_bytes: &[u8]) -> (usize, usize) {
(head_len, msg_len)
}
/// Read a single framed message from an async reader
pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<BytesMut> {
// Read first byte to determine header length
let mut first_byte = [0u8; 1];
reader.read_exact(&mut first_byte).await?;
let head_len = ((first_byte[0] & 0x3) + 1) as usize;
// Read remaining header bytes if needed
let mut header_rest = [0u8; 3];
if head_len > 1 {
reader.read_exact(&mut header_rest[..head_len - 1]).await?;
}
// Calculate message length
let (_, msg_len) = decode_header(first_byte[0], &header_rest);
if msg_len > MAX_PACKET_LENGTH {
@@ -88,7 +72,6 @@ pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Byte
));
}
// Read message body
let mut buf = BytesMut::with_capacity(msg_len);
buf.resize(msg_len, 0);
reader.read_exact(&mut buf).await?;
@@ -96,7 +79,6 @@ pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Byte
Ok(buf)
}
/// Write a framed message to an async writer
pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, data: &[u8]) -> io::Result<()> {
let frame = encode_frame(data)?;
writer.write_all(&frame).await?;
@@ -104,10 +86,6 @@ pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, data: &[u8]) ->
Ok(())
}
/// Write a framed message using a reusable buffer (reduces allocations)
///
/// This version reuses the provided BytesMut buffer to avoid allocation on each call.
/// The buffer is cleared before use and will grow as needed.
pub async fn write_frame_buffered<W: AsyncWrite + Unpin>(
writer: &mut W,
data: &[u8],
@@ -120,11 +98,9 @@ pub async fn write_frame_buffered<W: AsyncWrite + Unpin>(
Ok(())
}
/// Encode a message with RustDesk's variable-length framing into an existing buffer
pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
let len = data.len();
// Reserve space for header (max 4 bytes) + data
buf.reserve(4 + len);
if len <= 0x3F {
@@ -149,7 +125,7 @@ pub fn encode_frame_into(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
Ok(())
}
/// BytesCodec for stateful decoding (compatible with tokio-util codec)
/// Stateful decoder for `Framed`.
#[derive(Debug, Clone, Copy)]
pub struct BytesCodec {
state: DecodeState,
@@ -180,7 +156,6 @@ impl BytesCodec {
self.max_packet_length = n;
}
/// Decode from a BytesMut buffer (for use with Framed)
pub fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
let n = match self.state {
DecodeState::Head => match self.decode_head(src)? {
@@ -242,7 +217,6 @@ impl BytesCodec {
Ok(Some(src.split_to(n)))
}
/// Encode a message into a BytesMut buffer
pub fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> io::Result<()> {
let len = data.len();
@@ -276,7 +250,7 @@ mod tests {
fn test_encode_decode_small() {
let data = vec![1u8; 63];
let encoded = encode_frame(&data).unwrap();
assert_eq!(encoded.len(), 63 + 1); // 1 byte header
assert_eq!(encoded.len(), 63 + 1);
let mut codec = BytesCodec::new();
let mut buf = BytesMut::from(&encoded[..]);
@@ -288,7 +262,7 @@ mod tests {
fn test_encode_decode_medium() {
let data = vec![2u8; 1000];
let encoded = encode_frame(&data).unwrap();
assert_eq!(encoded.len(), 1000 + 2); // 2 byte header
assert_eq!(encoded.len(), 1000 + 2);
let mut codec = BytesCodec::new();
let mut buf = BytesMut::from(&encoded[..]);
@@ -300,7 +274,7 @@ mod tests {
fn test_encode_decode_large() {
let data = vec![3u8; 100000];
let encoded = encode_frame(&data).unwrap();
assert_eq!(encoded.len(), 100000 + 3); // 3 byte header
assert_eq!(encoded.len(), 100000 + 3);
let mut codec = BytesCodec::new();
let mut buf = BytesMut::from(&encoded[..]);

View File

@@ -1,57 +1,26 @@
//! RustDesk Configuration
//!
//! Configuration types for the RustDesk protocol integration.
use serde::{Deserialize, Serialize};
use typeshare::typeshare;
/// RustDesk configuration
#[typeshare]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct RustDeskConfig {
/// Enable RustDesk protocol
pub enabled: bool,
/// Rendezvous server address (hbbs), e.g., "rs.example.com" or "192.168.1.100:21116"
/// Required for RustDesk to function
pub rendezvous_server: String,
/// Relay server address (hbbr), if different from rendezvous server
/// Usually the same host as rendezvous server but different port (21117)
pub relay_server: Option<String>,
/// Relay server authentication key (licence_key)
/// Required if the relay server is configured with -k option
#[typeshare(skip)]
pub relay_key: Option<String>,
/// Device ID (9-digit number), auto-generated if empty
pub device_id: String,
/// Device password for client authentication
#[typeshare(skip)]
pub device_password: String,
/// Public key for encryption (Curve25519, base64 encoded), auto-generated
#[typeshare(skip)]
pub public_key: Option<String>,
/// Private key for encryption (Curve25519, base64 encoded), auto-generated
#[typeshare(skip)]
pub private_key: Option<String>,
/// Signing public key (Ed25519, base64 encoded), auto-generated
/// Used for SignedId verification by clients
#[typeshare(skip)]
pub signing_public_key: Option<String>,
/// Signing private key (Ed25519, base64 encoded), auto-generated
/// Used for signing SignedId messages
#[typeshare(skip)]
pub signing_private_key: Option<String>,
/// UUID for rendezvous server registration (persisted to avoid UUID_MISMATCH)
#[typeshare(skip)]
pub uuid: Option<String>,
}
@@ -75,8 +44,6 @@ impl Default for RustDeskConfig {
}
impl RustDeskConfig {
/// Check if the configuration is valid for starting the service
/// Returns true if enabled and has a valid server
pub fn is_valid(&self) -> bool {
self.enabled
&& !self.rendezvous_server.is_empty()
@@ -84,44 +51,35 @@ impl RustDeskConfig {
&& !self.device_password.is_empty()
}
/// Get the rendezvous server (user-configured)
pub fn effective_rendezvous_server(&self) -> &str {
&self.rendezvous_server
}
/// Generate a new random device ID
pub fn generate_device_id() -> String {
generate_device_id()
}
/// Generate a new random password
pub fn generate_password() -> String {
generate_random_password()
}
/// Get or generate the UUID for rendezvous registration
/// Returns (uuid_bytes, is_new) where is_new indicates if a new UUID was generated
pub fn ensure_uuid(&mut self) -> ([u8; 16], bool) {
if let Some(ref uuid_str) = self.uuid {
// Try to parse existing UUID
if let Ok(uuid) = uuid::Uuid::parse_str(uuid_str) {
return (*uuid.as_bytes(), false);
}
}
// Generate new UUID
let new_uuid = uuid::Uuid::new_v4();
self.uuid = Some(new_uuid.to_string());
(*new_uuid.as_bytes(), true)
}
/// Get the UUID bytes (returns None if not set)
pub fn get_uuid_bytes(&self) -> Option<[u8; 16]> {
self.uuid
.as_ref()
.and_then(|s| uuid::Uuid::parse_str(s).ok().map(|u| *u.as_bytes()))
}
/// Get the rendezvous server address with default port
pub fn rendezvous_addr(&self) -> String {
let server = &self.rendezvous_server;
if server.is_empty() {
@@ -133,7 +91,6 @@ impl RustDeskConfig {
}
}
/// Get the relay server address with default port
pub fn relay_addr(&self) -> Option<String> {
self.relay_server
.as_ref()
@@ -145,7 +102,6 @@ impl RustDeskConfig {
}
})
.or_else(|| {
// Default: same host as rendezvous server
let server = &self.rendezvous_server;
if !server.is_empty() {
let host = server.split(':').next().unwrap_or("");
@@ -161,7 +117,6 @@ impl RustDeskConfig {
}
}
/// Generate a random 9-digit device ID
pub fn generate_device_id() -> String {
use rand::Rng;
let mut rng = rand::rng();
@@ -169,7 +124,6 @@ pub fn generate_device_id() -> String {
id.to_string()
}
/// Generate a random 8-character password
pub fn generate_random_password() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
@@ -212,7 +166,6 @@ mod tests {
config.rendezvous_server = "example.com:21116".to_string();
assert_eq!(config.rendezvous_addr(), "example.com:21116");
// Empty server returns empty string
config.rendezvous_server = String::new();
assert_eq!(config.rendezvous_addr(), "");
}
@@ -224,17 +177,14 @@ mod tests {
..Default::default()
};
// Rendezvous server configured, relay defaults to same host
assert_eq!(config.relay_addr(), Some("example.com:21117".to_string()));
// Explicit relay server
config.relay_server = Some("relay.example.com".to_string());
assert_eq!(
config.relay_addr(),
Some("relay.example.com:21117".to_string())
);
// No rendezvous server, relay is None
config.rendezvous_server = String::new();
config.relay_server = None;
assert_eq!(config.relay_addr(), None);
@@ -247,10 +197,8 @@ mod tests {
..Default::default()
};
// When user sets a server, use it
assert_eq!(config.effective_rendezvous_server(), "custom.example.com");
// When empty, returns empty
config.rendezvous_server = String::new();
assert_eq!(config.effective_rendezvous_server(), "");
}

View File

@@ -1,12 +1,4 @@
//! RustDesk Connection Handler
//!
//! This module handles incoming connections from RustDesk clients.
//! It manages the connection lifecycle including:
//! - Connection establishment (P2P or via relay)
//! - Encrypted handshake
//! - Authentication
//! - Message routing (video, audio, input)
//! - Video frame streaming (shared with WebRTC)
//! Incoming RustDesk TCP sessions (handshake, AV, input).
use std::net::SocketAddr;
use std::sync::Arc;
@@ -23,6 +15,7 @@ use tracing::{debug, error, info, warn};
use crate::audio::AudioController;
use crate::hid::{CanonicalKey, HidController, KeyEventType, KeyboardEvent, KeyboardModifiers};
use crate::utils::hostname_from_etc;
use crate::video::codec_constraints::{
encoder_codec_to_id, encoder_codec_to_video_codec, video_codec_to_encoder_codec,
};
@@ -94,13 +87,6 @@ impl InputThrottler {
}
}
/// Get system hostname
fn get_hostname() -> String {
std::fs::read_to_string("/etc/hostname")
.map(|s| s.trim().to_string())
.unwrap_or_else(|_| "One-KVM".to_string())
}
/// Connection state
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
@@ -1165,7 +1151,7 @@ impl Connection {
let mut peer_info = PeerInfo::new();
peer_info.username = "one-kvm".to_string();
peer_info.hostname = get_hostname();
peer_info.hostname = hostname_from_etc();
peer_info.platform = RUSTDESK_COMPAT_PLATFORM.to_string();
peer_info.displays.push(display_info);
peer_info.current_display = 0;
@@ -1786,7 +1772,7 @@ async fn run_audio_streaming(
}
// Subscribe to the audio Opus stream
let mut opus_rx = match audio_controller.subscribe_opus_async().await {
let mut opus_rx = match audio_controller.subscribe_opus().await {
Some(rx) => rx,
None => {
// Audio not available, wait and retry
@@ -1831,18 +1817,18 @@ async fn run_audio_streaming(
break 'subscribe_loop;
}
result = opus_rx.changed() => {
if result.is_err() {
// Pipeline was restarted
info!("Audio pipeline closed for connection {}, re-subscribing...", conn_id);
audio_adapter.reset();
tokio::time::sleep(Duration::from_millis(100)).await;
continue 'subscribe_loop;
}
let opus_frame = match opus_rx.borrow().clone() {
result = opus_rx.recv() => {
let opus_frame = match result {
Some(frame) => frame,
None => continue,
None => {
info!(
"Audio pipeline closed for connection {}, re-subscribing...",
conn_id
);
audio_adapter.reset();
tokio::time::sleep(Duration::from_millis(100)).await;
continue 'subscribe_loop;
}
};
// Convert OpusFrame to RustDesk AudioFrame message

View File

@@ -1,10 +1,4 @@
//! RustDesk Cryptography
//!
//! This module implements the NaCl-based cryptography used by RustDesk:
//! - Curve25519 for key exchange
//! - XSalsa20-Poly1305 for authenticated encryption
//! - Ed25519 for signatures
//! - Ed25519 to Curve25519 key conversion for unified keypair usage
//! NaCl crypto (RustDesk-compatible).
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use sodiumoxide::crypto::box_::{self, Nonce, PublicKey, SecretKey};
@@ -12,7 +6,6 @@ use sodiumoxide::crypto::secretbox;
use sodiumoxide::crypto::sign::{self, ed25519};
use thiserror::Error;
/// Cryptography errors
#[derive(Debug, Error)]
pub enum CryptoError {
#[error("Failed to initialize sodiumoxide")]
@@ -31,13 +24,10 @@ pub enum CryptoError {
KeyConversionFailed,
}
/// Initialize the cryptography library
/// Must be called before using any crypto functions
pub fn init() -> Result<(), CryptoError> {
sodiumoxide::init().map_err(|_| CryptoError::InitError)
}
/// A keypair for asymmetric encryption
#[derive(Clone)]
pub struct KeyPair {
pub public_key: PublicKey,
@@ -45,7 +35,6 @@ pub struct KeyPair {
}
impl KeyPair {
/// Generate a new random keypair
pub fn generate() -> Self {
let (public_key, secret_key) = box_::gen_keypair();
Self {
@@ -54,7 +43,6 @@ impl KeyPair {
}
}
/// Create from existing keys
pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result<Self, CryptoError> {
let pk = PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?;
let sk = SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?;
@@ -64,27 +52,22 @@ impl KeyPair {
})
}
/// Get public key as bytes
pub fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_ref()
}
/// Get secret key as bytes
pub fn secret_key_bytes(&self) -> &[u8] {
self.secret_key.as_ref()
}
/// Encode public key as base64
pub fn public_key_base64(&self) -> String {
BASE64.encode(self.public_key_bytes())
}
/// Encode secret key as base64
pub fn secret_key_base64(&self) -> String {
BASE64.encode(self.secret_key_bytes())
}
/// Create from base64-encoded keys
pub fn from_base64(public_key: &str, secret_key: &str) -> Result<Self, CryptoError> {
let pk_bytes = BASE64
.decode(public_key)
@@ -96,15 +79,10 @@ impl KeyPair {
}
}
/// Generate a random nonce for box encryption
pub fn generate_nonce() -> Nonce {
box_::gen_nonce()
}
/// Encrypt data using public-key cryptography (NaCl box)
///
/// Uses the sender's secret key and receiver's public key for encryption.
/// Returns (nonce, ciphertext).
pub fn encrypt_box(
data: &[u8],
their_public_key: &PublicKey,
@@ -115,7 +93,6 @@ pub fn encrypt_box(
(nonce, ciphertext)
}
/// Decrypt data using public-key cryptography (NaCl box)
pub fn decrypt_box(
ciphertext: &[u8],
nonce: &Nonce,
@@ -126,14 +103,12 @@ pub fn decrypt_box(
.map_err(|_| CryptoError::DecryptionFailed)
}
/// Encrypt data with a precomputed shared key
pub fn encrypt_with_key(data: &[u8], key: &secretbox::Key) -> (secretbox::Nonce, Vec<u8>) {
let nonce = secretbox::gen_nonce();
let ciphertext = secretbox::seal(data, &nonce, key);
(nonce, ciphertext)
}
/// Decrypt data with a precomputed shared key
pub fn decrypt_with_key(
ciphertext: &[u8],
nonce: &secretbox::Nonce,
@@ -142,8 +117,6 @@ pub fn decrypt_with_key(
secretbox::open(ciphertext, nonce, key).map_err(|_| CryptoError::DecryptionFailed)
}
/// Compute a shared symmetric key from public/private keypair
/// This is the precomputed key for the NaCl box
pub fn precompute_key(
their_public_key: &PublicKey,
our_secret_key: &SecretKey,
@@ -151,23 +124,18 @@ pub fn precompute_key(
box_::precompute(their_public_key, our_secret_key)
}
/// Create a symmetric key from raw bytes
pub fn symmetric_key_from_slice(key: &[u8]) -> Result<secretbox::Key, CryptoError> {
secretbox::Key::from_slice(key).ok_or(CryptoError::InvalidKeyLength)
}
/// Parse a nonce from bytes
pub fn nonce_from_slice(bytes: &[u8]) -> Result<Nonce, CryptoError> {
Nonce::from_slice(bytes).ok_or(CryptoError::InvalidNonce)
}
/// Parse a public key from bytes
pub fn public_key_from_slice(bytes: &[u8]) -> Result<PublicKey, CryptoError> {
PublicKey::from_slice(bytes).ok_or(CryptoError::InvalidKeyLength)
}
/// Hash a password for storage/comparison
/// RustDesk uses simple SHA256 for password hashing
pub fn hash_password(password: &str, salt: &str) -> Vec<u8> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
@@ -176,35 +144,24 @@ pub fn hash_password(password: &str, salt: &str) -> Vec<u8> {
hasher.finalize().to_vec()
}
/// RustDesk double hash for password verification
/// Client calculates: SHA256(SHA256(password + salt) + challenge)
/// This matches what the client sends in LoginRequest
pub fn hash_password_double(password: &str, salt: &str, challenge: &str) -> Vec<u8> {
use sha2::{Digest, Sha256};
// First hash: SHA256(password + salt)
let mut hasher1 = Sha256::new();
hasher1.update(password.as_bytes());
hasher1.update(salt.as_bytes());
let first_hash = hasher1.finalize();
// Second hash: SHA256(first_hash + challenge)
let mut hasher2 = Sha256::new();
hasher2.update(first_hash);
hasher2.update(challenge.as_bytes());
hasher2.finalize().to_vec()
}
/// Verify a password hash
pub fn verify_password(password: &str, salt: &str, expected_hash: &[u8]) -> bool {
let computed = hash_password(password, salt);
// Constant-time comparison would be better, but for our use case this is acceptable
computed == expected_hash
}
/// Decrypt symmetric key using Curve25519 secret key directly
///
/// This is used when we have a fresh Curve25519 keypair for the connection
/// (as per RustDesk protocol - each connection generates a new keypair)
pub fn decrypt_symmetric_key(
their_temp_public_key: &[u8],
sealed_symmetric_key: &[u8],
@@ -217,7 +174,6 @@ pub fn decrypt_symmetric_key(
let their_pk =
PublicKey::from_slice(their_temp_public_key).ok_or(CryptoError::InvalidKeyLength)?;
// Use zero nonce as per RustDesk protocol
let nonce = box_::Nonce([0u8; box_::NONCEBYTES]);
let key_bytes = box_::open(sealed_symmetric_key, &nonce, &their_pk, our_secret_key)
@@ -226,11 +182,7 @@ pub fn decrypt_symmetric_key(
secretbox::Key::from_slice(&key_bytes).ok_or(CryptoError::InvalidKeyLength)
}
/// Encrypt a message using the negotiated symmetric key
///
/// RustDesk uses a specific nonce format for session encryption
pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) -> Vec<u8> {
// Create nonce from counter (little-endian, padded to 24 bytes)
let mut nonce_bytes = [0u8; secretbox::NONCEBYTES];
nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes());
let nonce = secretbox::Nonce(nonce_bytes);
@@ -238,13 +190,11 @@ pub fn encrypt_message(data: &[u8], key: &secretbox::Key, nonce_counter: u64) ->
secretbox::seal(data, &nonce, key)
}
/// Decrypt a message using the negotiated symmetric key
pub fn decrypt_message(
ciphertext: &[u8],
key: &secretbox::Key,
nonce_counter: u64,
) -> Result<Vec<u8>, CryptoError> {
// Create nonce from counter (little-endian, padded to 24 bytes)
let mut nonce_bytes = [0u8; secretbox::NONCEBYTES];
nonce_bytes[..8].copy_from_slice(&nonce_counter.to_le_bytes());
let nonce = secretbox::Nonce(nonce_bytes);
@@ -252,7 +202,6 @@ pub fn decrypt_message(
secretbox::open(ciphertext, &nonce, key).map_err(|_| CryptoError::DecryptionFailed)
}
/// Ed25519 signing keypair for RustDesk SignedId messages
#[derive(Clone)]
pub struct SigningKeyPair {
pub public_key: sign::PublicKey,
@@ -260,7 +209,6 @@ pub struct SigningKeyPair {
}
impl SigningKeyPair {
/// Generate a new random signing keypair
pub fn generate() -> Self {
let (public_key, secret_key) = sign::gen_keypair();
Self {
@@ -269,7 +217,6 @@ impl SigningKeyPair {
}
}
/// Create from existing keys
pub fn from_keys(public_key: &[u8], secret_key: &[u8]) -> Result<Self, CryptoError> {
let pk = sign::PublicKey::from_slice(public_key).ok_or(CryptoError::InvalidKeyLength)?;
let sk = sign::SecretKey::from_slice(secret_key).ok_or(CryptoError::InvalidKeyLength)?;
@@ -279,27 +226,22 @@ impl SigningKeyPair {
})
}
/// Get public key as bytes
pub fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_ref()
}
/// Get secret key as bytes
pub fn secret_key_bytes(&self) -> &[u8] {
self.secret_key.as_ref()
}
/// Encode public key as base64
pub fn public_key_base64(&self) -> String {
BASE64.encode(self.public_key_bytes())
}
/// Encode secret key as base64
pub fn secret_key_base64(&self) -> String {
BASE64.encode(self.secret_key_bytes())
}
/// Create from base64-encoded keys
pub fn from_base64(public_key: &str, secret_key: &str) -> Result<Self, CryptoError> {
let pk_bytes = BASE64
.decode(public_key)
@@ -310,42 +252,27 @@ impl SigningKeyPair {
Self::from_keys(&pk_bytes, &sk_bytes)
}
/// Sign a message
/// Returns the signature prepended to the message (as per RustDesk protocol)
pub fn sign(&self, message: &[u8]) -> Vec<u8> {
sign::sign(message, &self.secret_key)
}
/// Sign a message and return just the signature (64 bytes)
pub fn sign_detached(&self, message: &[u8]) -> [u8; 64] {
let sig = sign::sign_detached(message, &self.secret_key);
// Use as_ref() to access the signature bytes since the inner field is private
let sig_bytes: &[u8] = sig.as_ref();
let mut result = [0u8; 64];
result.copy_from_slice(sig_bytes);
result
}
/// Convert Ed25519 public key to Curve25519 public key for encryption
///
/// This allows using the same keypair for both signing and encryption,
/// which is required by RustDesk's protocol where clients encrypt the
/// symmetric key using the public key from IdPk.
pub fn to_curve25519_pk(&self) -> Result<PublicKey, CryptoError> {
ed25519::to_curve25519_pk(&self.public_key).map_err(|_| CryptoError::KeyConversionFailed)
}
/// Convert Ed25519 secret key to Curve25519 secret key for decryption
///
/// This allows decrypting messages that were encrypted using the
/// converted public key.
pub fn to_curve25519_sk(&self) -> Result<SecretKey, CryptoError> {
ed25519::to_curve25519_sk(&self.secret_key).map_err(|_| CryptoError::KeyConversionFailed)
}
}
/// Verify a signed message
/// Returns the original message if signature is valid
pub fn verify_signed(
signed_message: &[u8],
public_key: &sign::PublicKey,

View File

@@ -1,8 +1,3 @@
//! RustDesk Frame Adapters
//!
//! Converts One-KVM video/audio frames to RustDesk protocol format.
//! Optimized for zero-copy where possible and buffer reuse.
use bytes::Bytes;
use protobuf::Message as ProtobufMessage;
@@ -11,7 +6,6 @@ use super::protocol::hbb::message::{
CursorData, CursorPosition, EncodedVideoFrame, EncodedVideoFrames, Message, Misc, VideoFrame,
};
/// Video codec type for RustDesk
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VideoCodec {
H264,
@@ -22,7 +16,6 @@ pub enum VideoCodec {
}
impl VideoCodec {
/// Get the codec ID for the RustDesk protocol
pub fn to_codec_id(self) -> i32 {
match self {
VideoCodec::H264 => 0,
@@ -34,21 +27,15 @@ impl VideoCodec {
}
}
/// Video frame adapter for converting to RustDesk format
pub struct VideoFrameAdapter {
/// Current codec
codec: VideoCodec,
/// Frame sequence number
seq: u32,
/// Timestamp offset
timestamp_base: u64,
/// Cached H264 SPS/PPS (Annex B NAL without start code)
h264_sps: Option<Bytes>,
h264_pps: Option<Bytes>,
}
impl VideoFrameAdapter {
/// Create a new video frame adapter
pub fn new(codec: VideoCodec) -> Self {
Self {
codec,
@@ -59,14 +46,10 @@ impl VideoFrameAdapter {
}
}
/// Set codec type
pub fn set_codec(&mut self, codec: VideoCodec) {
self.codec = codec;
}
/// Convert encoded video data to RustDesk Message (zero-copy version)
///
/// This version takes Bytes directly to avoid copying the frame data.
pub fn encode_frame_from_bytes(
&mut self,
data: Bytes,
@@ -74,7 +57,6 @@ impl VideoFrameAdapter {
timestamp_ms: u64,
) -> Message {
let data = self.prepare_h264_frame(data, is_keyframe);
// Calculate relative timestamp
if self.seq == 0 {
self.timestamp_base = timestamp_ms;
}
@@ -87,11 +69,9 @@ impl VideoFrameAdapter {
self.seq = self.seq.wrapping_add(1);
// Wrap in EncodedVideoFrames container
let mut frames = EncodedVideoFrames::new();
frames.frames.push(frame);
// Create the appropriate VideoFrame variant based on codec
let mut video_frame = VideoFrame::new();
match self.codec {
VideoCodec::H264 => video_frame.union = Some(vf_union::Union::H264s(frames)),
@@ -111,7 +91,6 @@ impl VideoFrameAdapter {
return data;
}
// Parse SPS/PPS from Annex B data (without start codes)
let (sps, pps) = crate::webrtc::rtp::extract_sps_pps(&data);
let mut has_sps = false;
let mut has_pps = false;
@@ -125,7 +104,6 @@ impl VideoFrameAdapter {
has_pps = true;
}
// Inject cached SPS/PPS before IDR when missing
if is_keyframe && (!has_sps || !has_pps) {
if let (Some(sps), Some(pps)) = (self.h264_sps.as_ref(), self.h264_pps.as_ref()) {
let mut out = Vec::with_capacity(8 + sps.len() + pps.len() + data.len());
@@ -141,14 +119,10 @@ impl VideoFrameAdapter {
data
}
/// Convert encoded video data to RustDesk Message
pub fn encode_frame(&mut self, data: &[u8], is_keyframe: bool, timestamp_ms: u64) -> Message {
self.encode_frame_from_bytes(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms)
}
/// Encode frame to bytes for sending (zero-copy version)
///
/// Takes Bytes directly to avoid copying the frame data.
pub fn encode_frame_bytes_zero_copy(
&mut self,
data: Bytes,
@@ -159,7 +133,6 @@ impl VideoFrameAdapter {
Bytes::from(msg.write_to_bytes().unwrap_or_default())
}
/// Encode frame to bytes for sending
pub fn encode_frame_bytes(
&mut self,
data: &[u8],
@@ -169,24 +142,18 @@ impl VideoFrameAdapter {
self.encode_frame_bytes_zero_copy(Bytes::copy_from_slice(data), is_keyframe, timestamp_ms)
}
/// Get current sequence number
pub fn seq(&self) -> u32 {
self.seq
}
}
/// Audio frame adapter for converting to RustDesk format
pub struct AudioFrameAdapter {
/// Sample rate
sample_rate: u32,
/// Channels
channels: u8,
/// Format sent flag
format_sent: bool,
}
impl AudioFrameAdapter {
/// Create a new audio frame adapter
pub fn new(sample_rate: u32, channels: u8) -> Self {
Self {
sample_rate,
@@ -195,7 +162,6 @@ impl AudioFrameAdapter {
}
}
/// Create audio format message (should be sent once before audio frames)
pub fn create_format_message(&mut self) -> Message {
self.format_sent = true;
@@ -211,12 +177,10 @@ impl AudioFrameAdapter {
msg
}
/// Check if format message has been sent
pub fn format_sent(&self) -> bool {
self.format_sent
}
/// Convert Opus audio data to RustDesk Message
pub fn encode_opus_frame(&self, data: &[u8]) -> Message {
let mut frame = AudioFrame::new();
frame.data = Bytes::copy_from_slice(data);
@@ -226,23 +190,19 @@ impl AudioFrameAdapter {
msg
}
/// Encode Opus frame to bytes for sending
pub fn encode_opus_bytes(&self, data: &[u8]) -> Bytes {
let msg = self.encode_opus_frame(data);
Bytes::from(msg.write_to_bytes().unwrap_or_default())
}
/// Reset state (call when restarting audio stream)
pub fn reset(&mut self) {
self.format_sent = false;
}
}
/// Cursor data adapter
pub struct CursorAdapter;
impl CursorAdapter {
/// Create cursor data message
pub fn encode_cursor(
id: u64,
hotx: i32,
@@ -264,7 +224,6 @@ impl CursorAdapter {
msg
}
/// Create cursor position message
pub fn encode_position(x: i32, y: i32) -> Message {
let mut pos = CursorPosition::new();
pos.x = x;
@@ -284,7 +243,6 @@ mod tests {
fn test_video_frame_encoding() {
let mut adapter = VideoFrameAdapter::new(VideoCodec::H264);
// Encode a keyframe
let data = vec![0x00, 0x00, 0x00, 0x01, 0x67]; // H264 SPS NAL
let msg = adapter.encode_frame(&data, true, 0);
@@ -324,7 +282,6 @@ mod tests {
fn test_audio_frame_encoding() {
let adapter = AudioFrameAdapter::new(48000, 2);
// Encode an Opus frame
let opus_data = vec![0xFC, 0x01, 0x02]; // Fake Opus data
let msg = adapter.encode_opus_frame(&opus_data);

View File

@@ -1,7 +1,3 @@
//! RustDesk HID Adapter
//!
//! Converts RustDesk HID events (KeyEvent, MouseEvent) to One-KVM HID events.
use super::protocol::hbb::message::key_event as ke_union;
use super::protocol::{ControlKey, KeyEvent, MouseEvent};
use crate::hid::{
@@ -10,8 +6,6 @@ use crate::hid::{
};
use protobuf::Enum;
/// Mouse event types from RustDesk protocol
/// mask = (button << 3) | event_type
pub mod mouse_type {
pub const MOVE: i32 = 0;
pub const DOWN: i32 = 1;
@@ -21,7 +15,6 @@ pub mod mouse_type {
pub const MOVE_RELATIVE: i32 = 5;
}
/// Mouse button IDs from RustDesk protocol (before left shift by 3)
pub mod mouse_button {
pub const LEFT: i32 = 0x01;
pub const RIGHT: i32 = 0x02;
@@ -30,9 +23,6 @@ pub mod mouse_button {
pub const FORWARD: i32 = 0x10;
}
/// Convert RustDesk MouseEvent to One-KVM MouseEvent(s)
/// Returns a Vec because a single RustDesk event may need multiple One-KVM events
/// (e.g., move + button + scroll)
pub fn convert_mouse_event(
event: &MouseEvent,
screen_width: u32,
@@ -41,23 +31,18 @@ pub fn convert_mouse_event(
) -> Vec<OneKvmMouseEvent> {
let mut events = Vec::new();
// Parse RustDesk mask format: (button << 3) | event_type
let event_type = event.mask & 0x07;
let button_id = event.mask >> 3;
let include_abs_move = !relative_mode;
match event_type {
mouse_type::MOVE => {
// RustDesk uses absolute coordinates
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
// Normalize to 0-32767 range for absolute mouse (USB HID standard)
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
let abs_y = ((y as u64 * 32767) / screen_height.max(1) as u64) as i32;
// Move event - may have button held down (button_id > 0 means dragging)
// Just send move, button state is tracked separately by HID backend
events.push(OneKvmMouseEvent {
event_type: MouseEventType::MoveAbs,
x: abs_x,
@@ -67,7 +52,6 @@ pub fn convert_mouse_event(
});
}
mouse_type::MOVE_RELATIVE => {
// Relative movement uses delta values directly (dx, dy).
events.push(OneKvmMouseEvent {
event_type: MouseEventType::Move,
x: event.x,
@@ -78,7 +62,6 @@ pub fn convert_mouse_event(
}
mouse_type::DOWN => {
if include_abs_move {
// Button down - first move, then press
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -104,7 +87,6 @@ pub fn convert_mouse_event(
}
mouse_type::UP => {
if include_abs_move {
// Button up - first move, then release
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -130,7 +112,6 @@ pub fn convert_mouse_event(
}
mouse_type::WHEEL => {
if include_abs_move {
// Scroll event - move first, then scroll
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -144,9 +125,6 @@ pub fn convert_mouse_event(
});
}
// RustDesk encodes scroll direction in the y coordinate
// Positive y = scroll up, Negative y = scroll down
// The button_id field is not used for direction
let scroll = if event.y > 0 { 1i8 } else { -1i8 };
events.push(OneKvmMouseEvent {
event_type: MouseEventType::Scroll,
@@ -158,7 +136,6 @@ pub fn convert_mouse_event(
}
_ => {
if include_abs_move {
// Unknown event type, just move
let x = event.x.max(0) as u32;
let y = event.y.max(0) as u32;
let abs_x = ((x as u64 * 32767) / screen_width.max(1) as u64) as i32;
@@ -177,7 +154,6 @@ pub fn convert_mouse_event(
events
}
/// Convert RustDesk button ID to One-KVM MouseButton
fn button_id_to_button(button_id: i32) -> Option<MouseButton> {
match button_id {
mouse_button::LEFT => Some(MouseButton::Left),
@@ -187,34 +163,19 @@ fn button_id_to_button(button_id: i32) -> Option<MouseButton> {
}
}
/// Convert RustDesk KeyEvent to One-KVM KeyboardEvent
///
/// RustDesk KeyEvent has two modes:
/// - down=true/false: Key state (pressed/released)
/// - press=true: Complete key press (down + up), used for typing
///
/// For press=true events, we only send Down and let the caller handle
/// the timing for Up event if needed. Most systems handle this correctly.
pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
// Determine if this is a key down or key up event
// press=true means "key was pressed" (down event)
// down=true means key is currently held down
// down=false with press=false means key was released
let event_type = if event.down || event.press {
KeyEventType::Down
} else {
KeyEventType::Up
};
// For modifier keys sent as ControlKey, don't include them in modifiers
// to avoid double-pressing. The modifier will be tracked by HID state.
let modifiers = if is_modifier_control_key(event) {
KeyboardModifiers::default()
} else {
parse_modifiers(event)
};
// Handle control keys
if let Some(ke_union::Union::ControlKey(ck)) = &event.union {
if let Some(key) = control_key_to_hid(ck.value()) {
let key = CanonicalKey::from_hid_usage(key)?;
@@ -226,9 +187,7 @@ pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
}
}
// Handle character keys (chr field contains platform-specific keycode)
if let Some(ke_union::Union::Chr(chr)) = &event.union {
// chr contains USB HID scancode on Windows, X11 keycode on Linux
if let Some(key) = keycode_to_hid(*chr) {
let key = CanonicalKey::from_hid_usage(key)?;
return Some(KeyboardEvent {
@@ -239,13 +198,9 @@ pub fn convert_key_event(event: &KeyEvent) -> Option<KeyboardEvent> {
}
}
// Handle unicode (for text input, we'd need to convert to scancodes)
// Unicode input requires more complex handling, skip for now
None
}
/// Check if the event is a modifier key sent as ControlKey
fn is_modifier_control_key(event: &KeyEvent) -> bool {
if let Some(ke_union::Union::ControlKey(ck)) = &event.union {
let val = ck.value();
@@ -260,7 +215,6 @@ fn is_modifier_control_key(event: &KeyEvent) -> bool {
false
}
/// Parse modifier keys from RustDesk KeyEvent into KeyboardModifiers
fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers {
let mut modifiers = KeyboardModifiers::default();
@@ -281,7 +235,6 @@ fn parse_modifiers(event: &KeyEvent) -> KeyboardModifiers {
modifiers
}
/// Convert RustDesk ControlKey to USB HID usage code
fn control_key_to_hid(key: i32) -> Option<u8> {
match key {
x if x == ControlKey::Alt as i32 => Some(0xE2), // Left Alt
@@ -342,67 +295,47 @@ fn control_key_to_hid(key: i32) -> Option<u8> {
}
}
/// Convert platform keycode to USB HID usage code
/// Handles Windows Virtual Key Codes, X11 keycodes, and ASCII codes
fn keycode_to_hid(keycode: u32) -> Option<u8> {
// First try ASCII code mapping (RustDesk often sends ASCII codes)
if let Some(hid) = ascii_to_hid(keycode) {
return Some(hid);
}
// Then try Windows Virtual Key Code mapping
if let Some(hid) = windows_vk_to_hid(keycode) {
return Some(hid);
}
// Fall back to X11 keycode mapping for Linux clients
x11_keycode_to_hid(keycode)
}
/// Convert ASCII code to USB HID usage code
fn ascii_to_hid(ascii: u32) -> Option<u8> {
match ascii {
// Lowercase letters a-z (ASCII 97-122)
97..=122 => {
// USB HID: a=0x04, b=0x05, ..., z=0x1D
Some((ascii - 97 + 0x04) as u8)
}
// Uppercase letters A-Z (ASCII 65-90)
65..=90 => {
// USB HID: A=0x04, B=0x05, ..., Z=0x1D (same as lowercase)
Some((ascii - 65 + 0x04) as u8)
}
// Numbers 0-9 (ASCII 48-57)
97..=122 => Some((ascii - 97 + 0x04) as u8),
65..=90 => Some((ascii - 65 + 0x04) as u8),
48 => Some(0x27), // 0
49..=57 => Some((ascii - 49 + 0x1E) as u8), // 1-9
// Common punctuation
32 => Some(0x2C), // Space
13 => Some(0x28), // Enter (CR)
10 => Some(0x28), // Enter (LF)
9 => Some(0x2B), // Tab
27 => Some(0x29), // Escape
8 => Some(0x2A), // Backspace
127 => Some(0x4C), // Delete
// Symbols (US keyboard layout)
45 => Some(0x2D), // -
61 => Some(0x2E), // =
91 => Some(0x2F), // [
93 => Some(0x30), // ]
92 => Some(0x31), // \
59 => Some(0x33), // ;
39 => Some(0x34), // '
96 => Some(0x35), // `
44 => Some(0x36), // ,
46 => Some(0x37), // .
47 => Some(0x38), // /
32 => Some(0x2C), // Space
13 => Some(0x28), // Enter (CR)
10 => Some(0x28), // Enter (LF)
9 => Some(0x2B), // Tab
27 => Some(0x29), // Escape
8 => Some(0x2A), // Backspace
127 => Some(0x4C), // Delete
45 => Some(0x2D), // -
61 => Some(0x2E), // =
91 => Some(0x2F), // [
93 => Some(0x30), // ]
92 => Some(0x31), // \
59 => Some(0x33), // ;
39 => Some(0x34), // '
96 => Some(0x35), // `
44 => Some(0x36), // ,
46 => Some(0x37), // .
47 => Some(0x38), // /
_ => None,
}
}
/// Convert Windows Virtual Key Code to USB HID usage code
fn windows_vk_to_hid(vk: u32) -> Option<u8> {
match vk {
// Letters A-Z (VK_A=0x41 to VK_Z=0x5A)
0x41..=0x5A => {
// USB HID: A=0x04, B=0x05, ..., Z=0x1D
let letter = (vk - 0x41) as u8;
Some(match letter {
0 => 0x04, // A
@@ -434,21 +367,16 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
_ => return None,
})
}
// Numbers 0-9 (VK_0=0x30 to VK_9=0x39)
0x30 => Some(0x27), // 0
0x31..=0x39 => Some((vk - 0x31 + 0x1E) as u8), // 1-9
// Numpad 0-9 (VK_NUMPAD0=0x60 to VK_NUMPAD9=0x69)
0x60 => Some(0x62), // Numpad 0
0x61..=0x69 => Some((vk - 0x61 + 0x59) as u8), // Numpad 1-9
// Numpad operators
0x6A => Some(0x55), // Numpad *
0x6B => Some(0x57), // Numpad +
0x6D => Some(0x56), // Numpad -
0x6E => Some(0x63), // Numpad .
0x6F => Some(0x54), // Numpad /
// Function keys F1-F12 (VK_F1=0x70 to VK_F12=0x7B)
0x6A => Some(0x55), // Numpad *
0x6B => Some(0x57), // Numpad +
0x6D => Some(0x56), // Numpad -
0x6E => Some(0x63), // Numpad .
0x6F => Some(0x54), // Numpad /
0x70..=0x7B => Some((vk - 0x70 + 0x3A) as u8),
// Special keys
0x08 => Some(0x2A), // Backspace
0x09 => Some(0x2B), // Tab
0x0D => Some(0x28), // Enter
@@ -464,7 +392,6 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
0x28 => Some(0x51), // Down Arrow
0x2D => Some(0x49), // Insert
0x2E => Some(0x4C), // Delete
// OEM keys (US keyboard layout)
0xBA => Some(0x33), // ; :
0xBB => Some(0x2E), // = +
0xBC => Some(0x36), // , <
@@ -476,66 +403,56 @@ fn windows_vk_to_hid(vk: u32) -> Option<u8> {
0xDC => Some(0x31), // \ |
0xDD => Some(0x30), // ] }
0xDE => Some(0x34), // ' "
// Lock keys
0x14 => Some(0x39), // Caps Lock
0x90 => Some(0x53), // Num Lock
0x91 => Some(0x47), // Scroll Lock
// Print Screen, Pause
0x2C => Some(0x46), // Print Screen
0x13 => Some(0x48), // Pause
_ => None,
}
}
/// Convert X11 keycode to USB HID usage code (for Linux clients)
fn x11_keycode_to_hid(keycode: u32) -> Option<u8> {
match keycode {
// Numbers: X11 keycode 10="1", 11="2", ..., 18="9", 19="0"
10..=18 => Some((keycode - 10 + 0x1E) as u8), // 1-9
19 => Some(0x27), // 0
// Punctuation
20 => Some(0x2D), // -
21 => Some(0x2E), // =
34 => Some(0x2F), // [
35 => Some(0x30), // ]
// Letters (X11 keycodes are row-based)
// Row 1: q(24) w(25) e(26) r(27) t(28) y(29) u(30) i(31) o(32) p(33)
24 => Some(0x14), // q
25 => Some(0x1A), // w
26 => Some(0x08), // e
27 => Some(0x15), // r
28 => Some(0x17), // t
29 => Some(0x1C), // y
30 => Some(0x18), // u
31 => Some(0x0C), // i
32 => Some(0x12), // o
33 => Some(0x13), // p
// Row 2: a(38) s(39) d(40) f(41) g(42) h(43) j(44) k(45) l(46)
38 => Some(0x04), // a
39 => Some(0x16), // s
40 => Some(0x07), // d
41 => Some(0x09), // f
42 => Some(0x0A), // g
43 => Some(0x0B), // h
44 => Some(0x0D), // j
45 => Some(0x0E), // k
46 => Some(0x0F), // l
47 => Some(0x33), // ;
48 => Some(0x34), // '
49 => Some(0x35), // `
51 => Some(0x31), // \
// Row 3: z(52) x(53) c(54) v(55) b(56) n(57) m(58)
52 => Some(0x1D), // z
53 => Some(0x1B), // x
54 => Some(0x06), // c
55 => Some(0x19), // v
56 => Some(0x05), // b
57 => Some(0x11), // n
58 => Some(0x10), // m
59 => Some(0x36), // ,
60 => Some(0x37), // .
61 => Some(0x38), // /
// Space
20 => Some(0x2D), // -
21 => Some(0x2E), // =
34 => Some(0x2F), // [
35 => Some(0x30), // ]
24 => Some(0x14), // q
25 => Some(0x1A), // w
26 => Some(0x08), // e
27 => Some(0x15), // r
28 => Some(0x17), // t
29 => Some(0x1C), // y
30 => Some(0x18), // u
31 => Some(0x0C), // i
32 => Some(0x12), // o
33 => Some(0x13), // p
38 => Some(0x04), // a
39 => Some(0x16), // s
40 => Some(0x07), // d
41 => Some(0x09), // f
42 => Some(0x0A), // g
43 => Some(0x0B), // h
44 => Some(0x0D), // j
45 => Some(0x0E), // k
46 => Some(0x0F), // l
47 => Some(0x33), // ;
48 => Some(0x34), // '
49 => Some(0x35), // `
51 => Some(0x31), // \
52 => Some(0x1D), // z
53 => Some(0x1B), // x
54 => Some(0x06), // c
55 => Some(0x19), // v
56 => Some(0x05), // b
57 => Some(0x11), // n
58 => Some(0x10), // m
59 => Some(0x36), // ,
60 => Some(0x37), // .
61 => Some(0x38), // /
65 => Some(0x2C),
_ => None,
}
@@ -573,7 +490,6 @@ mod tests {
let events = convert_mouse_event(&event, 1920, 1080, false);
assert!(events.len() >= 2);
// Should have a button down event
assert!(events
.iter()
.any(|e| e.event_type == MouseEventType::Down && e.button == Some(MouseButton::Left)));

View File

@@ -1,17 +1,4 @@
//! RustDesk Protocol Integration Module
//!
//! This module implements the RustDesk client protocol, enabling One-KVM devices
//! to be accessed via standard RustDesk clients through existing hbbs/hbbr servers.
//!
//! ## Architecture
//!
//! - `config`: Configuration types for RustDesk settings
//! - `protocol`: Protobuf message wrappers and serialization
//! - `crypto`: NaCl cryptography (key generation, encryption, signatures)
//! - `rendezvous`: Communication with hbbs rendezvous server
//! - `connection`: Client session handling
//! - `frame_adapters`: Video/audio frame conversion to RustDesk format
//! - `hid_adapter`: RustDesk HID events to One-KVM conversion
//! RustDesk peer protocol (hbbs / hbbr).
pub mod bytes_codec;
pub mod config;
@@ -44,19 +31,13 @@ use self::connection::ConnectionManager;
use self::protocol::{make_local_addr, make_relay_response, make_request_relay};
use self::rendezvous::{AddrMangle, RendezvousMediator, RendezvousStatus};
/// Relay connection timeout
const RELAY_CONNECT_TIMEOUT_MS: u64 = 10_000;
/// RustDesk service status
#[derive(Debug, Clone, PartialEq)]
pub enum ServiceStatus {
/// Service is stopped
Stopped,
/// Service is starting
Starting,
/// Service is running and registered with rendezvous server
Running,
/// Service encountered an error
Error(String),
}
@@ -71,15 +52,8 @@ impl std::fmt::Display for ServiceStatus {
}
}
/// Default port for direct TCP connections (same as RustDesk)
const DIRECT_LISTEN_PORT: u16 = 21118;
/// RustDesk Service
///
/// Manages the RustDesk protocol integration, including:
/// - Registration with hbbs rendezvous server
/// - Accepting connections from RustDesk clients
/// - Streaming video/audio and receiving HID input
pub struct RustDeskService {
config: Arc<RwLock<RustDeskConfig>>,
status: Arc<RwLock<ServiceStatus>>,
@@ -95,7 +69,6 @@ pub struct RustDeskService {
}
impl RustDeskService {
/// Create a new RustDesk service instance
pub fn new(
config: RustDeskConfig,
video_manager: Arc<VideoStreamManager>,
@@ -120,42 +93,34 @@ impl RustDeskService {
}
}
/// Get the port for direct TCP connections
pub fn listen_port(&self) -> u16 {
*self.listen_port.read()
}
/// Get current service status
pub fn status(&self) -> ServiceStatus {
self.status.read().clone()
}
/// Get current configuration
pub fn config(&self) -> RustDeskConfig {
self.config.read().clone()
}
/// Update configuration
pub fn update_config(&self, config: RustDeskConfig) {
*self.config.write() = config;
}
/// Get rendezvous status
pub fn rendezvous_status(&self) -> Option<RendezvousStatus> {
self.rendezvous.read().as_ref().map(|r| r.status())
}
/// Get device ID
pub fn device_id(&self) -> String {
self.config.read().device_id.clone()
}
/// Get connection count
pub fn connection_count(&self) -> usize {
self.connection_manager.connection_count()
}
/// Start the RustDesk service
pub async fn start(&self) -> anyhow::Result<()> {
let config = self.config.read().clone();
@@ -181,74 +146,44 @@ impl RustDeskService {
config.rendezvous_addr()
);
// Initialize crypto
if let Err(e) = crypto::init() {
error!("Failed to initialize crypto: {}", e);
*self.status.write() = ServiceStatus::Error(e.to_string());
return Err(e.into());
}
// Create and start rendezvous mediator with relay callback
let mediator = Arc::new(RendezvousMediator::new(config.clone()));
// Set the keypair on connection manager (Curve25519 for encryption)
let keypair = mediator.ensure_keypair();
self.connection_manager.set_keypair(keypair);
// Set the signing keypair on connection manager (Ed25519 for SignedId)
let signing_keypair = mediator.ensure_signing_keypair();
self.connection_manager.set_signing_keypair(signing_keypair);
// Set the HID controller on connection manager
self.connection_manager.set_hid(self.hid.clone());
// Set the audio controller on connection manager for audio streaming
self.connection_manager.set_audio(self.audio.clone());
// Set the video manager on connection manager for video streaming
self.connection_manager
.set_video_manager(self.video_manager.clone());
*self.rendezvous.write() = Some(mediator.clone());
// Start TCP listener BEFORE the rendezvous mediator to ensure port is set correctly
// This prevents race condition where mediator starts registration with wrong port
let (tcp_handles, listen_port) = self.start_tcp_listener_with_port().await?;
*self.tcp_listener_handle.write() = Some(tcp_handles);
// Set the listen port on mediator before starting the registration loop
mediator.set_listen_port(listen_port);
// Create relay request handler
let connection_manager = self.connection_manager.clone();
let video_manager = self.video_manager.clone();
let hid = self.hid.clone();
let audio = self.audio.clone();
let service_config = self.config.clone();
// Set the punch callback on the mediator (tries P2P first, then relay)
let connection_manager_punch = self.connection_manager.clone();
let video_manager_punch = self.video_manager.clone();
let hid_punch = self.hid.clone();
let audio_punch = self.audio.clone();
let service_config_punch = self.config.clone();
mediator.set_punch_callback(Arc::new(
mediator.set_punch_callback(Arc::new({
let connection_manager = connection_manager.clone();
let service_config = service_config.clone();
move |peer_addr, rendezvous_addr, relay_server, uuid, socket_addr, device_id| {
let conn_mgr = connection_manager_punch.clone();
let video = video_manager_punch.clone();
let hid = hid_punch.clone();
let audio = audio_punch.clone();
let config = service_config_punch.clone();
let conn_mgr = connection_manager.clone();
let config = service_config.clone();
tokio::spawn(async move {
// Get relay_key from config (no public server fallback)
let relay_key = {
let cfg = config.read();
cfg.relay_key.clone().unwrap_or_default()
};
// Try P2P direct connection first
if let Some(addr) = peer_addr {
info!("Attempting P2P direct connection to {}", addr);
match punch::try_direct_connection(addr).await {
@@ -265,7 +200,7 @@ impl RustDeskService {
}
}
// Fall back to relay
let relay_key = rustdesk_relay_key(&config);
if let Err(e) = handle_relay_request(
&rendezvous_addr,
&relay_server,
@@ -274,34 +209,23 @@ impl RustDeskService {
&device_id,
&relay_key,
conn_mgr,
video,
hid,
audio,
)
.await
{
error!("Failed to handle relay request: {}", e);
}
});
},
));
}
}));
// Set the relay callback on the mediator
mediator.set_relay_callback(Arc::new(
mediator.set_relay_callback(Arc::new({
let connection_manager = connection_manager.clone();
let service_config = service_config.clone();
move |rendezvous_addr, relay_server, uuid, socket_addr, device_id| {
let conn_mgr = connection_manager.clone();
let video = video_manager.clone();
let hid = hid.clone();
let audio = audio.clone();
let config = service_config.clone();
tokio::spawn(async move {
// Get relay_key from config (no public server fallback)
let relay_key = {
let cfg = config.read();
cfg.relay_key.clone().unwrap_or_default()
};
let relay_key = rustdesk_relay_key(&config);
if let Err(e) = handle_relay_request(
&rendezvous_addr,
&relay_server,
@@ -310,19 +234,15 @@ impl RustDeskService {
&device_id,
&relay_key,
conn_mgr,
video,
hid,
audio,
)
.await
{
error!("Failed to handle relay request: {}", e);
}
});
},
));
}
}));
// Set the intranet callback on the mediator for same-LAN connections
let connection_manager2 = self.connection_manager.clone();
mediator.set_intranet_callback(Arc::new(
move |rendezvous_addr, peer_socket_addr, local_addr, relay_server, device_id| {
@@ -345,7 +265,6 @@ impl RustDeskService {
},
));
// Spawn rendezvous task
let status = self.status.clone();
let handle = tokio::spawn(async move {
loop {
@@ -357,7 +276,6 @@ impl RustDeskService {
Err(e) => {
error!("Rendezvous mediator error: {}", e);
*status.write() = ServiceStatus::Error(e.to_string());
// Wait before retry
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
*status.write() = ServiceStatus::Starting;
}
@@ -372,10 +290,7 @@ impl RustDeskService {
Ok(())
}
/// Start TCP listener for direct peer connections
/// Returns the join handle and the port that was bound
async fn start_tcp_listener_with_port(&self) -> anyhow::Result<(Vec<JoinHandle<()>>, u16)> {
// Try to bind to the default port, or find an available port
let (listeners, listen_port) = match self.bind_direct_listeners(DIRECT_LISTEN_PORT) {
Ok(result) => result,
Err(err) => {
@@ -453,7 +368,6 @@ impl RustDeskService {
Ok((listeners, listen_port))
}
/// Stop the RustDesk service
pub async fn stop(&self) -> anyhow::Result<()> {
if self.status() == ServiceStatus::Stopped {
return Ok(());
@@ -461,23 +375,18 @@ impl RustDeskService {
info!("Stopping RustDesk service");
// Send shutdown signal (this will stop the TCP listener)
let _ = self.shutdown_tx.send(());
// Close all connections
self.connection_manager.close_all();
// Stop rendezvous mediator
if let Some(mediator) = self.rendezvous.read().as_ref() {
mediator.stop();
}
// Wait for rendezvous task to finish
if let Some(handle) = self.rendezvous_handle.write().take() {
handle.abort();
}
// Wait for TCP listener task to finish
if let Some(handles) = self.tcp_listener_handle.write().take() {
for handle in handles {
handle.abort();
@@ -490,15 +399,12 @@ impl RustDeskService {
Ok(())
}
/// Restart the service with new configuration
pub async fn restart(&self, config: RustDeskConfig) -> anyhow::Result<()> {
self.stop().await?;
self.update_config(config);
self.start().await
}
/// Save keypair and UUID to config
/// Returns the updated config if changes were made
pub fn save_credentials(&self) -> Option<RustDeskConfig> {
if let Some(mediator) = self.rendezvous.read().as_ref() {
let kp = mediator.ensure_keypair();
@@ -506,7 +412,6 @@ impl RustDeskService {
let mut config = self.config.write();
let mut changed = false;
// Save encryption keypair (Curve25519)
let pk = kp.public_key_base64();
let sk = kp.secret_key_base64();
if config.public_key.as_ref() != Some(&pk) || config.private_key.as_ref() != Some(&sk) {
@@ -515,7 +420,6 @@ impl RustDeskService {
changed = true;
}
// Save signing keypair (Ed25519)
let signing_pk = skp.public_key_base64();
let signing_sk = skp.secret_key_base64();
if config.signing_public_key.as_ref() != Some(&signing_pk)
@@ -526,7 +430,6 @@ impl RustDeskService {
changed = true;
}
// Save UUID if it was newly generated
if mediator.uuid_needs_save() {
let mediator_config = mediator.config();
if let Some(uuid) = mediator_config.uuid {
@@ -545,21 +448,16 @@ impl RustDeskService {
None
}
/// Save keypair to config (deprecated, use save_credentials instead)
#[deprecated(note = "Use save_credentials instead")]
pub fn save_keypair(&self) {
let _ = self.save_credentials();
}
}
/// Handle relay request from rendezvous server
///
/// Correct flow based on RustDesk protocol:
/// 1. Connect to RENDEZVOUS server (not relay!)
/// 2. Send RelayResponse with client's socket_addr
/// 3. Connect to RELAY server
/// 4. Accept connection without waiting for response
#[allow(clippy::too_many_arguments)]
fn rustdesk_relay_key(config: &Arc<RwLock<RustDeskConfig>>) -> String {
config.read().relay_key.clone().unwrap_or_default()
}
async fn handle_relay_request(
rendezvous_addr: &str,
relay_server: &str,
@@ -568,16 +466,12 @@ async fn handle_relay_request(
device_id: &str,
relay_key: &str,
connection_manager: Arc<ConnectionManager>,
_video_manager: Arc<VideoStreamManager>,
_hid: Arc<HidController>,
_audio: Arc<AudioController>,
) -> anyhow::Result<()> {
info!(
"Handling relay request: rendezvous={}, relay={}, uuid={}",
rendezvous_addr, relay_server, uuid
);
// Step 1: Connect to RENDEZVOUS server and send RelayResponse
let rendezvous_socket_addr: SocketAddr = tokio::net::lookup_host(rendezvous_addr)
.await?
.next()
@@ -597,8 +491,7 @@ async fn handle_relay_request(
rendezvous_socket_addr
);
// Send RelayResponse to rendezvous server with client's socket_addr
// IMPORTANT: Include our device ID so rendezvous server can look up and sign our public key
// Rendezvous looks up our pk by device id (must set `id`, not raw pk on wire).
let relay_response = make_relay_response(uuid, socket_addr, relay_server, device_id);
let bytes = relay_response
.write_to_bytes()
@@ -606,10 +499,8 @@ async fn handle_relay_request(
bytes_codec::write_frame(&mut rendezvous_stream, &bytes).await?;
debug!("Sent RelayResponse to rendezvous server for uuid={}", uuid);
// Close rendezvous connection - we don't need to wait for response
drop(rendezvous_stream);
// Step 2: Connect to RELAY server and send RequestRelay to identify ourselves
let relay_addr: SocketAddr = tokio::net::lookup_host(relay_server)
.await?
.next()
@@ -624,9 +515,7 @@ async fn handle_relay_request(
info!("Connected to relay server at {}", relay_addr);
// Send RequestRelay to relay server with our uuid, licence_key, and peer's socket_addr
// The licence_key is required if the relay server is configured with -k option
// The socket_addr is CRITICAL - the relay server uses it to match us with the peer
// Relay pairs peers by uuid + mangled peer socket_addr (required when hbbr uses -k).
let request_relay = make_request_relay(uuid, relay_key, socket_addr);
let bytes = request_relay
.write_to_bytes()
@@ -634,10 +523,8 @@ async fn handle_relay_request(
bytes_codec::write_frame(&mut stream, &bytes).await?;
debug!("Sent RequestRelay to relay server for uuid={}", uuid);
// Decode peer address for logging
let peer_addr = rendezvous::AddrMangle::decode(socket_addr).unwrap_or(relay_addr);
// Step 3: Accept connection - relay server bridges the connection
connection_manager
.accept_connection(stream, peer_addr)
.await?;
@@ -649,14 +536,6 @@ async fn handle_relay_request(
Ok(())
}
/// Handle intranet/same-LAN connection request
///
/// When the server determines that the client and peer are on the same intranet
/// (same public IP or both on LAN), it sends FetchLocalAddr to the peer.
/// The peer must:
/// 1. Open a TCP connection to the rendezvous server
/// 2. Send LocalAddr with our local address
/// 3. Accept the peer connection over that same TCP stream
async fn handle_intranet_request(
rendezvous_addr: &str,
peer_socket_addr: &[u8],
@@ -670,11 +549,9 @@ async fn handle_intranet_request(
rendezvous_addr, local_addr, device_id
);
// Decode peer address for logging
let peer_addr = AddrMangle::decode(peer_socket_addr);
debug!("Peer address from FetchLocalAddr: {:?}", peer_addr);
// Connect to rendezvous server via TCP with timeout
let mut stream =
tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(rendezvous_addr))
.await
@@ -685,7 +562,6 @@ async fn handle_intranet_request(
rendezvous_addr
);
// Build LocalAddr message with our local address (mangled)
let local_addr_bytes = AddrMangle::encode(local_addr);
let msg = make_local_addr(
peer_socket_addr,
@@ -698,24 +574,16 @@ async fn handle_intranet_request(
.write_to_bytes()
.map_err(|e| anyhow::anyhow!("Failed to encode: {}", e))?;
// Send LocalAddr using RustDesk's variable-length framing
bytes_codec::write_frame(&mut stream, &bytes).await?;
info!("Sent LocalAddr to rendezvous server, waiting for peer connection");
// Now the rendezvous server will forward this to the client,
// and the client will connect to us through this same TCP stream.
// The server proxies the connection between client and peer.
// Get peer address for logging/connection tracking
let effective_peer_addr = peer_addr.unwrap_or_else(|| {
// If we can't decode the peer address, use the rendezvous server address
rendezvous_addr
.parse()
.unwrap_or_else(|_| "0.0.0.0:0".parse().unwrap())
});
// Accept the connection - the stream is now a proxied connection to the client
connection_manager
.accept_connection(stream, effective_peer_addr)
.await?;

View File

@@ -1,18 +1,12 @@
//! RustDesk Protocol Messages
//!
//! This module provides the compiled protobuf messages for the RustDesk protocol.
//! Messages are generated from rendezvous.proto and message.proto at build time.
//! Uses protobuf-rust (same as RustDesk server) for full compatibility.
//! Protobuf wrappers (`protos/` → `OUT_DIR`).
use protobuf::Message;
// Include the generated protobuf code
#[path = ""]
pub mod hbb {
include!(concat!(env!("OUT_DIR"), "/protos/mod.rs"));
}
// Re-export commonly used types
pub use hbb::rendezvous::{
punch_hole_response, relay_response, rendezvous_message, ConfigUpdate, ConnType,
FetchLocalAddr, HealthCheck, KeyExchange, LocalAddr, NatType, OnlineRequest, OnlineResponse,
@@ -21,7 +15,6 @@ pub use hbb::rendezvous::{
RequestRelay, SoftwareUpdate, TestNatRequest, TestNatResponse,
};
// Re-export message.proto types
pub use hbb::message::{
key_event, login_response, message, misc, AudioFormat, AudioFrame, Auth2FA, Clipboard,
ControlKey, CursorData, CursorPosition, DisplayInfo, EncodedVideoFrame, EncodedVideoFrames,
@@ -30,7 +23,6 @@ pub use hbb::message::{
SupportedResolutions, TestDelay, VideoFrame, WindowsSessions,
};
/// Helper to create a RendezvousMessage with RegisterPeer
pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage {
let mut rp = RegisterPeer::new();
rp.id = id.to_string();
@@ -41,7 +33,6 @@ pub fn make_register_peer(id: &str, serial: i32) -> RendezvousMessage {
msg
}
/// Helper to create a RendezvousMessage with RegisterPk
pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> RendezvousMessage {
let mut rpk = RegisterPk::new();
rpk.id = id.to_string();
@@ -54,7 +45,6 @@ pub fn make_register_pk(id: &str, uuid: &[u8], pk: &[u8], old_id: &str) -> Rende
msg
}
/// Helper to create a PunchHoleSent message
pub fn make_punch_hole_sent(
socket_addr: &[u8],
id: &str,
@@ -74,10 +64,7 @@ pub fn make_punch_hole_sent(
msg
}
/// Helper to create a RelayResponse message (sent to rendezvous server)
/// IMPORTANT: The union field should be `Id` (our device ID), NOT `Pk`.
/// The rendezvous server will look up our registered public key using this ID,
/// sign it with the server's private key, and set the `pk` field before forwarding to client.
/// Use `id` (device id), not raw `pk`; hbbs fills `pk` when forwarding.
pub fn make_relay_response(
uuid: &str,
socket_addr: &[u8],
@@ -96,13 +83,7 @@ pub fn make_relay_response(
msg
}
/// Helper to create a RequestRelay message (sent to relay server to identify ourselves)
///
/// The `licence_key` is required if the relay server is configured with a key.
/// If the key doesn't match, the relay server will silently reject the connection.
///
/// IMPORTANT: `socket_addr` is the peer's encoded socket address (from FetchLocalAddr/RelayResponse).
/// The relay server uses this to match the two peers connecting to the same relay session.
/// `socket_addr` must be the peer's mangled addr; `licence_key` required if hbbr uses `-k`.
pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) -> RendezvousMessage {
let mut rr = RequestRelay::new();
rr.uuid = uuid.to_string();
@@ -114,8 +95,6 @@ pub fn make_request_relay(uuid: &str, licence_key: &str, socket_addr: &[u8]) ->
msg
}
/// Helper to create a LocalAddr response message
/// This is sent in response to FetchLocalAddr when a peer on the same LAN wants to connect
pub fn make_local_addr(
socket_addr: &[u8],
local_addr: &[u8],
@@ -135,12 +114,10 @@ pub fn make_local_addr(
msg
}
/// Decode a RendezvousMessage from bytes
pub fn decode_rendezvous_message(buf: &[u8]) -> Result<RendezvousMessage, protobuf::Error> {
RendezvousMessage::parse_from_bytes(buf)
}
/// Decode a Message (session message) from bytes
pub fn decode_message(buf: &[u8]) -> Result<hbb::message::Message, protobuf::Error> {
hbb::message::Message::parse_from_bytes(buf)
}

View File

@@ -1,36 +1,19 @@
//! P2P Punch Hole Implementation
//!
//! This module implements TCP direct connection attempts with relay fallback.
//! When a PunchHole request is received, we try to connect directly to the peer.
//! If the direct connection fails (timeout), we fall back to relay.
//! Direct TCP attempt before relay fallback.
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use tracing::{debug, info};
use super::connection::ConnectionManager;
/// Timeout for direct TCP connection attempt
const DIRECT_CONNECT_TIMEOUT_MS: u64 = 3000;
/// Result of a punch hole attempt
#[derive(Debug)]
pub enum PunchResult {
/// Direct connection succeeded
DirectConnection(TcpStream),
/// Direct connection failed, should use relay
NeedRelay,
}
/// Attempt direct TCP connection to peer
///
/// This is a simplified P2P approach:
/// 1. Try to connect directly to the peer's address
/// 2. If successful within timeout, use direct connection
/// 3. If failed or timeout, fall back to relay
pub async fn try_direct_connection(peer_addr: SocketAddr) -> PunchResult {
info!("Attempting direct TCP connection to {}", peer_addr);
@@ -54,76 +37,3 @@ pub async fn try_direct_connection(peer_addr: SocketAddr) -> PunchResult {
}
}
}
/// Punch hole handler that tries direct connection first, then falls back to relay
pub struct PunchHoleHandler {
connection_manager: Arc<ConnectionManager>,
}
impl PunchHoleHandler {
pub fn new(connection_manager: Arc<ConnectionManager>) -> Self {
Self { connection_manager }
}
/// Handle punch hole request
///
/// Tries direct connection first, falls back to relay if needed.
/// Returns true if direct connection succeeded, false if relay is needed.
pub async fn handle_punch_hole(&self, peer_addr: Option<SocketAddr>) -> bool {
let peer_addr = match peer_addr {
Some(addr) => addr,
None => {
warn!("No peer address available for punch hole");
return false;
}
};
match try_direct_connection(peer_addr).await {
PunchResult::DirectConnection(stream) => {
// Direct connection succeeded, accept it
match self
.connection_manager
.accept_connection(stream, peer_addr)
.await
{
Ok(_) => {
info!("P2P direct connection established with {}", peer_addr);
true
}
Err(e) => {
warn!("Failed to accept direct connection: {}", e);
false
}
}
}
PunchResult::NeedRelay => {
debug!("Direct connection failed, need relay for {}", peer_addr);
false
}
}
}
}
/// Spawn a punch hole attempt with relay fallback
///
/// This function spawns an async task that:
/// 1. Tries direct TCP connection to peer
/// 2. If successful, accepts the connection
/// 3. If failed, calls the relay callback
pub fn spawn_punch_with_fallback<F>(
connection_manager: Arc<ConnectionManager>,
peer_addr: Option<SocketAddr>,
relay_callback: F,
) where
F: FnOnce() + Send + 'static,
{
tokio::spawn(async move {
let handler = PunchHoleHandler::new(connection_manager);
if !handler.handle_punch_hole(peer_addr).await {
// Direct connection failed, use relay
info!("Falling back to relay connection");
relay_callback();
}
});
}

View File

@@ -1,8 +1,4 @@
//! RustDesk Rendezvous Mediator
//!
//! This module handles communication with the hbbs rendezvous server.
//! It registers the device ID and public key, handles punch hole requests,
//! and relay requests.
//! HBBS UDP registration; punch / relay / intranet callbacks.
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
@@ -24,19 +20,14 @@ use super::protocol::{
rendezvous_message, NatType, RendezvousMessage,
};
/// Registration interval in milliseconds
const REG_INTERVAL_MS: u64 = 12_000;
/// Minimum registration timeout
const MIN_REG_TIMEOUT_MS: u64 = 3_000;
/// Maximum registration timeout
const MAX_REG_TIMEOUT_MS: u64 = 30_000;
/// Timer interval for checking registration status
const TIMER_INTERVAL_MS: u64 = 300;
/// Rendezvous mediator status
#[derive(Debug, Clone, PartialEq)]
pub enum RendezvousStatus {
Disconnected,
@@ -58,44 +49,13 @@ impl std::fmt::Display for RendezvousStatus {
}
}
/// Callback for handling incoming connection requests
pub type ConnectionCallback = Arc<dyn Fn(ConnectionRequest) + Send + Sync>;
/// Incoming connection request from a RustDesk client
#[derive(Debug, Clone)]
pub struct ConnectionRequest {
/// Peer socket address (encoded)
pub socket_addr: Vec<u8>,
/// Relay server to use
pub relay_server: String,
/// NAT type
pub nat_type: NatType,
/// Connection UUID
pub uuid: String,
/// Whether to use secure connection
pub secure: bool,
}
/// Callback type for relay requests
/// Parameters: rendezvous_addr, relay_server, uuid, socket_addr (client's mangled address), device_id
pub type RelayCallback = Arc<dyn Fn(String, String, String, Vec<u8>, String) + Send + Sync>;
/// Callback type for P2P punch hole requests
/// Parameters: peer_addr (decoded), relay_callback_params (rendezvous_addr, relay_server, uuid, socket_addr, device_id)
/// Returns: should call relay callback if P2P fails
pub type PunchCallback =
Arc<dyn Fn(Option<SocketAddr>, String, String, String, Vec<u8>, String) + Send + Sync>;
/// Callback type for intranet/local address connections
/// Parameters: rendezvous_addr, peer_socket_addr (mangled), local_addr, relay_server, device_id
pub type IntranetCallback = Arc<dyn Fn(String, Vec<u8>, SocketAddr, String, String) + Send + Sync>;
/// Rendezvous Mediator
///
/// Handles communication with hbbs rendezvous server:
/// - Registers device ID and public key
/// - Maintains keep-alive with server
/// - Handles punch hole and relay requests
pub struct RendezvousMediator {
config: Arc<RwLock<RustDeskConfig>>,
keypair: Arc<RwLock<Option<KeyPair>>>,
@@ -114,11 +74,9 @@ pub struct RendezvousMediator {
}
impl RendezvousMediator {
/// Create a new rendezvous mediator
pub fn new(mut config: RustDeskConfig) -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
// Get or generate UUID from config (persisted)
let (uuid, uuid_needs_save) = config.ensure_uuid();
Self {
@@ -139,88 +97,71 @@ impl RendezvousMediator {
}
}
/// Set the TCP listen port for direct connections
pub fn set_listen_port(&self, port: u16) {
let old_port = *self.listen_port.read();
if old_port != port {
*self.listen_port.write() = port;
// Port changed, increment serial to notify server
self.increment_serial();
}
}
/// Get the TCP listen port
pub fn listen_port(&self) -> u16 {
*self.listen_port.read()
}
/// Increment the serial number to indicate local state change
pub fn increment_serial(&self) {
let mut serial = self.serial.write();
*serial = serial.wrapping_add(1);
debug!("Serial incremented to {}", *serial);
}
/// Get current serial number
pub fn serial(&self) -> i32 {
*self.serial.read()
}
/// Check if UUID needs to be saved to persistent storage
pub fn uuid_needs_save(&self) -> bool {
*self.uuid_needs_save.read()
}
/// Get the current config (with UUID set)
pub fn config(&self) -> RustDeskConfig {
self.config.read().clone()
}
/// Mark UUID as saved
pub fn mark_uuid_saved(&self) {
*self.uuid_needs_save.write() = false;
}
/// Set the callback for relay requests
pub fn set_relay_callback(&self, callback: RelayCallback) {
*self.relay_callback.write() = Some(callback);
}
/// Set the callback for P2P punch hole requests
pub fn set_punch_callback(&self, callback: PunchCallback) {
*self.punch_callback.write() = Some(callback);
}
/// Set the callback for intranet/local address connections
pub fn set_intranet_callback(&self, callback: IntranetCallback) {
*self.intranet_callback.write() = Some(callback);
}
/// Get current status
pub fn status(&self) -> RendezvousStatus {
self.status.read().clone()
}
/// Update configuration
pub fn update_config(&self, config: RustDeskConfig) {
*self.config.write() = config;
// Config changed, increment serial to notify server
self.increment_serial();
}
/// Initialize or get keypair (Curve25519 for encryption)
pub fn ensure_keypair(&self) -> KeyPair {
let mut keypair_guard = self.keypair.write();
if keypair_guard.is_none() {
let config = self.config.read();
// Try to load from config first
if let (Some(pk), Some(sk)) = (&config.public_key, &config.private_key) {
if let Ok(kp) = KeyPair::from_base64(pk, sk) {
*keypair_guard = Some(kp.clone());
return kp;
}
}
// Generate new keypair
let kp = KeyPair::generate();
*keypair_guard = Some(kp.clone());
kp
@@ -229,12 +170,10 @@ impl RendezvousMediator {
}
}
/// Initialize or get signing keypair (Ed25519 for SignedId)
pub fn ensure_signing_keypair(&self) -> SigningKeyPair {
let mut signing_guard = self.signing_keypair.write();
if signing_guard.is_none() {
let config = self.config.read();
// Try to load from config first
if let (Some(pk), Some(sk)) = (&config.signing_public_key, &config.signing_private_key)
{
if let Ok(skp) = SigningKeyPair::from_base64(pk, sk) {
@@ -245,7 +184,6 @@ impl RendezvousMediator {
warn!("Failed to decode signing keypair from config, generating new one");
}
}
// Generate new signing keypair
let skp = SigningKeyPair::generate();
debug!("Generated new signing keypair");
*signing_guard = Some(skp.clone());
@@ -255,12 +193,10 @@ impl RendezvousMediator {
}
}
/// Get the device ID
pub fn device_id(&self) -> String {
self.config.read().device_id.clone()
}
/// Start the rendezvous mediator
pub async fn start(&self) -> anyhow::Result<()> {
let config = self.config.read().clone();
let effective_server = config.effective_rendezvous_server();
@@ -284,13 +220,11 @@ impl RendezvousMediator {
config.device_id, addr
);
// Resolve server address
let server_addr: SocketAddr = tokio::net::lookup_host(&addr)
.await?
.next()
.ok_or_else(|| anyhow::anyhow!("Failed to resolve {}", addr))?;
// Create UDP socket (match address family, enforce IPV6_V6ONLY)
let bind_addr = match server_addr {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
@@ -302,11 +236,9 @@ impl RendezvousMediator {
info!("Connected to rendezvous server at {}", server_addr);
*self.status.write() = RendezvousStatus::Connected;
// Start registration loop
self.registration_loop(socket).await
}
/// Main registration loop
async fn registration_loop(&self, socket: UdpSocket) -> anyhow::Result<()> {
let mut timer = interval(Duration::from_millis(TIMER_INTERVAL_MS));
let mut recv_buf = vec![0u8; 65535];
@@ -318,7 +250,6 @@ impl RendezvousMediator {
loop {
tokio::select! {
// Handle incoming messages
result = socket.recv(&mut recv_buf) => {
match result {
Ok(len) => {
@@ -336,7 +267,6 @@ impl RendezvousMediator {
}
}
// Periodic registration
_ = timer.tick() => {
let now = Instant::now();
let expired = last_register_resp
@@ -360,7 +290,6 @@ impl RendezvousMediator {
}
}
// Shutdown signal
_ = shutdown_rx.recv() => {
info!("Rendezvous mediator shutting down");
break;
@@ -372,20 +301,16 @@ impl RendezvousMediator {
Ok(())
}
/// Send registration message
async fn send_register(&self, socket: &UdpSocket) -> anyhow::Result<()> {
let key_confirmed = *self.key_confirmed.read();
if !key_confirmed {
// Send RegisterPk with public key
self.send_register_pk(socket).await
} else {
// Send RegisterPeer heartbeat
self.send_register_peer(socket).await
}
}
/// Send RegisterPeer message
async fn send_register_peer(&self, socket: &UdpSocket) -> anyhow::Result<()> {
let id = self.device_id();
let serial = *self.serial.read();
@@ -398,12 +323,8 @@ impl RendezvousMediator {
Ok(())
}
/// Send RegisterPk message
/// Uses the Ed25519 signing public key for registration
async fn send_register_pk(&self, socket: &UdpSocket) -> anyhow::Result<()> {
let id = self.device_id();
// Use signing public key (Ed25519) for RegisterPk
// This is what clients will use to verify our SignedId signature
let signing_keypair = self.ensure_signing_keypair();
let pk = signing_keypair.public_key_bytes();
let uuid = *self.uuid.read();
@@ -417,12 +338,6 @@ impl RendezvousMediator {
Ok(())
}
/// Handle FetchLocalAddr - send to callback for proper TCP handling
///
/// The intranet callback will:
/// 1. Open a TCP connection to the rendezvous server
/// 2. Send LocalAddr message
/// 3. Accept the peer connection over that same TCP stream
async fn send_local_addr(
&self,
_udp_socket: &UdpSocket,
@@ -431,21 +346,17 @@ impl RendezvousMediator {
) -> anyhow::Result<()> {
let id = self.device_id();
// Get our actual local IP addresses for same-LAN connection
let local_addrs = get_local_addresses();
if local_addrs.is_empty() {
debug!("No local addresses available for LocalAddr response");
return Ok(());
}
// Get the rendezvous server address for TCP connection
let config = self.config.read().clone();
let rendezvous_addr = config.rendezvous_addr();
// Use TCP listen port for direct connections
let listen_port = self.listen_port();
// Use the first local IP
let local_ip = local_addrs[0];
let local_sock_addr = SocketAddr::new(local_ip, listen_port);
@@ -454,7 +365,6 @@ impl RendezvousMediator {
local_sock_addr, rendezvous_addr
);
// Call the intranet callback if set
if let Some(callback) = self.intranet_callback.read().as_ref() {
callback(
rendezvous_addr,
@@ -470,7 +380,6 @@ impl RendezvousMediator {
Ok(())
}
/// Handle response from rendezvous server
async fn handle_response(
&self,
socket: &UdpSocket,
@@ -486,7 +395,6 @@ impl RendezvousMediator {
match msg.union {
Some(rendezvous_message::Union::RegisterPeerResponse(rpr)) => {
if rpr.request_pk {
// Server wants us to register our public key
info!("Server requested public key registration");
*self.key_confirmed.write() = false;
self.send_register_pk(socket).await?;
@@ -497,30 +405,24 @@ impl RendezvousMediator {
info!("Received RegisterPkResponse: result={:?}", rpr.result);
match rpr.result.value() {
0 => {
// OK
info!("✓ Public key registered successfully with server");
*self.key_confirmed.write() = true;
// Increment serial after successful registration
self.increment_serial();
*self.status.write() = RendezvousStatus::Registered;
}
2 => {
// UUID_MISMATCH
warn!("UUID mismatch, need to re-register");
*self.key_confirmed.write() = false;
}
3 => {
// ID_EXISTS
error!("Device ID already exists on server");
*self.status.write() =
RendezvousStatus::Error("Device ID already exists".to_string());
}
4 => {
// TOO_FREQUENT
warn!("Registration too frequent");
}
5 => {
// INVALID_ID_FORMAT
error!("Invalid device ID format");
*self.status.write() =
RendezvousStatus::Error("Invalid ID format".to_string());
@@ -540,7 +442,6 @@ impl RendezvousMediator {
let effective_relay_server =
select_relay_server(config.relay_server.as_deref(), &ph.relay_server);
// Decode the peer's socket address
let peer_addr = if !ph.socket_addr.is_empty() {
AddrMangle::decode(&ph.socket_addr)
} else {
@@ -556,9 +457,7 @@ impl RendezvousMediator {
ph.nat_type
);
// Send PunchHoleSent to acknowledge
// IMPORTANT: socket_addr in PunchHoleSent should be the PEER's address (from PunchHole),
// not our own address. This is how RustDesk protocol works.
let id = self.device_id();
info!(
@@ -586,16 +485,11 @@ impl RendezvousMediator {
info!("Sent PunchHoleSent response successfully");
}
// Try P2P direct connection first, fall back to relay if needed
if let Some(relay_server) = effective_relay_server {
// Generate a standard UUID v4 for relay pairing
// This must match the format used by RustDesk client
let uuid = uuid::Uuid::new_v4().to_string();
let rendezvous_addr = config.rendezvous_addr();
let device_id = config.device_id.clone();
// Use punch callback if set (tries P2P first, then relay)
// Otherwise fall back to relay callback directly
if let Some(callback) = self.punch_callback.read().as_ref() {
callback(
peer_addr,
@@ -630,7 +524,6 @@ impl RendezvousMediator {
rr.uuid,
rr.secure
);
// Call the relay callback to handle the connection
if let Some(callback) = self.relay_callback.read().as_ref() {
if let Some(relay_server) = effective_relay_server {
let rendezvous_addr = config.rendezvous_addr();
@@ -653,7 +546,6 @@ impl RendezvousMediator {
select_relay_server(config.relay_server.as_deref(), &fla.relay_server)
.unwrap_or_default();
// Decode the peer address for logging
let peer_addr = AddrMangle::decode(&fla.socket_addr);
info!(
"Received FetchLocalAddr request: peer_addr={:?}, socket_addr_len={}, relay_server={}, effective_relay_server={}",
@@ -662,7 +554,6 @@ impl RendezvousMediator {
fla.relay_server,
effective_relay_server
);
// Respond with our local address for same-LAN direct connection
self.send_local_addr(socket, &fla.socket_addr, &effective_relay_server)
.await?;
}
@@ -671,7 +562,6 @@ impl RendezvousMediator {
*self.serial.write() = cu.serial;
}
Some(other) => {
// Log the actual message type for debugging
let type_name = match other {
rendezvous_message::Union::PunchHoleRequest(_) => "PunchHoleRequest",
rendezvous_message::Union::PunchHoleResponse(_) => "PunchHoleResponse",
@@ -696,23 +586,18 @@ impl RendezvousMediator {
Ok(())
}
/// Stop the rendezvous mediator
pub fn stop(&self) {
info!("Stopping rendezvous mediator");
let _ = self.shutdown_tx.send(());
*self.status.write() = RendezvousStatus::Disconnected;
}
/// Get a shutdown receiver
pub fn shutdown_rx(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
}
/// AddrMangle - RustDesk's address encoding scheme
///
/// Certain routers and firewalls scan packets and modify IP addresses.
/// This encoding mangles the address to avoid detection.
/// RustDesk mangled socket encoding.
pub struct AddrMangle;
fn normalize_relay_server(server: &str) -> Option<String> {
@@ -735,9 +620,7 @@ fn select_relay_server(local_relay: Option<&str>, server_relay: &str) -> Option<
}
impl AddrMangle {
/// Encode a SocketAddr to bytes using RustDesk's mangle algorithm
pub fn encode(addr: SocketAddr) -> Vec<u8> {
// Try to convert IPv6-mapped IPv4 to plain IPv4
let addr = try_into_v4(addr);
match addr {
@@ -753,7 +636,6 @@ impl AddrMangle {
let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF));
let bytes = v.to_le_bytes();
// Remove trailing zeros
let mut n_padding = 0;
for i in bytes.iter().rev() {
if *i == 0u8 {
@@ -774,13 +656,11 @@ impl AddrMangle {
}
}
/// Decode bytes to SocketAddr using RustDesk's mangle algorithm
pub fn decode(bytes: &[u8]) -> Option<SocketAddr> {
use std::convert::TryInto;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4};
if bytes.len() > 16 {
// IPv6 format: 16 bytes IP + 2 bytes port
if bytes.len() != 18 {
return None;
}
@@ -791,7 +671,6 @@ impl AddrMangle {
return Some(SocketAddr::new(std::net::IpAddr::V6(ip), port));
}
// IPv4 mangled format
let mut padded = [0u8; 16];
padded[..bytes.len()].copy_from_slice(bytes);
let number = u128::from_le_bytes(padded);
@@ -805,7 +684,6 @@ impl AddrMangle {
}
}
/// Try to convert IPv6-mapped IPv4 address to plain IPv4
fn try_into_v4(addr: SocketAddr) -> SocketAddr {
match addr {
SocketAddr::V6(v6) if !addr.ip().is_loopback() => {
@@ -818,41 +696,30 @@ fn try_into_v4(addr: SocketAddr) -> SocketAddr {
addr
}
/// Check if an interface name belongs to Docker or other virtual networks
fn is_virtual_interface(name: &str) -> bool {
// Docker interfaces
name.starts_with("docker")
|| name.starts_with("br-")
|| name.starts_with("veth")
// Kubernetes/container interfaces
|| name.starts_with("cni")
|| name.starts_with("flannel")
|| name.starts_with("calico")
|| name.starts_with("weave")
// Virtual bridge interfaces
|| name.starts_with("virbr")
|| name.starts_with("lxcbr")
|| name.starts_with("lxdbr")
// VPN interfaces (usually not useful for LAN discovery)
|| name.starts_with("tun")
|| name.starts_with("tap")
}
/// Check if an IP address is in a Docker/container private range
fn is_docker_ip(ip: &std::net::IpAddr) -> bool {
if let std::net::IpAddr::V4(ipv4) = ip {
let octets = ipv4.octets();
// Docker default bridge: 172.17.0.0/16
if octets[0] == 172 && octets[1] == 17 {
return true;
}
// Docker user-defined networks: 172.18-31.0.0/16
if octets[0] == 172 && (18..=31).contains(&octets[1]) {
return true;
}
// Docker overlay networks: 10.0.0.0/8 (common range)
// Note: 10.x.x.x is also used for corporate LANs, so we only filter
// specific Docker-like patterns (10.0.x.x with small third octet)
if octets[0] == 10 && octets[1] == 0 && octets[2] < 10 {
return true;
}
@@ -860,22 +727,18 @@ fn is_docker_ip(ip: &std::net::IpAddr) -> bool {
false
}
/// Get local IP addresses (non-loopback, non-Docker)
fn get_local_addresses() -> Vec<std::net::IpAddr> {
let mut addrs = Vec::new();
// Use pnet or network-interface crate if available, otherwise use simple method
#[cfg(target_os = "linux")]
{
if let Ok(interfaces) = std::fs::read_dir("/sys/class/net") {
for entry in interfaces.flatten() {
let iface_name = entry.file_name().to_string_lossy().to_string();
// Skip loopback and virtual interfaces
if iface_name == "lo" || is_virtual_interface(&iface_name) {
continue;
}
// Try to get IP via command (simple approach)
if let Ok(output) = std::process::Command::new("ip")
.args(["-4", "addr", "show", &iface_name])
.output()
@@ -886,7 +749,6 @@ fn get_local_addresses() -> Vec<std::net::IpAddr> {
let ip_part = &line[inet_pos + 5..];
if let Some(slash_pos) = ip_part.find('/') {
if let Ok(ip) = ip_part[..slash_pos].parse::<std::net::IpAddr>() {
// Skip loopback and Docker IPs
if !ip.is_loopback() && !is_docker_ip(&ip) {
addrs.push(ip);
}
@@ -899,15 +761,11 @@ fn get_local_addresses() -> Vec<std::net::IpAddr> {
}
}
// Fallback: try to get default route interface IP
if addrs.is_empty() {
// Try using DNS lookup to get local IP (connects to external server)
if let Ok(socket) = std::net::UdpSocket::bind("0.0.0.0:0") {
// Connect to a public DNS server (doesn't actually send data)
if socket.connect("8.8.8.8:53").is_ok() {
if let Ok(local_addr) = socket.local_addr() {
let ip = local_addr.ip();
// Skip loopback and Docker IPs
if !ip.is_loopback() && !is_docker_ip(&ip) {
addrs.push(ip);
}

View File

@@ -1,12 +1,13 @@
use std::{collections::VecDeque, sync::Arc};
use tokio::sync::{broadcast, watch, RwLock};
use tokio::sync::{broadcast, watch, Mutex, RwLock};
use crate::atx::AtxController;
use crate::audio::AudioController;
use crate::auth::{SessionStore, UserStore};
use crate::config::ConfigStore;
use crate::db::DatabasePool;
use crate::events::{
AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, MsdDeviceInfo, SystemEvent,
AtxDeviceInfo, AudioDeviceInfo, EventBus, HidDeviceInfo, LedState, MsdDeviceInfo, SystemEvent,
TtydDeviceInfo, VideoDeviceInfo,
};
use crate::extensions::{ExtensionId, ExtensionManager};
@@ -17,68 +18,68 @@ use crate::rtsp::RtspService;
use crate::rustdesk::RustDeskService;
use crate::update::UpdateService;
use crate::video::VideoStreamManager;
use crate::webrtc::WebRtcStreamer;
/// Application-wide state shared across handlers
///
/// # Video Streaming
///
/// All video operations should go through `stream_manager`:
/// - `stream_manager.webrtc_streamer()` - WebRTC streaming (H264, extensible to VP8/VP9/H265)
/// - `stream_manager.mjpeg_handler()` - MJPEG stream handler
/// - `stream_manager.streamer()` - Low-level video capture
/// - `stream_manager.start()` / `stream_manager.stop()` - Stream control
/// - `stream_manager.stats()` - Stream statistics
/// - `stream_manager.list_devices()` - List video devices
#[derive(Clone)]
pub struct ConfigApplyLocks {
pub video: Arc<Mutex<()>>,
pub stream: Arc<Mutex<()>>,
pub otg: Arc<Mutex<()>>,
pub audio: Arc<Mutex<()>>,
pub atx: Arc<Mutex<()>>,
pub rustdesk: Arc<Mutex<()>>,
pub rtsp: Arc<Mutex<()>>,
}
impl ConfigApplyLocks {
fn new() -> Self {
Self {
video: Arc::new(Mutex::new(())),
stream: Arc::new(Mutex::new(())),
otg: Arc::new(Mutex::new(())),
audio: Arc::new(Mutex::new(())),
atx: Arc::new(Mutex::new(())),
rustdesk: Arc::new(Mutex::new(())),
rtsp: Arc::new(Mutex::new(())),
}
}
}
/// Shared Axum/App state: video flows through [`VideoStreamManager`]; WebRTC SDP/ICE/sessions on [`WebRtcStreamer`].
pub struct AppState {
/// Configuration store
pub db: DatabasePool,
pub config: ConfigStore,
/// Session store
pub sessions: SessionStore,
/// User store
pub users: UserStore,
/// OTG Service - centralized USB gadget lifecycle management
/// This is the single owner of OtgGadgetManager, coordinating HID and MSD functions
pub otg_service: Arc<OtgService>,
/// Video stream manager (unified MJPEG/WebRTC management)
/// This is the single entry point for all video operations.
pub stream_manager: Arc<VideoStreamManager>,
/// HID controller
pub webrtc: Arc<WebRtcStreamer>,
pub hid: Arc<HidController>,
/// MSD controller (optional, may not be initialized)
pub msd: Arc<RwLock<Option<MsdController>>>,
/// ATX controller (optional, may not be initialized)
pub atx: Arc<RwLock<Option<AtxController>>>,
/// Audio controller
pub audio: Arc<AudioController>,
/// RustDesk remote access service (optional)
pub rustdesk: Arc<RwLock<Option<Arc<RustDeskService>>>>,
/// RTSP streaming service (optional)
pub rtsp: Arc<RwLock<Option<Arc<RtspService>>>>,
/// Extension manager (ttyd, gostc, easytier)
pub extensions: Arc<ExtensionManager>,
/// Event bus for real-time notifications
pub events: Arc<EventBus>,
/// Latest device info snapshot for WebSocket clients
device_info_tx: watch::Sender<Option<SystemEvent>>,
/// Online update service
pub update: Arc<UpdateService>,
/// Shutdown signal sender
pub shutdown_tx: broadcast::Sender<()>,
/// Recently revoked session IDs (for client kick detection)
pub revoked_sessions: Arc<RwLock<VecDeque<String>>>,
/// Data directory path
pub config_apply_locks: ConfigApplyLocks,
data_dir: std::path::PathBuf,
}
impl AppState {
/// Create new application state
#[allow(clippy::too_many_arguments)]
pub fn new(
db: DatabasePool,
config: ConfigStore,
sessions: SessionStore,
users: UserStore,
otg_service: Arc<OtgService>,
stream_manager: Arc<VideoStreamManager>,
webrtc: Arc<WebRtcStreamer>,
hid: Arc<HidController>,
msd: Option<MsdController>,
atx: Option<AtxController>,
@@ -94,11 +95,13 @@ impl AppState {
let (device_info_tx, _device_info_rx) = watch::channel(None);
Arc::new(Self {
db,
config,
sessions,
users,
otg_service,
stream_manager,
webrtc,
hid,
msd: Arc::new(RwLock::new(msd)),
atx: Arc::new(RwLock::new(atx)),
@@ -111,26 +114,23 @@ impl AppState {
update,
shutdown_tx,
revoked_sessions: Arc::new(RwLock::new(VecDeque::new())),
config_apply_locks: ConfigApplyLocks::new(),
data_dir,
})
}
/// Get data directory path
pub fn data_dir(&self) -> &std::path::PathBuf {
&self.data_dir
}
/// Subscribe to shutdown signal
pub fn shutdown_signal(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.subscribe()
}
/// Subscribe to the latest device info snapshot.
pub fn subscribe_device_info(&self) -> watch::Receiver<Option<SystemEvent>> {
self.device_info_tx.subscribe()
}
/// Record revoked session IDs (bounded queue)
pub async fn remember_revoked_sessions(&self, session_ids: Vec<String>) {
if session_ids.is_empty() {
return;
@@ -144,19 +144,12 @@ impl AppState {
}
}
/// Check if a session ID was revoked (kicked)
pub async fn is_session_revoked(&self, session_id: &str) -> bool {
let guard = self.revoked_sessions.read().await;
guard.iter().any(|id| id == session_id)
}
/// Get complete device information for WebSocket clients
///
/// This method collects the current state of all devices (video, HID, MSD, ATX, Audio)
/// and returns a DeviceInfo event that can be sent to clients.
/// Uses tokio::join! to collect all device info in parallel for better performance.
pub async fn get_device_info(&self) -> SystemEvent {
// Collect all device info in parallel
let (video, hid, msd, atx, audio, ttyd) = tokio::join!(
self.collect_video_info(),
self.collect_hid_info(),
@@ -176,19 +169,15 @@ impl AppState {
}
}
/// Publish DeviceInfo event to all connected WebSocket clients
pub async fn publish_device_info(&self) {
let device_info = self.get_device_info().await;
let _ = self.device_info_tx.send(Some(device_info));
}
/// Collect video device information
async fn collect_video_info(&self) -> VideoDeviceInfo {
// Use stream_manager to get video info (includes stream_mode)
self.stream_manager.get_video_info().await
}
/// Collect HID device information
async fn collect_hid_info(&self) -> HidDeviceInfo {
let state = self.hid.snapshot().await;
@@ -199,14 +188,19 @@ impl AppState {
online: state.online,
supports_absolute_mouse: state.supports_absolute_mouse,
keyboard_leds_enabled: state.keyboard_leds_enabled,
led_state: state.led_state,
led_state: LedState {
num_lock: state.led_state.num_lock,
caps_lock: state.led_state.caps_lock,
scroll_lock: state.led_state.scroll_lock,
compose: state.led_state.compose,
kana: state.led_state.kana,
},
device: state.device,
error: state.error,
error_code: state.error_code,
}
}
/// Collect MSD device information (optional)
async fn collect_msd_info(&self) -> Option<MsdDeviceInfo> {
let msd_guard = self.msd.read().await;
let msd = msd_guard.as_ref()?;
@@ -227,9 +221,7 @@ impl AppState {
})
}
/// Collect ATX device information (optional)
async fn collect_atx_info(&self) -> Option<AtxDeviceInfo> {
// Predefined backend strings to avoid repeated allocations
const BACKEND_POWER_ONLY: &str = "power: configured, reset: none";
const BACKEND_RESET_ONLY: &str = "power: none, reset: configured";
const BACKEND_BOTH: &str = "power: configured, reset: configured";
@@ -254,7 +246,6 @@ impl AppState {
})
}
/// Collect Audio device information (optional)
async fn collect_audio_info(&self) -> Option<AudioDeviceInfo> {
let status = self.audio.status().await;
@@ -267,7 +258,6 @@ impl AppState {
})
}
/// Collect ttyd status information
async fn collect_ttyd_info(&self) -> TtydDeviceInfo {
let status = self.extensions.status(ExtensionId::Ttyd).await;

View File

@@ -1,90 +1,39 @@
//! MJPEG stream handler
//!
//! Manages video frame distribution and per-client statistics.
use arc_swap::ArcSwap;
use bytes::Bytes;
use parking_lot::Mutex as ParkingMutex;
use parking_lot::RwLock as ParkingRwLock;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
use tracing::{debug, info, warn};
/// Generation token paired with `client_id` so [`unregister_client`] ignores stale drops.
pub type ClientGeneration = u64;
use crate::video::encoder::traits::{Encoder, EncoderConfig};
use crate::video::encoder::JpegEncoder;
use crate::video::format::{PixelFormat, Resolution};
use crate::video::format::PixelFormat;
use crate::video::VideoFrame;
/// Cached "no signal" placeholder JPEG (640×360 dark-gray image).
/// Generated once on first use and reused for all NoSignal frames.
static NO_SIGNAL_JPEG: OnceLock<Bytes> = OnceLock::new();
/// Generate a minimal "no signal" JPEG (640×360, dark gray background).
/// Uses turbojpeg directly to produce a valid JPEG without additional deps.
fn generate_no_signal_jpeg() -> Bytes {
const W: usize = 640;
const H: usize = 360;
let y_size = W * H;
let uv_size = y_size / 4;
let mut i420 = vec![0u8; y_size + uv_size * 2];
// Y = 32 (dark gray, above the 16 black floor so it is clearly visible)
i420[..y_size].fill(32);
// U and V = 128 (neutral chroma → no colour tint)
i420[y_size..].fill(128);
match turbojpeg::Compressor::new() {
Ok(mut compressor) => {
let _ = compressor.set_quality(70);
let yuv = turbojpeg::YuvImage {
pixels: i420.as_slice(),
width: W,
height: H,
align: 1,
subsamp: turbojpeg::Subsamp::Sub2x2,
};
match compressor.compress_yuv_to_vec(yuv) {
Ok(jpeg) => Bytes::from(jpeg),
Err(_) => Bytes::new(),
}
}
Err(_) => Bytes::new(),
}
}
/// Return a reference to the cached no-signal JPEG bytes.
fn no_signal_jpeg() -> &'static Bytes {
NO_SIGNAL_JPEG.get_or_init(generate_no_signal_jpeg)
}
/// Client ID type (UUID string)
pub type ClientId = String;
/// Per-client session information
#[derive(Debug, Clone)]
pub struct ClientSession {
/// Unique client ID
pub id: ClientId,
/// Connection timestamp
pub generation: ClientGeneration,
pub connected_at: Instant,
/// Last activity timestamp (frame sent)
pub last_activity: Instant,
/// Frames sent to this client
pub frames_sent: u64,
/// FPS calculator (1-second rolling window)
pub fps_calculator: FpsCalculator,
}
impl ClientSession {
/// Create a new client session
pub fn new(id: ClientId) -> Self {
pub fn new(id: ClientId, generation: ClientGeneration) -> Self {
let now = Instant::now();
Self {
id,
generation,
connected_at: now,
last_activity: now,
frames_sent: 0,
@@ -92,78 +41,49 @@ impl ClientSession {
}
}
/// Get connection duration
pub fn connected_duration(&self) -> Duration {
self.last_activity.duration_since(self.connected_at)
}
/// Get idle duration
pub fn idle_duration(&self) -> Duration {
Instant::now().duration_since(self.last_activity)
pub fn connected_elapsed(&self) -> Duration {
self.connected_at.elapsed()
}
}
/// Rolling window FPS calculator
#[derive(Debug, Clone)]
pub struct FpsCalculator {
/// Frame timestamps in last window
frame_times: VecDeque<Instant>,
/// Window duration (default 1 second)
window: Duration,
/// Cached count of frames in current window (optimization to avoid O(n) filtering)
count_in_window: usize,
}
impl FpsCalculator {
/// Create a new FPS calculator with 1-second window
pub fn new() -> Self {
Self {
frame_times: VecDeque::with_capacity(120), // Max 120fps tracking
frame_times: VecDeque::with_capacity(120),
window: Duration::from_secs(1),
count_in_window: 0,
}
}
/// Record a frame sent
pub fn record_frame(&mut self) {
let now = Instant::now();
self.frame_times.push_back(now);
self.prune(now);
}
// Remove frames outside window and maintain count
/// Rolling-window FPS sample count (~1s).
pub fn current_fps(&mut self) -> u32 {
self.prune(Instant::now());
self.frame_times.len() as u32
}
fn prune(&mut self, now: Instant) {
let cutoff = now - self.window;
while let Some(&oldest) = self.frame_times.front() {
if oldest < cutoff {
self.frame_times.pop_front();
} else {
break;
}
while matches!(self.frame_times.front(), Some(&t) if t < cutoff) {
self.frame_times.pop_front();
}
// Update cached count
self.count_in_window = self.frame_times.len();
}
/// Calculate current FPS (frames in last 1 second window)
pub fn current_fps(&self) -> u32 {
// Return cached count instead of filtering entire deque (O(1) instead of O(n))
self.count_in_window as u32
}
}
impl Default for FpsCalculator {
fn default() -> Self {
Self::new()
}
}
/// Auto-pause configuration
#[derive(Debug, Clone)]
pub struct AutoPauseConfig {
/// Enable auto-pause when no clients
pub enabled: bool,
/// Delay before pausing (default 10s)
pub shutdown_delay_secs: u64,
/// Client timeout for cleanup (default 30s)
pub client_timeout_secs: u64,
}
@@ -177,49 +97,35 @@ impl Default for AutoPauseConfig {
}
}
/// MJPEG stream handler
/// Manages video frame distribution to HTTP clients
pub struct MjpegStreamHandler {
/// Current frame (latest) - using ArcSwap for lock-free reads
current_frame: ArcSwap<Option<VideoFrame>>,
/// Frame update notification
frame_notify: broadcast::Sender<()>,
/// Whether stream is online
online: AtomicBool,
/// Frame sequence counter
sequence: AtomicU64,
/// Per-client sessions (ClientId -> ClientSession)
/// Use parking_lot::RwLock for better performance
clients: ParkingRwLock<HashMap<ClientId, ClientSession>>,
/// Auto-pause configuration
next_generation: AtomicU64,
auto_pause_config: ParkingRwLock<AutoPauseConfig>,
/// Last frame timestamp
last_frame_ts: ParkingRwLock<Option<Instant>>,
/// Dropped same frames count
dropped_same_frames: AtomicU64,
/// Maximum consecutive same frames to drop (0 = disabled)
max_drop_same_frames: AtomicU64,
/// JPEG encoder for non-JPEG input formats
jpeg_encoder: ParkingMutex<Option<JpegEncoder>>,
/// JPEG quality for software encoding (1-100)
jpeg_quality: AtomicU64,
}
impl MjpegStreamHandler {
/// Create a new MJPEG stream handler
pub fn new() -> Self {
Self::with_drop_limit(100) // Default: drop up to 100 same frames
Self::with_drop_limit(100)
}
/// Create handler with custom drop limit
pub fn with_drop_limit(max_drop: u64) -> Self {
let (frame_notify, _) = broadcast::channel(16); // Buffer size 16 for low latency
let (frame_notify, _) = broadcast::channel(16);
Self {
current_frame: ArcSwap::from_pointee(None),
frame_notify,
online: AtomicBool::new(false),
sequence: AtomicU64::new(0),
clients: ParkingRwLock::new(HashMap::new()),
next_generation: AtomicU64::new(1),
jpeg_encoder: ParkingMutex::new(None),
auto_pause_config: ParkingRwLock::new(AutoPauseConfig::default()),
last_frame_ts: ParkingRwLock::new(None),
@@ -229,16 +135,12 @@ impl MjpegStreamHandler {
}
}
/// Set JPEG quality for software encoding (1-100)
pub fn set_jpeg_quality(&self, quality: u8) {
let clamped = quality.clamp(1, 100) as u64;
self.jpeg_quality.store(clamped, Ordering::Relaxed);
}
/// Update current frame
pub fn update_frame(&self, frame: VideoFrame) {
// Fast path: if no MJPEG clients are connected, do minimal bookkeeping and avoid
// expensive work (JPEG encoding and per-frame dedup hashing).
let has_clients = !self.clients.read().is_empty();
if !has_clients {
self.dropped_same_frames.store(0, Ordering::Relaxed);
@@ -246,8 +148,6 @@ impl MjpegStreamHandler {
self.online.store(frame.online, Ordering::SeqCst);
*self.last_frame_ts.write() = Some(Instant::now());
// Keep the latest compressed frame for "instant first frame" when a client connects.
// Avoid retaining large raw buffers when there are no MJPEG clients.
if frame.format.is_compressed() {
self.current_frame.store(Arc::new(Some(frame)));
} else {
@@ -256,7 +156,6 @@ impl MjpegStreamHandler {
return;
}
// If frame is not JPEG, encode it
let frame = if !frame.format.is_compressed() {
match self.encode_to_jpeg(&frame) {
Ok(jpeg_frame) => jpeg_frame,
@@ -269,17 +168,13 @@ impl MjpegStreamHandler {
frame
};
// Frame deduplication (ustreamer-style)
// Check if this frame is identical to the previous one
let max_drop = self.max_drop_same_frames.load(Ordering::Relaxed);
if max_drop > 0 && frame.online {
let current = self.current_frame.load();
if let Some(ref prev_frame) = **current {
let dropped_count = self.dropped_same_frames.load(Ordering::Relaxed);
// Check if we should drop this frame
if dropped_count < max_drop && frames_are_identical(prev_frame, &frame) {
// Check last frame timestamp to ensure minimum 1fps
let last_ts = *self.last_frame_ts.read();
let should_force_send = if let Some(ts) = last_ts {
ts.elapsed() >= Duration::from_secs(1)
@@ -288,16 +183,13 @@ impl MjpegStreamHandler {
};
if !should_force_send {
// Drop this duplicate frame
self.dropped_same_frames.fetch_add(1, Ordering::Relaxed);
return;
}
// If more than 1 second since last frame, force send even if identical
}
}
}
// Frame is different or limit reached or forced by 1fps guarantee, update
self.dropped_same_frames.store(0, Ordering::Relaxed);
self.sequence.fetch_add(1, Ordering::Relaxed);
@@ -305,17 +197,14 @@ impl MjpegStreamHandler {
*self.last_frame_ts.write() = Some(Instant::now());
self.current_frame.store(Arc::new(Some(frame)));
// Notify waiting clients
let _ = self.frame_notify.send(());
}
/// Encode a non-JPEG frame to JPEG
fn encode_to_jpeg(&self, frame: &VideoFrame) -> Result<VideoFrame, String> {
let resolution = frame.resolution;
let sequence = self.sequence.load(Ordering::Relaxed);
let desired_quality = self.jpeg_quality.load(Ordering::Relaxed) as u32;
// Get or create encoder
let mut encoder_guard = self.jpeg_encoder.lock();
let encoder = encoder_guard.get_or_insert_with(|| {
let config = EncoderConfig::jpeg(resolution, desired_quality);
@@ -328,15 +217,12 @@ impl MjpegStreamHandler {
enc
}
Err(e) => {
warn!("Failed to create JPEG encoder: {}, using default", e);
// Try with default config
JpegEncoder::new(EncoderConfig::jpeg(resolution, desired_quality))
.expect("Failed to create default JPEG encoder")
warn!("Failed to create JPEG encoder: {}", e);
panic!("Failed to create JPEG encoder");
}
}
});
// Check if resolution changed
if encoder.config().resolution != resolution {
debug!(
"Resolution changed, recreating JPEG encoder: {}x{}",
@@ -354,11 +240,13 @@ impl MjpegStreamHandler {
}
}
// Encode based on input format
let encoded = match frame.format {
PixelFormat::Yuyv => encoder
.encode_yuyv(frame.data(), sequence)
.map_err(|e| format!("YUYV encode failed: {}", e))?,
PixelFormat::Yvyu => encoder
.encode_yvyu(frame.data(), sequence)
.map_err(|e| format!("YVYU encode failed: {}", e))?,
PixelFormat::Nv12 => encoder
.encode_nv12(frame.data(), sequence)
.map_err(|e| format!("NV12 encode failed: {}", e))?,
@@ -382,80 +270,53 @@ impl MjpegStreamHandler {
}
};
// Create new VideoFrame with JPEG data (zero-copy: Bytes -> Arc<Bytes>)
Ok(VideoFrame::new(
encoded.data,
resolution,
PixelFormat::Mjpeg,
0, // stride not relevant for JPEG
0,
sequence,
))
}
/// Set stream offline
pub fn set_offline(&self) {
self.online.store(false, Ordering::SeqCst);
let _ = self.frame_notify.send(());
}
/// Push a "no signal" placeholder JPEG to all connected MJPEG clients.
///
/// Unlike `set_offline()`, this keeps the stream marked as **online** so
/// that HTTP clients remain connected and see the placeholder image instead
/// of a black/empty screen. Call this whenever the capture thread enters
/// the `NoSignal` state.
pub fn push_no_signal_placeholder(&self) {
let jpeg = no_signal_jpeg();
if jpeg.is_empty() {
return;
}
let frame = VideoFrame::new(
jpeg.clone(),
Resolution::new(640, 360),
PixelFormat::Mjpeg,
0,
self.sequence.fetch_add(1, Ordering::Relaxed),
);
// Store as current frame so late-joining clients get it immediately.
self.current_frame.store(Arc::new(Some(frame)));
// Ensure stream is marked online so the HTTP handler keeps iterating.
self.online.store(true, Ordering::SeqCst);
// Wake up waiting HTTP clients.
let _ = self.frame_notify.send(());
}
/// Set stream online (called when streaming starts)
pub fn set_online(&self) {
self.online.store(true, Ordering::SeqCst);
}
/// Check if stream is online
pub fn is_online(&self) -> bool {
self.online.load(Ordering::SeqCst)
}
/// Get current client count
pub fn client_count(&self) -> u64 {
self.clients.read().len() as u64
}
/// Register a new client
pub fn register_client(&self, client_id: ClientId) {
let session = ClientSession::new(client_id.clone());
/// Connects `client_id`; return value must be passed to [`unregister_client`].
pub fn register_client(&self, client_id: ClientId) -> ClientGeneration {
let generation = self.next_generation.fetch_add(1, Ordering::Relaxed);
let session = ClientSession::new(client_id.clone(), generation);
self.clients.write().insert(client_id.clone(), session);
info!(
"Client {} connected (total: {})",
client_id,
self.client_count()
);
generation
}
/// Unregister a client
pub fn unregister_client(&self, client_id: &str) {
if let Some(session) = self.clients.write().remove(client_id) {
let duration = session.connected_duration();
pub fn unregister_client(&self, client_id: &str, expected_generation: ClientGeneration) {
let mut clients = self.clients.write();
match clients.get(client_id) {
Some(session) if session.generation == expected_generation => {}
_ => return,
}
if let Some(session) = clients.remove(client_id) {
let duration = session.connected_elapsed();
let duration_secs = duration.as_secs_f32();
let avg_fps = if duration_secs > 0.1 {
session.frames_sent as f32 / duration_secs
@@ -469,7 +330,6 @@ impl MjpegStreamHandler {
}
}
/// Record frame sent to a specific client
pub fn record_frame_sent(&self, client_id: &str) {
if let Some(session) = self.clients.write().get_mut(client_id) {
session.last_activity = Instant::now();
@@ -478,54 +338,46 @@ impl MjpegStreamHandler {
}
}
/// Get per-client statistics
pub fn get_clients_stat(&self) -> HashMap<String, crate::events::types::ClientStats> {
// write() because `current_fps()` mutates the underlying VecDeque
// to prune stale samples. Held for ~microseconds, called once per
// second by the stats broadcaster.
self.clients
.read()
.iter()
.write()
.iter_mut()
.map(|(id, session)| {
(
id.clone(),
crate::events::types::ClientStats {
id: id.clone(),
fps: session.fps_calculator.current_fps(),
connected_secs: session.connected_duration().as_secs(),
connected_secs: session.connected_elapsed().as_secs(),
},
)
})
.collect()
}
/// Get auto-pause configuration
pub fn auto_pause_config(&self) -> AutoPauseConfig {
self.auto_pause_config.read().clone()
}
/// Update auto-pause configuration
pub fn set_auto_pause_config(&self, config: AutoPauseConfig) {
let config_clone = config.clone();
*self.auto_pause_config.write() = config;
info!(
"Auto-pause config updated: enabled={}, delay={}s, timeout={}s",
config_clone.enabled,
config_clone.shutdown_delay_secs,
config_clone.client_timeout_secs
config.enabled, config.shutdown_delay_secs, config.client_timeout_secs
);
*self.auto_pause_config.write() = config;
}
/// Get current frame (if any)
pub fn current_frame(&self) -> Option<VideoFrame> {
(**self.current_frame.load()).clone()
}
/// Subscribe to frame updates
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.frame_notify.subscribe()
}
/// Disconnect all clients (used during config changes)
/// This clears the client list and sets the stream offline,
/// which will cause all active MJPEG streams to terminate.
pub fn disconnect_all_clients(&self) {
let count = {
let mut clients = self.clients.write();
@@ -536,32 +388,26 @@ impl MjpegStreamHandler {
if count > 0 {
info!("Disconnected all {} MJPEG clients for config change", count);
}
// Set offline to signal all streaming tasks to stop
self.set_offline();
}
}
impl Default for MjpegStreamHandler {
fn default() -> Self {
Self::new()
}
}
/// RAII guard for client lifecycle management
/// Ensures cleanup even on panic or abrupt disconnection
pub struct ClientGuard {
client_id: ClientId,
generation: ClientGeneration,
handler: Arc<MjpegStreamHandler>,
}
impl ClientGuard {
/// Create a new client guard
pub fn new(client_id: ClientId, handler: Arc<MjpegStreamHandler>) -> Self {
handler.register_client(client_id.clone());
Self { client_id, handler }
let generation = handler.register_client(client_id.clone());
Self {
client_id,
generation,
handler,
}
}
/// Get client ID
pub fn id(&self) -> &ClientId {
&self.client_id
}
@@ -569,13 +415,12 @@ impl ClientGuard {
impl Drop for ClientGuard {
fn drop(&mut self) {
self.handler.unregister_client(&self.client_id);
self.handler
.unregister_client(&self.client_id, self.generation);
}
}
impl MjpegStreamHandler {
/// Start stale client cleanup task
/// Should be called once when handler is created
pub fn start_cleanup_task(self: Arc<Self>) {
let handler = self.clone();
tokio::spawn(async move {
@@ -589,7 +434,6 @@ impl MjpegStreamHandler {
let now = Instant::now();
let mut stale = Vec::new();
// Find stale clients
{
let clients = handler.clients.read();
for (id, session) in clients.iter() {
@@ -599,7 +443,6 @@ impl MjpegStreamHandler {
}
}
// Remove stale clients
if !stale.is_empty() {
let mut clients = handler.clients.write();
for id in stale {
@@ -617,10 +460,7 @@ impl MjpegStreamHandler {
}
}
/// Compare two frames for equality (hash-based, ustreamer-style)
/// Returns true if frames are identical in geometry and content
fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
// Quick checks first (geometry)
if a.len() != b.len() {
return false;
}
@@ -641,13 +481,10 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
return false;
}
// Avoid hashing the whole frame for obviously different frames by sampling a few
// fixed-size windows first. If all samples match, fall back to the cached hash.
let a_data = a.data();
let b_data = b.data();
let len = a_data.len();
// Small frames: direct compare is cheap.
if len <= 256 {
return a_data == b_data;
}
@@ -655,7 +492,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
const SAMPLE: usize = 16;
debug_assert!(len == b_data.len());
// Head + tail.
if a_data[..SAMPLE] != b_data[..SAMPLE] {
return false;
}
@@ -663,7 +499,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
return false;
}
// Two interior samples (quarter + middle) to catch common "same header/footer" cases.
let quarter = len / 4;
let quarter_start = quarter.saturating_sub(SAMPLE / 2);
if a_data[quarter_start..quarter_start + SAMPLE]
@@ -677,8 +512,6 @@ fn frames_are_identical(a: &VideoFrame, b: &VideoFrame) -> bool {
return false;
}
// Compare hashes instead of full binary data.
// Hash is computed once and cached in OnceLock for efficiency.
a.get_hash() == b.get_hash()
}
@@ -694,7 +527,6 @@ mod tests {
assert!(!handler.is_online());
assert_eq!(handler.client_count(), 0);
// Create a frame
let _frame = VideoFrame::new(
Bytes::from(vec![0xFF, 0xD8, 0x00, 0x00, 0xFF, 0xD9]),
Resolution::VGA,
@@ -708,15 +540,47 @@ mod tests {
fn test_fps_calculator() {
let mut calc = FpsCalculator::new();
// Initially empty
assert_eq!(calc.current_fps(), 0);
// Record some frames
calc.record_frame();
calc.record_frame();
calc.record_frame();
// Should have 3 frames in window
assert!(calc.frame_times.len() == 3);
assert_eq!(calc.current_fps(), 3);
assert_eq!(calc.frame_times.len(), 3);
}
#[test]
fn test_fps_calculator_decays_without_new_frames() {
let mut calc = FpsCalculator::new();
calc.window = Duration::from_millis(50);
calc.record_frame();
calc.record_frame();
assert_eq!(calc.current_fps(), 2);
std::thread::sleep(Duration::from_millis(80));
assert_eq!(calc.current_fps(), 0);
assert!(calc.frame_times.is_empty());
}
#[test]
fn test_client_guard_generation_isolation() {
let handler = Arc::new(MjpegStreamHandler::new());
let id = "shared-id".to_string();
let stale = ClientGuard::new(id.clone(), handler.clone());
let stale_gen = stale.generation;
let fresh = ClientGuard::new(id.clone(), handler.clone());
assert_ne!(stale_gen, fresh.generation);
assert_eq!(handler.client_count(), 1);
drop(stale);
assert_eq!(handler.client_count(), 1);
drop(fresh);
assert_eq!(handler.client_count(), 0);
}
}

View File

@@ -1,11 +1,4 @@
//! Video streaming module
//!
//! Provides MJPEG streaming and WebSocket handlers for MJPEG mode.
//!
//! # Components
//!
//! - `MjpegStreamHandler` - HTTP multipart MJPEG video streaming
//! - `WsHidHandler` - WebSocket HID input handler
//! MJPEG multipart streaming and WebSocket HID (for MJPEG mode).
pub mod mjpeg;
pub mod ws_hid;

View File

@@ -1,25 +1,4 @@
//! WebSocket HID Handler for MJPEG mode
//!
//! This module provides a standalone WebSocket HID handler that can be used
//! independently of the application state. It manages multiple WebSocket
//! connections and forwards HID events to the HID controller.
//!
//! # Protocol
//!
//! Only binary protocol is supported for optimal performance.
//! See `crate::hid::datachannel` for message format details.
//!
//! # Architecture
//!
//! ```text
//! WsHidHandler
//! |
//! +-- clients: HashMap<ClientId, WsHidClient>
//! +-- hid_controller: Arc<HidController>
//! |
//! +-- add_client() -> spawns client handler task
//! +-- remove_client()
//! ```
//! WebSocket HID for MJPEG mode; binary messages per `crate::hid::datachannel`.
use axum::extract::ws::{Message, WebSocket};
use futures::{SinkExt, StreamExt};
@@ -34,51 +13,34 @@ use tracing::{debug, error, info, warn};
use crate::hid::datachannel::{parse_hid_message, HidChannelEvent};
use crate::hid::HidController;
/// Client ID type
pub type ClientId = String;
/// WebSocket HID client information
#[derive(Debug)]
pub struct WsHidClient {
/// Client ID
pub id: ClientId,
/// Connection timestamp
pub connected_at: Instant,
/// Events processed
pub events_processed: AtomicU64,
/// Shutdown signal sender
shutdown_tx: mpsc::Sender<()>,
}
impl WsHidClient {
/// Get events processed count
pub fn events_count(&self) -> u64 {
self.events_processed.load(Ordering::Relaxed)
}
/// Get connection duration in seconds
pub fn connected_secs(&self) -> u64 {
self.connected_at.elapsed().as_secs()
}
}
/// WebSocket HID Handler
///
/// Manages WebSocket connections for HID input in MJPEG mode.
/// Only binary protocol is supported for optimal performance.
pub struct WsHidHandler {
/// HID controller reference
hid_controller: RwLock<Option<Arc<HidController>>>,
/// Active clients
clients: RwLock<HashMap<ClientId, Arc<WsHidClient>>>,
/// Running state
running: AtomicBool,
/// Total events processed
total_events: AtomicU64,
}
impl WsHidHandler {
/// Create a new WebSocket HID handler
pub fn new() -> Arc<Self> {
Arc::new(Self {
hid_controller: RwLock::new(None),
@@ -88,50 +50,39 @@ impl WsHidHandler {
})
}
/// Set HID controller
pub fn set_hid_controller(&self, hid: Arc<HidController>) {
*self.hid_controller.write() = Some(hid);
info!("WsHidHandler: HID controller set");
}
/// Get HID controller
pub fn hid_controller(&self) -> Option<Arc<HidController>> {
self.hid_controller.read().clone()
}
/// Check if HID controller is available
pub fn is_hid_available(&self) -> bool {
self.hid_controller.read().is_some()
}
/// Get client count
pub fn client_count(&self) -> usize {
self.clients.read().len()
}
/// Check if running
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
/// Stop the handler
pub fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
// Signal all clients to disconnect
let clients = self.clients.read();
for client in clients.values() {
let _ = client.shutdown_tx.try_send(());
}
}
/// Get total events processed
pub fn total_events(&self) -> u64 {
self.total_events.load(Ordering::Relaxed)
}
/// Add a new WebSocket client
///
/// This spawns a background task to handle the WebSocket connection.
pub async fn add_client(self: &Arc<Self>, client_id: ClientId, socket: WebSocket) {
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
@@ -151,7 +102,6 @@ impl WsHidHandler {
self.client_count()
);
// Spawn handler task
let handler = self.clone();
tokio::spawn(async move {
handler
@@ -161,7 +111,6 @@ impl WsHidHandler {
});
}
/// Remove a client
pub fn remove_client(&self, client_id: &str) {
if let Some(client) = self.clients.write().remove(client_id) {
info!(
@@ -173,7 +122,6 @@ impl WsHidHandler {
}
}
/// Handle a WebSocket client connection
async fn handle_client(
&self,
client_id: ClientId,
@@ -183,7 +131,6 @@ impl WsHidHandler {
) {
let (mut sender, mut receiver) = socket.split();
// Send initial status as binary: 0x00 = ok, 0x01 = error
let status_byte = if self.is_hid_available() {
0x00u8
} else {
@@ -222,7 +169,6 @@ impl WsHidHandler {
debug!("WsHidHandler: Client {} stream ended", client_id);
break;
}
// Ignore text messages - binary protocol only
Some(Ok(Message::Text(_))) => {
warn!("WsHidHandler: Ignoring text message from client {} (binary protocol only)", client_id);
}
@@ -232,7 +178,6 @@ impl WsHidHandler {
}
}
// Reset HID state when client disconnects to release any held keys/buttons
let hid = self.hid_controller.read().clone();
if let Some(hid) = hid {
if let Err(e) = hid.reset().await {
@@ -246,7 +191,6 @@ impl WsHidHandler {
}
}
/// Handle binary HID message
async fn handle_binary_message(&self, data: &[u8], client: &WsHidClient) -> Result<(), String> {
let hid = self
.hid_controller
@@ -279,17 +223,6 @@ impl WsHidHandler {
}
}
impl Default for WsHidHandler {
fn default() -> Self {
Self {
hid_controller: RwLock::new(None),
clients: RwLock::new(HashMap::new()),
running: AtomicBool::new(true),
total_events: AtomicU64::new(0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

18
src/stream_encoder.rs Normal file
View File

@@ -0,0 +1,18 @@
//! `EncoderType` → `EncoderBackend` (breaks config ↔ video import cycles).
use crate::config::EncoderType;
use crate::video::encoder::EncoderBackend;
/// `None` means “auto” in WebRTC / pipeline (same as `EncoderType::Auto`).
pub fn encoder_type_to_backend(encoder: EncoderType) -> Option<EncoderBackend> {
match encoder {
EncoderType::Auto => None,
EncoderType::Software => Some(EncoderBackend::Software),
EncoderType::Vaapi => Some(EncoderBackend::Vaapi),
EncoderType::Nvenc => Some(EncoderBackend::Nvenc),
EncoderType::Qsv => Some(EncoderBackend::Qsv),
EncoderType::Amf => Some(EncoderBackend::Amf),
EncoderType::Rkmpp => Some(EncoderBackend::Rkmpp),
EncoderType::V4l2m2m => Some(EncoderBackend::V4l2m2m),
}
}

View File

@@ -142,15 +142,10 @@ impl UpdateService {
}
pub async fn overview(&self, channel: UpdateChannel) -> Result<UpdateOverviewResponse> {
let channels: ChannelsManifest = self.fetch_json("/v1/channels.json").await?;
let releases: ReleasesManifest = self.fetch_json("/v1/releases.json").await?;
let (channels, releases) = self.fetch_manifests().await?;
let current_version = parse_version(env!("CARGO_PKG_VERSION"))?;
let latest_version_str = match channel {
UpdateChannel::Stable => channels.stable,
UpdateChannel::Beta => channels.beta,
};
let latest_version = parse_version(&latest_version_str)?;
let latest_version = parse_version(&channel_head_version(&channels, channel))?;
let current_parts = parse_version_parts(&current_version)?;
let latest_parts = parse_version_parts(&latest_version)?;
@@ -159,11 +154,7 @@ impl UpdateService {
if release.channel != channel {
continue;
}
let version = match parse_version(&release.version) {
Ok(v) => v,
Err(_) => continue,
};
let version_parts = match parse_version_parts(&version) {
let version_parts = match parse_version_parts(&release.version) {
Ok(parts) => parts,
Err(_) => continue,
};
@@ -253,16 +244,11 @@ impl UpdateService {
)
.await;
let channels: ChannelsManifest = self.fetch_json("/v1/channels.json").await?;
let releases: ReleasesManifest = self.fetch_json("/v1/releases.json").await?;
let (channels, releases) = self.fetch_manifests().await?;
let current_version = parse_version(env!("CARGO_PKG_VERSION"))?;
let target_version = if let Some(channel) = req.channel {
let version_str = match channel {
UpdateChannel::Stable => channels.stable,
UpdateChannel::Beta => channels.beta,
};
parse_version(&version_str)?
parse_version(&channel_head_version(&channels, channel))?
} else {
parse_version(req.target_version.as_deref().unwrap_or_default())?
};
@@ -443,6 +429,12 @@ impl UpdateService {
Ok(())
}
async fn fetch_manifests(&self) -> Result<(ChannelsManifest, ReleasesManifest)> {
let channels = self.fetch_json("/v1/channels.json").await?;
let releases = self.fetch_json("/v1/releases.json").await?;
Ok((channels, releases))
}
async fn fetch_json<T: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<T> {
let url = format!("{}{}", self.base_url.trim_end_matches('/'), path);
let response = self
@@ -494,22 +486,7 @@ impl UpdateService {
}
fn parse_version(input: &str) -> Result<String> {
let parts: Vec<&str> = input.split('.').collect();
if parts.len() != 3 {
return Err(AppError::Internal(format!(
"Invalid version {}, expected x.x.x",
input
)));
}
if parts
.iter()
.any(|p| p.is_empty() || !p.chars().all(|c| c.is_ascii_digit()))
{
return Err(AppError::Internal(format!(
"Invalid version {}, expected numeric x.x.x",
input
)));
}
parse_version_parts(input)?;
Ok(input.to_string())
}
@@ -527,16 +504,26 @@ fn parse_version_parts(input: &str) -> Result<[u64; 3]> {
input
)));
}
let major = parts[0]
.parse::<u64>()
.map_err(|e| AppError::Internal(format!("Invalid major version {}: {}", parts[0], e)))?;
let minor = parts[1]
.parse::<u64>()
.map_err(|e| AppError::Internal(format!("Invalid minor version {}: {}", parts[1], e)))?;
let patch = parts[2]
.parse::<u64>()
.map_err(|e| AppError::Internal(format!("Invalid patch version {}: {}", parts[2], e)))?;
Ok([major, minor, patch])
let mut out = [0u64; 3];
for (i, p) in parts.iter().enumerate() {
if p.is_empty() || !p.chars().all(|c| c.is_ascii_digit()) {
return Err(AppError::Internal(format!(
"Invalid version {}, expected numeric x.x.x",
input
)));
}
out[i] = p
.parse::<u64>()
.map_err(|e| AppError::Internal(format!("Invalid version component {}: {}", p, e)))?;
}
Ok(out)
}
fn channel_head_version(channels: &ChannelsManifest, channel: UpdateChannel) -> String {
match channel {
UpdateChannel::Stable => channels.stable.clone(),
UpdateChannel::Beta => channels.beta.clone(),
}
}
fn compare_version_parts(a: &[u64; 3], b: &[u64; 3]) -> std::cmp::Ordering {

23
src/utils/fs.rs Normal file
View File

@@ -0,0 +1,23 @@
//! Small filesystem helpers.
use std::path::Path;
/// Read a UTF-8 file and trim surrounding whitespace.
pub fn read_trimmed(path: &Path) -> Option<String> {
std::fs::read_to_string(path)
.ok()
.map(|value| value.trim().to_string())
}
/// Sorted list of directory entry names (lossy exclusion on non-UTF8).
pub fn list_dir_names(path: &Path) -> Vec<String> {
let mut names = std::fs::read_dir(path)
.ok()
.into_iter()
.flatten()
.flatten()
.filter_map(|entry| entry.file_name().into_string().ok())
.collect::<Vec<_>>();
names.sort();
names
}

15
src/utils/host.rs Normal file
View File

@@ -0,0 +1,15 @@
//! Host identity helpers.
/// Truncated content of `/etc/hostname`. Used where RustDesk peers expect the configured static name.
pub fn hostname_from_etc() -> String {
std::fs::read_to_string("/etc/hostname")
.map(|s| s.trim().to_string())
.unwrap_or_else(|_| "One-KVM".to_string())
}
/// Current kernel hostname (`gethostname`). Used for live device info in the UI.
pub fn hostname_uname() -> String {
nix::unistd::gethostname()
.map(|s| s.to_string_lossy().into_owned())
.unwrap_or_else(|_| "unknown".to_string())
}

View File

@@ -1,9 +1,11 @@
//! Utility modules for One-KVM
//!
//! This module contains common utilities used across the codebase.
//! Shared utilities.
pub mod fs;
pub mod host;
pub mod net;
pub mod throttle;
pub use fs::{list_dir_names, read_trimmed};
pub use host::{hostname_from_etc, hostname_uname};
pub use net::{bind_tcp_listener, bind_udp_socket};
pub use throttle::LogThrottler;

View File

@@ -1,44 +1,15 @@
//! Log throttling utility
//!
//! Provides a mechanism to limit how often the same log message is recorded,
//! preventing log flooding when errors occur repeatedly.
//! Limits repeated identical log lines (e.g. reconnect failures).
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
/// Log throttler that limits how often the same message is logged
///
/// This is useful for preventing log flooding when errors occur repeatedly,
/// such as when a device is disconnected and reconnection attempts fail.
///
/// # Example
///
/// ```rust
/// use one_kvm::utils::LogThrottler;
/// use std::time::Duration;
///
/// let throttler = LogThrottler::new(Duration::from_secs(5));
///
/// // First call returns true
/// assert!(throttler.should_log("device_error"));
///
/// // Subsequent calls within 5 seconds return false
/// assert!(!throttler.should_log("device_error"));
/// ```
pub struct LogThrottler {
/// Map of message key to last log time
last_logged: RwLock<HashMap<String, Instant>>,
/// Throttle interval
interval: Duration,
}
impl LogThrottler {
/// Create a new log throttler with the specified interval
///
/// # Arguments
///
/// * `interval` - The minimum time between log messages for the same key
pub fn new(interval: Duration) -> Self {
Self {
last_logged: RwLock::new(HashMap::new()),
@@ -46,23 +17,14 @@ impl LogThrottler {
}
}
/// Create a new log throttler with interval specified in seconds
pub fn with_secs(secs: u64) -> Self {
Self::new(Duration::from_secs(secs))
}
/// Check if a message should be logged (not throttled)
///
/// Returns `true` if the message should be logged, `false` if it should be throttled.
/// If `true` is returned, the internal timestamp is updated.
///
/// # Arguments
///
/// * `key` - A unique identifier for the message type
/// Returns whether to emit the log line; updates the timestamp when `true`.
pub fn should_log(&self, key: &str) -> bool {
let now = Instant::now();
// First check with read lock (fast path)
{
let map = self.last_logged.read().unwrap();
if let Some(last) = map.get(key) {
@@ -72,9 +34,7 @@ impl LogThrottler {
}
}
// Update with write lock
let mut map = self.last_logged.write().unwrap();
// Double-check after acquiring write lock
if let Some(last) = map.get(key) {
if now.duration_since(*last) < self.interval {
return false;
@@ -84,32 +44,14 @@ impl LogThrottler {
true
}
/// Clear throttle state for a specific key
///
/// This should be called when an error condition recovers,
/// so the next error will be logged immediately.
///
/// # Arguments
///
/// * `key` - The key to clear
/// Call when a condition recovers so the next failure logs immediately.
pub fn clear(&self, key: &str) {
self.last_logged.write().unwrap().remove(key);
}
/// Clear all throttle state
pub fn clear_all(&self) {
self.last_logged.write().unwrap().clear();
}
/// Get the number of tracked keys
pub fn len(&self) -> usize {
self.last_logged.read().unwrap().len()
}
/// Check if the throttler is empty
pub fn is_empty(&self) -> bool {
self.last_logged.read().unwrap().is_empty()
}
}
impl Clone for LogThrottler {
@@ -122,23 +64,11 @@ impl Clone for LogThrottler {
}
impl Default for LogThrottler {
/// Create a default log throttler with 5 second interval
fn default() -> Self {
Self::with_secs(5)
}
}
/// Macro for throttled warning logging
///
/// # Example
///
/// ```rust
/// use one_kvm::utils::LogThrottler;
/// use one_kvm::warn_throttled;
///
/// let throttler = LogThrottler::default();
/// warn_throttled!(throttler, "my_error", "Error occurred: {}", "details");
/// ```
#[macro_export]
macro_rules! warn_throttled {
($throttler:expr, $key:expr, $($arg:tt)*) => {
@@ -148,7 +78,6 @@ macro_rules! warn_throttled {
};
}
/// Macro for throttled error logging
#[macro_export]
macro_rules! error_throttled {
($throttler:expr, $key:expr, $($arg:tt)*) => {
@@ -158,16 +87,6 @@ macro_rules! error_throttled {
};
}
/// Macro for throttled info logging
#[macro_export]
macro_rules! info_throttled {
($throttler:expr, $key:expr, $($arg:tt)*) => {
if $throttler.should_log($key) {
tracing::info!($($arg)*);
}
};
}
#[cfg(test)]
mod tests {
use super::*;
@@ -183,16 +102,11 @@ mod tests {
fn test_throttling() {
let throttler = LogThrottler::new(Duration::from_millis(100));
// First call should succeed
assert!(throttler.should_log("test_key"));
// Immediate second call should be throttled
assert!(!throttler.should_log("test_key"));
// Wait for throttle to expire
thread::sleep(Duration::from_millis(150));
// Should succeed again
assert!(throttler.should_log("test_key"));
}
@@ -200,7 +114,6 @@ mod tests {
fn test_different_keys() {
let throttler = LogThrottler::with_secs(10);
// Different keys should be independent
assert!(throttler.should_log("key1"));
assert!(throttler.should_log("key2"));
assert!(!throttler.should_log("key1"));
@@ -214,10 +127,8 @@ mod tests {
assert!(throttler.should_log("test_key"));
assert!(!throttler.should_log("test_key"));
// Clear the key
throttler.clear("test_key");
// Should be able to log again
assert!(throttler.should_log("test_key"));
}
@@ -239,19 +150,4 @@ mod tests {
let throttler = LogThrottler::default();
assert!(throttler.should_log("test"));
}
#[test]
fn test_len_and_is_empty() {
let throttler = LogThrottler::with_secs(10);
assert!(throttler.is_empty());
assert_eq!(throttler.len(), 0);
throttler.should_log("key1");
assert!(!throttler.is_empty());
assert_eq!(throttler.len(), 1);
throttler.should_log("key2");
assert_eq!(throttler.len(), 2);
}
}

View File

@@ -0,0 +1,30 @@
//! Shared tuning for V4L2 MJPEG capture paths (`Streamer` + `SharedVideoPipeline`).
/// Frames smaller than this are treated as incomplete / noise.
pub(crate) const MIN_CAPTURE_FRAME_SIZE: usize = 128;
/// After startup, validate JPEG header every N frames to limit CPU use.
pub(crate) const JPEG_VALIDATE_INTERVAL: u64 = 30;
/// Validate every MJPEG frame for the first N frames (UVC warm-up / bad headers).
pub(crate) const STARTUP_JPEG_VALIDATE_FRAMES: u64 = 3;
#[inline]
pub(crate) fn should_validate_jpeg_frame(validate_counter: u64) -> bool {
validate_counter <= STARTUP_JPEG_VALIDATE_FRAMES
|| validate_counter.is_multiple_of(JPEG_VALIDATE_INTERVAL)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jpeg_validation_policy_startup_then_interval() {
assert!(should_validate_jpeg_frame(1));
assert!(should_validate_jpeg_frame(2));
assert!(should_validate_jpeg_frame(3));
assert!(!should_validate_jpeg_frame(4));
assert!(should_validate_jpeg_frame(30));
}
}

View File

@@ -0,0 +1,75 @@
//! Shared capture status and error classification helpers.
use std::io;
use crate::video::SignalStatus;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureIoErrorKind {
DeviceLost,
TransientSignal { status: Option<SignalStatus> },
Other,
}
pub fn signal_status_from_capture_kind(kind: &str) -> SignalStatus {
SignalStatus::from_str(kind).unwrap_or(SignalStatus::NoSignal)
}
pub fn classify_capture_io_error(err: &io::Error) -> CaptureIoErrorKind {
match err.raw_os_error() {
// ENXIO / ENODEV / ESHUTDOWN: the device node or endpoint is gone.
Some(6) | Some(19) | Some(108) => CaptureIoErrorKind::DeviceLost,
// EIO / EPIPE: source or transport glitched; EPROTO is common for UVC USB.
Some(5) | Some(32) => CaptureIoErrorKind::TransientSignal { status: None },
Some(71) => CaptureIoErrorKind::TransientSignal {
status: Some(SignalStatus::UvcUsbError),
},
_ => CaptureIoErrorKind::Other,
}
}
pub fn is_device_lost_message(message: &str) -> bool {
message.contains("No such file or directory")
|| message.contains("No such device")
|| message.contains("os error 2")
|| message.contains("ENODEV")
|| message.contains("ENXIO")
|| message.contains("ESHUTDOWN")
}
pub fn capture_error_log_key(err: &io::Error) -> String {
let message = err.to_string();
if message.contains("dqbuf failed") && message.contains("EINVAL") {
"capture_dqbuf_einval".to_string()
} else if message.contains("dqbuf failed") {
"capture_dqbuf".to_string()
} else {
format!("capture_{:?}", err.kind())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn maps_known_signal_status_strings() {
assert_eq!(
signal_status_from_capture_kind("out_of_range"),
SignalStatus::OutOfRange
);
assert_eq!(
signal_status_from_capture_kind("unknown"),
SignalStatus::NoSignal
);
}
#[test]
fn classifies_source_change_log_keys() {
let err = io::Error::other("dqbuf failed: EINVAL");
assert_eq!(capture_error_log_key(&err), "capture_dqbuf_einval");
let err = io::Error::new(io::ErrorKind::TimedOut, "capture timeout");
assert_eq!(capture_error_log_key(&err), "capture_TimedOut");
}
}

View File

@@ -135,7 +135,7 @@ pub async fn enforce_constraints_with_stream_manager(
}
if current_mode == StreamMode::WebRTC {
let current_codec = stream_manager.webrtc_streamer().current_video_codec().await;
let current_codec = stream_manager.current_video_codec().await;
if !constraints.is_webrtc_codec_allowed(current_codec) {
let target_codec = constraints.preferred_webrtc_codec();
stream_manager.set_video_codec(target_codec).await?;

355
src/video/csi_bridge.rs Normal file
View File

@@ -0,0 +1,355 @@
//! CSI/HDMI bridge helpers: subdev discovery, DV probe, RK628 "fake VGA" filter (must run before `S_FMT` / `STREAMON` on capture — see RK628 driver).
use std::fs::File;
use std::io;
use std::os::fd::{AsFd, AsRawFd, FromRawFd};
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use libc;
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use tracing::{debug, info, warn};
use v4l2r::bindings::{
v4l2_bt_timings, v4l2_dv_timings, V4L2_DV_BT_656_1120, V4L2_DV_FL_HAS_CEA861_VIC,
};
use v4l2r::ioctl::{self, Event as V4l2Event, EventType, QueryDvTimingsError, SubscribeEventFlags};
use v4l2r::nix::errno::Errno;
use crate::video::SignalStatus;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CsiBridgeKind {
Rk628,
RkHdmirx,
Tc358743,
Unknown,
}
impl CsiBridgeKind {
fn from_subdev_name(name: &str) -> Option<Self> {
let lower = name.to_ascii_lowercase();
if lower.contains("rk628") {
Some(Self::Rk628)
} else if lower.contains("hdmirx") || lower.contains("hdmi-rx") {
Some(Self::RkHdmirx)
} else if lower.contains("tc358743") || lower.contains("tc358746") {
Some(Self::Tc358743)
} else {
None
}
}
fn has_no_signal_fingerprint(self) -> bool {
matches!(self, Self::Rk628)
}
}
#[derive(Debug, Clone)]
pub enum ProbeResult {
Locked(DvTimingsMode),
NoCable,
NoSync,
OutOfRange,
NoSignal,
}
impl ProbeResult {
pub fn as_status(&self) -> Option<SignalStatus> {
match self {
ProbeResult::Locked(_) => None,
ProbeResult::NoCable => Some(SignalStatus::NoCable),
ProbeResult::NoSync => Some(SignalStatus::NoSync),
ProbeResult::OutOfRange => Some(SignalStatus::OutOfRange),
ProbeResult::NoSignal => Some(SignalStatus::NoSignal),
}
}
pub fn is_locked(&self) -> bool {
matches!(self, ProbeResult::Locked(_))
}
}
/// Scalar copy of BT timings (avoids unaligned refs into packed union).
#[derive(Clone, Copy)]
pub struct DvTimingsMode {
pub width: u32,
pub height: u32,
pub pixelclock: u64,
pub fps: Option<f64>,
pub raw: v4l2_dv_timings,
}
impl std::fmt::Debug for DvTimingsMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DvTimingsMode")
.field("width", &self.width)
.field("height", &self.height)
.field("pixelclock", &self.pixelclock)
.field("fps", &self.fps)
.finish()
}
}
/// Heuristic: scan `/sys/class/video4linux/v4l-subdev*` names for rk628 / hdmirx / tc358743.
pub fn discover_subdev_for_video(video_path: &Path) -> Option<(PathBuf, CsiBridgeKind)> {
let sysfs_base = Path::new("/sys/class/video4linux");
let entries = std::fs::read_dir(sysfs_base).ok()?;
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if !name_str.starts_with("v4l-subdev") {
continue;
}
let Some(kind) = read_sysfs_name(&entry.path())
.as_deref()
.and_then(CsiBridgeKind::from_subdev_name)
else {
continue;
};
let dev_path = PathBuf::from("/dev").join(&*name_str);
if dev_path.exists() {
info!(
"Discovered CSI bridge subdev for {:?}: {:?} ({:?})",
video_path, dev_path, kind
);
return Some((dev_path, kind));
}
}
debug!(
"No CSI bridge subdev found in /sys/class/video4linux for {:?}",
video_path
);
None
}
fn read_sysfs_name(subdev_sysfs: &Path) -> Option<String> {
std::fs::read_to_string(subdev_sysfs.join("name"))
.ok()
.map(|s| s.trim().to_string())
}
pub fn open_subdev(path: &Path) -> io::Result<File> {
File::options().read(true).write(true).open(path)
}
pub fn probe_signal(subdev_fd: &impl AsRawFd, kind: CsiBridgeKind) -> ProbeResult {
match ioctl::query_dv_timings::<v4l2_dv_timings>(subdev_fd) {
Ok(timings) => classify_timings(timings, kind),
Err(QueryDvTimingsError::NoLink) => ProbeResult::NoCable,
Err(QueryDvTimingsError::UnstableSignal) => ProbeResult::NoSync,
Err(QueryDvTimingsError::IoctlError(Errno::ERANGE)) => ProbeResult::OutOfRange,
Err(QueryDvTimingsError::IoctlError(Errno::EIO | Errno::EREMOTEIO | Errno::ETIMEDOUT)) => {
ProbeResult::NoSync
}
Err(QueryDvTimingsError::Unsupported) | Err(QueryDvTimingsError::IoctlError(_)) => {
ProbeResult::NoSignal
}
}
}
/// RK628 can block `QUERY_DV_TIMINGS` for seconds; probe uses a dup + timeout.
pub const RK628_SUBDEV_PROBE_TIMEOUT: Duration = Duration::from_millis(3000);
pub fn probe_signal_thread_timeout(
subdev_fd: &impl AsRawFd,
kind: CsiBridgeKind,
limit: Duration,
) -> Option<ProbeResult> {
let raw = subdev_fd.as_raw_fd();
let dup_fd = unsafe { libc::dup(raw) };
if dup_fd < 0 {
warn!(
"dup(subdev) for threaded DV probe failed: {}",
io::Error::last_os_error()
);
return None;
}
let dup_file = unsafe { File::from_raw_fd(dup_fd) };
let (tx, rx) = mpsc::channel::<ProbeResult>();
let handle = thread::spawn(move || {
let probe = probe_signal(&dup_file, kind);
let _ = tx.send(probe);
});
match rx.recv_timeout(limit) {
Ok(r) => {
let _ = handle.join();
Some(r)
}
Err(mpsc::RecvTimeoutError::Timeout) => {
warn!(
"QUERY_DV_TIMINGS exceeded {:?} (RK628 HDMI mode change?) — abandoning probe thread",
limit
);
drop(handle);
None
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
let _ = handle.join();
None
}
}
}
fn classify_timings(timings: v4l2_dv_timings, kind: CsiBridgeKind) -> ProbeResult {
let timings_type: u32 = timings.type_;
if timings_type != V4L2_DV_BT_656_1120 {
warn!(
"QUERY_DV_TIMINGS returned unexpected type {}, treating as NoSignal",
timings_type
);
return ProbeResult::NoSignal;
}
let bt: v4l2_bt_timings = unsafe { timings.__bindgen_anon_1.bt };
let width: u32 = bt.width;
let height: u32 = bt.height;
let pixelclock: u64 = bt.pixelclock;
if width == 0 || height == 0 || width <= 64 || height <= 64 {
return ProbeResult::NoSignal;
}
if kind.has_no_signal_fingerprint() && is_rk628_no_signal_fingerprint(&bt) {
debug!(
"RK628 reports synthetic {}x{} @ {} Hz VGA fingerprint → NoSignal",
width, height, pixelclock
);
return ProbeResult::NoSignal;
}
let total_h: u64 = (width + bt.hfrontporch + bt.hsync + bt.hbackporch) as u64;
let total_v: u64 = (height + bt.vfrontporch + bt.vsync + bt.vbackporch) as u64;
let fps = if total_h > 0 && total_v > 0 && pixelclock > 0 {
Some(pixelclock as f64 / (total_h as f64 * total_v as f64))
} else {
None
};
ProbeResult::Locked(DvTimingsMode {
width,
height,
pixelclock,
fps,
raw: timings,
})
}
/// RK628 returns DMT 640x480 @ ~25.175 MHz, VIC=1 when unlocked; do not stream on that.
fn is_rk628_no_signal_fingerprint(bt: &v4l2_bt_timings) -> bool {
let width: u32 = bt.width;
let height: u32 = bt.height;
let pixelclock: u64 = bt.pixelclock;
let flags: u32 = bt.flags;
let vic: u8 = bt.cea861_vic;
if width != 640 || height != 480 {
return false;
}
let pclk_matches = (pixelclock as i64 - 25_175_000).abs() < 50_000;
let has_vic_flag = flags & V4L2_DV_FL_HAS_CEA861_VIC != 0;
pclk_matches && has_vic_flag && vic == 1
}
pub fn apply_dv_timings(subdev_fd: &impl AsRawFd, timings: v4l2_dv_timings) {
match ioctl::s_dv_timings::<_, v4l2_dv_timings>(subdev_fd, timings) {
Ok(_) => debug!("S_DV_TIMINGS ok on subdev"),
Err(e) => debug!(
"S_DV_TIMINGS failed on subdev ({}), continuing with queried mode",
e
),
}
}
pub fn subscribe_source_change(subdev_fd: &impl AsRawFd) -> io::Result<()> {
ioctl::subscribe_event(
subdev_fd,
EventType::SourceChange(0),
SubscribeEventFlags::empty(),
)
.map_err(|e| io::Error::other(format!("subscribe_event(SOURCE_CHANGE): {}", e)))
}
/// `Ok(true)` if a SOURCE_CHANGE was drained; `Ok(false)` on timeout.
pub fn wait_source_change(subdev_fd: &File, timeout: Duration) -> io::Result<bool> {
let mut fds = [PollFd::new(subdev_fd.as_fd(), PollFlags::POLLPRI)];
let timeout_ms = timeout.as_millis().min(u16::MAX as u128) as u16;
let ready = poll(&mut fds, PollTimeout::from(timeout_ms))?;
if ready == 0 {
return Ok(false);
}
if let Some(revents) = fds[0].revents() {
if !revents.contains(PollFlags::POLLPRI) {
return Ok(false);
}
}
let mut drained = 0u32;
while let Ok(_ev) = ioctl::dqevent::<V4l2Event>(subdev_fd) {
drained = drained.saturating_add(1);
if drained >= 16 {
break;
}
}
debug!("subdev source_change drained {} event(s)", drained);
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rk628_fingerprint_matches_vga() {
let mut bt: v4l2_bt_timings = unsafe { std::mem::zeroed() };
bt.width = 640;
bt.height = 480;
bt.pixelclock = 25_175_000;
bt.flags = V4L2_DV_FL_HAS_CEA861_VIC;
bt.cea861_vic = 1;
assert!(is_rk628_no_signal_fingerprint(&bt));
}
#[test]
fn rk628_fingerprint_rejects_real_1080p() {
let mut bt: v4l2_bt_timings = unsafe { std::mem::zeroed() };
bt.width = 1920;
bt.height = 1080;
bt.pixelclock = 148_500_000;
bt.flags = V4L2_DV_FL_HAS_CEA861_VIC;
bt.cea861_vic = 16;
assert!(!is_rk628_no_signal_fingerprint(&bt));
}
#[test]
fn rk628_fingerprint_rejects_real_vga_without_vic() {
// A hypothetical legit VGA source would *not* carry the CEA VIC
// flag from the bridge (RK628 sets it synthetically when unlocked).
let mut bt: v4l2_bt_timings = unsafe { std::mem::zeroed() };
bt.width = 640;
bt.height = 480;
bt.pixelclock = 25_175_000;
bt.flags = 0;
bt.cea861_vic = 0;
assert!(!is_rk628_no_signal_fingerprint(&bt));
}
#[test]
fn from_subdev_name_recognises_known_bridges() {
assert_eq!(
CsiBridgeKind::from_subdev_name("rk628-csi-v4l2 9-0051"),
Some(CsiBridgeKind::Rk628)
);
assert_eq!(
CsiBridgeKind::from_subdev_name("rk-hdmirx-ctrl"),
Some(CsiBridgeKind::RkHdmirx)
);
assert_eq!(
CsiBridgeKind::from_subdev_name("tc358743 2-000f"),
Some(CsiBridgeKind::Tc358743)
);
assert_eq!(CsiBridgeKind::from_subdev_name("mystery"), None);
}
}

View File

@@ -16,11 +16,13 @@ use v4l2r::ioctl::{
use v4l2r::nix::errno::Errno;
use v4l2r::{Format as V4l2rFormat, QueueType};
use super::csi_bridge;
use super::format::{PixelFormat, Resolution};
use super::is_rk_hdmirx_driver;
use super::{is_rk_hdmirx_driver, is_rkcif_driver};
use crate::error::{AppError, Result};
const DEVICE_PROBE_TIMEOUT_MS: u64 = 400;
/// Per-node probe limit; rkcif/RK628 ioctl chains can exceed 1s under contention.
const DEVICE_PROBE_TIMEOUT_MS: u64 = 10_000;
/// Information about a video device
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -43,6 +45,43 @@ pub struct VideoDeviceInfo {
pub is_capture_card: bool,
/// Priority score for device selection (higher is better)
pub priority: u32,
/// Whether an HDMI signal is currently detected (CSI/HDMI bridge devices only;
/// always `true` for USB capture cards).
pub has_signal: bool,
/// Path of the bridge subdev (`/dev/v4l-subdevN`) paired with this
/// capture node, if any. On Rockchip boards that wire an RK628 /
/// TC358746 / RK-HDMIRX through `rkcif`, `QUERY_DV_TIMINGS`,
/// `S_DV_TIMINGS`, `SUBSCRIBE_EVENT(SOURCE_CHANGE)`, `S_EDID` etc. all
/// return `ENOTTY` on the video node — they only work here. `None`
/// for USB UVC and for bridges that expose DV ioctls on the video node
/// directly (tc358743 via `uvcvideo`).
pub subdev_path: Option<PathBuf>,
/// Classification of the paired bridge (drives fingerprint logic for
/// RK628's synthetic-VGA no-signal pattern).
pub bridge_kind: Option<String>,
}
#[derive(Debug, Clone)]
pub struct VideoDeviceRecoveryHint {
pub path: PathBuf,
pub name: String,
pub driver: String,
pub bus_info: String,
pub card: String,
pub is_capture_card: bool,
}
impl From<&VideoDeviceInfo> for VideoDeviceRecoveryHint {
fn from(device: &VideoDeviceInfo) -> Self {
Self {
path: device.path.clone(),
name: device.name.clone(),
driver: device.driver.clone(),
bus_info: device.bus_info.clone(),
card: device.card.clone(),
is_capture_card: device.is_capture_card,
}
}
}
/// Information about a supported format
@@ -147,12 +186,110 @@ impl VideoDevice {
read_write: flags.contains(Capabilities::READWRITE),
};
let formats = if is_rk_hdmirx_driver(&caps.driver, &caps.card) {
self.enumerate_current_format_only()?
// For CSI/HDMI bridges, try to locate the paired subdev *before*
// the signal check: RK628 + rkcif places QUERY_DV_TIMINGS on the
// subdev (the video node returns ENOTTY). Tc358743 and rk_hdmirx
// typically expose DV ioctls on the video node itself, but having
// the subdev handle for EDID/event subscription doesn't hurt.
let (subdev_path, bridge_kind) =
if is_rkcif_driver(&caps.driver) || is_rk_hdmirx_driver(&caps.driver, &caps.card) {
match csi_bridge::discover_subdev_for_video(&self.path) {
Some((path, kind)) => (Some(path), Some(format!("{:?}", kind).to_lowercase())),
None => (None, None),
}
} else {
(None, None)
};
// Probe the HDMI source for both signal presence *and* the live
// frame-rate. rkcif's `VIDIOC_ENUM_FRAMEINTERVALS` returns a
// meaningless `1.0..30.0` StepWise range, so the only trustworthy
// fps for rkcif + RK628 / rk_hdmirx boards comes from the bridge
// subdev's DV timings (pixelclock / total_width / total_height).
//
// Preference order:
// 1. Bridge subdev — on rkcif boards this is the *only* node
// where QUERY_DV_TIMINGS works, and it lets the RK628
// fingerprint filter kick in before we return has_signal=true.
// 2. Video node fallback — for rk_hdmirx / tc358743 where DV
// timings are exposed on the capture node directly.
// 3. USB UVC — always true (no signal concept), no hdmi_fps.
// Subdev-reported HDMI source mode (width, height, fps). On rkcif +
// RK628 boards this is the *only* place DV timings work; the video
// node itself returns ENOTTY for QUERY/G_DV_TIMINGS, so without
// threading this through to `enumerate_bridge_formats` the format
// list ends up with zero resolutions and `select_resolution` falls
// back to the user's preferred value (e.g. 4K) even when the real
// source is 1080p.
let mut subdev_hdmi_mode: Option<(u32, u32, Option<f64>)> = None;
let (has_signal, hdmi_fps) = if let Some(subdev_path) = subdev_path.as_ref() {
match csi_bridge::open_subdev(subdev_path) {
Ok(subdev_fd) => {
let kind = parse_bridge_kind(bridge_kind.as_deref())
.unwrap_or(csi_bridge::CsiBridgeKind::Unknown);
let probe = csi_bridge::probe_signal(&subdev_fd, kind);
debug!(
"has_signal via subdev {:?} ({:?}): {:?}",
subdev_path, kind, probe
);
let fps = match &probe {
csi_bridge::ProbeResult::Locked(mode) => {
subdev_hdmi_mode = Some((mode.width, mode.height, mode.fps));
mode.fps
}
_ => None,
};
(probe.is_locked(), fps)
}
Err(e) => {
warn!("Failed to open subdev {:?}: {}", subdev_path, e);
(false, None)
}
}
} else if is_rk_hdmirx_driver(&caps.driver, &caps.card) || is_rkcif_driver(&caps.driver) {
let dv = self.current_dv_timings_mode();
debug!(
"has_signal via video node {:?} (driver={}): dv_timings={:?}",
self.path, caps.driver, dv
);
let has_signal = dv
.as_ref()
.map(|(w, h, _)| *w > 64 && *h > 64)
.unwrap_or(false);
let fps = if has_signal {
dv.and_then(|(_, _, f)| f)
} else {
None
};
(has_signal, fps)
} else {
self.enumerate_formats()?
(true, None)
};
let mut formats =
if is_rk_hdmirx_driver(&caps.driver, &caps.card) || is_rkcif_driver(&caps.driver) {
// CSI/HDMI bridge drivers (rk_hdmirx, rkcif) expose multiple pixel
// formats via ENUM_FMT (e.g. rk_hdmirx: BGR3/NV24/NV16/NV12) but
// `ENUM_FRAMESIZES` is fiction for these drivers (rkcif reports a
// degenerate `64x64 StepWise 8/8` that only describes its DMA
// engine, rk_hdmirx returns ENOTTY). The only authoritative
// resolution is whatever the bridge subdev's DV timings report,
// so we treat the HDMI source mode as the single allowed
// resolution for every pixel format.
self.enumerate_bridge_formats(subdev_hdmi_mode)?
} else {
self.enumerate_formats()?
};
// For CSI/HDMI bridges, the driver-enumerated fps list is fiction
// (rkcif: always `1..30`; rk_hdmirx: typically `ENOTTY`). Replace
// it with the live HDMI source fps derived from the bridge DV
// timings so the UI reflects what the sink is actually receiving.
if let Some(fps) = hdmi_fps {
override_resolution_fps(&mut formats, fps);
}
// Determine if this is likely an HDMI capture card
let is_capture_card = Self::detect_capture_card(&caps.card, &caps.driver, &formats);
@@ -160,6 +297,11 @@ impl VideoDevice {
let priority =
Self::calculate_priority(&caps.card, &caps.driver, &formats, is_capture_card);
debug!(
"Device {:?}: {} formats, priority={}, has_signal={}, hdmi_fps={:?}, is_capture_card={}, subdev={:?}",
self.path, formats.len(), priority, has_signal, hdmi_fps, is_capture_card, subdev_path
);
Ok(VideoDeviceInfo {
path: self.path.clone(),
name: caps.card.clone(),
@@ -170,6 +312,9 @@ impl VideoDevice {
capabilities,
is_capture_card,
priority,
has_signal,
subdev_path,
bridge_kind,
})
}
@@ -213,32 +358,119 @@ impl VideoDevice {
Ok(formats)
}
fn enumerate_current_format_only(&self) -> Result<Vec<FormatInfo>> {
let current = self.get_format()?;
let Some(format) = PixelFormat::from_v4l2r(current.pixelformat) else {
/// Enumerate formats for CSI/HDMI bridge devices (rk_hdmirx, rkcif).
///
/// Uses `VIDIOC_ENUM_FMT` to discover all supported pixel formats (the
/// output of `v4l2-ctl --list-formats`) and attaches the HDMI source
/// resolution read from the bridge DV timings (or G_FMT as a last
/// resort) as the single allowed resolution for every format.
///
/// `ENUM_FRAMESIZES` is deliberately ignored here: rkcif advertises a
/// degenerate `64x64 StepWise 8/8` that only describes its DMA engine
/// (not what the HDMI source can actually deliver), and rk_hdmirx
/// typically returns ENOTTY. Neither the bridge nor rkcif performs
/// any hardware scaling, so the capture resolution is always the
/// HDMI source mode.
///
/// Returned formats are sorted by `PixelFormat::priority()` so the
/// higher-level `select_format` picks a sensible default (NV12 > YUYV on
/// rkcif / rk_hdmirx) instead of whatever the driver happens to
/// have stuck as the current active format.
fn enumerate_bridge_formats(
&self,
subdev_hdmi_mode: Option<(u32, u32, Option<f64>)>,
) -> Result<Vec<FormatInfo>> {
let queue = self.capture_queue_type()?;
let current_fmt = self.get_format().ok();
if let Some(fmt) = &current_fmt {
debug!(
"Current active format {:?} is not supported by One-KVM, falling back to full enumeration",
current.pixelformat
"enumerate_bridge_formats: current G_FMT -> {:?} {}x{}",
fmt.pixelformat, fmt.width, fmt.height
);
return self.enumerate_formats();
};
}
let description = self
.format_description(current.pixelformat)
.unwrap_or_else(|| format.to_string());
// Preference order for the HDMI source resolution:
// 1. Subdev-reported DV timings (authoritative on rkcif + RK628 where
// the video node returns ENOTTY for QUERY_DV_TIMINGS).
// 2. Video-node DV timings / G_FMT (rk_hdmirx, tc358743 direct).
let hdmi_mode = subdev_hdmi_mode
.map(|(w, h, fps)| {
let mut fps_list = Vec::new();
if let Some(f) = fps {
fps_list.push(f);
}
if let Some(parm_fps) = self.current_parm_fps() {
fps_list.push(parm_fps);
}
normalize_fps_list(&mut fps_list);
ResolutionInfo::new(w, h, fps_list)
})
.or_else(|| self.current_mode_resolution_info());
if let Some(info) = &hdmi_mode {
debug!(
"enumerate_bridge_formats: HDMI source mode {}x{} (from {})",
info.width,
info.height,
if subdev_hdmi_mode.is_some() {
"subdev"
} else {
"video node"
}
);
} else {
debug!("enumerate_bridge_formats: no HDMI source mode available");
}
let mut resolutions = self.enumerate_resolutions(current.pixelformat)?;
if resolutions.is_empty() {
if let Some(current_mode) = self.current_mode_resolution_info() {
resolutions.push(current_mode);
let mut formats: Vec<FormatInfo> = Vec::new();
for desc in FormatIterator::new(&self.fd, queue) {
let Some(format) = PixelFormat::from_v4l2r(desc.pixelformat) else {
debug!(
"enumerate_bridge_formats: skipping unsupported fourcc {:?} ({})",
desc.pixelformat, desc.description
);
continue;
};
let resolutions = hdmi_mode.clone().into_iter().collect();
formats.push(FormatInfo {
format,
resolutions,
description: desc.description.clone(),
});
}
if formats.is_empty() {
// Fallback: driver refused ENUM_FMT entirely, use just the current
// active format reported by G_FMT so we still have something.
if let Some(fmt) = current_fmt {
if let Some(format) = PixelFormat::from_v4l2r(fmt.pixelformat) {
let description = self
.format_description(fmt.pixelformat)
.unwrap_or_else(|| format.to_string());
let resolutions = hdmi_mode.into_iter().collect();
formats.push(FormatInfo {
format,
resolutions,
description,
});
}
}
}
Ok(vec![FormatInfo {
format,
resolutions,
description,
}])
// Highest priority first (MJPEG > NV12 > NV16 > NV24 > BGR24 > ...).
formats.sort_by(|a, b| b.format.priority().cmp(&a.format.priority()));
debug!(
"enumerate_bridge_formats: resolved formats {:?}",
formats
.iter()
.map(|f| format!("{}({} res)", f.format, f.resolutions.len()))
.collect::<Vec<_>>()
);
Ok(formats)
}
/// Enumerate resolutions for a specific format
@@ -259,24 +491,26 @@ impl VideoDevice {
resolutions.push(ResolutionInfo::new(d.width, d.height, fps));
}
FrmSizeTypes::StepWise(s) => {
for res in [
Resolution::VGA,
Resolution::HD720,
Resolution::HD1080,
Resolution::UHD4K,
] {
if res.width >= s.min_width
&& res.width <= s.max_width
&& res.height >= s.min_height
&& res.height <= s.max_height
{
let fps = self
.enumerate_fps(fourcc, res.width, res.height)
.unwrap_or_default();
resolutions
.push(ResolutionInfo::new(res.width, res.height, fps));
}
// StepWise ranges are ignored on purpose: on
// CSI/HDMI bridge drivers (rkcif) the range
// only describes the DMA engine's capability
// and not what the HDMI source can deliver,
// so synthesising candidate resolutions from
// it is misleading. Bridge devices go
// through `enumerate_bridge_formats` and use
// the DV-timings source mode directly; for
// any other driver that emits StepWise we
// fall back to the current active mode below.
debug!(
"ENUM_FRAMESIZES {:?}: ignoring StepWise {}x{} - {}x{} step {}/{}",
fourcc, s.min_width, s.min_height,
s.max_width, s.max_height,
s.step_width, s.step_height
);
if resolutions.is_empty() {
should_fallback_to_current_mode = true;
}
break;
}
}
}
@@ -449,6 +683,8 @@ impl VideoDevice {
"macrosilicon",
"tc358743",
"uvc",
"rkcif",
"rk_hdmirx",
];
// Check card/driver names
@@ -637,22 +873,18 @@ impl VideoDevice {
/// Enumerate all video capture devices
pub fn enumerate_devices() -> Result<Vec<VideoDeviceInfo>> {
info!("Enumerating video devices...");
debug!("Enumerating video devices...");
let mut devices = Vec::new();
// Scan /dev/video* devices
// First pass: collect candidates that pass the sysfs-based pre-filter.
// This avoids opening orphan /dev/videoN nodes (ENODEV) and m2m codec
// nodes (ENOTTY) that would otherwise waste one syscall + one ioctl each.
let mut candidates: Vec<PathBuf> = Vec::new();
for entry in std::fs::read_dir("/dev")
.map_err(|e| AppError::VideoError(format!("Failed to read /dev: {}", e)))?
{
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let Ok(entry) = entry else { continue };
let path = entry.path();
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if !name.starts_with("video") {
continue;
}
@@ -663,11 +895,31 @@ pub fn enumerate_devices() -> Result<Vec<VideoDeviceInfo>> {
debug!("Skipping non-capture candidate (sysfs): {:?}", path);
continue;
}
candidates.push(path);
}
// Try to open and query the device (with timeout)
match probe_device_with_timeout(&path, Duration::from_millis(DEVICE_PROBE_TIMEOUT_MS)) {
collapse_rkcif_probe_candidates(&mut candidates);
// Second pass: probe the remaining candidates in parallel. Each probe
// already spawns its own worker thread inside `probe_device_with_timeout`,
// so the total wall-clock time is bounded by `DEVICE_PROBE_TIMEOUT_MS`
// rather than (N × per-probe-latency).
let timeout = Duration::from_millis(DEVICE_PROBE_TIMEOUT_MS);
let mut handles = Vec::with_capacity(candidates.len());
for path in candidates {
handles.push(std::thread::spawn(move || {
(path.clone(), probe_device_with_timeout(&path, timeout))
}));
}
let mut devices = Vec::new();
for handle in handles {
let (path, info) = match handle.join() {
Ok(pair) => pair,
Err(_) => continue,
};
match info {
Some(info) => {
// Only include devices with video capture capability
if info.capabilities.video_capture || info.capabilities.video_capture_mplane {
info!(
"Found capture device: {} ({}) - {} formats",
@@ -686,13 +938,128 @@ pub fn enumerate_devices() -> Result<Vec<VideoDeviceInfo>> {
}
}
// Sort by priority (highest first)
devices.sort_by(|a, b| b.priority.cmp(&a.priority));
// Sort by priority (highest first), then by path (lowest first) as tiebreaker.
// The path tiebreaker ensures deterministic ordering when multiple sub-devices
// share the same priority (e.g. rkcif nodes), so that /dev/video0 is preferred
// over /dev/video10 after deduplication.
devices.sort_by(|a, b| {
b.priority
.cmp(&a.priority)
.then_with(|| a.path.cmp(&b.path))
});
// Deduplicate rkcif sub-devices: the driver exposes many /dev/video* nodes
// for a single MIPI CSI pipeline. Keep only the highest-priority node per
// (driver, bus_info) group so users see one device instead of ~11.
dedup_platform_subdevices(&mut devices);
info!("Found {} video capture devices", devices.len());
Ok(devices)
}
pub fn select_recovery_device(
devices: &[VideoDeviceInfo],
hint: &VideoDeviceRecoveryHint,
) -> Option<VideoDeviceInfo> {
devices
.iter()
.find(|device| device.path == hint.path)
.or_else(|| {
if hint.bus_info.trim().is_empty() {
None
} else {
devices
.iter()
.find(|device| device.bus_info == hint.bus_info)
}
})
.or_else(|| {
if hint.driver.trim().is_empty() || hint.card.trim().is_empty() {
None
} else {
devices
.iter()
.find(|device| device.driver == hint.driver && device.card == hint.card)
}
})
.or_else(|| {
if hint.driver.trim().is_empty() || hint.name.trim().is_empty() {
None
} else {
devices
.iter()
.find(|device| device.driver == hint.driver && device.name == hint.name)
}
})
.or_else(|| {
if hint.is_capture_card {
devices.iter().find(|device| device.is_capture_card)
} else {
None
}
})
.or_else(|| devices.first())
.cloned()
}
/// Collapse platform sub-device nodes that share the same driver + bus_info
/// into a single entry (the one with the highest priority / most formats).
/// Currently applies to the `rkcif` driver on Rockchip SoCs where each
/// media-pipeline link creates its own `/dev/video*` node.
fn dedup_platform_subdevices(devices: &mut Vec<VideoDeviceInfo>) {
// devices is already sorted by priority (descending).
// Walk the list and keep only the first (highest-priority) representative
// of each (driver, bus_info) group that needs deduplication.
let mut seen = std::collections::HashSet::new();
devices.retain(|d| {
if !is_rkcif_driver(&d.driver) || d.bus_info.is_empty() {
return true;
}
let key = (d.driver.clone(), d.bus_info.clone());
seen.insert(key)
});
}
/// rkcif registers many `/dev/video*` queues; probing all in parallel can
/// contend and time out. Keep one node per board (lowest `videoN`).
fn collapse_rkcif_probe_candidates(candidates: &mut Vec<PathBuf>) {
let mut rkcif: Vec<PathBuf> = Vec::new();
let mut rest: Vec<PathBuf> = Vec::new();
for p in candidates.drain(..) {
if sysfs_uevent_driver(&p).is_some_and(|d| d.contains("rkcif")) {
rkcif.push(p);
} else {
rest.push(p);
}
}
if let Some(one) = rkcif
.iter()
.min_by_key(|p| video_index(p).unwrap_or(u32::MAX))
.cloned()
{
rest.push(one);
}
*candidates = rest;
}
fn sysfs_uevent_driver(path: &Path) -> Option<String> {
let name = path.file_name()?.to_str()?;
let uevent = read_sysfs_string(
&Path::new("/sys/class/video4linux")
.join(name)
.join("device/uevent"),
)?;
extract_uevent_value(&uevent, "driver")
}
fn video_index(path: &Path) -> Option<u32> {
path.file_name()?
.to_str()?
.strip_prefix("video")?
.parse()
.ok()
}
fn probe_device_with_timeout(path: &Path, timeout: Duration) -> Option<VideoDeviceInfo> {
let path = path.to_path_buf();
let path_for_thread = path.clone();
@@ -725,8 +1092,29 @@ fn sysfs_maybe_capture(path: &Path) -> bool {
Some(name) => name,
None => return true,
};
// Fast-path: nodes whose filename clearly marks them as m2m codecs
// (e.g. /dev/video-enc0, /dev/video-dec0 on Rockchip). These never
// answer VIDIOC_QUERYCAP as capture devices.
let name_lower = name.to_ascii_lowercase();
let filename_skip = ["-enc", "-dec", "-codec", "-m2m", "-vepu", "-vdpu"];
if filename_skip.iter().any(|hint| name_lower.contains(hint)) {
return false;
}
let sysfs_base = Path::new("/sys/class/video4linux").join(name);
// Orphan /dev/videoN nodes (no matching sysfs entry) can appear when the
// kernel driver that created them has been unloaded but the device nodes
// were never cleaned up. Opening them returns ENODEV; skip the probe.
if !sysfs_base.exists() {
debug!(
"Skipping {:?}: no matching /sys/class/video4linux entry",
path
);
return false;
}
let sysfs_name = read_sysfs_string(&sysfs_base.join("name"))
.unwrap_or_default()
.to_lowercase();
@@ -746,19 +1134,57 @@ fn sysfs_maybe_capture(path: &Path) -> bool {
"macrosilicon",
"tc358743",
"grabber",
"rkcif",
"rk_hdmirx",
];
if capture_hints.iter().any(|hint| sysfs_name.contains(hint)) {
maybe_capture = true;
}
if let Some(driver) = driver {
if driver.contains("uvcvideo") || driver.contains("tc358743") {
if let Some(driver) = &driver {
if driver.contains("uvcvideo")
|| driver.contains("tc358743")
|| driver.contains("rkcif")
|| driver.contains("rk_hdmirx")
{
maybe_capture = true;
}
}
// Skip known non-capture drivers (RK video codecs, Hantro VPU, ISP/VPE
// pipelines, MIPI ISP statistics / params nodes). These would otherwise
// succeed QUERYCAP but expose only VIDEO_M2M / STATS / PARAMS and get
// filtered later — skipping here saves an open() + ioctl() per node.
let driver_skip = [
"rkvenc",
"rkvdec",
"vepu",
"vdpu",
"hantro",
"mpp_",
"rockchip-vpu",
];
if let Some(driver) = &driver {
if driver_skip.iter().any(|hint| driver.contains(hint)) {
return false;
}
}
let skip_hints = [
"codec", "decoder", "encoder", "isp", "mem2mem", "m2m", "vbi", "radio", "metadata",
"codec",
"decoder",
"encoder",
"isp",
"mem2mem",
"m2m",
"vbi",
"radio",
"metadata",
"output",
// rkisp sub-nodes that are not video capture queues
"rkisp-statistics",
"rkisp-input-params",
"rkisp_rawrd",
"rkisp_rawwr",
];
if skip_hints.iter().any(|hint| sysfs_name.contains(hint)) && !maybe_capture {
return false;
@@ -783,6 +1209,18 @@ fn extract_uevent_value(content: &str, key: &str) -> Option<String> {
None
}
/// Parse the `bridge_kind` string serialised into `VideoDeviceInfo` back
/// into the strongly-typed enum used by [`csi_bridge`].
pub(crate) fn parse_bridge_kind(kind: Option<&str>) -> Option<csi_bridge::CsiBridgeKind> {
Some(match kind? {
"rk628" => csi_bridge::CsiBridgeKind::Rk628,
"rkhdmirx" => csi_bridge::CsiBridgeKind::RkHdmirx,
"tc358743" => csi_bridge::CsiBridgeKind::Tc358743,
"unknown" => csi_bridge::CsiBridgeKind::Unknown,
_ => return None,
})
}
fn dv_timings_fps(bt: &v4l2_bt_timings) -> Option<f64> {
let total_width = bt.width + bt.hfrontporch + bt.hsync + bt.hbackporch;
let total_height = if bt.interlaced != 0 {
@@ -813,6 +1251,24 @@ fn normalize_fps_list(fps_list: &mut Vec<f64>) {
fps_list.dedup_by(|a, b| (*a - *b).abs() < 0.01);
}
/// Replace every `ResolutionInfo::fps` in `formats` with the single HDMI
/// source frame-rate. Used for CSI/HDMI bridge devices (rkcif, rk_hdmirx)
/// whose `VIDIOC_ENUM_FRAMEINTERVALS` returns meaningless StepWise values
/// — the only trustworthy fps comes from the bridge DV-timings on the
/// paired subdev. Silently no-op when `fps` normalises to empty.
fn override_resolution_fps(formats: &mut [FormatInfo], fps: f64) {
let mut normalized = vec![fps];
normalize_fps_list(&mut normalized);
if normalized.is_empty() {
return;
}
for fi in formats.iter_mut() {
for res in fi.resolutions.iter_mut() {
res.fps = normalized.clone();
}
}
}
/// Find the best video device for KVM use
pub fn find_best_device() -> Result<VideoDeviceInfo> {
let devices = enumerate_devices()?;
@@ -827,6 +1283,35 @@ pub fn find_best_device() -> Result<VideoDeviceInfo> {
mod tests {
use super::*;
fn test_device(
path: &str,
name: &str,
driver: &str,
bus_info: &str,
card: &str,
is_capture_card: bool,
priority: u32,
) -> VideoDeviceInfo {
VideoDeviceInfo {
path: PathBuf::from(path),
name: name.to_string(),
driver: driver.to_string(),
bus_info: bus_info.to_string(),
card: card.to_string(),
formats: Vec::new(),
capabilities: DeviceCapabilities {
video_capture: true,
streaming: true,
..Default::default()
},
is_capture_card,
priority,
has_signal: true,
subdev_path: None,
bridge_kind: None,
}
}
#[test]
fn test_pixel_format_conversion() {
let format = PixelFormat::Mjpeg;
@@ -842,4 +1327,70 @@ mod tests {
assert_eq!(res.height, 1080);
assert!(res.is_valid());
}
#[test]
fn recovery_selection_prefers_original_path() {
let original = test_device(
"/dev/video0",
"USB Capture",
"uvcvideo",
"usb-1",
"USB Capture",
true,
100,
);
let other = test_device(
"/dev/video2",
"USB Capture",
"uvcvideo",
"usb-1",
"USB Capture",
true,
200,
);
let hint = VideoDeviceRecoveryHint::from(&original);
let selected = select_recovery_device(&[other, original.clone()], &hint).unwrap();
assert_eq!(selected.path, original.path);
}
#[test]
fn recovery_selection_matches_bus_info_after_path_change() {
let original = test_device(
"/dev/video0",
"USB Capture",
"uvcvideo",
"usb-1",
"USB Capture",
true,
100,
);
let recovered = test_device(
"/dev/video3",
"USB Capture",
"uvcvideo",
"usb-1",
"USB Capture",
true,
100,
);
let hint = VideoDeviceRecoveryHint::from(&original);
let selected = select_recovery_device(&[recovered.clone()], &hint).unwrap();
assert_eq!(selected.path, recovered.path);
}
#[test]
fn recovery_selection_falls_back_to_capture_priority() {
let hint = VideoDeviceRecoveryHint {
path: PathBuf::from("/dev/video9"),
name: "Gone".to_string(),
driver: "gone".to_string(),
bus_info: String::new(),
card: "Gone".to_string(),
is_capture_card: true,
};
let lower = test_device("/dev/video1", "A", "uvcvideo", "usb-a", "A", true, 10);
let higher = test_device("/dev/video2", "B", "uvcvideo", "usb-b", "B", true, 20);
let selected = select_recovery_device(&[higher.clone(), lower], &hint).unwrap();
assert_eq!(selected.path, higher.path);
}
}

View File

@@ -152,6 +152,41 @@ impl JpegEncoder {
self.encode_i420_to_jpeg(sequence)
}
/// YVYU → swap chroma to YUYV in scratch, then same as [`Self::encode_yuyv`].
pub fn encode_yvyu(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
let height = self.config.resolution.height as usize;
let expected_size = width * height * 2;
if data.len() < expected_size {
return Err(AppError::VideoError(format!(
"YVYU data too small: {} < {}",
data.len(),
expected_size
)));
}
// Reuse bgra_buffer as scratch for the swapped YUYV data.
if self.bgra_buffer.len() < expected_size {
self.bgra_buffer.resize(expected_size, 0);
}
let dst = &mut self.bgra_buffer[..expected_size];
let src = &data[..expected_size];
// Swap bytes [1] and [3] in every 4-byte macropixel: Y0 V0 Y1 U0 → Y0 U0 Y1 V0
for (chunk_dst, chunk_src) in dst.chunks_exact_mut(4).zip(src.chunks_exact(4)) {
chunk_dst[0] = chunk_src[0]; // Y0
chunk_dst[1] = chunk_src[3]; // U0
chunk_dst[2] = chunk_src[2]; // Y1
chunk_dst[3] = chunk_src[1]; // V0
}
libyuv::yuy2_to_i420(dst, &mut self.i420_buffer, width as i32, height as i32)
.map_err(|e| AppError::VideoError(format!("libyuv YVYU→I420 failed: {}", e)))?;
self.encode_i420_to_jpeg(sequence)
}
/// Encode NV12 frame to JPEG
pub fn encode_nv12(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
let width = self.config.resolution.width as usize;
@@ -323,7 +358,8 @@ impl crate::video::encoder::traits::Encoder for JpegEncoder {
fn encode(&mut self, data: &[u8], sequence: u64) -> Result<EncodedFrame> {
match self.config.input_format {
PixelFormat::Yuyv | PixelFormat::Yvyu => self.encode_yuyv(data, sequence),
PixelFormat::Yuyv => self.encode_yuyv(data, sequence),
PixelFormat::Yvyu => self.encode_yvyu(data, sequence),
PixelFormat::Nv12 => self.encode_nv12(data, sequence),
PixelFormat::Nv16 => self.encode_nv16(data, sequence),
PixelFormat::Nv24 => self.encode_nv24(data, sequence),

Some files were not shown because too many files have changed in this diff Show More