mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2026-01-29 00:51:53 +08:00
feat: merge upstream master - version 4.94
Merge upstream PiKVM master branch updates: - Bump version from 4.93 to 4.94 - HID: improved jiggler pattern for better compatibility - Streamer: major refactoring for improved performance and maintainability - Prometheus: tidying GPIO channel name formatting - Web: added __gpio-label class for custom styling - HID: customizable /api/hid/print delay configuration - ATX: independent power/reset regions for better control - OLED: added --fill option for display testing - Web: improved keyboard handling in modal dialogs - Web: enhanced login error messages - Switch: added heartbeat functionality - Web: mouse touch code simplification and refactoring - Configs: use systemd-networkd-wait-online --any by default - PKGBUILD: use cp -r to install systemd units properly - Various bug fixes and performance improvements
This commit is contained in:
@@ -65,6 +65,7 @@ from ..validators.basic import valid_string_list
|
||||
|
||||
from ..validators.auth import valid_user
|
||||
from ..validators.auth import valid_users_list
|
||||
from ..validators.auth import valid_expire
|
||||
|
||||
from ..validators.os import valid_abs_path
|
||||
from ..validators.os import valid_abs_file
|
||||
@@ -73,6 +74,7 @@ from ..validators.os import valid_unix_mode
|
||||
from ..validators.os import valid_options
|
||||
from ..validators.os import valid_command
|
||||
|
||||
from ..validators.net import valid_ip
|
||||
from ..validators.net import valid_ip_or_host
|
||||
from ..validators.net import valid_net
|
||||
from ..validators.net import valid_port
|
||||
@@ -190,6 +192,14 @@ def _init_config(config_path: str, override_options: list[str], **load_flags: bo
|
||||
|
||||
|
||||
def _patch_raw(raw_config: dict) -> None: # pylint: disable=too-many-branches
|
||||
for (sub, cmd) in [("iface", "ip_cmd"), ("firewall", "iptables_cmd")]:
|
||||
if isinstance(raw_config.get("otgnet"), dict):
|
||||
if isinstance(raw_config["otgnet"].get(sub), dict):
|
||||
if raw_config["otgnet"][sub].get(cmd):
|
||||
raw_config["otgnet"].setdefault("commands", {})
|
||||
raw_config["otgnet"]["commands"][cmd] = raw_config["otgnet"][sub][cmd]
|
||||
del raw_config["otgnet"][sub][cmd]
|
||||
|
||||
if isinstance(raw_config.get("otg"), dict):
|
||||
for (old, new) in [
|
||||
("msd", "msd"),
|
||||
@@ -357,6 +367,12 @@ def _get_config_scheme() -> dict:
|
||||
|
||||
"auth": {
|
||||
"enabled": Option(True, type=valid_bool),
|
||||
"expire": Option(0, type=valid_expire),
|
||||
|
||||
"usc": {
|
||||
"users": Option([], type=valid_users_list), # PiKVM username has a same regex as a UNIX username
|
||||
"groups": Option(["kvmd-selfauth"], type=valid_users_list), # groupname has a same regex as a username
|
||||
},
|
||||
|
||||
"internal": {
|
||||
"type": Option("htpasswd"),
|
||||
@@ -457,7 +473,7 @@ def _get_config_scheme() -> dict:
|
||||
|
||||
"unix": Option("/run/kvmd/ustreamer.sock", type=valid_abs_path, unpack_as="unix_path"),
|
||||
"timeout": Option(2.0, type=valid_float_f01),
|
||||
"snapshot_timeout": Option(1.0, type=valid_float_f01), # error_delay * 3 + 1
|
||||
"snapshot_timeout": Option(5.0, type=valid_float_f01), # error_delay * 3 + 1
|
||||
|
||||
"process_name_prefix": Option("kvmd/streamer"),
|
||||
|
||||
@@ -504,8 +520,9 @@ def _get_config_scheme() -> dict:
|
||||
},
|
||||
|
||||
"switch": {
|
||||
"device": Option("/dev/kvmd-switch", type=valid_abs_path, unpack_as="device_path"),
|
||||
"default_edid": Option("/etc/kvmd/switch-edid.hex", type=valid_abs_path, unpack_as="default_edid_path"),
|
||||
"device": Option("/dev/kvmd-switch", type=valid_abs_path, unpack_as="device_path"),
|
||||
"default_edid": Option("/etc/kvmd/switch-edid.hex", type=valid_abs_path, unpack_as="default_edid_path"),
|
||||
"ignore_hpd_on_top": Option(False, type=valid_bool),
|
||||
},
|
||||
},
|
||||
|
||||
@@ -558,15 +575,15 @@ def _get_config_scheme() -> dict:
|
||||
"vendor_id": Option(0x1D6B, type=valid_otg_id), # Linux Foundation
|
||||
"product_id": Option(0x0104, type=valid_otg_id), # Multifunction Composite Gadget
|
||||
"manufacturer": Option("PiKVM", type=valid_stripped_string),
|
||||
"product": Option("Composite KVM Device", type=valid_stripped_string),
|
||||
"product": Option("PiKVM Composite Device", type=valid_stripped_string),
|
||||
"serial": Option("CAFEBABE", type=valid_stripped_string, if_none=None),
|
||||
"config": Option("", type=valid_stripped_string),
|
||||
"device_version": Option(-1, type=functools.partial(valid_number, min=-1, max=0xFFFF)),
|
||||
"usb_version": Option(0x0200, type=valid_otg_id),
|
||||
"max_power": Option(250, type=functools.partial(valid_number, min=50, max=500)),
|
||||
"remote_wakeup": Option(True, type=valid_bool),
|
||||
|
||||
"gadget": Option("kvmd", type=valid_otg_gadget),
|
||||
"config": Option("PiKVM device", type=valid_stripped_string_not_empty),
|
||||
"udc": Option("", type=valid_stripped_string),
|
||||
"endpoints": Option(9, type=valid_int_f0),
|
||||
"init_delay": Option(3.0, type=valid_float_f01),
|
||||
@@ -657,8 +674,7 @@ def _get_config_scheme() -> dict:
|
||||
|
||||
"otgnet": {
|
||||
"iface": {
|
||||
"net": Option("172.30.30.0/24", type=functools.partial(valid_net, v6=False)),
|
||||
"ip_cmd": Option(["/usr/bin/ip"], type=valid_command),
|
||||
"net": Option("172.30.30.0/24", type=functools.partial(valid_net, v6=False)),
|
||||
},
|
||||
|
||||
"firewall": {
|
||||
@@ -666,10 +682,13 @@ def _get_config_scheme() -> dict:
|
||||
"allow_tcp": Option([], type=valid_ports_list),
|
||||
"allow_udp": Option([67], type=valid_ports_list),
|
||||
"forward_iface": Option("", type=valid_stripped_string),
|
||||
"iptables_cmd": Option(["/usr/sbin/iptables", "--wait=5"], type=valid_command),
|
||||
},
|
||||
|
||||
"commands": {
|
||||
"ip_cmd": Option(["/usr/bin/ip"], type=valid_command),
|
||||
"iptables_cmd": Option(["/usr/sbin/iptables", "--wait=5"], type=valid_command),
|
||||
"sysctl_cmd": Option(["/usr/sbin/sysctl"], type=valid_command),
|
||||
|
||||
"pre_start_cmd": Option(["/bin/true", "pre-start"], type=valid_command),
|
||||
"pre_start_cmd_remove": Option([], type=valid_options),
|
||||
"pre_start_cmd_append": Option([], type=valid_options),
|
||||
@@ -734,7 +753,7 @@ def _get_config_scheme() -> dict:
|
||||
"desired_fps": Option(30, type=valid_stream_fps),
|
||||
"mouse_output": Option("usb", type=valid_hid_mouse_output),
|
||||
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
|
||||
"allow_cut_after": Option(3.0, type=valid_float_f0),
|
||||
"scroll_rate": Option(4, type=functools.partial(valid_number, min=1, max=30)),
|
||||
|
||||
"server": {
|
||||
"host": Option("", type=valid_ip_or_host, if_empty=""),
|
||||
@@ -786,8 +805,8 @@ def _get_config_scheme() -> dict:
|
||||
|
||||
"auth": {
|
||||
"vncauth": {
|
||||
"enabled": Option(False, type=valid_bool),
|
||||
"file": Option("/etc/kvmd/vncpasswd", type=valid_abs_file, unpack_as="path"),
|
||||
"enabled": Option(False, type=valid_bool, unpack_as="vncpass_enabled"),
|
||||
"file": Option("/etc/kvmd/vncpasswd", type=valid_abs_file, unpack_as="vncpass_path"),
|
||||
},
|
||||
"vencrypt": {
|
||||
"enabled": Option(True, type=valid_bool, unpack_as="vencrypt_enabled"),
|
||||
@@ -795,13 +814,24 @@ def _get_config_scheme() -> dict:
|
||||
},
|
||||
},
|
||||
|
||||
"localhid": {
|
||||
"kvmd": {
|
||||
"unix": Option("/run/kvmd/kvmd.sock", type=valid_abs_path, unpack_as="unix_path"),
|
||||
"timeout": Option(5.0, type=valid_float_f01),
|
||||
},
|
||||
},
|
||||
|
||||
"nginx": {
|
||||
"http": {
|
||||
"port": Option(80, type=valid_port),
|
||||
"ipv4": Option("0.0.0.0", type=functools.partial(valid_ip, v6=False)),
|
||||
"ipv6": Option("::", type=functools.partial(valid_ip, v4=False)),
|
||||
"port": Option(80, type=valid_port),
|
||||
},
|
||||
"https": {
|
||||
"enabled": Option(True, type=valid_bool),
|
||||
"port": Option(443, type=valid_port),
|
||||
"enabled": Option(True, type=valid_bool),
|
||||
"ipv4": Option("0.0.0.0", type=functools.partial(valid_ip, v6=False)),
|
||||
"ipv6": Option("::", type=functools.partial(valid_ip, v4=False)),
|
||||
"port": Option(443, type=valid_port),
|
||||
},
|
||||
},
|
||||
|
||||
|
||||
@@ -61,6 +61,33 @@ def _print_edid(edid: Edid) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _find_out2_edid_path() -> str:
|
||||
card = os.path.basename(os.readlink("/dev/dri/by-path/platform-gpu-card"))
|
||||
path = f"/sys/devices/platform/gpu/drm/{card}/{card}-HDMI-A-2"
|
||||
with open(os.path.join(path, "status")) as file:
|
||||
if file.read().startswith("d"):
|
||||
raise SystemExit("No display found")
|
||||
return os.path.join(path, "edid")
|
||||
|
||||
|
||||
def _adopt_out2_ids(dest: Edid) -> None:
|
||||
src = Edid.from_file(_find_out2_edid_path())
|
||||
dest.set_monitor_name(src.get_monitor_name())
|
||||
try:
|
||||
dest.get_monitor_serial()
|
||||
except EdidNoBlockError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
ser = src.get_monitor_serial()
|
||||
except EdidNoBlockError:
|
||||
ser = "{:08X}".format(src.get_serial())
|
||||
dest.set_monitor_serial(ser)
|
||||
dest.set_mfc_id(src.get_mfc_id())
|
||||
dest.set_product_id(src.get_product_id())
|
||||
dest.set_serial(src.get_serial())
|
||||
|
||||
|
||||
# =====
|
||||
def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-branches,too-many-statements
|
||||
# (parent_parser, argv, _) = init(
|
||||
@@ -89,6 +116,10 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
|
||||
help="Import the specified bin/hex EDID to the [--edid] file as a hex text", metavar="<file>")
|
||||
parser.add_argument("--import-preset", choices=presets,
|
||||
help="Restore default EDID or choose the preset", metavar=f"{{ {' | '.join(presets)} }}",)
|
||||
parser.add_argument("--import-display-ids", action="store_true",
|
||||
help="On PiKVM V4, import and adopt IDs from a physical display connected to the OUT2 port")
|
||||
parser.add_argument("--import-display", action="store_true",
|
||||
help="On PiKVM V4, import full EDID from a physical display connected to the OUT2 port")
|
||||
parser.add_argument("--set-audio", type=valid_bool,
|
||||
help="Enable or disable audio", metavar="<yes|no>")
|
||||
parser.add_argument("--set-mfc-id",
|
||||
@@ -120,6 +151,9 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
|
||||
imp = f"_{imp}"
|
||||
options.imp = os.path.join(options.presets_path, f"{imp}.hex")
|
||||
|
||||
if options.import_display:
|
||||
options.imp = _find_out2_edid_path()
|
||||
|
||||
orig_edid_path = options.edid_path
|
||||
if options.imp:
|
||||
options.export_hex = options.edid_path
|
||||
@@ -128,6 +162,10 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
|
||||
edid = Edid.from_file(options.edid_path)
|
||||
changed = False
|
||||
|
||||
if options.import_display_ids:
|
||||
_adopt_out2_ids(edid)
|
||||
changed = True
|
||||
|
||||
for cmd in dir(Edid):
|
||||
if cmd.startswith("set_"):
|
||||
value = getattr(options, cmd)
|
||||
|
||||
@@ -30,27 +30,27 @@ import argparse
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import passlib.apache
|
||||
|
||||
from ...yamlconf import Section
|
||||
|
||||
from ...validators import ValidatorError
|
||||
from ...validators.auth import valid_user
|
||||
from ...validators.auth import valid_passwd
|
||||
|
||||
from ...crypto import KvmdHtpasswdFile
|
||||
|
||||
from .. import init
|
||||
|
||||
|
||||
# =====
|
||||
def _get_htpasswd_path(config: Section) -> str:
|
||||
if config.kvmd.auth.internal.type != "htpasswd":
|
||||
raise SystemExit(f"Error: KVMD internal auth not using 'htpasswd'"
|
||||
raise SystemExit(f"Error: KVMD internal auth does not use 'htpasswd'"
|
||||
f" (now configured {config.kvmd.auth.internal.type!r})")
|
||||
return config.kvmd.auth.internal.file
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _get_htpasswd_for_write(config: Section) -> Generator[passlib.apache.HtpasswdFile, None, None]:
|
||||
def _get_htpasswd_for_write(config: Section) -> Generator[KvmdHtpasswdFile, None, None]:
|
||||
path = _get_htpasswd_path(config)
|
||||
(tmp_fd, tmp_path) = tempfile.mkstemp(
|
||||
prefix=f".{os.path.basename(path)}.",
|
||||
@@ -65,7 +65,7 @@ def _get_htpasswd_for_write(config: Section) -> Generator[passlib.apache.Htpassw
|
||||
os.fchmod(tmp_fd, st.st_mode)
|
||||
finally:
|
||||
os.close(tmp_fd)
|
||||
htpasswd = passlib.apache.HtpasswdFile(tmp_path)
|
||||
htpasswd = KvmdHtpasswdFile(tmp_path)
|
||||
yield htpasswd
|
||||
htpasswd.save()
|
||||
os.rename(tmp_path, path)
|
||||
@@ -96,28 +96,55 @@ def _print_invalidate_tip(prepend_nl: bool) -> None:
|
||||
|
||||
# ====
|
||||
def _cmd_list(config: Section, _: argparse.Namespace) -> None:
|
||||
for user in sorted(passlib.apache.HtpasswdFile(_get_htpasswd_path(config)).users()):
|
||||
for user in sorted(KvmdHtpasswdFile(_get_htpasswd_path(config)).users()):
|
||||
print(user)
|
||||
|
||||
|
||||
def _cmd_set(config: Section, options: argparse.Namespace) -> None:
|
||||
def _change_user(config: Section, options: argparse.Namespace, create: bool) -> None:
|
||||
with _get_htpasswd_for_write(config) as htpasswd:
|
||||
assert options.user == options.user.strip()
|
||||
assert options.user
|
||||
|
||||
has_user = (options.user in htpasswd.users())
|
||||
if create:
|
||||
if has_user:
|
||||
raise SystemExit(f"The user {options.user!r} is already exists")
|
||||
else:
|
||||
if not has_user:
|
||||
raise SystemExit(f"The user {options.user!r} is not exist")
|
||||
|
||||
if options.read_stdin:
|
||||
passwd = valid_passwd(input())
|
||||
else:
|
||||
passwd = valid_passwd(getpass.getpass("Password: ", stream=sys.stderr))
|
||||
if valid_passwd(getpass.getpass("Repeat: ", stream=sys.stderr)) != passwd:
|
||||
raise SystemExit("Sorry, passwords do not match")
|
||||
|
||||
htpasswd.set_password(options.user, passwd)
|
||||
|
||||
if has_user and not options.quiet:
|
||||
_print_invalidate_tip(True)
|
||||
|
||||
|
||||
def _cmd_add(config: Section, options: argparse.Namespace) -> None:
|
||||
_change_user(config, options, create=True)
|
||||
|
||||
|
||||
def _cmd_set(config: Section, options: argparse.Namespace) -> None:
|
||||
_change_user(config, options, create=False)
|
||||
|
||||
|
||||
def _cmd_delete(config: Section, options: argparse.Namespace) -> None:
|
||||
with _get_htpasswd_for_write(config) as htpasswd:
|
||||
assert options.user == options.user.strip()
|
||||
assert options.user
|
||||
|
||||
has_user = (options.user in htpasswd.users())
|
||||
if not has_user:
|
||||
raise SystemExit(f"The user {options.user!r} is not exist")
|
||||
|
||||
htpasswd.delete(options.user)
|
||||
|
||||
if has_user and not options.quiet:
|
||||
_print_invalidate_tip(False)
|
||||
|
||||
@@ -138,19 +165,25 @@ def main(argv: (list[str] | None)=None) -> None:
|
||||
parser.set_defaults(cmd=(lambda *_: parser.print_help()))
|
||||
subparsers = parser.add_subparsers()
|
||||
|
||||
cmd_list_parser = subparsers.add_parser("list", help="List users")
|
||||
cmd_list_parser.set_defaults(cmd=_cmd_list)
|
||||
sub = subparsers.add_parser("list", help="List users")
|
||||
sub.set_defaults(cmd=_cmd_list)
|
||||
|
||||
cmd_set_parser = subparsers.add_parser("set", help="Create user or change password")
|
||||
cmd_set_parser.add_argument("user", type=valid_user)
|
||||
cmd_set_parser.add_argument("-i", "--read-stdin", action="store_true", help="Read password from stdin")
|
||||
cmd_set_parser.add_argument("-q", "--quiet", action="store_true", help="Don't show invalidation note")
|
||||
cmd_set_parser.set_defaults(cmd=_cmd_set)
|
||||
sub = subparsers.add_parser("add", help="Add user")
|
||||
sub.add_argument("user", type=valid_user)
|
||||
sub.add_argument("-i", "--read-stdin", action="store_true", help="Read password from stdin")
|
||||
sub.add_argument("-q", "--quiet", action="store_true", help="Don't show invalidation note")
|
||||
sub.set_defaults(cmd=_cmd_add)
|
||||
|
||||
cmd_delete_parser = subparsers.add_parser("del", help="Delete user")
|
||||
cmd_delete_parser.add_argument("user", type=valid_user)
|
||||
cmd_delete_parser.add_argument("-q", "--quiet", action="store_true", help="Don't show invalidation note")
|
||||
cmd_delete_parser.set_defaults(cmd=_cmd_delete)
|
||||
sub = subparsers.add_parser("set", help="Change user's password")
|
||||
sub.add_argument("user", type=valid_user)
|
||||
sub.add_argument("-i", "--read-stdin", action="store_true", help="Read password from stdin")
|
||||
sub.add_argument("-q", "--quiet", action="store_true", help="Don't show invalidation note")
|
||||
sub.set_defaults(cmd=_cmd_set)
|
||||
|
||||
sub = subparsers.add_parser("del", help="Delete user")
|
||||
sub.add_argument("user", type=valid_user)
|
||||
sub.add_argument("-q", "--quiet", action="store_true", help="Don't show invalidation note")
|
||||
sub.set_defaults(cmd=_cmd_delete)
|
||||
|
||||
options = parser.parse_args(argv[1:])
|
||||
try:
|
||||
|
||||
@@ -20,7 +20,13 @@
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
import functools
|
||||
import time
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
from ... import tools
|
||||
|
||||
|
||||
# =====
|
||||
@@ -29,60 +35,42 @@ class IpmiPasswdError(Exception):
|
||||
super().__init__(f"Syntax error at {path}:{lineno}: {msg}")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class IpmiUserCredentials:
|
||||
ipmi_user: str
|
||||
ipmi_passwd: str
|
||||
kvmd_user: str
|
||||
kvmd_passwd: str
|
||||
|
||||
|
||||
class IpmiAuthManager:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.__path = path
|
||||
with open(path) as file:
|
||||
self.__credentials = self.__parse_passwd_file(file.read().split("\n"))
|
||||
self.__lock = threading.Lock()
|
||||
|
||||
def __contains__(self, ipmi_user: str) -> bool:
|
||||
return (ipmi_user in self.__credentials)
|
||||
def get(self, user: str) -> (str | None):
|
||||
creds = self.__get_credentials(int(time.time()))
|
||||
return creds.get(user)
|
||||
|
||||
def __getitem__(self, ipmi_user: str) -> str:
|
||||
return self.__credentials[ipmi_user].ipmi_passwd
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def __get_credentials(self, ts: int) -> dict[str, str]:
|
||||
_ = ts
|
||||
with self.__lock:
|
||||
try:
|
||||
return self.__read_credentials()
|
||||
except Exception as ex:
|
||||
get_logger().error("%s", tools.efmt(ex))
|
||||
return {}
|
||||
|
||||
def get_credentials(self, ipmi_user: str) -> IpmiUserCredentials:
|
||||
return self.__credentials[ipmi_user]
|
||||
def __read_credentials(self) -> dict[str, str]:
|
||||
with open(self.__path) as file:
|
||||
creds: dict[str, str] = {}
|
||||
for (lineno, line) in tools.passwds_splitted(file.read()):
|
||||
if " -> " in line: # Compatibility with old ipmipasswd file format
|
||||
line = line.split(" -> ", 1)[0]
|
||||
|
||||
def __parse_passwd_file(self, lines: list[str]) -> dict[str, IpmiUserCredentials]:
|
||||
credentials: dict[str, IpmiUserCredentials] = {}
|
||||
for (lineno, line) in enumerate(lines):
|
||||
if len(line.strip()) == 0 or line.lstrip().startswith("#"):
|
||||
continue
|
||||
if ":" not in line:
|
||||
raise IpmiPasswdError(self.__path, lineno, "Missing ':' operator")
|
||||
|
||||
if " -> " not in line:
|
||||
raise IpmiPasswdError(self.__path, lineno, "Missing ' -> ' operator")
|
||||
(user, passwd) = line.split(":", 1)
|
||||
user = user.strip()
|
||||
if len(user) == 0:
|
||||
raise IpmiPasswdError(self.__path, lineno, "Empty IPMI user")
|
||||
|
||||
(left, right) = map(str.lstrip, line.split(" -> ", 1))
|
||||
for (name, pair) in [("left", left), ("right", right)]:
|
||||
if ":" not in pair:
|
||||
raise IpmiPasswdError(self.__path, lineno, f"Missing ':' operator in {name} credentials")
|
||||
if user in creds:
|
||||
raise IpmiPasswdError(self.__path, lineno, f"Found duplicating user {user!r}")
|
||||
|
||||
(ipmi_user, ipmi_passwd) = left.split(":")
|
||||
ipmi_user = ipmi_user.strip()
|
||||
if len(ipmi_user) == 0:
|
||||
raise IpmiPasswdError(self.__path, lineno, "Empty IPMI user (left)")
|
||||
|
||||
(kvmd_user, kvmd_passwd) = right.split(":")
|
||||
kvmd_user = kvmd_user.strip()
|
||||
if len(kvmd_user) == 0:
|
||||
raise IpmiPasswdError(self.__path, lineno, "Empty KVMD user (left)")
|
||||
|
||||
if ipmi_user in credentials:
|
||||
raise IpmiPasswdError(self.__path, lineno, f"Found duplicating user {ipmi_user!r} (left)")
|
||||
|
||||
credentials[ipmi_user] = IpmiUserCredentials(
|
||||
ipmi_user=ipmi_user,
|
||||
ipmi_passwd=ipmi_passwd,
|
||||
kvmd_user=kvmd_user,
|
||||
kvmd_passwd=kvmd_passwd,
|
||||
)
|
||||
return credentials
|
||||
creds[user] = passwd
|
||||
return creds
|
||||
|
||||
@@ -70,7 +70,6 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
|
||||
|
||||
super().__init__(authdata=auth_manager, address=host, port=port)
|
||||
|
||||
self.__auth_manager = auth_manager
|
||||
self.__kvmd = kvmd
|
||||
|
||||
self.__host = host
|
||||
@@ -165,11 +164,10 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
|
||||
def __make_request(self, session: IpmiServerSession, name: str, func_path: str, **kwargs): # type: ignore
|
||||
async def runner(): # type: ignore
|
||||
logger = get_logger(0)
|
||||
credentials = self.__auth_manager.get_credentials(session.username.decode())
|
||||
logger.info("[%s]: Performing request %s from user %r (IPMI) as %r (KVMD)",
|
||||
session.sockaddr[0], name, credentials.ipmi_user, credentials.kvmd_user)
|
||||
logger.info("[%s]: Performing request %s from IPMI user %r ...",
|
||||
session.sockaddr[0], name, session.username.decode())
|
||||
try:
|
||||
async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session:
|
||||
async with self.__kvmd.make_session() as kvmd_session:
|
||||
func = functools.reduce(getattr, func_path.split("."), kvmd_session)
|
||||
return (await func(**kwargs))
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
|
||||
|
||||
@@ -21,6 +21,7 @@ class _Netcfg:
|
||||
nat_type: StunNatType = dataclasses.field(default=StunNatType.ERROR)
|
||||
src_ip: str = dataclasses.field(default="")
|
||||
ext_ip: str = dataclasses.field(default="")
|
||||
stun_host: str = dataclasses.field(default="")
|
||||
stun_ip: str = dataclasses.field(default="")
|
||||
stun_port: int = dataclasses.field(default=0)
|
||||
|
||||
@@ -172,7 +173,10 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
|
||||
part.format(**placeholders)
|
||||
for part in cmd
|
||||
]
|
||||
self.__janus_proc = await aioproc.run_process(cmd)
|
||||
self.__janus_proc = await aioproc.run_process(
|
||||
cmd=cmd,
|
||||
env={"JANUS_USTREAMER_WEB_ICE_URL": f"stun:{netcfg.stun_host}:{netcfg.stun_port}"},
|
||||
)
|
||||
get_logger(0).info("Started Janus pid=%d: %s", self.__janus_proc.pid, tools.cmdfmt(cmd))
|
||||
|
||||
async def __kill_janus_proc(self) -> None:
|
||||
|
||||
@@ -30,6 +30,7 @@ class StunInfo:
|
||||
nat_type: StunNatType
|
||||
src_ip: str
|
||||
ext_ip: str
|
||||
stun_host: str
|
||||
stun_ip: str
|
||||
stun_port: int
|
||||
|
||||
@@ -102,6 +103,7 @@ class Stun:
|
||||
nat_type=nat_type,
|
||||
src_ip=src_ip,
|
||||
ext_ip=ext_ip,
|
||||
stun_host=self.__host,
|
||||
stun_ip=self.__stun_ip,
|
||||
stun_port=self.__port,
|
||||
)
|
||||
|
||||
@@ -76,14 +76,17 @@ def main(argv: (list[str] | None)=None) -> None:
|
||||
KvmdServer(
|
||||
auth_manager=AuthManager(
|
||||
enabled=config.auth.enabled,
|
||||
expire=config.auth.expire,
|
||||
usc_users=config.auth.usc.users,
|
||||
usc_groups=config.auth.usc.groups,
|
||||
unauth_paths=([] if config.prometheus.auth.enabled else ["/export/prometheus/metrics"]),
|
||||
|
||||
internal_type=config.auth.internal.type,
|
||||
internal_kwargs=config.auth.internal._unpack(ignore=["type", "force_users"]),
|
||||
force_internal_users=config.auth.internal.force_users,
|
||||
int_type=config.auth.internal.type,
|
||||
int_kwargs=config.auth.internal._unpack(ignore=["type", "force_users"]),
|
||||
force_int_users=config.auth.internal.force_users,
|
||||
|
||||
external_type=config.auth.external.type,
|
||||
external_kwargs=(config.auth.external._unpack(ignore=["type"]) if config.auth.external.type else {}),
|
||||
ext_type=config.auth.external.type,
|
||||
ext_kwargs=(config.auth.external._unpack(ignore=["type"]) if config.auth.external.type else {}),
|
||||
|
||||
totp_secret_path=config.auth.totp.secret.file,
|
||||
),
|
||||
|
||||
@@ -31,9 +31,11 @@ from ....htserver import HttpExposed
|
||||
from ....htserver import exposed_http
|
||||
from ....htserver import make_json_response
|
||||
from ....htserver import set_request_auth_info
|
||||
from ....htserver import get_request_unix_credentials
|
||||
|
||||
from ....validators.auth import valid_user
|
||||
from ....validators.auth import valid_passwd
|
||||
from ....validators.auth import valid_expire
|
||||
from ....validators.auth import valid_auth_token
|
||||
|
||||
from ..auth import AuthManager
|
||||
@@ -43,39 +45,64 @@ from ..auth import AuthManager
|
||||
_COOKIE_AUTH_TOKEN = "auth_token"
|
||||
|
||||
|
||||
async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, req: Request) -> None:
|
||||
if auth_manager.is_auth_required(exposed):
|
||||
user = req.headers.get("X-KVMD-User", "")
|
||||
async def _check_xhdr(auth_manager: AuthManager, _: HttpExposed, req: Request) -> bool:
|
||||
user = req.headers.get("X-KVMD-User", "")
|
||||
if user:
|
||||
user = valid_user(user)
|
||||
passwd = req.headers.get("X-KVMD-Passwd", "")
|
||||
set_request_auth_info(req, f"{user} (xhdr)")
|
||||
if (await auth_manager.authorize(user, valid_passwd(passwd))):
|
||||
return True
|
||||
raise ForbiddenError()
|
||||
return False
|
||||
|
||||
|
||||
async def _check_token(auth_manager: AuthManager, _: HttpExposed, req: Request) -> bool:
|
||||
token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
|
||||
if token:
|
||||
user = auth_manager.check(valid_auth_token(token))
|
||||
if user:
|
||||
user = valid_user(user)
|
||||
passwd = req.headers.get("X-KVMD-Passwd", "")
|
||||
set_request_auth_info(req, f"{user} (xhdr)")
|
||||
if not (await auth_manager.authorize(user, valid_passwd(passwd))):
|
||||
raise ForbiddenError()
|
||||
return
|
||||
|
||||
token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
|
||||
if token:
|
||||
user = auth_manager.check(valid_auth_token(token)) # type: ignore
|
||||
if not user:
|
||||
set_request_auth_info(req, "- (token)")
|
||||
raise ForbiddenError()
|
||||
set_request_auth_info(req, f"{user} (token)")
|
||||
return
|
||||
return True
|
||||
set_request_auth_info(req, "- (token)")
|
||||
raise ForbiddenError()
|
||||
return False
|
||||
|
||||
basic_auth = req.headers.get("Authorization", "")
|
||||
if basic_auth and basic_auth[:6].lower() == "basic ":
|
||||
try:
|
||||
(user, passwd) = base64.b64decode(basic_auth[6:]).decode("utf-8").split(":")
|
||||
except Exception:
|
||||
raise UnauthorizedError()
|
||||
user = valid_user(user)
|
||||
set_request_auth_info(req, f"{user} (basic)")
|
||||
if not (await auth_manager.authorize(user, valid_passwd(passwd))):
|
||||
raise ForbiddenError()
|
||||
return
|
||||
|
||||
async def _check_basic(auth_manager: AuthManager, _: HttpExposed, req: Request) -> bool:
|
||||
basic_auth = req.headers.get("Authorization", "")
|
||||
if basic_auth and basic_auth[:6].lower() == "basic ":
|
||||
try:
|
||||
(user, passwd) = base64.b64decode(basic_auth[6:]).decode("utf-8").split(":")
|
||||
except Exception:
|
||||
raise UnauthorizedError()
|
||||
user = valid_user(user)
|
||||
set_request_auth_info(req, f"{user} (basic)")
|
||||
if (await auth_manager.authorize(user, valid_passwd(passwd))):
|
||||
return True
|
||||
raise ForbiddenError()
|
||||
return False
|
||||
|
||||
|
||||
async def _check_usc(auth_manager: AuthManager, exposed: HttpExposed, req: Request) -> bool:
|
||||
if exposed.allow_usc:
|
||||
creds = get_request_unix_credentials(req)
|
||||
if creds is not None:
|
||||
user = auth_manager.check_unix_credentials(creds)
|
||||
if user:
|
||||
set_request_auth_info(req, f"{user}[{creds.uid}] (unix)")
|
||||
return True
|
||||
raise UnauthorizedError()
|
||||
return False
|
||||
|
||||
|
||||
async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, req: Request) -> None:
|
||||
if not auth_manager.is_auth_required(exposed):
|
||||
return
|
||||
for checker in [_check_xhdr, _check_token, _check_basic, _check_usc]:
|
||||
if (await checker(auth_manager, exposed, req)):
|
||||
return
|
||||
raise UnauthorizedError()
|
||||
|
||||
|
||||
class AuthApi:
|
||||
@@ -84,26 +111,28 @@ class AuthApi:
|
||||
|
||||
# =====
|
||||
|
||||
@exposed_http("POST", "/auth/login", auth_required=False)
|
||||
@exposed_http("POST", "/auth/login", auth_required=False, allow_usc=False)
|
||||
async def __login_handler(self, req: Request) -> Response:
|
||||
if self.__auth_manager.is_auth_enabled():
|
||||
credentials = await req.post()
|
||||
token = await self.__auth_manager.login(
|
||||
user=valid_user(credentials.get("user", "")),
|
||||
passwd=valid_passwd(credentials.get("passwd", "")),
|
||||
expire=valid_expire(credentials.get("expire", "0")),
|
||||
)
|
||||
if token:
|
||||
return make_json_response(set_cookies={_COOKIE_AUTH_TOKEN: token})
|
||||
raise ForbiddenError()
|
||||
return make_json_response()
|
||||
|
||||
@exposed_http("POST", "/auth/logout")
|
||||
@exposed_http("POST", "/auth/logout", allow_usc=False)
|
||||
async def __logout_handler(self, req: Request) -> Response:
|
||||
if self.__auth_manager.is_auth_enabled():
|
||||
token = valid_auth_token(req.cookies.get(_COOKIE_AUTH_TOKEN, ""))
|
||||
self.__auth_manager.logout(token)
|
||||
return make_json_response()
|
||||
|
||||
@exposed_http("GET", "/auth/check")
|
||||
# XXX: This handle is used for access control so it should NEVER allow access by socket credentials
|
||||
@exposed_http("GET", "/auth/check", allow_usc=False)
|
||||
async def __check_handler(self, _: Request) -> Response:
|
||||
return make_json_response()
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from typing import Any
|
||||
|
||||
@@ -57,7 +58,7 @@ class ExportApi:
|
||||
async def __get_prometheus_metrics(self) -> str:
|
||||
(atx_state, info_state, gpio_state) = await asyncio.gather(*[
|
||||
self.__atx.get_state(),
|
||||
self.__info_manager.get_state(["hw", "fan"]),
|
||||
self.__info_manager.get_state(["health", "fan"]),
|
||||
self.__user_gpio.get_state(),
|
||||
])
|
||||
rows: list[str] = []
|
||||
@@ -68,10 +69,11 @@ class ExportApi:
|
||||
for mode in sorted(UserGpioModes.ALL):
|
||||
for (channel, ch_state) in gpio_state["state"][f"{mode}s"].items(): # type: ignore
|
||||
if not channel.startswith("__"): # Hide special GPIOs
|
||||
channel = re.sub(r"[^\w]", "_", channel)
|
||||
for key in ["online", "state"]:
|
||||
self.__append_prometheus_rows(rows, ch_state["state"], f"pikvm_gpio_{mode}_{key}_{channel}")
|
||||
|
||||
self.__append_prometheus_rows(rows, info_state["hw"]["health"], "pikvm_hw") # type: ignore
|
||||
self.__append_prometheus_rows(rows, info_state["health"], "pikvm_hw") # type: ignore
|
||||
self.__append_prometheus_rows(rows, info_state["fan"], "pikvm_fan")
|
||||
|
||||
return "\n".join(rows)
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
import os
|
||||
import stat
|
||||
import functools
|
||||
import itertools
|
||||
import struct
|
||||
|
||||
from typing import Iterable
|
||||
@@ -31,8 +32,11 @@ from typing import Callable
|
||||
from aiohttp.web import Request
|
||||
from aiohttp.web import Response
|
||||
|
||||
from ....keyboard.mappings import WEB_TO_EVDEV
|
||||
from ....keyboard.keysym import build_symmap
|
||||
from ....keyboard.printer import text_to_web_keys
|
||||
from ....keyboard.printer import text_to_evdev_keys
|
||||
|
||||
from ....mouse import MOUSE_TO_EVDEV
|
||||
|
||||
from ....htserver import exposed_http
|
||||
from ....htserver import exposed_ws
|
||||
@@ -43,7 +47,9 @@ from ....plugins.hid import BaseHid
|
||||
|
||||
from ....validators import raise_error
|
||||
from ....validators.basic import valid_bool
|
||||
from ....validators.basic import valid_number
|
||||
from ....validators.basic import valid_int_f0
|
||||
from ....validators.basic import valid_string_list
|
||||
from ....validators.os import valid_printable_filename
|
||||
from ....validators.hid import valid_hid_keyboard_output
|
||||
from ....validators.hid import valid_hid_mouse_output
|
||||
@@ -97,6 +103,11 @@ class HidApi:
|
||||
await self.__hid.reset()
|
||||
return make_json_response()
|
||||
|
||||
@exposed_http("GET", "/hid/inactivity")
|
||||
async def __inactivity_handler(self, _: Request) -> Response:
|
||||
secs = self.__hid.get_inactivity_seconds()
|
||||
return make_json_response({"inactivity": secs})
|
||||
|
||||
# =====
|
||||
|
||||
async def get_keymaps(self) -> dict: # Ugly hack to generate hid_keymaps_state (see server.py)
|
||||
@@ -119,15 +130,26 @@ class HidApi:
|
||||
@exposed_http("POST", "/hid/print")
|
||||
async def __print_handler(self, req: Request) -> Response:
|
||||
text = await req.text()
|
||||
limit = int(valid_int_f0(req.query.get("limit", 1024)))
|
||||
limit = valid_int_f0(req.query.get("limit", 1024))
|
||||
if limit > 0:
|
||||
text = text[:limit]
|
||||
symmap = self.__ensure_symmap(req.query.get("keymap", self.__default_keymap_name))
|
||||
slow = valid_bool(req.query.get("slow", False))
|
||||
await self.__hid.send_key_events(text_to_web_keys(text, symmap), no_ignore_keys=True, slow=slow)
|
||||
delay = float(valid_number(
|
||||
arg=req.query.get("delay", (0.02 if slow else 0)),
|
||||
min=0,
|
||||
max=5,
|
||||
type=float,
|
||||
name="keys delay",
|
||||
))
|
||||
await self.__hid.send_key_events(
|
||||
keys=text_to_evdev_keys(text, symmap),
|
||||
no_ignore_keys=True,
|
||||
delay=delay,
|
||||
)
|
||||
return make_json_response()
|
||||
|
||||
def __ensure_symmap(self, keymap_name: str) -> dict[int, dict[int, str]]:
|
||||
def __ensure_symmap(self, keymap_name: str) -> dict[int, dict[int, int]]:
|
||||
keymap_name = valid_printable_filename(keymap_name, "keymap")
|
||||
path = os.path.join(self.__keymaps_dir_path, keymap_name)
|
||||
try:
|
||||
@@ -139,7 +161,7 @@ class HidApi:
|
||||
return self.__inner_ensure_symmap(path, st.st_mtime)
|
||||
|
||||
@functools.lru_cache(maxsize=10)
|
||||
def __inner_ensure_symmap(self, path: str, mod_ts: int) -> dict[int, dict[int, str]]:
|
||||
def __inner_ensure_symmap(self, path: str, mod_ts: int) -> dict[int, dict[int, int]]:
|
||||
_ = mod_ts # For LRU
|
||||
return build_symmap(path)
|
||||
|
||||
@@ -148,9 +170,12 @@ class HidApi:
|
||||
@exposed_ws(1)
|
||||
async def __ws_bin_key_handler(self, _: WsSession, data: bytes) -> None:
|
||||
try:
|
||||
key = valid_hid_key(data[1:].decode("ascii"))
|
||||
state = bool(data[0] & 0b01)
|
||||
finish = bool(data[0] & 0b10)
|
||||
if data[0] & 0b10000000:
|
||||
key = struct.unpack(">H", data[1:])[0]
|
||||
else:
|
||||
key = WEB_TO_EVDEV[valid_hid_key(data[1:33].decode("ascii"))]
|
||||
except Exception:
|
||||
return
|
||||
self.__hid.send_key_event(key, state, finish)
|
||||
@@ -158,7 +183,11 @@ class HidApi:
|
||||
@exposed_ws(2)
|
||||
async def __ws_bin_mouse_button_handler(self, _: WsSession, data: bytes) -> None:
|
||||
try:
|
||||
button = valid_hid_mouse_button(data[1:].decode("ascii"))
|
||||
state = bool(data[0] & 0b01)
|
||||
if data[0] & 0b10000000:
|
||||
button = struct.unpack(">H", data[1:])[0]
|
||||
else:
|
||||
button = MOUSE_TO_EVDEV[valid_hid_mouse_button(data[1:33].decode("ascii"))]
|
||||
state = bool(data[0] & 0b01)
|
||||
except Exception:
|
||||
return
|
||||
@@ -199,7 +228,7 @@ class HidApi:
|
||||
@exposed_ws("key")
|
||||
async def __ws_key_handler(self, _: WsSession, event: dict) -> None:
|
||||
try:
|
||||
key = valid_hid_key(event["key"])
|
||||
key = WEB_TO_EVDEV[valid_hid_key(event["key"])]
|
||||
state = valid_bool(event["state"])
|
||||
finish = valid_bool(event.get("finish", False))
|
||||
except Exception:
|
||||
@@ -209,7 +238,7 @@ class HidApi:
|
||||
@exposed_ws("mouse_button")
|
||||
async def __ws_mouse_button_handler(self, _: WsSession, event: dict) -> None:
|
||||
try:
|
||||
button = valid_hid_mouse_button(event["button"])
|
||||
button = MOUSE_TO_EVDEV[valid_hid_mouse_button(event["button"])]
|
||||
state = valid_bool(event["state"])
|
||||
except Exception:
|
||||
return
|
||||
@@ -246,9 +275,22 @@ class HidApi:
|
||||
|
||||
# =====
|
||||
|
||||
@exposed_http("POST", "/hid/events/send_shortcut")
|
||||
async def __events_send_shortcut_handler(self, req: Request) -> Response:
|
||||
shortcut = valid_string_list(req.query.get("keys"), subval=valid_hid_key)
|
||||
if shortcut:
|
||||
press = [WEB_TO_EVDEV[key] for key in shortcut]
|
||||
release = list(reversed(press))
|
||||
seq = [
|
||||
*zip(press, itertools.repeat(True)),
|
||||
*zip(release, itertools.repeat(False)),
|
||||
]
|
||||
await self.__hid.send_key_events(seq, no_ignore_keys=True, delay=0.05)
|
||||
return make_json_response()
|
||||
|
||||
@exposed_http("POST", "/hid/events/send_key")
|
||||
async def __events_send_key_handler(self, req: Request) -> Response:
|
||||
key = valid_hid_key(req.query.get("key"))
|
||||
key = WEB_TO_EVDEV[valid_hid_key(req.query.get("key"))]
|
||||
if "state" in req.query:
|
||||
state = valid_bool(req.query["state"])
|
||||
finish = valid_bool(req.query.get("finish", False))
|
||||
@@ -259,7 +301,7 @@ class HidApi:
|
||||
|
||||
@exposed_http("POST", "/hid/events/send_mouse_button")
|
||||
async def __events_send_mouse_button_handler(self, req: Request) -> Response:
|
||||
button = valid_hid_mouse_button(req.query.get("button"))
|
||||
button = MOUSE_TO_EVDEV[valid_hid_mouse_button(req.query.get("button"))]
|
||||
if "state" in req.query:
|
||||
state = valid_bool(req.query["state"])
|
||||
self.__hid.send_mouse_button_event(button, state)
|
||||
|
||||
@@ -45,7 +45,10 @@ class InfoApi:
|
||||
|
||||
def __valid_info_fields(self, req: Request) -> list[str]:
|
||||
available = self.__info_manager.get_subs()
|
||||
available.add("hw")
|
||||
default = set(available)
|
||||
default.remove("health")
|
||||
return sorted(valid_info_fields(
|
||||
arg=req.query.get("fields", ",".join(available)),
|
||||
variants=available,
|
||||
arg=req.query.get("fields", ",".join(default)),
|
||||
variants=(available),
|
||||
) or available)
|
||||
|
||||
@@ -52,17 +52,15 @@ class LogApi:
|
||||
raise LogReaderDisabledError()
|
||||
seek = valid_log_seek(req.query.get("seek", 0))
|
||||
follow = valid_bool(req.query.get("follow", False))
|
||||
response = await start_streaming(req, "text/plain")
|
||||
resp = await start_streaming(req, "text/plain")
|
||||
try:
|
||||
async for record in self.__log_reader.poll_log(seek, follow):
|
||||
await response.write(("[%s %s] --- %s" % (
|
||||
await resp.write(("[%s %s] --- %s" % (
|
||||
record["dt"].strftime("%Y-%m-%d %H:%M:%S"),
|
||||
record["service"],
|
||||
record["msg"],
|
||||
)).encode("utf-8") + b"\r\n")
|
||||
except Exception as exception:
|
||||
if record is None:
|
||||
record = exception
|
||||
await response.write(f"Module systemd.journal is unavailable.\n{record}".encode("utf-8"))
|
||||
return response
|
||||
return response
|
||||
await resp.write(f"Module systemd.journal is unavailable.\n{exception}".encode("utf-8"))
|
||||
return resp
|
||||
return resp
|
||||
|
||||
@@ -133,10 +133,10 @@ class MsdApi:
|
||||
src = compressed()
|
||||
size = -1
|
||||
|
||||
response = await start_streaming(req, "application/octet-stream", size, name + suffix)
|
||||
resp = await start_streaming(req, "application/octet-stream", size, name + suffix)
|
||||
async for chunk in src:
|
||||
await response.write(chunk)
|
||||
return response
|
||||
await resp.write(chunk)
|
||||
return resp
|
||||
|
||||
# =====
|
||||
|
||||
@@ -166,11 +166,11 @@ class MsdApi:
|
||||
|
||||
name = ""
|
||||
size = written = 0
|
||||
response: (StreamResponse | None) = None
|
||||
resp: (StreamResponse | None) = None
|
||||
|
||||
async def stream_write_info() -> None:
|
||||
assert response is not None
|
||||
await stream_json(response, self.__make_write_info(name, size, written))
|
||||
assert resp is not None
|
||||
await stream_json(resp, self.__make_write_info(name, size, written))
|
||||
|
||||
try:
|
||||
async with htclient.download(
|
||||
@@ -190,7 +190,7 @@ class MsdApi:
|
||||
get_logger(0).info("Downloading image %r as %r to MSD ...", url, name)
|
||||
async with self.__msd.write_image(name, size, remove_incomplete) as writer:
|
||||
chunk_size = writer.get_chunk_size()
|
||||
response = await start_streaming(req, "application/x-ndjson")
|
||||
resp = await start_streaming(req, "application/x-ndjson")
|
||||
await stream_write_info()
|
||||
last_report_ts = 0
|
||||
async for chunk in remote.content.iter_chunked(chunk_size):
|
||||
@@ -201,12 +201,12 @@ class MsdApi:
|
||||
last_report_ts = now
|
||||
|
||||
await stream_write_info()
|
||||
return response
|
||||
return resp
|
||||
|
||||
except Exception as ex:
|
||||
if response is not None:
|
||||
if resp is not None:
|
||||
await stream_write_info()
|
||||
await stream_json_exception(response, ex)
|
||||
await stream_json_exception(resp, ex)
|
||||
elif isinstance(ex, aiohttp.ClientError):
|
||||
return make_json_exception(ex, 400)
|
||||
raise
|
||||
|
||||
@@ -102,14 +102,26 @@ class RedfishApi:
|
||||
"Actions": {
|
||||
"#ComputerSystem.Reset": {
|
||||
"ResetType@Redfish.AllowableValues": list(self.__actions),
|
||||
"target": "/redfish/v1/Systems/0/Actions/ComputerSystem.Reset"
|
||||
"target": "/redfish/v1/Systems/0/Actions/ComputerSystem.Reset",
|
||||
},
|
||||
"#ComputerSystem.SetDefaultBootOrder": { # https://github.com/pikvm/pikvm/issues/1525
|
||||
"target": "/redfish/v1/Systems/0/Actions/ComputerSystem.SetDefaultBootOrder",
|
||||
},
|
||||
},
|
||||
"Id": "0",
|
||||
"HostName": host,
|
||||
"PowerState": ("On" if atx_state["leds"]["power"] else "Off"), # type: ignore
|
||||
"Boot": {
|
||||
"BootSourceOverrideEnabled": "Disabled",
|
||||
"BootSourceOverrideTarget": None,
|
||||
},
|
||||
}, wrap_result=False)
|
||||
|
||||
@exposed_http("PATCH", "/redfish/v1/Systems/0")
|
||||
async def __patch_handler(self, _: Request) -> Response:
|
||||
# https://github.com/pikvm/pikvm/issues/1525
|
||||
return Response(body=None, status=204)
|
||||
|
||||
@exposed_http("POST", "/redfish/v1/Systems/0/Actions/ComputerSystem.Reset")
|
||||
async def __power_handler(self, req: Request) -> Response:
|
||||
try:
|
||||
|
||||
@@ -28,6 +28,7 @@ from ....htserver import make_json_response
|
||||
|
||||
from ....validators.basic import valid_bool
|
||||
from ....validators.basic import valid_int_f0
|
||||
from ....validators.basic import valid_float_f0
|
||||
from ....validators.basic import valid_stripped_string_not_empty
|
||||
from ....validators.kvm import valid_atx_power_action
|
||||
from ....validators.kvm import valid_atx_button
|
||||
@@ -52,9 +53,19 @@ class SwitchApi:
|
||||
async def __state_handler(self, _: Request) -> Response:
|
||||
return make_json_response(await self.__switch.get_state())
|
||||
|
||||
@exposed_http("POST", "/switch/set_active_prev")
|
||||
async def __set_active_prev_handler(self, _: Request) -> Response:
|
||||
await self.__switch.set_active_prev()
|
||||
return make_json_response()
|
||||
|
||||
@exposed_http("POST", "/switch/set_active_next")
|
||||
async def __set_active_next_handler(self, _: Request) -> Response:
|
||||
await self.__switch.set_active_next()
|
||||
return make_json_response()
|
||||
|
||||
@exposed_http("POST", "/switch/set_active")
|
||||
async def __set_active_port_handler(self, req: Request) -> Response:
|
||||
port = valid_int_f0(req.query.get("port"))
|
||||
port = valid_float_f0(req.query.get("port"))
|
||||
await self.__switch.set_active_port(port)
|
||||
return make_json_response()
|
||||
|
||||
@@ -62,7 +73,7 @@ class SwitchApi:
|
||||
async def __set_beacon_handler(self, req: Request) -> Response:
|
||||
on = valid_bool(req.query.get("state"))
|
||||
if "port" in req.query:
|
||||
port = valid_int_f0(req.query.get("port"))
|
||||
port = valid_float_f0(req.query.get("port"))
|
||||
await self.__switch.set_port_beacon(port, on)
|
||||
elif "uplink" in req.query:
|
||||
unit = valid_int_f0(req.query.get("uplink"))
|
||||
@@ -74,11 +85,12 @@ class SwitchApi:
|
||||
|
||||
@exposed_http("POST", "/switch/set_port_params")
|
||||
async def __set_port_params(self, req: Request) -> Response:
|
||||
port = valid_int_f0(req.query.get("port"))
|
||||
port = valid_float_f0(req.query.get("port"))
|
||||
params = {
|
||||
param: validator(req.query.get(param))
|
||||
for (param, validator) in [
|
||||
("edid_id", (lambda arg: valid_switch_edid_id(arg, allow_default=True))),
|
||||
("dummy", valid_bool),
|
||||
("name", valid_switch_port_name),
|
||||
("atx_click_power_delay", valid_switch_atx_click_delay),
|
||||
("atx_click_power_long_delay", valid_switch_atx_click_delay),
|
||||
@@ -142,7 +154,7 @@ class SwitchApi:
|
||||
|
||||
@exposed_http("POST", "/switch/atx/power")
|
||||
async def __power_handler(self, req: Request) -> Response:
|
||||
port = valid_int_f0(req.query.get("port"))
|
||||
port = valid_float_f0(req.query.get("port"))
|
||||
action = valid_atx_power_action(req.query.get("action"))
|
||||
await ({
|
||||
"on": self.__switch.atx_power_on,
|
||||
@@ -154,7 +166,7 @@ class SwitchApi:
|
||||
|
||||
@exposed_http("POST", "/switch/atx/click")
|
||||
async def __click_handler(self, req: Request) -> Response:
|
||||
port = valid_int_f0(req.query.get("port"))
|
||||
port = valid_float_f0(req.query.get("port"))
|
||||
button = valid_atx_button(req.query.get("button"))
|
||||
await ({
|
||||
"power": self.__switch.atx_click_power,
|
||||
|
||||
@@ -20,6 +20,12 @@
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import pwd
|
||||
import grp
|
||||
import dataclasses
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import secrets
|
||||
import pyotp
|
||||
|
||||
@@ -31,48 +37,79 @@ from ...plugins.auth import BaseAuthService
|
||||
from ...plugins.auth import get_auth_service_class
|
||||
|
||||
from ...htserver import HttpExposed
|
||||
from ...htserver import RequestUnixCredentials
|
||||
|
||||
|
||||
# =====
|
||||
class AuthManager:
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Session:
|
||||
user: str
|
||||
expire_ts: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.user == self.user.strip()
|
||||
assert self.user
|
||||
assert self.expire_ts >= 0
|
||||
|
||||
|
||||
class AuthManager: # pylint: disable=too-many-arguments,too-many-instance-attributes
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
expire: int,
|
||||
usc_users: list[str],
|
||||
usc_groups: list[str],
|
||||
unauth_paths: list[str],
|
||||
|
||||
internal_type: str,
|
||||
internal_kwargs: dict,
|
||||
force_internal_users: list[str],
|
||||
int_type: str,
|
||||
int_kwargs: dict,
|
||||
force_int_users: list[str],
|
||||
|
||||
external_type: str,
|
||||
external_kwargs: dict,
|
||||
ext_type: str,
|
||||
ext_kwargs: dict,
|
||||
|
||||
totp_secret_path: str,
|
||||
) -> None:
|
||||
|
||||
logger = get_logger(0)
|
||||
|
||||
self.__enabled = enabled
|
||||
if not enabled:
|
||||
get_logger().warning("AUTHORIZATION IS DISABLED")
|
||||
logger.warning("AUTHORIZATION IS DISABLED")
|
||||
|
||||
assert expire >= 0
|
||||
self.__expire = expire
|
||||
if expire > 0:
|
||||
logger.info("Maximum user session time is limited: %s",
|
||||
self.__format_seconds(expire))
|
||||
|
||||
self.__usc_uids = self.__load_usc_uids(usc_users, usc_groups)
|
||||
if self.__usc_uids:
|
||||
logger.info("Selfauth UNIX socket access is allowed for users: %s",
|
||||
list(self.__usc_uids.values()))
|
||||
|
||||
self.__unauth_paths = frozenset(unauth_paths) # To speed up
|
||||
for path in self.__unauth_paths:
|
||||
get_logger().warning("Authorization is disabled for API %r", path)
|
||||
if self.__unauth_paths:
|
||||
logger.info("Authorization is disabled for APIs: %s",
|
||||
list(self.__unauth_paths))
|
||||
|
||||
self.__internal_service: (BaseAuthService | None) = None
|
||||
self.__int_service: (BaseAuthService | None) = None
|
||||
if enabled:
|
||||
self.__internal_service = get_auth_service_class(internal_type)(**internal_kwargs)
|
||||
get_logger().info("Using internal auth service %r", self.__internal_service.get_plugin_name())
|
||||
self.__int_service = get_auth_service_class(int_type)(**int_kwargs)
|
||||
logger.info("Using internal auth service %r",
|
||||
self.__int_service.get_plugin_name())
|
||||
|
||||
self.__force_internal_users = force_internal_users
|
||||
self.__force_int_users = force_int_users
|
||||
|
||||
self.__external_service: (BaseAuthService | None) = None
|
||||
if enabled and external_type:
|
||||
self.__external_service = get_auth_service_class(external_type)(**external_kwargs)
|
||||
get_logger().info("Using external auth service %r", self.__external_service.get_plugin_name())
|
||||
self.__ext_service: (BaseAuthService | None) = None
|
||||
if enabled and ext_type:
|
||||
self.__ext_service = get_auth_service_class(ext_type)(**ext_kwargs)
|
||||
logger.info("Using external auth service %r",
|
||||
self.__ext_service.get_plugin_name())
|
||||
|
||||
self.__totp_secret_path = totp_secret_path
|
||||
|
||||
self.__tokens: dict[str, str] = {} # {token: user}
|
||||
self.__sessions: dict[str, _Session] = {} # {token: session}
|
||||
|
||||
def is_auth_enabled(self) -> bool:
|
||||
return self.__enabled
|
||||
@@ -88,7 +125,8 @@ class AuthManager:
|
||||
assert user == user.strip()
|
||||
assert user
|
||||
assert self.__enabled
|
||||
assert self.__internal_service
|
||||
assert self.__int_service
|
||||
logger = get_logger(0)
|
||||
|
||||
if self.__totp_secret_path:
|
||||
with open(self.__totp_secret_path) as file:
|
||||
@@ -96,60 +134,150 @@ class AuthManager:
|
||||
if secret:
|
||||
code = passwd[-6:]
|
||||
if not pyotp.TOTP(secret).verify(code, valid_window=1):
|
||||
get_logger().error("Got access denied for user %r by TOTP", user)
|
||||
logger.error("Got access denied for user %r by TOTP", user)
|
||||
return False
|
||||
passwd = passwd[:-6]
|
||||
|
||||
if user not in self.__force_internal_users and self.__external_service:
|
||||
service = self.__external_service
|
||||
if user not in self.__force_int_users and self.__ext_service:
|
||||
service = self.__ext_service
|
||||
else:
|
||||
service = self.__internal_service
|
||||
service = self.__int_service
|
||||
|
||||
pname = service.get_plugin_name()
|
||||
ok = (await service.authorize(user, passwd))
|
||||
if ok:
|
||||
get_logger().info("Authorized user %r via auth service %r", user, service.get_plugin_name())
|
||||
logger.info("Authorized user %r via auth service %r", user, pname)
|
||||
else:
|
||||
get_logger().error("Got access denied for user %r from auth service %r", user, service.get_plugin_name())
|
||||
logger.error("Got access denied for user %r from auth service %r", user, pname)
|
||||
return ok
|
||||
|
||||
async def login(self, user: str, passwd: str) -> (str | None):
|
||||
async def login(self, user: str, passwd: str, expire: int) -> (str | None):
|
||||
assert user == user.strip()
|
||||
assert user
|
||||
assert expire >= 0
|
||||
assert self.__enabled
|
||||
|
||||
if (await self.authorize(user, passwd)):
|
||||
token = self.__make_new_token()
|
||||
self.__tokens[token] = user
|
||||
get_logger().info("Logged in user %r", user)
|
||||
session = _Session(
|
||||
user=user,
|
||||
expire_ts=self.__make_expire_ts(expire),
|
||||
)
|
||||
self.__sessions[token] = session
|
||||
get_logger(0).info("Logged in user %r; expire=%s, sessions_now=%d",
|
||||
session.user,
|
||||
self.__format_expire_ts(session.expire_ts),
|
||||
self.__get_sessions_number(session.user))
|
||||
return token
|
||||
else:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def __make_new_token(self) -> str:
|
||||
for _ in range(10):
|
||||
token = secrets.token_hex(32)
|
||||
if token not in self.__tokens:
|
||||
if token not in self.__sessions:
|
||||
return token
|
||||
raise AssertionError("Can't generate new unique token")
|
||||
raise RuntimeError("Can't generate new unique token")
|
||||
|
||||
def __make_expire_ts(self, expire: int) -> int:
|
||||
assert expire >= 0
|
||||
assert self.__expire >= 0
|
||||
|
||||
if expire == 0:
|
||||
# The user requested infinite session: apply global expire.
|
||||
# It will allow this (0) or set a limit.
|
||||
expire = self.__expire
|
||||
else:
|
||||
# The user wants a limited session
|
||||
if self.__expire > 0:
|
||||
# If we have a global limit, override the user limit
|
||||
assert expire > 0
|
||||
expire = min(expire, self.__expire)
|
||||
|
||||
if expire > 0:
|
||||
return (self.__get_now_ts() + expire)
|
||||
|
||||
assert expire == 0
|
||||
return 0
|
||||
|
||||
def __get_now_ts(self) -> int:
|
||||
return int(time.monotonic())
|
||||
|
||||
def __format_expire_ts(self, expire_ts: int) -> str:
|
||||
if expire_ts > 0:
|
||||
seconds = expire_ts - self.__get_now_ts()
|
||||
return f"[{self.__format_seconds(seconds)}]"
|
||||
return "INF"
|
||||
|
||||
def __format_seconds(self, seconds: int) -> str:
|
||||
return str(datetime.timedelta(seconds=seconds))
|
||||
|
||||
def __get_sessions_number(self, user: str) -> int:
|
||||
return sum(
|
||||
1
|
||||
for session in self.__sessions.values()
|
||||
if session.user == user
|
||||
)
|
||||
|
||||
def logout(self, token: str) -> None:
|
||||
assert self.__enabled
|
||||
if token in self.__tokens:
|
||||
user = self.__tokens[token]
|
||||
if token in self.__sessions:
|
||||
user = self.__sessions[token].user
|
||||
count = 0
|
||||
for (r_token, r_user) in list(self.__tokens.items()):
|
||||
if r_user == user:
|
||||
for (key_t, session) in list(self.__sessions.items()):
|
||||
if session.user == user:
|
||||
count += 1
|
||||
del self.__tokens[r_token]
|
||||
get_logger().info("Logged out user %r (%d)", user, count)
|
||||
del self.__sessions[key_t]
|
||||
get_logger(0).info("Logged out user %r; sessions_closed=%d", user, count)
|
||||
|
||||
def check(self, token: str) -> (str | None):
|
||||
assert self.__enabled
|
||||
return self.__tokens.get(token)
|
||||
session = self.__sessions.get(token)
|
||||
if session is not None:
|
||||
if session.expire_ts <= 0:
|
||||
# Infinite session
|
||||
return session.user
|
||||
else:
|
||||
# Limited session
|
||||
if self.__get_now_ts() < session.expire_ts:
|
||||
return session.user
|
||||
else:
|
||||
del self.__sessions[token]
|
||||
get_logger(0).info("The session of user %r is expired; sessions_left=%d",
|
||||
session.user,
|
||||
self.__get_sessions_number(session.user))
|
||||
return None
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def cleanup(self) -> None:
|
||||
if self.__enabled:
|
||||
assert self.__internal_service
|
||||
await self.__internal_service.cleanup()
|
||||
if self.__external_service:
|
||||
await self.__external_service.cleanup()
|
||||
assert self.__int_service
|
||||
await self.__int_service.cleanup()
|
||||
if self.__ext_service:
|
||||
await self.__ext_service.cleanup()
|
||||
|
||||
# =====
|
||||
|
||||
def __load_usc_uids(self, users: list[str], groups: list[str]) -> dict[int, str]:
|
||||
uids: dict[int, str] = {}
|
||||
|
||||
pwds: dict[str, int] = {}
|
||||
for pw in pwd.getpwall():
|
||||
assert pw.pw_name == pw.pw_name.strip()
|
||||
assert pw.pw_name
|
||||
pwds[pw.pw_name] = pw.pw_uid
|
||||
if pw.pw_name in users:
|
||||
uids[pw.pw_uid] = pw.pw_name
|
||||
|
||||
for gr in grp.getgrall():
|
||||
if gr.gr_name in groups:
|
||||
for member in gr.gr_mem:
|
||||
if member in pwds:
|
||||
uid = pwds[member]
|
||||
uids[uid] = member
|
||||
|
||||
return uids
|
||||
|
||||
def check_unix_credentials(self, creds: RequestUnixCredentials) -> (str | None):
|
||||
assert self.__enabled
|
||||
return self.__usc_uids.get(creds.uid)
|
||||
|
||||
@@ -31,7 +31,7 @@ from .auth import AuthInfoSubmanager
|
||||
from .system import SystemInfoSubmanager
|
||||
from .meta import MetaInfoSubmanager
|
||||
from .extras import ExtrasInfoSubmanager
|
||||
from .hw import HwInfoSubmanager
|
||||
from .health import HealthInfoSubmanager
|
||||
from .fan import FanInfoSubmanager
|
||||
|
||||
|
||||
@@ -39,11 +39,11 @@ from .fan import FanInfoSubmanager
|
||||
class InfoManager:
|
||||
def __init__(self, config: Section) -> None:
|
||||
self.__subs: dict[str, BaseInfoSubmanager] = {
|
||||
"system": SystemInfoSubmanager(config.kvmd.streamer.cmd),
|
||||
"system": SystemInfoSubmanager(config.kvmd.info.hw.platform, config.kvmd.streamer.cmd),
|
||||
"auth": AuthInfoSubmanager(config.kvmd.auth.enabled),
|
||||
"meta": MetaInfoSubmanager(config.kvmd.info.meta),
|
||||
"extras": ExtrasInfoSubmanager(config),
|
||||
"hw": HwInfoSubmanager(**config.kvmd.info.hw._unpack()),
|
||||
"health": HealthInfoSubmanager(**config.kvmd.info.hw._unpack(ignore="platform")),
|
||||
"fan": FanInfoSubmanager(**config.kvmd.info.fan._unpack()),
|
||||
}
|
||||
self.__queue: "asyncio.Queue[tuple[str, (dict | None)]]" = asyncio.Queue()
|
||||
@@ -52,12 +52,29 @@ class InfoManager:
|
||||
return set(self.__subs)
|
||||
|
||||
async def get_state(self, fields: (list[str] | None)=None) -> dict:
|
||||
fields = (fields or list(self.__subs))
|
||||
return dict(zip(fields, await asyncio.gather(*[
|
||||
fields_set = set(fields or list(self.__subs))
|
||||
|
||||
hw = ("hw" in fields_set) # Old for compatible
|
||||
system = ("system" in fields_set)
|
||||
if hw:
|
||||
fields_set.remove("hw")
|
||||
fields_set.add("health")
|
||||
fields_set.add("system")
|
||||
|
||||
state = dict(zip(fields_set, await asyncio.gather(*[
|
||||
self.__subs[field].get_state()
|
||||
for field in fields
|
||||
for field in fields_set
|
||||
])))
|
||||
|
||||
if hw:
|
||||
state["hw"] = {
|
||||
"health": state.pop("health"),
|
||||
"platform": (state["system"] or {}).pop("platform"), # {} makes mypy happy
|
||||
}
|
||||
if not system:
|
||||
state.pop("system")
|
||||
return state
|
||||
|
||||
async def trigger_state(self) -> None:
|
||||
await asyncio.gather(*[
|
||||
sub.trigger_state()
|
||||
@@ -70,7 +87,7 @@ class InfoManager:
|
||||
# - auth -- Partial
|
||||
# - meta -- Partial, nullable
|
||||
# - extras -- Partial, nullable
|
||||
# - hw -- Partial
|
||||
# - health -- Partial
|
||||
# - fan -- Partial
|
||||
# ===========================
|
||||
|
||||
|
||||
@@ -99,9 +99,9 @@ class FanInfoSubmanager(BaseInfoSubmanager):
|
||||
async def __get_fan_state(self) -> (dict | None):
|
||||
try:
|
||||
async with self.__make_http_session() as session:
|
||||
async with session.get("http://localhost/state") as response:
|
||||
htclient.raise_not_200(response)
|
||||
return (await response.json())["result"]
|
||||
async with session.get("http://localhost/state") as resp:
|
||||
htclient.raise_not_200(resp)
|
||||
return (await resp.json())["result"]
|
||||
except Exception as ex:
|
||||
get_logger(0).error("Can't read fan state: %s", ex)
|
||||
return None
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import copy
|
||||
|
||||
@@ -45,59 +44,41 @@ _RetvalT = TypeVar("_RetvalT")
|
||||
|
||||
|
||||
# =====
|
||||
class HwInfoSubmanager(BaseInfoSubmanager):
|
||||
class HealthInfoSubmanager(BaseInfoSubmanager):
|
||||
def __init__(
|
||||
self,
|
||||
platform_path: str,
|
||||
vcgencmd_cmd: list[str],
|
||||
ignore_past: bool,
|
||||
state_poll: float,
|
||||
) -> None:
|
||||
|
||||
self.__platform_path = platform_path
|
||||
self.__vcgencmd_cmd = vcgencmd_cmd
|
||||
self.__ignore_past = ignore_past
|
||||
self.__state_poll = state_poll
|
||||
|
||||
self.__dt_cache: dict[str, str] = {}
|
||||
|
||||
self.__notifier = aiotools.AioNotifier()
|
||||
|
||||
async def get_state(self) -> dict:
|
||||
(
|
||||
base,
|
||||
serial,
|
||||
platform,
|
||||
throttling,
|
||||
cpu_percent,
|
||||
cpu_temp,
|
||||
mem,
|
||||
) = await asyncio.gather(
|
||||
self.__read_dt_file("model", _upper=False),
|
||||
self.__read_dt_file("serial-number", _upper=True),
|
||||
self.__read_platform_file(),
|
||||
self.__get_throttling(),
|
||||
self.__get_cpu_percent(),
|
||||
self.__get_cpu_temp(),
|
||||
self.__get_mem(),
|
||||
)
|
||||
return {
|
||||
"platform": {
|
||||
"type": "rpi",
|
||||
"base": base,
|
||||
"serial": serial,
|
||||
**platform, # type: ignore
|
||||
"temp": {
|
||||
"cpu": cpu_temp,
|
||||
},
|
||||
"health": {
|
||||
"temp": {
|
||||
"cpu": cpu_temp,
|
||||
},
|
||||
"cpu": {
|
||||
"percent": cpu_percent,
|
||||
},
|
||||
"mem": mem,
|
||||
"throttling": throttling,
|
||||
"cpu": {
|
||||
"percent": cpu_percent,
|
||||
},
|
||||
"mem": mem,
|
||||
"throttling": throttling,
|
||||
}
|
||||
|
||||
async def trigger_state(self) -> None:
|
||||
@@ -115,36 +96,6 @@ class HwInfoSubmanager(BaseInfoSubmanager):
|
||||
|
||||
# =====
|
||||
|
||||
async def __read_dt_file(self, name: str, _upper: bool) -> (str | None):
|
||||
if name not in self.__dt_cache:
|
||||
path = os.path.join(f"{env.PROCFS_PREFIX}/proc/device-tree", name)
|
||||
if not os.path.exists(path):
|
||||
path = os.path.join(f"{env.PROCFS_PREFIX}/etc/kvmd/hw_info/", name)
|
||||
try:
|
||||
self.__dt_cache[name] = (await aiotools.read_file(path)).strip(" \t\r\n\0")
|
||||
except Exception:
|
||||
# get_logger(0).warn("Can't read DT %s from %s: %s", name, path, err)
|
||||
return None
|
||||
return self.__dt_cache[name]
|
||||
|
||||
async def __read_platform_file(self) -> dict:
|
||||
try:
|
||||
text = await aiotools.read_file(self.__platform_path)
|
||||
parsed: dict[str, str] = {}
|
||||
for row in text.split("\n"):
|
||||
row = row.strip()
|
||||
if row:
|
||||
(key, value) = row.split("=", 1)
|
||||
parsed[key.strip()] = value.strip()
|
||||
return {
|
||||
"model": parsed["PIKVM_MODEL"],
|
||||
"video": parsed["PIKVM_VIDEO"],
|
||||
"board": parsed["PIKVM_BOARD"],
|
||||
}
|
||||
except Exception:
|
||||
get_logger(0).exception("Can't read device model")
|
||||
return {"model": None, "video": None, "board": None}
|
||||
|
||||
async def __get_cpu_temp(self) -> (float | None):
|
||||
temp_path = f"{env.SYSFS_PREFIX}/sys/class/thermal/thermal_zone0/temp"
|
||||
try:
|
||||
@@ -20,6 +20,8 @@
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import socket
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from ....logging import get_logger
|
||||
@@ -39,7 +41,10 @@ class MetaInfoSubmanager(BaseInfoSubmanager):
|
||||
|
||||
async def get_state(self) -> (dict | None):
|
||||
try:
|
||||
return ((await aiotools.run_async(load_yaml_file, self.__meta_path)) or {})
|
||||
meta = ((await aiotools.run_async(load_yaml_file, self.__meta_path)) or {})
|
||||
if meta["server"]["host"] == "@auto":
|
||||
meta["server"]["host"] = socket.getfqdn()
|
||||
return meta
|
||||
except Exception:
|
||||
get_logger(0).exception("Can't parse meta")
|
||||
return None
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing import AsyncGenerator
|
||||
|
||||
from ....logging import get_logger
|
||||
|
||||
from .... import env
|
||||
from .... import aiotools
|
||||
from .... import aioproc
|
||||
|
||||
@@ -38,12 +39,30 @@ from .base import BaseInfoSubmanager
|
||||
|
||||
# =====
|
||||
class SystemInfoSubmanager(BaseInfoSubmanager):
|
||||
def __init__(self, streamer_cmd: list[str]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
platform_path: str,
|
||||
streamer_cmd: list[str],
|
||||
) -> None:
|
||||
|
||||
self.__platform_path = platform_path
|
||||
self.__streamer_cmd = streamer_cmd
|
||||
|
||||
self.__dt_cache: dict[str, str] = {}
|
||||
self.__notifier = aiotools.AioNotifier()
|
||||
|
||||
async def get_state(self) -> dict:
|
||||
streamer_info = await self.__get_streamer_info()
|
||||
(
|
||||
base,
|
||||
serial,
|
||||
pl,
|
||||
streamer_info,
|
||||
) = await asyncio.gather(
|
||||
self.__read_dt_file("model", upper=False),
|
||||
self.__read_dt_file("serial-number", upper=True),
|
||||
self.__read_platform_file(),
|
||||
self.__get_streamer_info(),
|
||||
)
|
||||
uname_info = platform.uname() # Uname using the internal cache
|
||||
return {
|
||||
"kvmd": {"version": __version__},
|
||||
@@ -52,6 +71,12 @@ class SystemInfoSubmanager(BaseInfoSubmanager):
|
||||
field: getattr(uname_info, field)
|
||||
for field in ["system", "release", "version", "machine"]
|
||||
},
|
||||
"platform": {
|
||||
"type": "rpi",
|
||||
"base": base,
|
||||
"serial": serial,
|
||||
**pl, # type: ignore
|
||||
},
|
||||
}
|
||||
|
||||
async def trigger_state(self) -> None:
|
||||
@@ -64,6 +89,35 @@ class SystemInfoSubmanager(BaseInfoSubmanager):
|
||||
|
||||
# =====
|
||||
|
||||
async def __read_dt_file(self, name: str, upper: bool) -> (str | None):
|
||||
if name not in self.__dt_cache:
|
||||
path = os.path.join(f"{env.PROCFS_PREFIX}/proc/device-tree", name)
|
||||
try:
|
||||
value = (await aiotools.read_file(path)).strip(" \t\r\n\0")
|
||||
self.__dt_cache[name] = (value.upper() if upper else value)
|
||||
except Exception as ex:
|
||||
get_logger(0).error("Can't read DT %s from %s: %s", name, path, ex)
|
||||
return None
|
||||
return self.__dt_cache[name]
|
||||
|
||||
async def __read_platform_file(self) -> dict:
|
||||
try:
|
||||
text = await aiotools.read_file(self.__platform_path)
|
||||
parsed: dict[str, str] = {}
|
||||
for row in text.split("\n"):
|
||||
row = row.strip()
|
||||
if row:
|
||||
(key, value) = row.split("=", 1)
|
||||
parsed[key.strip()] = value.strip()
|
||||
return {
|
||||
"model": parsed["PIKVM_MODEL"],
|
||||
"video": parsed["PIKVM_VIDEO"],
|
||||
"board": parsed["PIKVM_BOARD"],
|
||||
}
|
||||
except Exception:
|
||||
get_logger(0).exception("Can't read device model")
|
||||
return {"model": None, "video": None, "board": None}
|
||||
|
||||
async def __get_streamer_info(self) -> dict:
|
||||
version = ""
|
||||
features: dict[str, bool] = {}
|
||||
|
||||
@@ -254,6 +254,10 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
async def __ws_ping_handler(self, ws: WsSession, _: dict) -> None:
|
||||
await ws.send_event("pong", {})
|
||||
|
||||
@exposed_ws(0)
|
||||
async def __ws_bin_ping_handler(self, ws: WsSession, _: bytes) -> None:
|
||||
await ws.send_bin(255, b"") # Ping-pong
|
||||
|
||||
# ===== SYSTEM STUFF
|
||||
|
||||
def run(self, **kwargs: Any) -> None: # type: ignore # pylint: disable=arguments-differ
|
||||
@@ -318,18 +322,17 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
while True:
|
||||
cur = (self.__has_stream_clients() or self.__snapshoter.snapshoting() or self.__stream_forever)
|
||||
if not prev and cur:
|
||||
await self.__streamer.ensure_start(reset=False)
|
||||
await self.__streamer.ensure_start()
|
||||
elif prev and not cur:
|
||||
await self.__streamer.ensure_stop(immediately=False)
|
||||
await self.__streamer.ensure_stop()
|
||||
|
||||
if self.__reset_streamer or self.__new_streamer_params:
|
||||
start = self.__streamer.is_working()
|
||||
await self.__streamer.ensure_stop(immediately=True)
|
||||
if self.__new_streamer_params:
|
||||
self.__streamer.set_params(self.__new_streamer_params)
|
||||
self.__new_streamer_params = {}
|
||||
if start:
|
||||
await self.__streamer.ensure_start(reset=self.__reset_streamer)
|
||||
if self.__new_streamer_params:
|
||||
self.__streamer.set_params(self.__new_streamer_params)
|
||||
self.__new_streamer_params = {}
|
||||
self.__reset_streamer = True
|
||||
|
||||
if self.__reset_streamer:
|
||||
await self.__streamer.ensure_restart()
|
||||
self.__reset_streamer = False
|
||||
|
||||
prev = cur
|
||||
|
||||
@@ -31,6 +31,8 @@ from ... import aiotools
|
||||
|
||||
from ...plugins.hid import BaseHid
|
||||
|
||||
from ...keyboard.mappings import WEB_TO_EVDEV
|
||||
|
||||
from .streamer import Streamer
|
||||
|
||||
|
||||
@@ -63,7 +65,7 @@ class Snapshoter: # pylint: disable=too-many-instance-attributes
|
||||
else:
|
||||
self.__idle_interval = self.__live_interval = 0.0
|
||||
|
||||
self.__wakeup_key = wakeup_key
|
||||
self.__wakeup_key = WEB_TO_EVDEV.get(wakeup_key, 0)
|
||||
self.__wakeup_move = wakeup_move
|
||||
|
||||
self.__online_delay = online_delay
|
||||
@@ -121,8 +123,8 @@ class Snapshoter: # pylint: disable=too-many-instance-attributes
|
||||
async def __wakeup(self) -> None:
|
||||
logger = get_logger(0)
|
||||
|
||||
if self.__wakeup_key:
|
||||
logger.info("Waking up using key %r ...", self.__wakeup_key)
|
||||
if self.__wakeup_key > 0:
|
||||
logger.info("Waking up using keyboard ...")
|
||||
await self.__hid.send_key_events(
|
||||
keys=[(self.__wakeup_key, True), (self.__wakeup_key, False)],
|
||||
no_ignore_keys=True,
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import signal
|
||||
import asyncio
|
||||
import asyncio.subprocess
|
||||
import dataclasses
|
||||
import copy
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
from ...clients.streamer import StreamerSnapshot
|
||||
from ...clients.streamer import HttpStreamerClient
|
||||
from ...clients.streamer import HttpStreamerClientSession
|
||||
|
||||
from ... import tools
|
||||
from ... import aiotools
|
||||
from ... import aioproc
|
||||
from ... import htclient
|
||||
|
||||
|
||||
# =====
|
||||
class _StreamerParams:
|
||||
__DESIRED_FPS = "desired_fps"
|
||||
|
||||
__QUALITY = "quality"
|
||||
|
||||
__RESOLUTION = "resolution"
|
||||
__AVAILABLE_RESOLUTIONS = "available_resolutions"
|
||||
|
||||
__H264_BITRATE = "h264_bitrate"
|
||||
__H264_GOP = "h264_gop"
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
quality: int,
|
||||
|
||||
resolution: str,
|
||||
available_resolutions: list[str],
|
||||
|
||||
desired_fps: int,
|
||||
desired_fps_min: int,
|
||||
desired_fps_max: int,
|
||||
|
||||
h264_bitrate: int,
|
||||
h264_bitrate_min: int,
|
||||
h264_bitrate_max: int,
|
||||
|
||||
h264_gop: int,
|
||||
h264_gop_min: int,
|
||||
h264_gop_max: int,
|
||||
) -> None:
|
||||
|
||||
self.__has_quality = bool(quality)
|
||||
self.__has_resolution = bool(resolution)
|
||||
self.__has_h264 = bool(h264_bitrate)
|
||||
|
||||
self.__params: dict = {self.__DESIRED_FPS: min(max(desired_fps, desired_fps_min), desired_fps_max)}
|
||||
self.__limits: dict = {self.__DESIRED_FPS: {"min": desired_fps_min, "max": desired_fps_max}}
|
||||
|
||||
if self.__has_quality:
|
||||
self.__params[self.__QUALITY] = quality
|
||||
|
||||
if self.__has_resolution:
|
||||
self.__params[self.__RESOLUTION] = resolution
|
||||
self.__limits[self.__AVAILABLE_RESOLUTIONS] = available_resolutions
|
||||
|
||||
if self.__has_h264:
|
||||
self.__params[self.__H264_BITRATE] = min(max(h264_bitrate, h264_bitrate_min), h264_bitrate_max)
|
||||
self.__limits[self.__H264_BITRATE] = {"min": h264_bitrate_min, "max": h264_bitrate_max}
|
||||
self.__params[self.__H264_GOP] = min(max(h264_gop, h264_gop_min), h264_gop_max)
|
||||
self.__limits[self.__H264_GOP] = {"min": h264_gop_min, "max": h264_gop_max}
|
||||
|
||||
def get_features(self) -> dict:
|
||||
return {
|
||||
self.__QUALITY: self.__has_quality,
|
||||
self.__RESOLUTION: self.__has_resolution,
|
||||
"h264": self.__has_h264,
|
||||
}
|
||||
|
||||
def get_limits(self) -> dict:
|
||||
limits = copy.deepcopy(self.__limits)
|
||||
if self.__has_resolution:
|
||||
limits[self.__AVAILABLE_RESOLUTIONS] = list(limits[self.__AVAILABLE_RESOLUTIONS])
|
||||
return limits
|
||||
|
||||
def get_params(self) -> dict:
|
||||
return dict(self.__params)
|
||||
|
||||
def set_params(self, params: dict) -> None:
|
||||
new_params = dict(self.__params)
|
||||
|
||||
if self.__QUALITY in params and self.__has_quality:
|
||||
new_params[self.__QUALITY] = min(max(params[self.__QUALITY], 1), 100)
|
||||
|
||||
if self.__RESOLUTION in params and self.__has_resolution:
|
||||
if params[self.__RESOLUTION] in self.__limits[self.__AVAILABLE_RESOLUTIONS]:
|
||||
new_params[self.__RESOLUTION] = params[self.__RESOLUTION]
|
||||
|
||||
for (key, enabled) in [
|
||||
(self.__DESIRED_FPS, True),
|
||||
(self.__H264_BITRATE, self.__has_h264),
|
||||
(self.__H264_GOP, self.__has_h264),
|
||||
]:
|
||||
if key in params and enabled:
|
||||
if self.__check_limits_min_max(key, params[key]):
|
||||
new_params[key] = params[key]
|
||||
|
||||
self.__params = new_params
|
||||
|
||||
def __check_limits_min_max(self, key: str, value: int) -> bool:
|
||||
return (self.__limits[key]["min"] <= value <= self.__limits[key]["max"])
|
||||
|
||||
|
||||
class Streamer: # pylint: disable=too-many-instance-attributes
|
||||
__ST_FULL = 0xFF
|
||||
__ST_PARAMS = 0x01
|
||||
__ST_STREAMER = 0x02
|
||||
__ST_SNAPSHOT = 0x04
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
||||
self,
|
||||
|
||||
reset_delay: float,
|
||||
shutdown_delay: float,
|
||||
state_poll: float,
|
||||
|
||||
unix_path: str,
|
||||
timeout: float,
|
||||
snapshot_timeout: float,
|
||||
|
||||
process_name_prefix: str,
|
||||
|
||||
pre_start_cmd: list[str],
|
||||
pre_start_cmd_remove: list[str],
|
||||
pre_start_cmd_append: list[str],
|
||||
|
||||
cmd: list[str],
|
||||
cmd_remove: list[str],
|
||||
cmd_append: list[str],
|
||||
|
||||
post_stop_cmd: list[str],
|
||||
post_stop_cmd_remove: list[str],
|
||||
post_stop_cmd_append: list[str],
|
||||
|
||||
**params_kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
self.__reset_delay = reset_delay
|
||||
self.__shutdown_delay = shutdown_delay
|
||||
self.__state_poll = state_poll
|
||||
|
||||
self.__unix_path = unix_path
|
||||
self.__snapshot_timeout = snapshot_timeout
|
||||
|
||||
self.__process_name_prefix = process_name_prefix
|
||||
|
||||
self.__pre_start_cmd = tools.build_cmd(pre_start_cmd, pre_start_cmd_remove, pre_start_cmd_append)
|
||||
self.__cmd = tools.build_cmd(cmd, cmd_remove, cmd_append)
|
||||
self.__post_stop_cmd = tools.build_cmd(post_stop_cmd, post_stop_cmd_remove, post_stop_cmd_append)
|
||||
|
||||
self.__params = _StreamerParams(**params_kwargs)
|
||||
|
||||
self.__stop_task: (asyncio.Task | None) = None
|
||||
self.__stop_wip = False
|
||||
|
||||
self.__streamer_task: (asyncio.Task | None) = None
|
||||
self.__streamer_proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member
|
||||
|
||||
self.__client = HttpStreamerClient(
|
||||
name="jpeg",
|
||||
unix_path=self.__unix_path,
|
||||
timeout=timeout,
|
||||
user_agent=htclient.make_user_agent("KVMD"),
|
||||
)
|
||||
self.__client_session: (HttpStreamerClientSession | None) = None
|
||||
|
||||
self.__snapshot: (StreamerSnapshot | None) = None
|
||||
|
||||
self.__notifier = aiotools.AioNotifier()
|
||||
|
||||
# =====
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_start(self, reset: bool) -> None:
|
||||
if not self.__streamer_task or self.__stop_task:
|
||||
logger = get_logger(0)
|
||||
|
||||
if self.__stop_task:
|
||||
if not self.__stop_wip:
|
||||
self.__stop_task.cancel()
|
||||
await asyncio.gather(self.__stop_task, return_exceptions=True)
|
||||
logger.info("Streamer stop cancelled")
|
||||
return
|
||||
else:
|
||||
await asyncio.gather(self.__stop_task, return_exceptions=True)
|
||||
|
||||
if reset and self.__reset_delay > 0:
|
||||
logger.info("Waiting %.2f seconds for reset delay ...", self.__reset_delay)
|
||||
await asyncio.sleep(self.__reset_delay)
|
||||
logger.info("Starting streamer ...")
|
||||
await self.__inner_start()
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_stop(self, immediately: bool) -> None:
|
||||
if self.__streamer_task:
|
||||
logger = get_logger(0)
|
||||
|
||||
if immediately:
|
||||
if self.__stop_task:
|
||||
if not self.__stop_wip:
|
||||
self.__stop_task.cancel()
|
||||
await asyncio.gather(self.__stop_task, return_exceptions=True)
|
||||
logger.info("Stopping streamer immediately ...")
|
||||
await self.__inner_stop()
|
||||
else:
|
||||
await asyncio.gather(self.__stop_task, return_exceptions=True)
|
||||
else:
|
||||
logger.info("Stopping streamer immediately ...")
|
||||
await self.__inner_stop()
|
||||
|
||||
elif not self.__stop_task:
|
||||
|
||||
async def delayed_stop() -> None:
|
||||
try:
|
||||
await asyncio.sleep(self.__shutdown_delay)
|
||||
self.__stop_wip = True
|
||||
logger.info("Stopping streamer after delay ...")
|
||||
await self.__inner_stop()
|
||||
finally:
|
||||
self.__stop_task = None
|
||||
self.__stop_wip = False
|
||||
|
||||
logger.info("Planning to stop streamer in %.2f seconds ...", self.__shutdown_delay)
|
||||
self.__stop_task = asyncio.create_task(delayed_stop())
|
||||
|
||||
def is_working(self) -> bool:
|
||||
# Запущено и не планирует останавливаться
|
||||
return bool(self.__streamer_task and not self.__stop_task)
|
||||
|
||||
# =====
|
||||
|
||||
def set_params(self, params: dict) -> None:
|
||||
assert not self.__streamer_task
|
||||
self.__notifier.notify(self.__ST_PARAMS)
|
||||
return self.__params.set_params(params)
|
||||
|
||||
def get_params(self) -> dict:
|
||||
return self.__params.get_params()
|
||||
|
||||
# =====
|
||||
|
||||
async def get_state(self) -> dict:
|
||||
return {
|
||||
"features": self.__params.get_features(),
|
||||
"limits": self.__params.get_limits(),
|
||||
"params": self.__params.get_params(),
|
||||
"streamer": (await self.__get_streamer_state()),
|
||||
"snapshot": self.__get_snapshot_state(),
|
||||
}
|
||||
|
||||
async def trigger_state(self) -> None:
|
||||
self.__notifier.notify(self.__ST_FULL)
|
||||
|
||||
async def poll_state(self) -> AsyncGenerator[dict, None]:
|
||||
# ==== Granularity table ====
|
||||
# - features -- Full
|
||||
# - limits -- Partial, paired with params
|
||||
# - params -- Partial, paired with limits
|
||||
# - streamer -- Partial, nullable
|
||||
# - snapshot -- Partial
|
||||
# ===========================
|
||||
|
||||
def signal_handler(*_: Any) -> None:
|
||||
get_logger(0).info("Got SIGUSR2, checking the stream state ...")
|
||||
self.__notifier.notify(self.__ST_STREAMER)
|
||||
|
||||
get_logger(0).info("Installing SIGUSR2 streamer handler ...")
|
||||
asyncio.get_event_loop().add_signal_handler(signal.SIGUSR2, signal_handler)
|
||||
|
||||
prev: dict = {}
|
||||
while True:
|
||||
new: dict = {}
|
||||
|
||||
mask = await self.__notifier.wait(timeout=self.__state_poll)
|
||||
if mask == self.__ST_FULL:
|
||||
new = await self.get_state()
|
||||
prev = copy.deepcopy(new)
|
||||
yield new
|
||||
continue
|
||||
|
||||
if mask < 0:
|
||||
mask = self.__ST_STREAMER
|
||||
|
||||
def check_update(key: str, value: (dict | None)) -> None:
|
||||
if prev.get(key) != value:
|
||||
new[key] = value
|
||||
|
||||
if mask & self.__ST_PARAMS:
|
||||
check_update("params", self.__params.get_params())
|
||||
if mask & self.__ST_STREAMER:
|
||||
check_update("streamer", await self.__get_streamer_state())
|
||||
if mask & self.__ST_SNAPSHOT:
|
||||
check_update("snapshot", self.__get_snapshot_state())
|
||||
|
||||
if new and prev != new:
|
||||
prev.update(copy.deepcopy(new))
|
||||
yield new
|
||||
|
||||
async def __get_streamer_state(self) -> (dict | None):
|
||||
if self.__streamer_task:
|
||||
session = self.__ensure_client_session()
|
||||
try:
|
||||
return (await session.get_state())
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
|
||||
pass
|
||||
except Exception:
|
||||
get_logger().exception("Invalid streamer response from /state")
|
||||
return None
|
||||
|
||||
def __get_snapshot_state(self) -> dict:
|
||||
if self.__snapshot:
|
||||
snapshot = dataclasses.asdict(self.__snapshot)
|
||||
del snapshot["headers"]
|
||||
del snapshot["data"]
|
||||
return {"saved": snapshot}
|
||||
return {"saved": None}
|
||||
|
||||
# =====
|
||||
|
||||
async def take_snapshot(self, save: bool, load: bool, allow_offline: bool) -> (StreamerSnapshot | None):
|
||||
if load:
|
||||
return self.__snapshot
|
||||
logger = get_logger()
|
||||
session = self.__ensure_client_session()
|
||||
try:
|
||||
snapshot = await session.take_snapshot(self.__snapshot_timeout)
|
||||
if snapshot.online or allow_offline:
|
||||
if save:
|
||||
self.__snapshot = snapshot
|
||||
self.__notifier.notify(self.__ST_SNAPSHOT)
|
||||
return snapshot
|
||||
logger.error("Stream is offline, no signal or so")
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as ex:
|
||||
logger.error("Can't connect to streamer: %s", tools.efmt(ex))
|
||||
except Exception:
|
||||
logger.exception("Invalid streamer response from /snapshot")
|
||||
return None
|
||||
|
||||
def remove_snapshot(self) -> None:
|
||||
self.__snapshot = None
|
||||
|
||||
# =====
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def cleanup(self) -> None:
|
||||
await self.ensure_stop(immediately=True)
|
||||
if self.__client_session:
|
||||
await self.__client_session.close()
|
||||
self.__client_session = None
|
||||
|
||||
def __ensure_client_session(self) -> HttpStreamerClientSession:
|
||||
if not self.__client_session:
|
||||
self.__client_session = self.__client.make_session()
|
||||
return self.__client_session
|
||||
|
||||
# =====
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def __inner_start(self) -> None:
|
||||
assert not self.__streamer_task
|
||||
await self.__run_hook("PRE-START-CMD", self.__pre_start_cmd)
|
||||
self.__streamer_task = asyncio.create_task(self.__streamer_task_loop())
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def __inner_stop(self) -> None:
|
||||
assert self.__streamer_task
|
||||
self.__streamer_task.cancel()
|
||||
await asyncio.gather(self.__streamer_task, return_exceptions=True)
|
||||
await self.__kill_streamer_proc()
|
||||
await self.__run_hook("POST-STOP-CMD", self.__post_stop_cmd)
|
||||
self.__streamer_task = None
|
||||
|
||||
# =====
|
||||
|
||||
async def __streamer_task_loop(self) -> None: # pylint: disable=too-many-branches
|
||||
logger = get_logger(0)
|
||||
while True: # pylint: disable=too-many-nested-blocks
|
||||
try:
|
||||
await self.__start_streamer_proc()
|
||||
assert self.__streamer_proc is not None
|
||||
await aioproc.log_stdout_infinite(self.__streamer_proc, logger)
|
||||
raise RuntimeError("Streamer unexpectedly died")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
if self.__streamer_proc:
|
||||
logger.exception("Unexpected streamer error: pid=%d", self.__streamer_proc.pid)
|
||||
else:
|
||||
logger.exception("Can't start streamer")
|
||||
await self.__kill_streamer_proc()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def __make_cmd(self, cmd: list[str]) -> list[str]:
|
||||
return [
|
||||
part.format(
|
||||
unix=self.__unix_path,
|
||||
process_name_prefix=self.__process_name_prefix,
|
||||
**self.__params.get_params(),
|
||||
)
|
||||
for part in cmd
|
||||
]
|
||||
|
||||
async def __run_hook(self, name: str, cmd: list[str]) -> None:
|
||||
logger = get_logger()
|
||||
cmd = self.__make_cmd(cmd)
|
||||
logger.info("%s: %s", name, tools.cmdfmt(cmd))
|
||||
try:
|
||||
await aioproc.log_process(cmd, logger, prefix=name)
|
||||
except Exception as ex:
|
||||
logger.exception("Can't execute command: %s", ex)
|
||||
|
||||
async def __start_streamer_proc(self) -> None:
|
||||
assert self.__streamer_proc is None
|
||||
cmd = self.__make_cmd(self.__cmd)
|
||||
self.__streamer_proc = await aioproc.run_process(cmd)
|
||||
get_logger(0).info("Started streamer pid=%d: %s", self.__streamer_proc.pid, tools.cmdfmt(cmd))
|
||||
|
||||
async def __kill_streamer_proc(self) -> None:
|
||||
if self.__streamer_proc:
|
||||
await aioproc.kill_process(self.__streamer_proc, 1, get_logger(0))
|
||||
self.__streamer_proc = None
|
||||
254
kvmd/apps/kvmd/streamer/__init__.py
Normal file
254
kvmd/apps/kvmd/streamer/__init__.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import signal
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import copy
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ....logging import get_logger
|
||||
|
||||
from ....clients.streamer import StreamerSnapshot
|
||||
from ....clients.streamer import HttpStreamerClient
|
||||
from ....clients.streamer import HttpStreamerClientSession
|
||||
|
||||
from .... import tools
|
||||
from .... import aiotools
|
||||
from .... import htclient
|
||||
|
||||
from .params import Params
|
||||
from .runner import Runner
|
||||
|
||||
|
||||
# =====
|
||||
class Streamer: # pylint: disable=too-many-instance-attributes
|
||||
__ST_FULL = 0xFF
|
||||
__ST_PARAMS = 0x01
|
||||
__ST_STREAMER = 0x02
|
||||
__ST_SNAPSHOT = 0x04
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
||||
self,
|
||||
|
||||
reset_delay: float,
|
||||
shutdown_delay: float,
|
||||
state_poll: float,
|
||||
|
||||
unix_path: str,
|
||||
timeout: float,
|
||||
snapshot_timeout: float,
|
||||
|
||||
process_name_prefix: str,
|
||||
|
||||
pre_start_cmd: list[str],
|
||||
pre_start_cmd_remove: list[str],
|
||||
pre_start_cmd_append: list[str],
|
||||
|
||||
cmd: list[str],
|
||||
cmd_remove: list[str],
|
||||
cmd_append: list[str],
|
||||
|
||||
post_stop_cmd: list[str],
|
||||
post_stop_cmd_remove: list[str],
|
||||
post_stop_cmd_append: list[str],
|
||||
|
||||
**params_kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
self.__state_poll = state_poll
|
||||
|
||||
self.__unix_path = unix_path
|
||||
self.__snapshot_timeout = snapshot_timeout
|
||||
self.__process_name_prefix = process_name_prefix
|
||||
|
||||
self.__params = Params(**params_kwargs)
|
||||
|
||||
self.__runner = Runner(
|
||||
reset_delay=reset_delay,
|
||||
shutdown_delay=shutdown_delay,
|
||||
pre_start_cmd=tools.build_cmd(pre_start_cmd, pre_start_cmd_remove, pre_start_cmd_append),
|
||||
cmd=tools.build_cmd(cmd, cmd_remove, cmd_append),
|
||||
post_stop_cmd=tools.build_cmd(post_stop_cmd, post_stop_cmd_remove, post_stop_cmd_append),
|
||||
)
|
||||
|
||||
self.__client = HttpStreamerClient(
|
||||
name="jpeg",
|
||||
unix_path=self.__unix_path,
|
||||
timeout=timeout,
|
||||
user_agent=htclient.make_user_agent("KVMD"),
|
||||
)
|
||||
self.__client_session: (HttpStreamerClientSession | None) = None
|
||||
|
||||
self.__snapshot: (StreamerSnapshot | None) = None
|
||||
|
||||
self.__notifier = aiotools.AioNotifier()
|
||||
|
||||
# =====
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_start(self) -> None:
|
||||
await self.__runner.ensure_start(self.__make_params())
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_restart(self) -> None:
|
||||
await self.__runner.ensure_restart(self.__make_params())
|
||||
|
||||
def __make_params(self) -> dict:
|
||||
return {
|
||||
"unix": self.__unix_path,
|
||||
"process_name_prefix": self.__process_name_prefix,
|
||||
**self.__params.get_params(),
|
||||
}
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_stop(self) -> None:
|
||||
await self.__runner.ensure_stop(immediately=False)
|
||||
|
||||
# =====
|
||||
|
||||
def set_params(self, params: dict) -> None:
|
||||
self.__notifier.notify(self.__ST_PARAMS)
|
||||
return self.__params.set_params(params)
|
||||
|
||||
def get_params(self) -> dict:
|
||||
return self.__params.get_params()
|
||||
|
||||
# =====
|
||||
|
||||
async def get_state(self) -> dict:
|
||||
return {
|
||||
"features": self.__params.get_features(),
|
||||
"limits": self.__params.get_limits(),
|
||||
"params": self.__params.get_params(),
|
||||
"streamer": (await self.__get_streamer_state()),
|
||||
"snapshot": self.__get_snapshot_state(),
|
||||
}
|
||||
|
||||
async def trigger_state(self) -> None:
|
||||
self.__notifier.notify(self.__ST_FULL)
|
||||
|
||||
async def poll_state(self) -> AsyncGenerator[dict, None]:
|
||||
# ==== Granularity table ====
|
||||
# - features -- Full
|
||||
# - limits -- Partial, paired with params
|
||||
# - params -- Partial, paired with limits
|
||||
# - streamer -- Partial, nullable
|
||||
# - snapshot -- Partial
|
||||
# ===========================
|
||||
|
||||
def signal_handler(*_: Any) -> None:
|
||||
get_logger(0).info("Got SIGUSR2, checking the stream state ...")
|
||||
self.__notifier.notify(self.__ST_STREAMER)
|
||||
|
||||
get_logger(0).info("Installing SIGUSR2 streamer handler ...")
|
||||
asyncio.get_event_loop().add_signal_handler(signal.SIGUSR2, signal_handler)
|
||||
|
||||
prev: dict = {}
|
||||
while True:
|
||||
new: dict = {}
|
||||
|
||||
mask = await self.__notifier.wait(timeout=self.__state_poll)
|
||||
if mask == self.__ST_FULL:
|
||||
new = await self.get_state()
|
||||
prev = copy.deepcopy(new)
|
||||
yield new
|
||||
continue
|
||||
|
||||
if mask < 0:
|
||||
mask = self.__ST_STREAMER
|
||||
|
||||
def check_update(key: str, value: (dict | None)) -> None:
|
||||
if prev.get(key) != value:
|
||||
new[key] = value
|
||||
|
||||
if mask & self.__ST_PARAMS:
|
||||
check_update("params", self.__params.get_params())
|
||||
if mask & self.__ST_STREAMER:
|
||||
check_update("streamer", await self.__get_streamer_state())
|
||||
if mask & self.__ST_SNAPSHOT:
|
||||
check_update("snapshot", self.__get_snapshot_state())
|
||||
|
||||
if new and prev != new:
|
||||
prev.update(copy.deepcopy(new))
|
||||
yield new
|
||||
|
||||
async def __get_streamer_state(self) -> (dict | None):
|
||||
if self.__runner.is_running():
|
||||
session = self.__ensure_client_session()
|
||||
try:
|
||||
return (await session.get_state())
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
|
||||
pass
|
||||
except Exception:
|
||||
get_logger().exception("Invalid streamer response from /state")
|
||||
return None
|
||||
|
||||
def __get_snapshot_state(self) -> dict:
|
||||
if self.__snapshot:
|
||||
snapshot = dataclasses.asdict(self.__snapshot)
|
||||
del snapshot["headers"]
|
||||
del snapshot["data"]
|
||||
return {"saved": snapshot}
|
||||
return {"saved": None}
|
||||
|
||||
# =====
|
||||
|
||||
async def take_snapshot(self, save: bool, load: bool, allow_offline: bool) -> (StreamerSnapshot | None):
|
||||
if load:
|
||||
return self.__snapshot
|
||||
logger = get_logger()
|
||||
session = self.__ensure_client_session()
|
||||
try:
|
||||
snapshot = await session.take_snapshot(self.__snapshot_timeout)
|
||||
if snapshot.online or allow_offline:
|
||||
if save:
|
||||
self.__snapshot = snapshot
|
||||
self.__notifier.notify(self.__ST_SNAPSHOT)
|
||||
return snapshot
|
||||
logger.error("Stream is offline, no signal or so")
|
||||
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as ex:
|
||||
logger.error("Can't connect to streamer: %s", tools.efmt(ex))
|
||||
except Exception:
|
||||
logger.exception("Invalid streamer response from /snapshot")
|
||||
return None
|
||||
|
||||
def remove_snapshot(self) -> None:
|
||||
self.__snapshot = None
|
||||
|
||||
# =====
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def cleanup(self) -> None:
|
||||
await self.__runner.ensure_stop(immediately=True)
|
||||
if self.__client_session:
|
||||
await self.__client_session.close()
|
||||
self.__client_session = None
|
||||
|
||||
def __ensure_client_session(self) -> HttpStreamerClientSession:
|
||||
if not self.__client_session:
|
||||
self.__client_session = self.__client.make_session()
|
||||
return self.__client_session
|
||||
117
kvmd/apps/kvmd/streamer/params.py
Normal file
117
kvmd/apps/kvmd/streamer/params.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
# =====
|
||||
class Params:
|
||||
__DESIRED_FPS = "desired_fps"
|
||||
|
||||
__QUALITY = "quality"
|
||||
|
||||
__RESOLUTION = "resolution"
|
||||
__AVAILABLE_RESOLUTIONS = "available_resolutions"
|
||||
|
||||
__H264 = "h264"
|
||||
__H264_BITRATE = "h264_bitrate"
|
||||
__H264_GOP = "h264_gop"
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
quality: int,
|
||||
|
||||
resolution: str,
|
||||
available_resolutions: list[str],
|
||||
|
||||
desired_fps: int,
|
||||
desired_fps_min: int,
|
||||
desired_fps_max: int,
|
||||
|
||||
h264_bitrate: int,
|
||||
h264_bitrate_min: int,
|
||||
h264_bitrate_max: int,
|
||||
|
||||
h264_gop: int,
|
||||
h264_gop_min: int,
|
||||
h264_gop_max: int,
|
||||
) -> None:
|
||||
|
||||
self.__has_quality = bool(quality)
|
||||
self.__has_resolution = bool(resolution)
|
||||
self.__has_h264 = bool(h264_bitrate)
|
||||
|
||||
self.__params: dict = {self.__DESIRED_FPS: min(max(desired_fps, desired_fps_min), desired_fps_max)}
|
||||
self.__limits: dict = {self.__DESIRED_FPS: {"min": desired_fps_min, "max": desired_fps_max}}
|
||||
|
||||
if self.__has_quality:
|
||||
self.__params[self.__QUALITY] = quality
|
||||
|
||||
if self.__has_resolution:
|
||||
self.__params[self.__RESOLUTION] = resolution
|
||||
self.__limits[self.__AVAILABLE_RESOLUTIONS] = available_resolutions
|
||||
|
||||
if self.__has_h264:
|
||||
self.__params[self.__H264_BITRATE] = min(max(h264_bitrate, h264_bitrate_min), h264_bitrate_max)
|
||||
self.__limits[self.__H264_BITRATE] = {"min": h264_bitrate_min, "max": h264_bitrate_max}
|
||||
self.__params[self.__H264_GOP] = min(max(h264_gop, h264_gop_min), h264_gop_max)
|
||||
self.__limits[self.__H264_GOP] = {"min": h264_gop_min, "max": h264_gop_max}
|
||||
|
||||
def get_features(self) -> dict:
|
||||
return {
|
||||
self.__QUALITY: self.__has_quality,
|
||||
self.__RESOLUTION: self.__has_resolution,
|
||||
self.__H264: self.__has_h264,
|
||||
}
|
||||
|
||||
def get_limits(self) -> dict:
|
||||
limits = copy.deepcopy(self.__limits)
|
||||
if self.__has_resolution:
|
||||
limits[self.__AVAILABLE_RESOLUTIONS] = list(limits[self.__AVAILABLE_RESOLUTIONS])
|
||||
return limits
|
||||
|
||||
def get_params(self) -> dict:
|
||||
return dict(self.__params)
|
||||
|
||||
def set_params(self, params: dict) -> None:
|
||||
new = dict(self.__params)
|
||||
|
||||
if self.__QUALITY in params and self.__has_quality:
|
||||
new[self.__QUALITY] = min(max(params[self.__QUALITY], 1), 100)
|
||||
|
||||
if self.__RESOLUTION in params and self.__has_resolution:
|
||||
if params[self.__RESOLUTION] in self.__limits[self.__AVAILABLE_RESOLUTIONS]:
|
||||
new[self.__RESOLUTION] = params[self.__RESOLUTION]
|
||||
|
||||
for (key, enabled) in [
|
||||
(self.__DESIRED_FPS, True),
|
||||
(self.__H264_BITRATE, self.__has_h264),
|
||||
(self.__H264_GOP, self.__has_h264),
|
||||
]:
|
||||
if key in params and enabled:
|
||||
if self.__check_limits_min_max(key, params[key]):
|
||||
new[key] = params[key]
|
||||
|
||||
self.__params = new
|
||||
|
||||
def __check_limits_min_max(self, key: str, value: int) -> bool:
|
||||
return (self.__limits[key]["min"] <= value <= self.__limits[key]["max"])
|
||||
182
kvmd/apps/kvmd/streamer/runner.py
Normal file
182
kvmd/apps/kvmd/streamer/runner.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import asyncio
|
||||
import asyncio.subprocess
|
||||
|
||||
from ....logging import get_logger
|
||||
|
||||
from .... import tools
|
||||
from .... import aiotools
|
||||
from .... import aioproc
|
||||
|
||||
|
||||
# =====
|
||||
class Runner: # pylint: disable=too-many-instance-attributes
|
||||
def __init__(
|
||||
self,
|
||||
reset_delay: float,
|
||||
shutdown_delay: float,
|
||||
|
||||
pre_start_cmd: list[str],
|
||||
cmd: list[str],
|
||||
post_stop_cmd: list[str],
|
||||
) -> None:
|
||||
|
||||
self.__reset_delay = reset_delay
|
||||
self.__shutdown_delay = shutdown_delay
|
||||
|
||||
self.__pre_start_cmd: list[str] = pre_start_cmd
|
||||
self.__cmd: list[str] = cmd
|
||||
self.__post_stop_cmd: list[str] = post_stop_cmd
|
||||
|
||||
self.__proc_params: dict = {}
|
||||
self.__proc_task: (asyncio.Task | None) = None
|
||||
self.__proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member
|
||||
|
||||
self.__stopper_task: (asyncio.Task | None) = None
|
||||
self.__stopper_wip = False
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_start(self, params: dict) -> None:
|
||||
if not self.__proc_task or self.__stopper_task:
|
||||
logger = get_logger(0)
|
||||
|
||||
if self.__stopper_task:
|
||||
if not self.__stopper_wip:
|
||||
self.__stopper_task.cancel()
|
||||
await asyncio.gather(self.__stopper_task, return_exceptions=True)
|
||||
logger.info("Streamer stop cancelled")
|
||||
return
|
||||
else:
|
||||
await asyncio.gather(self.__stopper_task, return_exceptions=True)
|
||||
|
||||
logger.info("Starting streamer ...")
|
||||
await self.__inner_start(params)
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_restart(self, params: dict) -> None:
|
||||
logger = get_logger(0)
|
||||
start = bool(self.__proc_task and not self.__stopper_task) # Если запущено и не планирует останавливаться
|
||||
await self.ensure_stop(immediately=True)
|
||||
if self.__reset_delay > 0:
|
||||
logger.info("Waiting %.2f seconds for reset delay ...", self.__reset_delay)
|
||||
await asyncio.sleep(self.__reset_delay)
|
||||
if start:
|
||||
await self.ensure_start(params)
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def ensure_stop(self, immediately: bool) -> None:
|
||||
if self.__proc_task:
|
||||
logger = get_logger(0)
|
||||
|
||||
if immediately:
|
||||
if self.__stopper_task:
|
||||
if not self.__stopper_wip:
|
||||
self.__stopper_task.cancel()
|
||||
await asyncio.gather(self.__stopper_task, return_exceptions=True)
|
||||
logger.info("Stopping streamer immediately ...")
|
||||
await self.__inner_stop()
|
||||
else:
|
||||
await asyncio.gather(self.__stopper_task, return_exceptions=True)
|
||||
else:
|
||||
logger.info("Stopping streamer immediately ...")
|
||||
await self.__inner_stop()
|
||||
|
||||
elif not self.__stopper_task:
|
||||
|
||||
async def delayed_stop() -> None:
|
||||
try:
|
||||
await asyncio.sleep(self.__shutdown_delay)
|
||||
self.__stopper_wip = True
|
||||
logger.info("Stopping streamer after delay ...")
|
||||
await self.__inner_stop()
|
||||
finally:
|
||||
self.__stopper_task = None
|
||||
self.__stopper_wip = False
|
||||
|
||||
logger.info("Planning to stop streamer in %.2f seconds ...", self.__shutdown_delay)
|
||||
self.__stopper_task = asyncio.create_task(delayed_stop())
|
||||
|
||||
def is_running(self) -> bool:
|
||||
return bool(self.__proc_task)
|
||||
|
||||
# =====
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def __inner_start(self, params: dict) -> None:
|
||||
assert not self.__proc_task
|
||||
self.__proc_params = params
|
||||
await self.__run_hook("PRE-START-CMD", self.__pre_start_cmd)
|
||||
self.__proc_task = asyncio.create_task(self.__process_task_loop())
|
||||
|
||||
@aiotools.atomic_fg
|
||||
async def __inner_stop(self) -> None:
|
||||
assert self.__proc_task
|
||||
self.__proc_task.cancel()
|
||||
await asyncio.gather(self.__proc_task, return_exceptions=True)
|
||||
await self.__kill_process()
|
||||
await self.__run_hook("POST-STOP-CMD", self.__post_stop_cmd)
|
||||
self.__proc_task = None
|
||||
|
||||
# =====
|
||||
|
||||
async def __process_task_loop(self) -> None: # pylint: disable=too-many-branches
|
||||
logger = get_logger(0)
|
||||
while True: # pylint: disable=too-many-nested-blocks
|
||||
try:
|
||||
await self.__start_process()
|
||||
assert self.__proc is not None
|
||||
await aioproc.log_stdout_infinite(self.__proc, logger)
|
||||
raise RuntimeError("Streamer unexpectedly died")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
if self.__proc:
|
||||
logger.exception("Unexpected streamer error: pid=%d", self.__proc.pid)
|
||||
else:
|
||||
logger.exception("Can't start streamer")
|
||||
await self.__kill_process()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def __make_cmd(self, cmd: list[str]) -> list[str]:
|
||||
return [part.format(**self.__proc_params) for part in cmd]
|
||||
|
||||
async def __run_hook(self, name: str, cmd: list[str]) -> None:
|
||||
logger = get_logger()
|
||||
cmd = self.__make_cmd(cmd)
|
||||
logger.info("%s: %s", name, tools.cmdfmt(cmd))
|
||||
try:
|
||||
await aioproc.log_process(cmd, logger, prefix=name)
|
||||
except Exception:
|
||||
logger.exception("Can't execute %s hook: %s", name, tools.cmdfmt(cmd))
|
||||
|
||||
async def __start_process(self) -> None:
|
||||
assert self.__proc is None
|
||||
cmd = self.__make_cmd(self.__cmd)
|
||||
self.__proc = await aioproc.run_process(cmd)
|
||||
get_logger(0).info("Started streamer pid=%d: %s", self.__proc.pid, tools.cmdfmt(cmd))
|
||||
|
||||
async def __kill_process(self) -> None:
|
||||
if self.__proc:
|
||||
await aioproc.kill_process(self.__proc, 1, get_logger(0))
|
||||
self.__proc = None
|
||||
@@ -32,6 +32,7 @@ from .lib import Inotify
|
||||
|
||||
from .types import Edid
|
||||
from .types import Edids
|
||||
from .types import Dummies
|
||||
from .types import Color
|
||||
from .types import Colors
|
||||
from .types import PortNames
|
||||
@@ -68,6 +69,7 @@ class SwitchUnknownEdidError(SwitchOperationError):
|
||||
# =====
|
||||
class Switch: # pylint: disable=too-many-public-methods
|
||||
__X_EDIDS = "edids"
|
||||
__X_DUMMIES = "dummies"
|
||||
__X_COLORS = "colors"
|
||||
__X_PORT_NAMES = "port_names"
|
||||
__X_ATX_CP_DELAYS = "atx_cp_delays"
|
||||
@@ -75,7 +77,7 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
__X_ATX_CR_DELAYS = "atx_cr_delays"
|
||||
|
||||
__X_ALL = frozenset([
|
||||
__X_EDIDS, __X_COLORS, __X_PORT_NAMES,
|
||||
__X_EDIDS, __X_DUMMIES, __X_COLORS, __X_PORT_NAMES,
|
||||
__X_ATX_CP_DELAYS, __X_ATX_CPL_DELAYS, __X_ATX_CR_DELAYS,
|
||||
])
|
||||
|
||||
@@ -84,11 +86,12 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
device_path: str,
|
||||
default_edid_path: str,
|
||||
pst_unix_path: str,
|
||||
ignore_hpd_on_top: bool,
|
||||
) -> None:
|
||||
|
||||
self.__default_edid_path = default_edid_path
|
||||
|
||||
self.__chain = Chain(device_path)
|
||||
self.__chain = Chain(device_path, ignore_hpd_on_top)
|
||||
self.__cache = StateCache()
|
||||
self.__storage = Storage(pst_unix_path)
|
||||
|
||||
@@ -104,6 +107,12 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
if save:
|
||||
self.__save_notifier.notify()
|
||||
|
||||
def __x_set_dummies(self, dummies: Dummies, save: bool=True) -> None:
|
||||
self.__chain.set_dummies(dummies)
|
||||
self.__cache.set_dummies(dummies)
|
||||
if save:
|
||||
self.__save_notifier.notify()
|
||||
|
||||
def __x_set_colors(self, colors: Colors, save: bool=True) -> None:
|
||||
self.__chain.set_colors(colors)
|
||||
self.__cache.set_colors(colors)
|
||||
@@ -132,13 +141,19 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
|
||||
# =====
|
||||
|
||||
async def set_active_port(self, port: int) -> None:
|
||||
self.__chain.set_active_port(port)
|
||||
async def set_active_prev(self) -> None:
|
||||
self.__chain.set_active_prev()
|
||||
|
||||
async def set_active_next(self) -> None:
|
||||
self.__chain.set_active_next()
|
||||
|
||||
async def set_active_port(self, port: float) -> None:
|
||||
self.__chain.set_active_port(self.__chain.translate_port(port))
|
||||
|
||||
# =====
|
||||
|
||||
async def set_port_beacon(self, port: int, on: bool) -> None:
|
||||
self.__chain.set_port_beacon(port, on)
|
||||
async def set_port_beacon(self, port: float, on: bool) -> None:
|
||||
self.__chain.set_port_beacon(self.__chain.translate_port(port), on)
|
||||
|
||||
async def set_uplink_beacon(self, unit: int, on: bool) -> None:
|
||||
self.__chain.set_uplink_beacon(unit, on)
|
||||
@@ -148,33 +163,35 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
|
||||
# =====
|
||||
|
||||
async def atx_power_on(self, port: int) -> None:
|
||||
async def atx_power_on(self, port: float) -> None:
|
||||
self.__inner_atx_cp(port, False, self.__X_ATX_CP_DELAYS)
|
||||
|
||||
async def atx_power_off(self, port: int) -> None:
|
||||
async def atx_power_off(self, port: float) -> None:
|
||||
self.__inner_atx_cp(port, True, self.__X_ATX_CP_DELAYS)
|
||||
|
||||
async def atx_power_off_hard(self, port: int) -> None:
|
||||
async def atx_power_off_hard(self, port: float) -> None:
|
||||
self.__inner_atx_cp(port, True, self.__X_ATX_CPL_DELAYS)
|
||||
|
||||
async def atx_power_reset_hard(self, port: int) -> None:
|
||||
async def atx_power_reset_hard(self, port: float) -> None:
|
||||
self.__inner_atx_cr(port, True)
|
||||
|
||||
async def atx_click_power(self, port: int) -> None:
|
||||
async def atx_click_power(self, port: float) -> None:
|
||||
self.__inner_atx_cp(port, None, self.__X_ATX_CP_DELAYS)
|
||||
|
||||
async def atx_click_power_long(self, port: int) -> None:
|
||||
async def atx_click_power_long(self, port: float) -> None:
|
||||
self.__inner_atx_cp(port, None, self.__X_ATX_CPL_DELAYS)
|
||||
|
||||
async def atx_click_reset(self, port: int) -> None:
|
||||
async def atx_click_reset(self, port: float) -> None:
|
||||
self.__inner_atx_cr(port, None)
|
||||
|
||||
def __inner_atx_cp(self, port: int, if_powered: (bool | None), x_delay: str) -> None:
|
||||
def __inner_atx_cp(self, port: float, if_powered: (bool | None), x_delay: str) -> None:
|
||||
assert x_delay in [self.__X_ATX_CP_DELAYS, self.__X_ATX_CPL_DELAYS]
|
||||
port = self.__chain.translate_port(port)
|
||||
delay = getattr(self.__cache, f"get_{x_delay}")()[port]
|
||||
self.__chain.click_power(port, delay, if_powered)
|
||||
|
||||
def __inner_atx_cr(self, port: int, if_powered: (bool | None)) -> None:
|
||||
def __inner_atx_cr(self, port: float, if_powered: (bool | None)) -> None:
|
||||
port = self.__chain.translate_port(port)
|
||||
delay = self.__cache.get_atx_cr_delays()[port]
|
||||
self.__chain.click_reset(port, delay, if_powered)
|
||||
|
||||
@@ -235,12 +252,14 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
self,
|
||||
port: int,
|
||||
edid_id: (str | None)=None,
|
||||
dummy: (bool | None)=None,
|
||||
name: (str | None)=None,
|
||||
atx_click_power_delay: (float | None)=None,
|
||||
atx_click_power_long_delay: (float | None)=None,
|
||||
atx_click_reset_delay: (float | None)=None,
|
||||
) -> None:
|
||||
|
||||
port = self.__chain.translate_port(port)
|
||||
async with self.__lock:
|
||||
if edid_id is not None:
|
||||
edids = self.__cache.get_edids()
|
||||
@@ -249,15 +268,16 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
edids.assign(port, edid_id)
|
||||
self.__x_set_edids(edids)
|
||||
|
||||
for (key, value) in [
|
||||
(self.__X_PORT_NAMES, name),
|
||||
(self.__X_ATX_CP_DELAYS, atx_click_power_delay),
|
||||
(self.__X_ATX_CPL_DELAYS, atx_click_power_long_delay),
|
||||
(self.__X_ATX_CR_DELAYS, atx_click_reset_delay),
|
||||
for (reset, key, value) in [
|
||||
(None, self.__X_DUMMIES, dummy), # None can't be used now
|
||||
("", self.__X_PORT_NAMES, name),
|
||||
(0, self.__X_ATX_CP_DELAYS, atx_click_power_delay),
|
||||
(0, self.__X_ATX_CPL_DELAYS, atx_click_power_long_delay),
|
||||
(0, self.__X_ATX_CR_DELAYS, atx_click_reset_delay),
|
||||
]:
|
||||
if value is not None:
|
||||
new = getattr(self.__cache, f"get_{key}")()
|
||||
new[port] = (value or None) # None == reset to default
|
||||
new[port] = (None if value == reset else value) # Value or reset default
|
||||
getattr(self, f"_Switch__x_set_{key}")(new)
|
||||
|
||||
# =====
|
||||
@@ -374,7 +394,7 @@ class Switch: # pylint: disable=too-many-public-methods
|
||||
prevs = dict.fromkeys(self.__X_ALL)
|
||||
while True:
|
||||
await self.__save_notifier.wait()
|
||||
while (await self.__save_notifier.wait(5)):
|
||||
while not (await self.__save_notifier.wait(5)):
|
||||
pass
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -34,6 +34,7 @@ from .lib import aiotools
|
||||
from .lib import aioproc
|
||||
|
||||
from .types import Edids
|
||||
from .types import Dummies
|
||||
from .types import Colors
|
||||
|
||||
from .proto import Response
|
||||
@@ -54,6 +55,14 @@ class _CmdSetActual(_BaseCmd):
|
||||
actual: bool
|
||||
|
||||
|
||||
class _CmdSetActivePrev(_BaseCmd):
|
||||
pass
|
||||
|
||||
|
||||
class _CmdSetActiveNext(_BaseCmd):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _CmdSetActivePort(_BaseCmd):
|
||||
port: int
|
||||
@@ -80,6 +89,11 @@ class _CmdSetEdids(_BaseCmd):
|
||||
edids: Edids
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _CmdSetDummies(_BaseCmd):
|
||||
dummies: Dummies
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _CmdSetColors(_BaseCmd):
|
||||
colors: Colors
|
||||
@@ -177,13 +191,19 @@ class UnitAtxLedsEvent(BaseEvent):
|
||||
|
||||
# =====
|
||||
class Chain: # pylint: disable=too-many-instance-attributes
|
||||
def __init__(self, device_path: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
device_path: str,
|
||||
ignore_hpd_on_top: bool,
|
||||
) -> None:
|
||||
|
||||
self.__device = Device(device_path)
|
||||
self.__ignore_hpd_on_top = ignore_hpd_on_top
|
||||
|
||||
self.__actual = False
|
||||
|
||||
self.__edids = Edids()
|
||||
|
||||
self.__dummies = Dummies({})
|
||||
self.__colors = Colors()
|
||||
|
||||
self.__units: list[_UnitContext] = []
|
||||
@@ -200,6 +220,24 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
# =====
|
||||
|
||||
def translate_port(self, port: float) -> int:
|
||||
assert port >= 0
|
||||
if int(port) == port:
|
||||
return int(port)
|
||||
(unit, ch) = map(int, str(port).split("."))
|
||||
unit = min(max(unit, 1), 5)
|
||||
ch = min(max(ch, 1), 4)
|
||||
port = min((unit - 1) * 4 + (ch - 1), 19)
|
||||
return port
|
||||
|
||||
# =====
|
||||
|
||||
def set_active_prev(self) -> None:
|
||||
self.__queue_cmd(_CmdSetActivePrev())
|
||||
|
||||
def set_active_next(self) -> None:
|
||||
self.__queue_cmd(_CmdSetActiveNext())
|
||||
|
||||
def set_active_port(self, port: int) -> None:
|
||||
self.__queue_cmd(_CmdSetActivePort(port))
|
||||
|
||||
@@ -219,6 +257,9 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
def set_edids(self, edids: Edids) -> None:
|
||||
self.__queue_cmd(_CmdSetEdids(edids)) # Will be copied because of multiprocessing.Queue()
|
||||
|
||||
def set_dummies(self, dummies: Dummies) -> None:
|
||||
self.__queue_cmd(_CmdSetDummies(dummies))
|
||||
|
||||
def set_colors(self, colors: Colors) -> None:
|
||||
self.__queue_cmd(_CmdSetColors(colors))
|
||||
|
||||
@@ -290,12 +331,21 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
self.__device.request_state()
|
||||
self.__device.request_atx_leds()
|
||||
while not self.__stop_event.is_set():
|
||||
count = 0
|
||||
if self.__select():
|
||||
count = 0
|
||||
for resp in self.__device.read_all():
|
||||
self.__update_units(resp)
|
||||
self.__adjust_quirks()
|
||||
self.__adjust_start_port()
|
||||
self.__finish_changing_request(resp)
|
||||
self.__consume_commands()
|
||||
else:
|
||||
count += 1
|
||||
if count >= 5:
|
||||
# Heartbeat
|
||||
self.__device.request_state()
|
||||
count = 0
|
||||
self.__ensure_config()
|
||||
|
||||
def __select(self) -> bool:
|
||||
@@ -314,10 +364,29 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
case _CmdSetActual():
|
||||
self.__actual = cmd.actual
|
||||
|
||||
case _CmdSetActivePrev():
|
||||
if len(self.__units) > 0:
|
||||
port = self.__active_port
|
||||
port -= 1
|
||||
if port >= 0:
|
||||
self.__active_port = port
|
||||
self.__queue_event(PortActivatedEvent(self.__active_port))
|
||||
|
||||
case _CmdSetActiveNext():
|
||||
port = self.__active_port
|
||||
if port < 0:
|
||||
port = 0
|
||||
else:
|
||||
port += 1
|
||||
if port < len(self.__units) * 4:
|
||||
self.__active_port = port
|
||||
self.__queue_event(PortActivatedEvent(self.__active_port))
|
||||
|
||||
case _CmdSetActivePort():
|
||||
# Может быть вызвано изнутри при синхронизации
|
||||
self.__active_port = cmd.port
|
||||
self.__queue_event(PortActivatedEvent(self.__active_port))
|
||||
if cmd.port < len(self.__units) * 4:
|
||||
self.__active_port = cmd.port
|
||||
self.__queue_event(PortActivatedEvent(self.__active_port))
|
||||
|
||||
case _CmdSetPortBeacon():
|
||||
(unit, ch) = self.get_real_unit_channel(cmd.port)
|
||||
@@ -341,6 +410,9 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
case _CmdSetEdids():
|
||||
self.__edids = cmd.edids
|
||||
|
||||
case _CmdSetDummies():
|
||||
self.__dummies = cmd.dummies
|
||||
|
||||
case _CmdSetColors():
|
||||
self.__colors = cmd.colors
|
||||
|
||||
@@ -364,6 +436,15 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
self.__units[resp.header.unit].atx_leds = resp.body
|
||||
self.__queue_event(UnitAtxLedsEvent(resp.header.unit, resp.body))
|
||||
|
||||
def __adjust_quirks(self) -> None:
|
||||
for (unit, ctx) in enumerate(self.__units):
|
||||
if ctx.state is not None and ctx.state.version.is_fresh(7):
|
||||
ignore_hpd = (unit == 0 and self.__ignore_hpd_on_top)
|
||||
if ctx.state.quirks.ignore_hpd != ignore_hpd:
|
||||
get_logger().info("Applying quirk ignore_hpd=%s to [%d] ...",
|
||||
ignore_hpd, unit)
|
||||
self.__device.request_set_quirks(unit, ignore_hpd)
|
||||
|
||||
def __adjust_start_port(self) -> None:
|
||||
if self.__active_port < 0:
|
||||
for (unit, ctx) in enumerate(self.__units):
|
||||
@@ -387,6 +468,7 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
self.__ensure_config_port(unit, ctx)
|
||||
if self.__actual:
|
||||
self.__ensure_config_edids(unit, ctx)
|
||||
self.__ensure_config_dummies(unit, ctx)
|
||||
self.__ensure_config_colors(unit, ctx)
|
||||
|
||||
def __ensure_config_port(self, unit: int, ctx: _UnitContext) -> None:
|
||||
@@ -413,6 +495,19 @@ class Chain: # pylint: disable=too-many-instance-attributes
|
||||
ctx.changing_rid = self.__device.request_set_edid(unit, ch, edid)
|
||||
break # Busy globally
|
||||
|
||||
def __ensure_config_dummies(self, unit: int, ctx: _UnitContext) -> None:
|
||||
assert ctx.state is not None
|
||||
if ctx.state.version.is_fresh(8) and ctx.can_be_changed():
|
||||
for ch in range(4):
|
||||
port = self.get_virtual_port(unit, ch)
|
||||
dummy = self.__dummies[port]
|
||||
if ctx.state.video_dummies[ch] != dummy:
|
||||
get_logger().info("Changing dummy flag on port %d on [%d:%d]: %d -> %d ...",
|
||||
port, unit, ch,
|
||||
ctx.state.video_dummies[ch], dummy)
|
||||
ctx.changing_rid = self.__device.request_set_dummy(unit, ch, dummy)
|
||||
break # Busy globally (actually not but it can be changed in the firmware)
|
||||
|
||||
def __ensure_config_colors(self, unit: int, ctx: _UnitContext) -> None:
|
||||
assert self.__actual
|
||||
assert ctx.state is not None
|
||||
|
||||
@@ -41,7 +41,9 @@ from .proto import BodySetBeacon
|
||||
from .proto import BodyAtxClick
|
||||
from .proto import BodySetEdid
|
||||
from .proto import BodyClearEdid
|
||||
from .proto import BodySetDummy
|
||||
from .proto import BodySetColors
|
||||
from .proto import BodySetQuirks
|
||||
|
||||
|
||||
# =====
|
||||
@@ -163,9 +165,15 @@ class Device:
|
||||
return self.__send_request(Header.SET_EDID, unit, BodySetEdid(ch, edid))
|
||||
return self.__send_request(Header.CLEAR_EDID, unit, BodyClearEdid(ch))
|
||||
|
||||
def request_set_dummy(self, unit: int, ch: int, on: bool) -> int:
|
||||
return self.__send_request(Header.SET_DUMMY, unit, BodySetDummy(ch, on))
|
||||
|
||||
def request_set_colors(self, unit: int, ch: int, colors: Colors) -> int:
|
||||
return self.__send_request(Header.SET_COLORS, unit, BodySetColors(ch, colors))
|
||||
|
||||
def request_set_quirks(self, unit: int, ignore_hpd: bool) -> int:
|
||||
return self.__send_request(Header.SET_QUIRKS, unit, BodySetQuirks(ignore_hpd))
|
||||
|
||||
def __send_request(self, op: int, unit: int, body: (Packable | None)) -> int:
|
||||
assert self.__tty is not None
|
||||
req = Request(Header(
|
||||
|
||||
@@ -60,6 +60,8 @@ class Header(Packable, Unpackable):
|
||||
SET_EDID = 9
|
||||
CLEAR_EDID = 10
|
||||
SET_COLORS = 12
|
||||
SET_QUIRKS = 13
|
||||
SET_DUMMY = 14
|
||||
|
||||
__struct = struct.Struct("<BHBB")
|
||||
|
||||
@@ -89,17 +91,32 @@ class Nak(Unpackable):
|
||||
return Nak(*cls.__struct.unpack_from(data, offset=offset))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnitVersion:
|
||||
hw: int
|
||||
sw: int
|
||||
sw_dev: bool
|
||||
|
||||
def is_fresh(self, version: int) -> bool:
|
||||
return (self.sw_dev or (self.sw >= version))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnitFlags:
|
||||
changing_busy: bool
|
||||
flashing_busy: bool
|
||||
has_downlink: bool
|
||||
has_hpd: bool
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnitQuirks:
|
||||
ignore_hpd: bool
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UnitState(Unpackable): # pylint: disable=too-many-instance-attributes
|
||||
sw_version: int
|
||||
hw_version: int
|
||||
version: UnitVersion
|
||||
flags: UnitFlags
|
||||
ch: int
|
||||
beacons: tuple[bool, bool, bool, bool, bool, bool]
|
||||
@@ -108,10 +125,12 @@ class UnitState(Unpackable): # pylint: disable=too-many-instance-attributes
|
||||
video_hpd: tuple[bool, bool, bool, bool, bool]
|
||||
video_edid: tuple[bool, bool, bool, bool]
|
||||
video_crc: tuple[int, int, int, int]
|
||||
video_dummies: tuple[bool, bool, bool, bool]
|
||||
usb_5v_sens: tuple[bool, bool, bool, bool]
|
||||
atx_busy: tuple[bool, bool, bool, bool]
|
||||
quirks: UnitQuirks
|
||||
|
||||
__struct = struct.Struct("<HHHBBHHHHHHBBBHHHHBxB30x")
|
||||
__struct = struct.Struct("<HHHBBHHHHHHBBBHHHHBxBBB28x")
|
||||
|
||||
def compare_edid(self, ch: int, edid: Optional["Edid"]) -> bool:
|
||||
if edid is None:
|
||||
@@ -128,15 +147,19 @@ class UnitState(Unpackable): # pylint: disable=too-many-instance-attributes
|
||||
sw_version, hw_version, flags, ch,
|
||||
beacons, nc0, nc1, nc2, nc3, nc4, nc5,
|
||||
video_5v_sens, video_hpd, video_edid, vc0, vc1, vc2, vc3,
|
||||
usb_5v_sens, atx_busy,
|
||||
usb_5v_sens, atx_busy, quirks, video_dummies,
|
||||
) = cls.__struct.unpack_from(data, offset=offset)
|
||||
return UnitState(
|
||||
sw_version,
|
||||
hw_version,
|
||||
version=UnitVersion(
|
||||
hw=hw_version,
|
||||
sw=(sw_version & 0x7FFF),
|
||||
sw_dev=bool(sw_version & 0x8000),
|
||||
),
|
||||
flags=UnitFlags(
|
||||
changing_busy=bool(flags & 0x80),
|
||||
flashing_busy=bool(flags & 0x40),
|
||||
has_downlink=bool(flags & 0x02),
|
||||
has_hpd=bool(flags & 0x04),
|
||||
),
|
||||
ch=ch,
|
||||
beacons=cls.__make_flags6(beacons),
|
||||
@@ -145,8 +168,10 @@ class UnitState(Unpackable): # pylint: disable=too-many-instance-attributes
|
||||
video_hpd=cls.__make_flags5(video_hpd),
|
||||
video_edid=cls.__make_flags4(video_edid),
|
||||
video_crc=(vc0, vc1, vc2, vc3),
|
||||
video_dummies=cls.__make_flags4(video_dummies),
|
||||
usb_5v_sens=cls.__make_flags4(usb_5v_sens),
|
||||
atx_busy=cls.__make_flags4(atx_busy),
|
||||
quirks=UnitQuirks(ignore_hpd=bool(quirks & 0x01)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -251,6 +276,18 @@ class BodyClearEdid(Packable):
|
||||
return self.ch.to_bytes()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BodySetDummy(Packable):
|
||||
ch: int
|
||||
on: bool
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert 0 <= self.ch <= 3
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return self.ch.to_bytes() + self.on.to_bytes()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BodySetColors(Packable):
|
||||
ch: int
|
||||
@@ -263,6 +300,14 @@ class BodySetColors(Packable):
|
||||
return self.ch.to_bytes() + self.colors.pack()
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BodySetQuirks(Packable):
|
||||
ignore_hpd: bool
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return self.ignore_hpd.to_bytes()
|
||||
|
||||
|
||||
# =====
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Request:
|
||||
|
||||
@@ -27,6 +27,7 @@ import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .types import Edids
|
||||
from .types import Dummies
|
||||
from .types import Color
|
||||
from .types import Colors
|
||||
from .types import PortNames
|
||||
@@ -48,8 +49,8 @@ class _UnitInfo:
|
||||
|
||||
|
||||
# =====
|
||||
class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
__FW_VERSION = 5
|
||||
class StateCache: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
__FW_VERSION = 8
|
||||
|
||||
__FULL = 0xFFFF
|
||||
__SUMMARY = 0x01
|
||||
@@ -62,6 +63,7 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.__edids = Edids()
|
||||
self.__dummies = Dummies({})
|
||||
self.__colors = Colors()
|
||||
self.__port_names = PortNames({})
|
||||
self.__atx_cp_delays = AtxClickPowerDelays({})
|
||||
@@ -77,6 +79,9 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
def get_edids(self) -> Edids:
|
||||
return self.__edids.copy()
|
||||
|
||||
def get_dummies(self) -> Dummies:
|
||||
return self.__dummies.copy()
|
||||
|
||||
def get_colors(self) -> Colors:
|
||||
return self.__colors
|
||||
|
||||
@@ -158,7 +163,17 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
},
|
||||
}
|
||||
if x_summary:
|
||||
state["summary"] = {"active_port": self.__active_port, "synced": self.__synced}
|
||||
state["summary"] = {
|
||||
"active_port": self.__active_port,
|
||||
"active_id": (
|
||||
"" if self.__active_port < 0 else (
|
||||
f"{self.__active_port // 4 + 1}.{self.__active_port % 4 + 1}"
|
||||
if len(self.__units) > 1 else
|
||||
f"{self.__active_port + 1}"
|
||||
)
|
||||
),
|
||||
"synced": self.__synced,
|
||||
}
|
||||
if x_edids:
|
||||
state["edids"] = {
|
||||
"all": {
|
||||
@@ -195,7 +210,10 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
assert ui.state is not None
|
||||
assert ui.atx_leds is not None
|
||||
if x_model:
|
||||
state["model"]["units"].append({"firmware": {"version": ui.state.sw_version}})
|
||||
state["model"]["units"].append({"firmware": {
|
||||
"version": ui.state.version.sw,
|
||||
"devbuild": ui.state.version.sw_dev,
|
||||
}})
|
||||
if x_video:
|
||||
state["video"]["links"].extend(ui.state.video_5v_sens[:4])
|
||||
if x_usb:
|
||||
@@ -216,6 +234,7 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
"unit": unit,
|
||||
"channel": ch,
|
||||
"name": self.__port_names[port],
|
||||
"id": (f"{unit + 1}.{ch + 1}" if len(self.__units) > 1 else f"{ch + 1}"),
|
||||
"atx": {
|
||||
"click_delays": {
|
||||
"power": self.__atx_cp_delays[port],
|
||||
@@ -223,6 +242,9 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
"reset": self.__atx_cr_delays[port],
|
||||
},
|
||||
},
|
||||
"video": {
|
||||
"dummy": self.__dummies[port],
|
||||
},
|
||||
})
|
||||
if x_edids:
|
||||
state["edids"]["used"].append(self.__edids.get_id_for_port(port))
|
||||
@@ -324,6 +346,12 @@ class StateCache: # pylint: disable=too-many-instance-attributes
|
||||
if changed:
|
||||
self.__bump_state(self.__EDIDS)
|
||||
|
||||
def set_dummies(self, dummies: Dummies) -> None:
|
||||
changed = (not self.__dummies.compare_on_ports(dummies, self.__get_ports()))
|
||||
self.__dummies = dummies.copy()
|
||||
if changed:
|
||||
self.__bump_state(self.__FULL)
|
||||
|
||||
def set_colors(self, colors: Colors) -> None:
|
||||
changed = (self.__colors != colors)
|
||||
self.__colors = colors
|
||||
|
||||
@@ -39,6 +39,7 @@ from .lib import get_logger
|
||||
|
||||
from .types import Edid
|
||||
from .types import Edids
|
||||
from .types import Dummies
|
||||
from .types import Color
|
||||
from .types import Colors
|
||||
from .types import PortNames
|
||||
@@ -52,6 +53,8 @@ class StorageContext:
|
||||
__F_EDIDS_ALL = "edids_all.json"
|
||||
__F_EDIDS_PORT = "edids_port.json"
|
||||
|
||||
__F_DUMMIES = "dummies.json"
|
||||
|
||||
__F_COLORS = "colors.json"
|
||||
|
||||
__F_PORT_NAMES = "port_names.json"
|
||||
@@ -74,6 +77,9 @@ class StorageContext:
|
||||
})
|
||||
await self.__write_json_keyvals(self.__F_EDIDS_PORT, edids.port)
|
||||
|
||||
async def write_dummies(self, dummies: Dummies) -> None:
|
||||
await self.__write_json_keyvals(self.__F_DUMMIES, dummies.kvs)
|
||||
|
||||
async def write_colors(self, colors: Colors) -> None:
|
||||
await self.__write_json_keyvals(self.__F_COLORS, {
|
||||
role: {
|
||||
@@ -116,6 +122,10 @@ class StorageContext:
|
||||
port_edids = await self.__read_json_keyvals_int(self.__F_EDIDS_PORT)
|
||||
return Edids(all_edids, port_edids)
|
||||
|
||||
async def read_dummies(self) -> Dummies:
|
||||
kvs = await self.__read_json_keyvals_int(self.__F_DUMMIES)
|
||||
return Dummies({key: bool(value) for (key, value) in kvs.items()})
|
||||
|
||||
async def read_colors(self) -> Colors:
|
||||
raw = await self.__read_json_keyvals(self.__F_COLORS)
|
||||
return Colors(**{ # type: ignore
|
||||
|
||||
@@ -59,31 +59,37 @@ class EdidInfo:
|
||||
except ParsedEdidNoBlockError:
|
||||
pass
|
||||
|
||||
audio: bool = False
|
||||
try:
|
||||
audio = parsed.get_audio()
|
||||
except ParsedEdidNoBlockError:
|
||||
pass
|
||||
|
||||
return EdidInfo(
|
||||
mfc_id=parsed.get_mfc_id(),
|
||||
product_id=parsed.get_product_id(),
|
||||
serial=parsed.get_serial(),
|
||||
monitor_name=monitor_name,
|
||||
monitor_serial=monitor_serial,
|
||||
audio=parsed.get_audio(),
|
||||
audio=audio,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Edid:
|
||||
name: str
|
||||
data: bytes
|
||||
crc: int = dataclasses.field(default=0)
|
||||
valid: bool = dataclasses.field(default=False)
|
||||
info: (EdidInfo | None) = dataclasses.field(default=None)
|
||||
|
||||
__HEADER = b"\x00\xFF\xFF\xFF\xFF\xFF\xFF\x00"
|
||||
name: str
|
||||
data: bytes
|
||||
crc: int = dataclasses.field(default=0)
|
||||
valid: bool = dataclasses.field(default=False)
|
||||
info: (EdidInfo | None) = dataclasses.field(default=None)
|
||||
_packed: bytes = dataclasses.field(default=b"")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert len(self.name) > 0
|
||||
assert len(self.data) == 256
|
||||
object.__setattr__(self, "crc", bitbang.make_crc16(self.data))
|
||||
object.__setattr__(self, "valid", self.data.startswith(self.__HEADER))
|
||||
assert len(self.data) in [128, 256]
|
||||
object.__setattr__(self, "_packed", (self.data + (b"\x00" * 128))[:256])
|
||||
object.__setattr__(self, "crc", bitbang.make_crc16(self._packed)) # Calculate CRC for filled data
|
||||
object.__setattr__(self, "valid", ParsedEdid.is_header_valid(self.data))
|
||||
try:
|
||||
object.__setattr__(self, "info", EdidInfo.from_data(self.data))
|
||||
except Exception:
|
||||
@@ -93,7 +99,7 @@ class Edid:
|
||||
return "".join(f"{item:0{2}X}" for item in self.data)
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return self.data
|
||||
return self._packed
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, name: str, data: (str | bytes | None)) -> "Edid":
|
||||
@@ -101,14 +107,14 @@ class Edid:
|
||||
return Edid(name, b"\x00" * 256)
|
||||
|
||||
if isinstance(data, bytes):
|
||||
if data.startswith(cls.__HEADER):
|
||||
if ParsedEdid.is_header_valid(cls.data):
|
||||
return Edid(name, data) # Бинарный едид
|
||||
data_hex = data.decode() # Текстовый едид, прочитанный как бинарный из файла
|
||||
else: # isinstance(data, str)
|
||||
data_hex = str(data) # Текстовый едид
|
||||
|
||||
data_hex = re.sub(r"\s", "", data_hex)
|
||||
assert len(data_hex) == 512
|
||||
assert len(data_hex) in [256, 512]
|
||||
data = bytes([
|
||||
int(data_hex[index:index + 2], 16)
|
||||
for index in range(0, len(data_hex), 2)
|
||||
@@ -275,6 +281,19 @@ class _PortsDict(Generic[_T]):
|
||||
else:
|
||||
self.kvs[port] = value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return (self.kvs == other.kvs)
|
||||
|
||||
|
||||
class Dummies(_PortsDict[bool]):
|
||||
def __init__(self, kvs: dict[int, bool]) -> None:
|
||||
super().__init__(True, kvs)
|
||||
|
||||
def copy(self) -> "Dummies":
|
||||
return Dummies(self.kvs)
|
||||
|
||||
|
||||
class PortNames(_PortsDict[str]):
|
||||
def __init__(self, kvs: dict[int, str]) -> None:
|
||||
|
||||
45
kvmd/apps/localhid/__init__.py
Normal file
45
kvmd/apps/localhid/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2020 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
from ...clients.kvmd import KvmdClient
|
||||
|
||||
from ... import htclient
|
||||
|
||||
from .. import init
|
||||
|
||||
from .server import LocalHidServer
|
||||
|
||||
|
||||
# =====
|
||||
def main(argv: (list[str] | None)=None) -> None:
|
||||
config = init(
|
||||
prog="kvmd-localhid",
|
||||
description=" Local HID to KVMD proxy",
|
||||
check_run=True,
|
||||
argv=argv,
|
||||
)[2].localhid
|
||||
|
||||
user_agent = htclient.make_user_agent("KVMD-LocalHID")
|
||||
|
||||
LocalHidServer(
|
||||
kvmd=KvmdClient(user_agent=user_agent, **config.kvmd._unpack()),
|
||||
).run()
|
||||
24
kvmd/apps/localhid/__main__.py
Normal file
24
kvmd/apps/localhid/__main__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
from . import main
|
||||
main()
|
||||
152
kvmd/apps/localhid/hid.py
Normal file
152
kvmd/apps/localhid/hid.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import Final
|
||||
from typing import Generator
|
||||
|
||||
import evdev
|
||||
from evdev import ecodes
|
||||
|
||||
|
||||
# =====
|
||||
class Hid: # pylint: disable=too-many-instance-attributes
|
||||
KEY: Final[int] = 0
|
||||
MOUSE_BUTTON: Final[int] = 1
|
||||
MOUSE_REL: Final[int] = 2
|
||||
MOUSE_WHEEL: Final[int] = 3
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self.__device = evdev.InputDevice(path)
|
||||
|
||||
caps = self.__device.capabilities(absinfo=False)
|
||||
|
||||
syns = caps.get(ecodes.EV_SYN, [])
|
||||
self.__has_syn = (ecodes.SYN_REPORT in syns)
|
||||
|
||||
leds = caps.get(ecodes.EV_LED, [])
|
||||
self.__has_caps = (ecodes.LED_CAPSL in leds)
|
||||
self.__has_scroll = (ecodes.LED_SCROLLL in leds)
|
||||
self.__has_num = (ecodes.LED_NUML in leds)
|
||||
|
||||
keys = caps.get(ecodes.EV_KEY, [])
|
||||
self.__has_keyboard = (
|
||||
ecodes.KEY_LEFTCTRL in keys
|
||||
or ecodes.KEY_RIGHTCTRL in keys
|
||||
or ecodes.KEY_LEFTSHIFT in keys
|
||||
or ecodes.KEY_RIGHTSHIFT in keys
|
||||
)
|
||||
|
||||
rels = caps.get(ecodes.EV_REL, [])
|
||||
self.__has_mouse_rel = (
|
||||
ecodes.BTN_LEFT in keys
|
||||
and ecodes.REL_X in rels
|
||||
)
|
||||
|
||||
self.__grabbed = False
|
||||
|
||||
def is_suitable(self) -> bool:
|
||||
return (self.__has_keyboard or self.__has_mouse_rel)
|
||||
|
||||
def set_leds(self, caps: bool, scroll: bool, num: bool) -> None:
|
||||
if self.__grabbed:
|
||||
if self.__has_caps:
|
||||
self.__device.set_led(ecodes.LED_CAPSL, caps)
|
||||
if self.__has_scroll:
|
||||
self.__device.set_led(ecodes.LED_SCROLLL, scroll)
|
||||
if self.__has_num:
|
||||
self.__device.set_led(ecodes.LED_NUML, num)
|
||||
|
||||
def set_grabbed(self, grabbed: bool) -> None:
|
||||
if self.__grabbed != grabbed:
|
||||
getattr(self.__device, ("grab" if grabbed else "ungrab"))()
|
||||
self.__grabbed = grabbed
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self.__device.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def poll_to_queue(self, queue: asyncio.Queue[tuple[int, tuple]]) -> None:
|
||||
def put(event: int, args: tuple) -> None:
|
||||
queue.put_nowait((event, args))
|
||||
|
||||
move_x = move_y = 0
|
||||
wheel_x = wheel_y = 0
|
||||
async for event in self.__device.async_read_loop():
|
||||
if not self.__grabbed:
|
||||
# Клавиши перехватываются всегда для обработки хоткеев,
|
||||
# всё остальное пропускается для экономии ресурсов.
|
||||
if event.type == ecodes.EV_KEY and event.value != 2 and (event.code in ecodes.KEY):
|
||||
put(self.KEY, (event.code, bool(event.value)))
|
||||
continue
|
||||
|
||||
if event.type == ecodes.EV_REL:
|
||||
match event.code:
|
||||
case ecodes.REL_X:
|
||||
move_x += event.value
|
||||
case ecodes.REL_Y:
|
||||
move_y += event.value
|
||||
case ecodes.REL_HWHEEL:
|
||||
wheel_x += event.value
|
||||
case ecodes.REL_WHEEL:
|
||||
wheel_y += event.value
|
||||
|
||||
if not self.__has_syn or event.type == ecodes.SYN_REPORT:
|
||||
if move_x or move_y:
|
||||
for xy in self.__splitted_deltas(move_x, move_y):
|
||||
put(self.MOUSE_REL, xy)
|
||||
move_x = move_y = 0
|
||||
if wheel_x or wheel_y:
|
||||
for xy in self.__splitted_deltas(wheel_x, wheel_y):
|
||||
put(self.MOUSE_WHEEL, xy)
|
||||
wheel_x = wheel_y = 0
|
||||
|
||||
elif event.type == ecodes.EV_KEY and event.value != 2:
|
||||
if event.code in ecodes.KEY:
|
||||
put(self.KEY, (event.code, bool(event.value)))
|
||||
elif event.code in ecodes.BTN:
|
||||
put(self.MOUSE_BUTTON, (event.code, bool(event.value)))
|
||||
|
||||
def __splitted_deltas(self, delta_x: int, delta_y: int) -> Generator[tuple[int, int], None, None]:
|
||||
sign_x = (-1 if delta_x < 0 else 1)
|
||||
sign_y = (-1 if delta_y < 0 else 1)
|
||||
delta_x = abs(delta_x)
|
||||
delta_y = abs(delta_y)
|
||||
while delta_x > 0 or delta_y > 0:
|
||||
dx = sign_x * max(min(delta_x, 127), 0)
|
||||
dy = sign_y * max(min(delta_y, 127), 0)
|
||||
yield (dx, dy)
|
||||
delta_x -= 127
|
||||
delta_y -= 127
|
||||
|
||||
def __str__(self) -> str:
|
||||
info: list[str] = []
|
||||
if self.__has_syn:
|
||||
info.append("syn")
|
||||
if self.__has_keyboard:
|
||||
info.append("keyboard")
|
||||
if self.__has_mouse_rel:
|
||||
info.append("mouse_rel")
|
||||
return f"Hid({self.__device.path!r}, {self.__device.name!r}, {self.__device.phys!r}, {', '.join(info)})"
|
||||
178
kvmd/apps/localhid/multi.py
Normal file
178
kvmd/apps/localhid/multi.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2020 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import errno
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pyudev
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
from ... import aiotools
|
||||
|
||||
from .hid import Hid
|
||||
|
||||
|
||||
# =====
|
||||
def _udev_check(device: pyudev.Device) -> str:
|
||||
props = device.properties
|
||||
if props.get("ID_INPUT") == "1":
|
||||
path = props.get("DEVNAME")
|
||||
if isinstance(path, str) and path.startswith("/dev/input/event"):
|
||||
return path
|
||||
return ""
|
||||
|
||||
|
||||
async def _follow_udev_hids() -> AsyncGenerator[tuple[bool, str], None]:
|
||||
ctx = pyudev.Context()
|
||||
|
||||
monitor = pyudev.Monitor.from_netlink(pyudev.Context())
|
||||
monitor.filter_by(subsystem="input")
|
||||
monitor.start()
|
||||
fd = monitor.fileno()
|
||||
|
||||
read_event = asyncio.Event()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.add_reader(fd, read_event.set)
|
||||
|
||||
try:
|
||||
for device in ctx.list_devices(subsystem="input"):
|
||||
path = _udev_check(device)
|
||||
if path:
|
||||
yield (True, path)
|
||||
|
||||
while True:
|
||||
await read_event.wait()
|
||||
while True:
|
||||
device = monitor.poll(0)
|
||||
if device is None:
|
||||
read_event.clear()
|
||||
break
|
||||
path = _udev_check(device)
|
||||
if path:
|
||||
if device.action == "add":
|
||||
yield (True, path)
|
||||
elif device.action == "remove":
|
||||
yield (False, path)
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Worker:
|
||||
task: asyncio.Task
|
||||
hid: (Hid | None)
|
||||
|
||||
|
||||
class MultiHid:
|
||||
def __init__(self, queue: asyncio.Queue[tuple[int, tuple]]) -> None:
|
||||
self.__queue = queue
|
||||
self.__workers: dict[str, _Worker] = {}
|
||||
self.__grabbed = True
|
||||
self.__leds = (False, False, False)
|
||||
|
||||
async def run(self) -> None:
|
||||
logger = get_logger(0)
|
||||
logger.info("Starting UDEV loop ...")
|
||||
try:
|
||||
async for (added, path) in _follow_udev_hids():
|
||||
if added:
|
||||
await self.__add_worker(path)
|
||||
else:
|
||||
await self.__remove_worker(path)
|
||||
finally:
|
||||
logger.info("Cleanup ...")
|
||||
await aiotools.shield_fg(self.__cleanup())
|
||||
|
||||
async def __cleanup(self) -> None:
|
||||
for path in list(self.__workers):
|
||||
await self.__remove_worker(path)
|
||||
|
||||
async def __add_worker(self, path: str) -> None:
|
||||
if path in self.__workers:
|
||||
await self.__remove_worker(path)
|
||||
self.__workers[path] = _Worker(asyncio.create_task(self.__worker_task_loop(path)), None)
|
||||
|
||||
async def __remove_worker(self, path: str) -> None:
|
||||
if path not in self.__workers:
|
||||
return
|
||||
try:
|
||||
worker = self.__workers[path]
|
||||
worker.task.cancel()
|
||||
await asyncio.gather(worker.task, return_exceptions=True)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self.__workers.pop(path, None)
|
||||
|
||||
async def __worker_task_loop(self, path: str) -> None:
|
||||
logger = get_logger(0)
|
||||
while True:
|
||||
hid: (Hid | None) = None
|
||||
try:
|
||||
hid = Hid(path)
|
||||
if not hid.is_suitable():
|
||||
break
|
||||
logger.info("Opened: %s", hid)
|
||||
if self.__grabbed:
|
||||
hid.set_grabbed(True)
|
||||
hid.set_leds(*self.__leds)
|
||||
self.__workers[path].hid = hid
|
||||
await hid.poll_to_queue(self.__queue)
|
||||
except Exception as ex:
|
||||
if isinstance(ex, OSError) and ex.errno == errno.ENODEV: # pylint: disable=no-member
|
||||
logger.info("Closed: %s", hid)
|
||||
break
|
||||
logger.exception("Unhandled exception while polling %s", hid)
|
||||
await asyncio.sleep(5)
|
||||
finally:
|
||||
self.__workers[path].hid = None
|
||||
if hid:
|
||||
hid.close()
|
||||
|
||||
def is_grabbed(self) -> bool:
|
||||
return self.__grabbed
|
||||
|
||||
async def set_grabbed(self, grabbed: bool) -> None:
|
||||
await aiotools.run_async(self.__inner_set_grabbed, grabbed)
|
||||
|
||||
def __inner_set_grabbed(self, grabbed: bool) -> None:
|
||||
if self.__grabbed != grabbed:
|
||||
get_logger(0).info("Grabbing ..." if grabbed else "Ungrabbing ...")
|
||||
self.__grabbed = grabbed
|
||||
for worker in self.__workers.values():
|
||||
if worker.hid:
|
||||
worker.hid.set_grabbed(grabbed)
|
||||
self.__inner_set_leds(*self.__leds)
|
||||
|
||||
async def set_leds(self, caps: bool, scroll: bool, num: bool) -> None:
|
||||
await aiotools.run_async(self.__inner_set_leds, caps, scroll, num)
|
||||
|
||||
def __inner_set_leds(self, caps: bool, scroll: bool, num: bool) -> None:
|
||||
self.__leds = (caps, scroll, num)
|
||||
if self.__grabbed:
|
||||
for worker in self.__workers.values():
|
||||
if worker.hid:
|
||||
worker.hid.set_leds(*self.__leds)
|
||||
192
kvmd/apps/localhid/server.py
Normal file
192
kvmd/apps/localhid/server.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2020 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
|
||||
from typing import Callable
|
||||
from typing import Coroutine
|
||||
|
||||
import aiohttp
|
||||
import async_lru
|
||||
|
||||
from evdev import ecodes
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
from ... import tools
|
||||
from ... import aiotools
|
||||
|
||||
from ...keyboard.magic import MagicHandler
|
||||
|
||||
from ...clients.kvmd import KvmdClient
|
||||
from ...clients.kvmd import KvmdClientSession
|
||||
from ...clients.kvmd import KvmdClientWs
|
||||
|
||||
from .hid import Hid
|
||||
from .multi import MultiHid
|
||||
|
||||
|
||||
# =====
|
||||
class LocalHidServer: # pylint: disable=too-many-instance-attributes
|
||||
def __init__(self, kvmd: KvmdClient) -> None:
|
||||
self.__kvmd = kvmd
|
||||
|
||||
self.__kvmd_session: (KvmdClientSession | None) = None
|
||||
self.__kvmd_ws: (KvmdClientWs | None) = None
|
||||
|
||||
self.__queue: asyncio.Queue[tuple[int, tuple]] = asyncio.Queue()
|
||||
self.__hid = MultiHid(self.__queue)
|
||||
|
||||
self.__info_switch_units = 0
|
||||
self.__info_switch_active = ""
|
||||
self.__info_mouse_absolute = True
|
||||
self.__info_mouse_outputs: list[str] = []
|
||||
|
||||
self.__magic = MagicHandler(
|
||||
proxy_handler=self.__on_magic_key_proxy,
|
||||
key_handlers={
|
||||
ecodes.KEY_H: self.__on_magic_grab,
|
||||
ecodes.KEY_K: self.__on_magic_ungrab,
|
||||
ecodes.KEY_UP: self.__on_magic_switch_prev,
|
||||
ecodes.KEY_LEFT: self.__on_magic_switch_prev,
|
||||
ecodes.KEY_DOWN: self.__on_magic_switch_next,
|
||||
ecodes.KEY_RIGHT: self.__on_magic_switch_next,
|
||||
},
|
||||
numeric_handler=self.__on_magic_switch_port,
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
aiotools.run(self.__inner_run())
|
||||
finally:
|
||||
get_logger(0).info("Bye-bye")
|
||||
|
||||
async def __inner_run(self) -> None:
|
||||
await aiotools.spawn_and_follow(
|
||||
self.__create_loop(self.__hid.run),
|
||||
self.__create_loop(self.__queue_worker),
|
||||
self.__create_loop(self.__api_worker),
|
||||
)
|
||||
|
||||
async def __create_loop(self, func: Callable[[], Coroutine]) -> None:
|
||||
while True:
|
||||
try:
|
||||
await func()
|
||||
except Exception as ex:
|
||||
if isinstance(ex, OSError) and ex.errno == errno.ENODEV: # pylint: disable=no-member
|
||||
pass # Device disconnected
|
||||
elif isinstance(ex, aiohttp.ClientError):
|
||||
get_logger(0).error("KVMD client error: %s", tools.efmt(ex))
|
||||
else:
|
||||
get_logger(0).exception("Unhandled exception in the loop: %s", func)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def __queue_worker(self) -> None:
|
||||
while True:
|
||||
(event, args) = await self.__queue.get()
|
||||
if event == Hid.KEY:
|
||||
await self.__magic.handle_key(*args)
|
||||
continue
|
||||
elif self.__hid.is_grabbed() and self.__kvmd_session and self.__kvmd_ws:
|
||||
match event:
|
||||
case Hid.MOUSE_BUTTON:
|
||||
await self.__kvmd_ws.send_mouse_button_event(*args)
|
||||
case Hid.MOUSE_REL:
|
||||
await self.__ensure_mouse_relative()
|
||||
await self.__kvmd_ws.send_mouse_relative_event(*args)
|
||||
case Hid.MOUSE_WHEEL:
|
||||
await self.__kvmd_ws.send_mouse_wheel_event(*args)
|
||||
|
||||
async def __api_worker(self) -> None:
|
||||
logger = get_logger(0)
|
||||
async with self.__kvmd.make_session() as session:
|
||||
async with session.ws(stream=False) as ws:
|
||||
logger.info("KVMD session opened")
|
||||
self.__kvmd_session = session
|
||||
self.__kvmd_ws = ws
|
||||
try:
|
||||
async for (event_type, event) in ws.communicate():
|
||||
if event_type == "hid":
|
||||
if "leds" in event.get("keyboard", {}):
|
||||
await self.__hid.set_leds(**event["keyboard"]["leds"])
|
||||
if "absolute" in event.get("mouse", {}):
|
||||
self.__info_mouse_outputs = event["mouse"]["outputs"]["available"]
|
||||
self.__info_mouse_absolute = event["mouse"]["absolute"]
|
||||
elif event_type == "switch":
|
||||
if "model" in event:
|
||||
self.__info_switch_units = len(event["model"]["units"])
|
||||
if "summary" in event:
|
||||
self.__info_switch_active = event["summary"]["active_id"]
|
||||
finally:
|
||||
logger.info("KVMD session closed")
|
||||
self.__kvmd_session = None
|
||||
self.__kvmd_ws = None
|
||||
|
||||
# =====
|
||||
|
||||
async def __ensure_mouse_relative(self) -> None:
|
||||
if self.__info_mouse_absolute:
|
||||
# Avoid unnecessary LRU checks, just to speed up a bit
|
||||
await self.__inner_ensure_mouse_relative()
|
||||
|
||||
@async_lru.alru_cache(maxsize=1, ttl=1)
|
||||
async def __inner_ensure_mouse_relative(self) -> None:
|
||||
if self.__kvmd_session and self.__info_mouse_absolute:
|
||||
for output in ["usb_rel", "ps2"]:
|
||||
if output in self.__info_mouse_outputs:
|
||||
await self.__kvmd_session.hid.set_params(mouse_output=output)
|
||||
|
||||
async def __on_magic_key_proxy(self, key: int, state: bool) -> None:
|
||||
if self.__hid.is_grabbed() and self.__kvmd_ws:
|
||||
await self.__kvmd_ws.send_key_event(key, state)
|
||||
|
||||
async def __on_magic_grab(self) -> None:
|
||||
await self.__hid.set_grabbed(True)
|
||||
|
||||
async def __on_magic_ungrab(self) -> None:
|
||||
await self.__hid.set_grabbed(False)
|
||||
|
||||
async def __on_magic_switch_prev(self) -> None:
|
||||
if self.__kvmd_session and self.__info_switch_units > 0:
|
||||
get_logger(0).info("Switching port to the previous one ...")
|
||||
await self.__kvmd_session.switch.set_active_prev()
|
||||
|
||||
async def __on_magic_switch_next(self) -> None:
|
||||
if self.__kvmd_session and self.__info_switch_units > 0:
|
||||
get_logger(0).info("Switching port to the next one ...")
|
||||
await self.__kvmd_session.switch.set_active_next()
|
||||
|
||||
async def __on_magic_switch_port(self, codes: list[int]) -> bool:
|
||||
assert len(codes) > 0
|
||||
if self.__info_switch_units <= 0:
|
||||
return True
|
||||
elif 1 <= self.__info_switch_units <= 2:
|
||||
port = float(codes[0])
|
||||
else: # self.__info_switch_units > 2:
|
||||
if len(codes) == 1:
|
||||
return False # Wait for the second key
|
||||
port = (codes[0] + 1) + (codes[1] + 1) / 10
|
||||
if self.__kvmd_session:
|
||||
get_logger(0).info("Switching port to %s ...", port)
|
||||
await self.__kvmd_session.switch.set_active(port)
|
||||
return True
|
||||
@@ -52,6 +52,9 @@ class _Source:
|
||||
clients: dict[WsSession, "_Client"] = dataclasses.field(default_factory=dict)
|
||||
key_required: bool = dataclasses.field(default=False)
|
||||
|
||||
def is_diff(self) -> bool:
|
||||
return StreamerFormats.is_diff(self.streamer.get_format())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Client:
|
||||
@@ -98,6 +101,14 @@ class MediaServer(HttpServer):
|
||||
async def __ws_bin_ping_handler(self, ws: WsSession, _: bytes) -> None:
|
||||
await ws.send_bin(255, b"") # Ping-pong
|
||||
|
||||
@exposed_ws(1)
|
||||
async def __ws_bin_key_handler(self, ws: WsSession, _: bytes) -> None:
|
||||
for src in self.__srcs:
|
||||
if ws in src.clients:
|
||||
if src.is_diff():
|
||||
src.key_required = True
|
||||
break
|
||||
|
||||
@exposed_ws("start")
|
||||
async def __ws_start_handler(self, ws: WsSession, event: dict) -> None:
|
||||
try:
|
||||
@@ -145,7 +156,7 @@ class MediaServer(HttpServer):
|
||||
# =====
|
||||
|
||||
async def __sender(self, client: _Client) -> None:
|
||||
need_key = StreamerFormats.is_diff(client.src.streamer.get_format())
|
||||
need_key = client.src.is_diff()
|
||||
if need_key:
|
||||
client.src.key_required = True
|
||||
has_key = False
|
||||
|
||||
@@ -50,8 +50,12 @@ def main(argv: (list[str] | None)=None) -> None:
|
||||
template = in_file.read()
|
||||
|
||||
rendered = mako.template.Template(template).render(
|
||||
http_ipv4=config.nginx.http.ipv4,
|
||||
http_ipv6=config.nginx.http.ipv6,
|
||||
http_port=config.nginx.http.port,
|
||||
https_enabled=config.nginx.https.enabled,
|
||||
https_ipv4=config.nginx.https.ipv4,
|
||||
https_ipv6=config.nginx.https.ipv6,
|
||||
https_port=config.nginx.https.port,
|
||||
ipv6_enabled=network.is_ipv6_enabled(),
|
||||
)
|
||||
|
||||
@@ -78,6 +78,7 @@ def main() -> None: # pylint: disable=too-many-locals,too-many-branches,too-man
|
||||
parser.add_argument("--image", default="", type=(lambda arg: _get_data_path("pics", arg)), help="Display some image, wait a single interval and exit")
|
||||
parser.add_argument("--text", default="", help="Display some text, wait a single interval and exit")
|
||||
parser.add_argument("--pipe", action="store_true", help="Read and display lines from stdin until EOF, wait a single interval and exit")
|
||||
parser.add_argument("--fill", action="store_true", help="Fill the display with 0xFF")
|
||||
parser.add_argument("--clear-on-exit", action="store_true", help="Clear display on exit")
|
||||
parser.add_argument("--contrast", default=64, type=int, help="Set OLED contrast, values from 0 to 255")
|
||||
parser.add_argument("--fahrenheit", action="store_true", help="Display temperature in Fahrenheit instead of Celsius")
|
||||
@@ -121,6 +122,9 @@ def main() -> None: # pylint: disable=too-many-locals,too-many-branches,too-man
|
||||
text = ""
|
||||
time.sleep(options.interval)
|
||||
|
||||
elif options.fill:
|
||||
screen.draw_white()
|
||||
|
||||
else:
|
||||
stop_reason: (str | None) = None
|
||||
|
||||
|
||||
@@ -52,3 +52,7 @@ class Screen:
|
||||
def draw_image(self, image_path: str) -> None:
|
||||
with luma_canvas(self.__device) as draw:
|
||||
draw.bitmap(self.__offset, Image.open(image_path).convert("1"), fill="white")
|
||||
|
||||
def draw_white(self) -> None:
|
||||
with luma_canvas(self.__device) as draw:
|
||||
draw.rectangle((0, 0, self.__device.width, self.__device.height), fill="white")
|
||||
|
||||
@@ -291,8 +291,9 @@ def _cmd_start(config: Section) -> None: # pylint: disable=too-many-statements,
|
||||
|
||||
profile_path = join(gadget_path, usb.G_PROFILE)
|
||||
_mkdir(profile_path)
|
||||
_mkdir(join(profile_path, "strings/0x409"))
|
||||
_write(join(profile_path, "strings/0x409/configuration"), f"Config 1: {config.otg.config}")
|
||||
if config.otg.config:
|
||||
_mkdir(join(profile_path, "strings/0x409"))
|
||||
_write(join(profile_path, "strings/0x409/configuration"), config.otg.config)
|
||||
_write(join(profile_path, "MaxPower"), config.otg.max_power)
|
||||
if config.otg.remote_wakeup:
|
||||
# XXX: Should we use MaxPower=100 with Remote Wakeup?
|
||||
|
||||
@@ -45,6 +45,7 @@ from .netctl import IptablesAllowIcmpCtl
|
||||
from .netctl import IptablesAllowPortCtl
|
||||
from .netctl import IptablesForwardOut
|
||||
from .netctl import IptablesForwardIn
|
||||
from .netctl import SysctlIpv4ForwardCtl
|
||||
from .netctl import CustomCtl
|
||||
|
||||
|
||||
@@ -63,14 +64,16 @@ class _Netcfg: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
class _Service: # pylint: disable=too-many-instance-attributes
|
||||
def __init__(self, config: Section) -> None:
|
||||
self.__ip_cmd: list[str] = config.otgnet.commands.ip_cmd
|
||||
self.__iptables_cmd: list[str] = config.otgnet.commands.iptables_cmd
|
||||
self.__sysctl_cmd: list[str] = config.otgnet.commands.sysctl_cmd
|
||||
|
||||
self.__iface_net: str = config.otgnet.iface.net
|
||||
self.__ip_cmd: list[str] = config.otgnet.iface.ip_cmd
|
||||
|
||||
self.__allow_icmp: bool = config.otgnet.firewall.allow_icmp
|
||||
self.__allow_tcp: list[int] = sorted(set(config.otgnet.firewall.allow_tcp))
|
||||
self.__allow_udp: list[int] = sorted(set(config.otgnet.firewall.allow_udp))
|
||||
self.__forward_iface: str = config.otgnet.firewall.forward_iface
|
||||
self.__iptables_cmd: list[str] = config.otgnet.firewall.iptables_cmd
|
||||
|
||||
def build_cmd(key: str) -> list[str]:
|
||||
return tools.build_cmd(
|
||||
@@ -115,6 +118,7 @@ class _Service: # pylint: disable=too-many-instance-attributes
|
||||
*([IptablesForwardIn(self.__iptables_cmd, netcfg.iface)] if self.__forward_iface else []),
|
||||
IptablesDropAllCtl(self.__iptables_cmd, netcfg.iface),
|
||||
IfaceAddIpCtl(self.__ip_cmd, netcfg.iface, f"{netcfg.iface_ip}/{netcfg.net_prefix}"),
|
||||
*([SysctlIpv4ForwardCtl(self.__sysctl_cmd)] if self.__forward_iface else []),
|
||||
CustomCtl(self.__post_start_cmd, self.__pre_stop_cmd, placeholders),
|
||||
]
|
||||
if direct:
|
||||
@@ -130,6 +134,8 @@ class _Service: # pylint: disable=too-many-instance-attributes
|
||||
async def __run_ctl(self, ctl: BaseCtl, direct: bool) -> bool:
|
||||
logger = get_logger()
|
||||
cmd = ctl.get_command(direct)
|
||||
if not cmd:
|
||||
return True
|
||||
logger.info("CMD: %s", tools.cmdfmt(cmd))
|
||||
try:
|
||||
return (not (await aioproc.log_process(cmd, logger)).returncode)
|
||||
|
||||
@@ -121,6 +121,16 @@ class IptablesForwardIn(BaseCtl):
|
||||
]
|
||||
|
||||
|
||||
class SysctlIpv4ForwardCtl(BaseCtl):
|
||||
def __init__(self, base_cmd: list[str]) -> None:
|
||||
self.__base_cmd = base_cmd
|
||||
|
||||
def get_command(self, direct: bool) -> list[str]:
|
||||
if direct:
|
||||
return [*self.__base_cmd, "net.ipv4.ip_forward=1"]
|
||||
return [] # Don't revert the command because some services can require it too
|
||||
|
||||
|
||||
class CustomCtl(BaseCtl):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -66,22 +66,22 @@ async def _run_process(cmd: list[str], data_path: str) -> asyncio.subprocess.Pro
|
||||
|
||||
async def _run_cmd_ws(cmd: list[str], ws: aiohttp.ClientWebSocketResponse) -> int: # pylint: disable=too-many-branches
|
||||
logger = get_logger(0)
|
||||
receive_task: (asyncio.Task | None) = None
|
||||
recv_task: (asyncio.Task | None) = None
|
||||
proc_task: (asyncio.Task | None) = None
|
||||
proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member
|
||||
|
||||
try: # pylint: disable=too-many-nested-blocks
|
||||
while True:
|
||||
if receive_task is None:
|
||||
receive_task = asyncio.create_task(ws.receive())
|
||||
if recv_task is None:
|
||||
recv_task = asyncio.create_task(ws.receive())
|
||||
if proc_task is None and proc is not None:
|
||||
proc_task = asyncio.create_task(proc.wait())
|
||||
|
||||
tasks = list(filter(None, [receive_task, proc_task]))
|
||||
tasks = list(filter(None, [recv_task, proc_task]))
|
||||
done = (await aiotools.wait_first(*tasks))[0]
|
||||
|
||||
if receive_task in done:
|
||||
msg = receive_task.result()
|
||||
if recv_task in done:
|
||||
msg = recv_task.result()
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
(event_type, event) = htserver.parse_ws_event(msg.data)
|
||||
if event_type == "storage":
|
||||
@@ -98,15 +98,15 @@ async def _run_cmd_ws(cmd: list[str], ws: aiohttp.ClientWebSocketResponse) -> in
|
||||
else:
|
||||
logger.error("Unknown PST message type: %r", msg)
|
||||
break
|
||||
receive_task = None
|
||||
recv_task = None
|
||||
|
||||
if proc_task in done:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Unhandled exception")
|
||||
|
||||
if receive_task is not None:
|
||||
receive_task.cancel()
|
||||
if recv_task is not None:
|
||||
recv_task.cancel()
|
||||
if proc_task is not None:
|
||||
proc_task.cancel()
|
||||
if proc is not None:
|
||||
|
||||
@@ -30,7 +30,6 @@ from ... import htclient
|
||||
|
||||
from .. import init
|
||||
|
||||
from .vncauth import VncAuthManager
|
||||
from .server import VncServer
|
||||
|
||||
|
||||
@@ -71,12 +70,12 @@ def main(argv: (list[str] | None)=None) -> None:
|
||||
desired_fps=config.desired_fps,
|
||||
mouse_output=config.mouse_output,
|
||||
keymap_path=config.keymap,
|
||||
allow_cut_after=config.allow_cut_after,
|
||||
scroll_rate=config.scroll_rate,
|
||||
|
||||
kvmd=KvmdClient(user_agent=user_agent, **config.kvmd._unpack()),
|
||||
streamers=streamers,
|
||||
vnc_auth_manager=VncAuthManager(**config.auth.vncauth._unpack()),
|
||||
|
||||
**config.server.keepalive._unpack(),
|
||||
**config.auth.vncauth._unpack(),
|
||||
**config.auth.vencrypt._unpack(),
|
||||
).run()
|
||||
|
||||
@@ -22,17 +22,22 @@
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
import time
|
||||
|
||||
from typing import Callable
|
||||
from typing import Coroutine
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from evdev import ecodes
|
||||
|
||||
from ....logging import get_logger
|
||||
|
||||
from .... import tools
|
||||
from .... import aiotools
|
||||
|
||||
from ....keyboard.keysym import SymmapModifiers
|
||||
from ....keyboard.mappings import EvdevModifiers
|
||||
from ....keyboard.mappings import X11Modifiers
|
||||
from ....keyboard.mappings import AT1_TO_EVDEV
|
||||
from ....mouse import MouseRange
|
||||
|
||||
from .errors import RfbError
|
||||
@@ -47,6 +52,11 @@ from .crypto import rfb_encrypt_challenge
|
||||
from .stream import RfbClientStream
|
||||
|
||||
|
||||
# =====
|
||||
class _SecurityError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# =====
|
||||
class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attributes
|
||||
# https://github.com/rfbproto/rfbproto/blob/master/rfbproto.rst
|
||||
@@ -65,8 +75,10 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
width: int,
|
||||
height: int,
|
||||
name: str,
|
||||
allow_cut_after: float,
|
||||
vnc_passwds: list[str],
|
||||
symmap: dict[int, dict[int, int]],
|
||||
scroll_rate: int,
|
||||
|
||||
vncpasses: set[str],
|
||||
vencrypt: bool,
|
||||
none_auth_only: bool,
|
||||
) -> None:
|
||||
@@ -81,8 +93,10 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
self._width = width
|
||||
self._height = height
|
||||
self.__name = name
|
||||
self.__allow_cut_after = allow_cut_after
|
||||
self.__vnc_passwds = vnc_passwds
|
||||
self.__scroll_rate = scroll_rate
|
||||
self.__symmap = symmap
|
||||
|
||||
self.__vncpasses = vncpasses
|
||||
self.__vencrypt = vencrypt
|
||||
self.__none_auth_only = none_auth_only
|
||||
|
||||
@@ -93,10 +107,16 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
self.__fb_cont_updates = False
|
||||
self.__fb_reset_h264 = False
|
||||
|
||||
self.__allow_cut_since_ts = 0.0
|
||||
self.__authorized = False
|
||||
|
||||
self.__lock = asyncio.Lock()
|
||||
|
||||
# Эти состояния шарить не обязательно - бекенд исключает дублирующиеся события.
|
||||
# Все это нужно только чтобы не посылать лишние события в сокет KVMD
|
||||
self.__modifiers = 0
|
||||
self.__mouse_buttons: dict[int, bool] = {}
|
||||
self.__mouse_move = (-1, -1, -1, -1) # (width, height, X, Y)
|
||||
|
||||
# =====
|
||||
|
||||
async def _run(self, **coros: Coroutine) -> None:
|
||||
@@ -135,6 +155,8 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
async def __main_task_loop(self) -> None:
|
||||
await self.__handshake_version()
|
||||
await self.__handshake_security()
|
||||
if not self.__authorized:
|
||||
raise _SecurityError()
|
||||
await self.__handshake_init()
|
||||
await self.__main_loop()
|
||||
|
||||
@@ -143,21 +165,24 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
async def _authorize_userpass(self, user: str, passwd: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _on_authorized_vnc_passwd(self, passwd: str) -> str:
|
||||
async def _on_authorized_vncpass(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _on_authorized_none(self) -> bool:
|
||||
async def _authorize_none(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# =====
|
||||
|
||||
async def _on_key_event(self, code: int, state: bool) -> None:
|
||||
async def _on_key_event(self, key: int, state: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _on_ext_key_event(self, code: int, state: bool) -> None:
|
||||
async def _on_mouse_button_event(self, button: int, state: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _on_pointer_event(self, buttons: dict[str, bool], wheel: dict[str, int], move: dict[str, int]) -> None:
|
||||
async def _on_mouse_move_event(self, to_x: int, to_y: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _on_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _on_cut_event(self, text: str) -> None:
|
||||
@@ -235,18 +260,18 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
|
||||
await self._write_struct("handshake server version", "", b"RFB 003.008\n")
|
||||
|
||||
response = await self._read_text("handshake client version", 12)
|
||||
resp = await self._read_text("handshake client version", 12)
|
||||
if (
|
||||
not response.startswith("RFB 003.00")
|
||||
or not response.endswith("\n")
|
||||
or response[-2] not in ["3", "5", "7", "8"]
|
||||
not resp.startswith("RFB 003.00")
|
||||
or not resp.endswith("\n")
|
||||
or resp[-2] not in ["3", "5", "7", "8"]
|
||||
):
|
||||
raise RfbError(f"Invalid version response: {response!r}")
|
||||
raise RfbError(f"Invalid version response: {resp!r}")
|
||||
|
||||
try:
|
||||
version = int(response[-2])
|
||||
version = int(resp[-2])
|
||||
except ValueError:
|
||||
raise RfbError(f"Invalid version response: {response!r}")
|
||||
raise RfbError(f"Invalid version response: {resp!r}")
|
||||
self.__rfb_version = (3 if version == 5 else version)
|
||||
get_logger(0).info("%s [main]: Using RFB version 3.%d", self._remote, self.__rfb_version)
|
||||
|
||||
@@ -258,7 +283,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
sec_types[19] = ("VeNCrypt", self.__handshake_security_vencrypt)
|
||||
if self.__none_auth_only:
|
||||
sec_types[1] = ("None", self.__handshake_security_none)
|
||||
elif self.__vnc_passwds:
|
||||
elif self.__vncpasses:
|
||||
sec_types[2] = ("VNCAuth", self.__handshake_security_vnc_auth)
|
||||
|
||||
if not sec_types:
|
||||
@@ -304,7 +329,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
if self.__x509_cert_path:
|
||||
auth_types[262] = ("VeNCrypt/X509Plain", 2, self.__handshake_security_vencrypt_userpass)
|
||||
auth_types[259] = ("VeNCrypt/TLSPlain", 1, self.__handshake_security_vencrypt_userpass)
|
||||
if self.__vnc_passwds:
|
||||
if self.__vncpasses:
|
||||
# Некоторые клиенты не умеют работать с нешифрованными соединениями внутри VeNCrypt:
|
||||
# - https://github.com/LibVNC/libvncserver/issues/458
|
||||
# - https://bugzilla.redhat.com/show_bug.cgi?id=692048
|
||||
@@ -354,7 +379,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
)
|
||||
|
||||
async def __handshake_security_none(self) -> None:
|
||||
allow = await self._on_authorized_none()
|
||||
allow = await self._authorize_none()
|
||||
await self.__handshake_security_send_result(
|
||||
allow=allow,
|
||||
allow_msg="NoneAuth access granted",
|
||||
@@ -366,20 +391,19 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
challenge = rfb_make_challenge()
|
||||
await self._write_struct("VNCAuth challenge request", "", challenge)
|
||||
|
||||
user = ""
|
||||
allow = False
|
||||
response = (await self._read_struct("VNCAuth challenge response", "16s"))[0]
|
||||
for passwd in self.__vnc_passwds:
|
||||
for passwd in self.__vncpasses:
|
||||
passwd_bytes = passwd.encode("utf-8", errors="ignore")
|
||||
if rfb_encrypt_challenge(challenge, passwd_bytes) == response:
|
||||
user = await self._on_authorized_vnc_passwd(passwd)
|
||||
if user:
|
||||
assert user == user.strip()
|
||||
await self._on_authorized_vncpass()
|
||||
allow = True
|
||||
break
|
||||
|
||||
await self.__handshake_security_send_result(
|
||||
allow=bool(user),
|
||||
allow_msg=f"VNCAuth access granted for user {user!r}",
|
||||
deny_msg="VNCAuth access denied (user not found)",
|
||||
allow=allow,
|
||||
allow_msg="VNCAuth access granted",
|
||||
deny_msg="VNCAuth access denied (passwd not found)",
|
||||
deny_reason="Invalid password",
|
||||
)
|
||||
|
||||
@@ -387,6 +411,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
if allow:
|
||||
get_logger(0).info("%s [main]: %s", self._remote, allow_msg)
|
||||
await self._write_struct("access OK", "L", 0)
|
||||
self.__authorized = True
|
||||
else:
|
||||
await self._write_struct("access denial flag", "L", 1, drain=(self.__rfb_version < 8))
|
||||
if self.__rfb_version >= 8:
|
||||
@@ -396,6 +421,9 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
# =====
|
||||
|
||||
async def __handshake_init(self) -> None:
|
||||
if not self.__authorized:
|
||||
raise _SecurityError()
|
||||
|
||||
await self._read_number("initial shared flag", "B") # Shared flag, ignored
|
||||
|
||||
await self._write_struct("initial FB size", "HH", self._width, self._height, drain=False)
|
||||
@@ -419,7 +447,8 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
# =====
|
||||
|
||||
async def __main_loop(self) -> None:
|
||||
self.__allow_cut_since_ts = time.monotonic() + self.__allow_cut_after
|
||||
if not self.__authorized:
|
||||
raise _SecurityError()
|
||||
handlers = {
|
||||
0: self.__handle_set_pixel_format,
|
||||
2: self.__handle_set_encodings,
|
||||
@@ -486,40 +515,101 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
|
||||
async def __handle_key_event(self) -> None:
|
||||
(state, code) = await self._read_struct("key event", "? xx L")
|
||||
await self._on_key_event(code, state) # type: ignore
|
||||
state = bool(state)
|
||||
|
||||
is_modifier = self.__switch_modifiers_x11(code, state)
|
||||
variants = self.__symmap.get(code)
|
||||
fake_shift = False
|
||||
|
||||
if variants:
|
||||
if is_modifier:
|
||||
key = variants.get(0)
|
||||
else:
|
||||
key = variants.get(self.__modifiers)
|
||||
if key is None:
|
||||
key = variants.get(0)
|
||||
|
||||
if key is None and self.__modifiers == 0 and SymmapModifiers.SHIFT in variants:
|
||||
# JUMP doesn't send shift events:
|
||||
# - https://github.com/pikvm/pikvm/issues/820
|
||||
key = variants[SymmapModifiers.SHIFT]
|
||||
fake_shift = True
|
||||
|
||||
if key:
|
||||
if fake_shift:
|
||||
await self._on_key_event(EvdevModifiers.SHIFT_LEFT, True)
|
||||
await self._on_key_event(key, state)
|
||||
if fake_shift:
|
||||
await self._on_key_event(EvdevModifiers.SHIFT_LEFT, False)
|
||||
|
||||
def __switch_modifiers_x11(self, code: int, state: bool) -> bool:
|
||||
mod = 0
|
||||
if code in X11Modifiers.SHIFTS:
|
||||
mod = SymmapModifiers.SHIFT
|
||||
elif code == X11Modifiers.ALTGR:
|
||||
mod = SymmapModifiers.ALTGR
|
||||
elif code in X11Modifiers.CTRLS:
|
||||
mod = SymmapModifiers.CTRL
|
||||
if mod == 0:
|
||||
return False
|
||||
if state:
|
||||
self.__modifiers |= mod
|
||||
else:
|
||||
self.__modifiers &= ~mod
|
||||
return True
|
||||
|
||||
def __switch_modifiers_evdev(self, key: int, state: bool) -> bool:
|
||||
mod = 0
|
||||
if key in EvdevModifiers.SHIFTS:
|
||||
mod = SymmapModifiers.SHIFT
|
||||
elif key == EvdevModifiers.ALT_RIGHT:
|
||||
mod = SymmapModifiers.ALTGR
|
||||
elif key in EvdevModifiers.CTRLS:
|
||||
mod = SymmapModifiers.CTRL
|
||||
if mod == 0:
|
||||
return False
|
||||
if state:
|
||||
self.__modifiers |= mod
|
||||
else:
|
||||
self.__modifiers &= ~mod
|
||||
return True
|
||||
|
||||
async def __handle_pointer_event(self) -> None:
|
||||
(buttons, to_x, to_y) = await self._read_struct("pointer event", "B HH")
|
||||
ext_buttons = 0
|
||||
if self._encodings.has_ext_mouse and (buttons & 0x80): # Marker bit 7 for ext event
|
||||
ext_buttons = await self._read_number("ext pointer event buttons", "B")
|
||||
await self._on_pointer_event(
|
||||
buttons={
|
||||
"left": bool(buttons & 0x1),
|
||||
"right": bool(buttons & 0x4),
|
||||
"middle": bool(buttons & 0x2),
|
||||
"up": bool(ext_buttons & 0x2),
|
||||
"down": bool(ext_buttons & 0x1),
|
||||
},
|
||||
wheel={
|
||||
"x": (-4 if buttons & 0x40 else (4 if buttons & 0x20 else 0)),
|
||||
"y": (-4 if buttons & 0x10 else (4 if buttons & 0x8 else 0)),
|
||||
},
|
||||
move={
|
||||
"x": tools.remap(to_x, 0, self._width, *MouseRange.RANGE),
|
||||
"y": tools.remap(to_y, 0, self._height, *MouseRange.RANGE),
|
||||
},
|
||||
)
|
||||
|
||||
if buttons & (0x40 | 0x20 | 0x10 | 0x08):
|
||||
sr = self.__scroll_rate
|
||||
await self._on_mouse_wheel_event(
|
||||
(-sr if buttons & 0x40 else (sr if buttons & 0x20 else 0)),
|
||||
(-sr if buttons & 0x10 else (sr if buttons & 0x08 else 0)),
|
||||
)
|
||||
|
||||
move = (self._width, self._height, to_x, to_y)
|
||||
if self.__mouse_move != move:
|
||||
await self._on_mouse_move_event(
|
||||
tools.remap(to_x, 0, self._width - 1, *MouseRange.RANGE),
|
||||
tools.remap(to_y, 0, self._height - 1, *MouseRange.RANGE),
|
||||
)
|
||||
self.__mouse_move = move
|
||||
|
||||
for (code, state) in [
|
||||
(ecodes.BTN_LEFT, bool(buttons & 0x1)),
|
||||
(ecodes.BTN_RIGHT, bool(buttons & 0x4)),
|
||||
(ecodes.BTN_MIDDLE, bool(buttons & 0x2)),
|
||||
(ecodes.BTN_BACK, bool(ext_buttons & 0x2)),
|
||||
(ecodes.BTN_FORWARD, bool(ext_buttons & 0x1)),
|
||||
]:
|
||||
if self.__mouse_buttons.get(code) != state:
|
||||
await self._on_mouse_button_event(code, state)
|
||||
self.__mouse_buttons[code] = state
|
||||
|
||||
async def __handle_client_cut_text(self) -> None:
|
||||
length = (await self._read_struct("cut text length", "xxx L"))[0]
|
||||
text = await self._read_text("cut text data", length)
|
||||
if self.__allow_cut_since_ts > 0 and time.monotonic() >= self.__allow_cut_since_ts:
|
||||
# We should ignore cut event a few seconds after handshake
|
||||
# because bVNC, AVNC and maybe some other clients perform
|
||||
# it right after the connection automatically.
|
||||
# - https://github.com/pikvm/pikvm/issues/1420
|
||||
await self._on_cut_event(text)
|
||||
await self._on_cut_event(text)
|
||||
|
||||
async def __handle_enable_cont_updates(self) -> None:
|
||||
enabled = bool((await self._read_struct("enabled ContUpdates", "B HH HH"))[0])
|
||||
@@ -532,6 +622,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
|
||||
async def __handle_qemu_event(self) -> None:
|
||||
(sub_type, state, code) = await self._read_struct("QEMU event (key?)", "B H xxxx L")
|
||||
state = bool(state)
|
||||
if sub_type != 0:
|
||||
raise RfbError(f"Invalid QEMU sub-message type: {sub_type}")
|
||||
if code == 0xB7:
|
||||
@@ -539,4 +630,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
|
||||
code = 0x54
|
||||
if code & 0x80:
|
||||
code = (0xE0 << 8) | (code & ~0x80)
|
||||
await self._on_ext_key_event(code, bool(state))
|
||||
key = AT1_TO_EVDEV.get(code, 0)
|
||||
if key:
|
||||
self.__switch_modifiers_evdev(key, state) # Предполагаем, что модификаторы всегда известны
|
||||
await self._on_key_event(key, state)
|
||||
|
||||
@@ -110,32 +110,13 @@ class RfbClientStream:
|
||||
# =====
|
||||
|
||||
async def _start_tls(self, ssl_context: ssl.SSLContext, ssl_timeout: float) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
ssl_reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(ssl_reader)
|
||||
|
||||
try:
|
||||
transport = await loop.start_tls(
|
||||
self.__writer.transport,
|
||||
protocol,
|
||||
await self.__writer.start_tls(
|
||||
ssl_context,
|
||||
server_side=True,
|
||||
ssl_handshake_timeout=ssl_timeout,
|
||||
)
|
||||
except ConnectionError as ex:
|
||||
raise RfbConnectionError("Can't start TLS", ex)
|
||||
|
||||
ssl_reader.set_transport(transport) # type: ignore
|
||||
ssl_writer = asyncio.StreamWriter(
|
||||
transport=transport, # type: ignore
|
||||
protocol=protocol,
|
||||
reader=ssl_reader,
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
self.__reader = ssl_reader
|
||||
self.__writer = ssl_writer
|
||||
|
||||
async def _close(self) -> None:
|
||||
await aiotools.close_writer(self.__writer)
|
||||
|
||||
@@ -27,14 +27,14 @@ import dataclasses
|
||||
import contextlib
|
||||
|
||||
import aiohttp
|
||||
import async_lru
|
||||
|
||||
from evdev import ecodes
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
from ...keyboard.keysym import SymmapModifiers
|
||||
from ...keyboard.keysym import build_symmap
|
||||
from ...keyboard.mappings import WebModifiers
|
||||
from ...keyboard.mappings import X11Modifiers
|
||||
from ...keyboard.mappings import AT1_TO_WEB
|
||||
from ...keyboard.magic import MagicHandler
|
||||
|
||||
from ...clients.kvmd import KvmdClientWs
|
||||
from ...clients.kvmd import KvmdClientSession
|
||||
@@ -53,9 +53,6 @@ from .rfb import RfbClient
|
||||
from .rfb.stream import rfb_format_remote
|
||||
from .rfb.errors import RfbError
|
||||
|
||||
from .vncauth import VncAuthKvmdCredentials
|
||||
from .vncauth import VncAuthManager
|
||||
|
||||
from .render import make_text_jpeg
|
||||
|
||||
|
||||
@@ -80,29 +77,30 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
desired_fps: int,
|
||||
mouse_output: str,
|
||||
keymap_name: str,
|
||||
symmap: dict[int, dict[int, str]],
|
||||
allow_cut_after: float,
|
||||
symmap: dict[int, dict[int, int]],
|
||||
scroll_rate: int,
|
||||
|
||||
kvmd: KvmdClient,
|
||||
streamers: list[BaseStreamerClient],
|
||||
|
||||
vnc_credentials: dict[str, VncAuthKvmdCredentials],
|
||||
vncpasses: set[str],
|
||||
vencrypt: bool,
|
||||
none_auth_only: bool,
|
||||
|
||||
shared_params: _SharedParams,
|
||||
) -> None:
|
||||
|
||||
self.__vnc_credentials = vnc_credentials
|
||||
|
||||
super().__init__(
|
||||
RfbClient.__init__(
|
||||
self,
|
||||
reader=reader,
|
||||
writer=writer,
|
||||
tls_ciphers=tls_ciphers,
|
||||
tls_timeout=tls_timeout,
|
||||
x509_cert_path=x509_cert_path,
|
||||
x509_key_path=x509_key_path,
|
||||
allow_cut_after=allow_cut_after,
|
||||
vnc_passwds=list(vnc_credentials),
|
||||
symmap=symmap,
|
||||
scroll_rate=scroll_rate,
|
||||
vncpasses=vncpasses,
|
||||
vencrypt=vencrypt,
|
||||
none_auth_only=none_auth_only,
|
||||
**dataclasses.asdict(shared_params),
|
||||
@@ -111,7 +109,6 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
self.__desired_fps = desired_fps
|
||||
self.__mouse_output = mouse_output
|
||||
self.__keymap_name = keymap_name
|
||||
self.__symmap = symmap
|
||||
|
||||
self.__kvmd = kvmd
|
||||
self.__streamers = streamers
|
||||
@@ -128,12 +125,23 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
self.__fb_queue: "asyncio.Queue[dict]" = asyncio.Queue()
|
||||
self.__fb_has_key = False
|
||||
|
||||
# Эти состояния шарить не обязательно - бекенд исключает дублирующиеся события.
|
||||
# Все это нужно только чтобы не посылать лишние жсоны в сокет KVMD
|
||||
self.__mouse_buttons: dict[str, (bool | None)] = dict.fromkeys(["left", "right", "middle", "up", "down"], None)
|
||||
self.__mouse_move = {"x": -1, "y": -1}
|
||||
self.__clipboard = ""
|
||||
|
||||
self.__modifiers = 0
|
||||
self.__info_host = ""
|
||||
self.__info_switch_units = 0
|
||||
self.__info_switch_active = ""
|
||||
|
||||
self.__magic = MagicHandler(
|
||||
proxy_handler=self.__on_magic_key_proxy,
|
||||
key_handlers={
|
||||
ecodes.KEY_P: self.__on_magic_clipboard_print,
|
||||
ecodes.KEY_UP: self.__on_magic_switch_prev,
|
||||
ecodes.KEY_LEFT: self.__on_magic_switch_prev,
|
||||
ecodes.KEY_DOWN: self.__on_magic_switch_next,
|
||||
ecodes.KEY_RIGHT: self.__on_magic_switch_next,
|
||||
},
|
||||
numeric_handler=self.__on_magic_switch_port,
|
||||
)
|
||||
|
||||
# =====
|
||||
|
||||
@@ -179,16 +187,22 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
async def __process_ws_event(self, event_type: str, event: dict) -> None:
|
||||
if event_type == "info":
|
||||
if "meta" in event:
|
||||
host = ""
|
||||
try:
|
||||
host = event["meta"]["server"]["host"]
|
||||
if isinstance(event["meta"]["server"]["host"], str):
|
||||
host = event["meta"]["server"]["host"].strip()
|
||||
except Exception:
|
||||
host = None
|
||||
else:
|
||||
if isinstance(host, str):
|
||||
name = f"PiKVM: {host}"
|
||||
if self._encodings.has_rename:
|
||||
await self._send_rename(name)
|
||||
self.__shared_params.name = name
|
||||
pass
|
||||
self.__info_host = host
|
||||
await self.__update_info()
|
||||
|
||||
elif event_type == "switch":
|
||||
if "model" in event:
|
||||
self.__info_switch_units = len(event["model"]["units"])
|
||||
if "summary" in event:
|
||||
self.__info_switch_active = event["summary"]["active_id"]
|
||||
if "model" in event or "summary" in event:
|
||||
await self.__update_info()
|
||||
|
||||
elif event_type == "hid":
|
||||
if (
|
||||
@@ -198,6 +212,17 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
):
|
||||
await self._send_leds_state(**event["keyboard"]["leds"])
|
||||
|
||||
async def __update_info(self) -> None:
|
||||
info: list[str] = []
|
||||
if self.__info_switch_units > 0:
|
||||
info.append("Port " + (self.__info_switch_active or "not selected"))
|
||||
if self.__info_host:
|
||||
info.append(self.__info_host)
|
||||
info.append("PiKVM")
|
||||
self.__shared_params.name = " | ".join(info)
|
||||
if self._encodings.has_rename:
|
||||
await self._send_rename(self.__shared_params.name)
|
||||
|
||||
# =====
|
||||
|
||||
async def __streamer_task_loop(self) -> None:
|
||||
@@ -213,10 +238,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
if not streaming:
|
||||
logger.info("%s [streamer]: Streaming ...", self._remote)
|
||||
streaming = True
|
||||
if frame["online"]:
|
||||
await self.__queue_frame(frame)
|
||||
else:
|
||||
await self.__queue_frame("No signal")
|
||||
await self.__queue_frame(frame)
|
||||
except StreamerError as ex:
|
||||
if isinstance(ex, StreamerPermError):
|
||||
streamer = self.__get_default_streamer()
|
||||
@@ -317,98 +339,91 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
|
||||
# =====
|
||||
|
||||
async def _authorize_userpass(self, user: str, passwd: str) -> bool:
|
||||
self.__kvmd_session = self.__kvmd.make_session(user, passwd)
|
||||
if (await self.__kvmd_session.auth.check()):
|
||||
self.__kvmd_session = self.__kvmd.make_session()
|
||||
if (await self.__kvmd_session.auth.check(user, passwd)):
|
||||
self.__stage1_authorized.set_passed()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _on_authorized_vnc_passwd(self, passwd: str) -> str:
|
||||
kc = self.__vnc_credentials[passwd]
|
||||
if (await self._authorize_userpass(kc.user, kc.passwd)):
|
||||
return kc.user
|
||||
return ""
|
||||
async def _on_authorized_vncpass(self) -> None:
|
||||
self.__kvmd_session = self.__kvmd.make_session()
|
||||
self.__stage1_authorized.set_passed()
|
||||
|
||||
async def _on_authorized_none(self) -> bool:
|
||||
async def _authorize_none(self) -> bool:
|
||||
return (await self._authorize_userpass("", ""))
|
||||
|
||||
# =====
|
||||
|
||||
async def _on_key_event(self, code: int, state: bool) -> None:
|
||||
is_modifier = self.__switch_modifiers(code, state)
|
||||
variants = self.__symmap.get(code)
|
||||
fake_shift = False
|
||||
async def _on_key_event(self, key: int, state: bool) -> None:
|
||||
assert self.__stage1_authorized.is_passed()
|
||||
await self.__magic.handle_key(key, state)
|
||||
|
||||
if variants:
|
||||
if is_modifier:
|
||||
web_key = variants.get(0)
|
||||
else:
|
||||
web_key = variants.get(self.__modifiers)
|
||||
if web_key is None:
|
||||
web_key = variants.get(0)
|
||||
async def __on_magic_switch_prev(self) -> None:
|
||||
assert self.__kvmd_session
|
||||
if self.__info_switch_units > 0:
|
||||
get_logger(0).info("%s [main]: Switching port to the previous one ...", self._remote)
|
||||
await self.__kvmd_session.switch.set_active_prev()
|
||||
|
||||
if web_key is None and self.__modifiers == 0 and SymmapModifiers.SHIFT in variants:
|
||||
# JUMP doesn't send shift events:
|
||||
# - https://github.com/pikvm/pikvm/issues/820
|
||||
web_key = variants[SymmapModifiers.SHIFT]
|
||||
fake_shift = True
|
||||
async def __on_magic_switch_next(self) -> None:
|
||||
assert self.__kvmd_session
|
||||
if self.__info_switch_units > 0:
|
||||
get_logger(0).info("%s [main]: Switching port to the next one ...", self._remote)
|
||||
await self.__kvmd_session.switch.set_active_next()
|
||||
|
||||
if web_key and self.__kvmd_ws:
|
||||
if fake_shift:
|
||||
await self.__kvmd_ws.send_key_event(WebModifiers.SHIFT_LEFT, True)
|
||||
await self.__kvmd_ws.send_key_event(web_key, state)
|
||||
if fake_shift:
|
||||
await self.__kvmd_ws.send_key_event(WebModifiers.SHIFT_LEFT, False)
|
||||
|
||||
async def _on_ext_key_event(self, code: int, state: bool) -> None:
|
||||
web_key = AT1_TO_WEB.get(code)
|
||||
if web_key:
|
||||
self.__switch_modifiers(web_key, state) # Предполагаем, что модификаторы всегда известны
|
||||
if self.__kvmd_ws:
|
||||
await self.__kvmd_ws.send_key_event(web_key, state)
|
||||
|
||||
def __switch_modifiers(self, key: (int | str), state: bool) -> bool:
|
||||
mod = 0
|
||||
if key in X11Modifiers.SHIFTS or key in WebModifiers.SHIFTS:
|
||||
mod = SymmapModifiers.SHIFT
|
||||
elif key == X11Modifiers.ALTGR or key == WebModifiers.ALT_RIGHT:
|
||||
mod = SymmapModifiers.ALTGR
|
||||
elif key in X11Modifiers.CTRLS or key in WebModifiers.CTRLS:
|
||||
mod = SymmapModifiers.CTRL
|
||||
if mod == 0:
|
||||
return False
|
||||
if state:
|
||||
self.__modifiers |= mod
|
||||
else:
|
||||
self.__modifiers &= ~mod
|
||||
async def __on_magic_switch_port(self, codes: list[int]) -> bool:
|
||||
assert self.__kvmd_session
|
||||
assert len(codes) > 0
|
||||
if self.__info_switch_units <= 0:
|
||||
return True
|
||||
elif 1 <= self.__info_switch_units <= 2:
|
||||
port = float(codes[0])
|
||||
else: # self.__info_switch_units > 2:
|
||||
if len(codes) == 1:
|
||||
return False # Wait for the second key
|
||||
port = (codes[0] + 1) + (codes[1] + 1) / 10
|
||||
get_logger(0).info("%s [main]: Switching port to %s ...", self._remote, port)
|
||||
await self.__kvmd_session.switch.set_active(port)
|
||||
return True
|
||||
|
||||
async def _on_pointer_event(self, buttons: dict[str, bool], wheel: dict[str, int], move: dict[str, int]) -> None:
|
||||
async def __on_magic_clipboard_print(self) -> None:
|
||||
assert self.__kvmd_session
|
||||
if self.__clipboard:
|
||||
logger = get_logger(0)
|
||||
logger.info("%s [main]: Printing %d characters ...", self._remote, len(self.__clipboard))
|
||||
try:
|
||||
(keymap_name, available) = await self.__kvmd_session.hid.get_keymaps()
|
||||
if self.__keymap_name in available:
|
||||
keymap_name = self.__keymap_name
|
||||
await self.__kvmd_session.hid.print(self.__clipboard, 0, keymap_name)
|
||||
except Exception:
|
||||
logger.exception("%s [main]: Can't print characters", self._remote)
|
||||
|
||||
async def __on_magic_key_proxy(self, key: int, state: bool) -> None:
|
||||
if self.__kvmd_ws:
|
||||
if wheel["x"] or wheel["y"]:
|
||||
await self.__kvmd_ws.send_mouse_wheel_event(wheel["x"], wheel["y"])
|
||||
await self.__kvmd_ws.send_key_event(key, state)
|
||||
|
||||
if self.__mouse_move != move:
|
||||
await self.__kvmd_ws.send_mouse_move_event(move["x"], move["y"])
|
||||
self.__mouse_move = move
|
||||
# =====
|
||||
|
||||
for (button, state) in buttons.items():
|
||||
if self.__mouse_buttons[button] != state:
|
||||
await self.__kvmd_ws.send_mouse_button_event(button, state)
|
||||
self.__mouse_buttons[button] = state
|
||||
async def _on_mouse_button_event(self, button: int, state: bool) -> None:
|
||||
assert self.__stage1_authorized.is_passed()
|
||||
if self.__kvmd_ws:
|
||||
await self.__kvmd_ws.send_mouse_button_event(button, state)
|
||||
|
||||
async def _on_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
|
||||
assert self.__stage1_authorized.is_passed()
|
||||
if self.__kvmd_ws:
|
||||
await self.__kvmd_ws.send_mouse_wheel_event(delta_x, delta_y)
|
||||
|
||||
async def _on_mouse_move_event(self, to_x: int, to_y: int) -> None:
|
||||
assert self.__stage1_authorized.is_passed()
|
||||
if self.__kvmd_ws:
|
||||
await self.__kvmd_ws.send_mouse_move_event(to_x, to_y)
|
||||
|
||||
# =====
|
||||
|
||||
async def _on_cut_event(self, text: str) -> None:
|
||||
assert self.__stage1_authorized.is_passed()
|
||||
assert self.__kvmd_session
|
||||
logger = get_logger(0)
|
||||
logger.info("%s [main]: Printing %d characters ...", self._remote, len(text))
|
||||
try:
|
||||
(keymap_name, available) = await self.__kvmd_session.hid.get_keymaps()
|
||||
if self.__keymap_name in available:
|
||||
keymap_name = self.__keymap_name
|
||||
await self.__kvmd_session.hid.print(text, 0, keymap_name)
|
||||
except Exception:
|
||||
logger.exception("%s [main]: Can't print characters", self._remote)
|
||||
self.__clipboard = text
|
||||
|
||||
async def _on_set_encodings(self) -> None:
|
||||
assert self.__stage1_authorized.is_passed()
|
||||
@@ -441,16 +456,17 @@ class VncServer: # pylint: disable=too-many-instance-attributes
|
||||
x509_cert_path: str,
|
||||
x509_key_path: str,
|
||||
|
||||
vncpass_enabled: bool,
|
||||
vncpass_path: str,
|
||||
vencrypt_enabled: bool,
|
||||
|
||||
desired_fps: int,
|
||||
mouse_output: str,
|
||||
keymap_path: str,
|
||||
allow_cut_after: float,
|
||||
scroll_rate: int,
|
||||
|
||||
kvmd: KvmdClient,
|
||||
streamers: list[BaseStreamerClient],
|
||||
vnc_auth_manager: VncAuthManager,
|
||||
) -> None:
|
||||
|
||||
self.__host = network.get_listen_host(host)
|
||||
@@ -460,7 +476,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
|
||||
keymap_name = os.path.basename(keymap_path)
|
||||
symmap = build_symmap(keymap_path)
|
||||
|
||||
self.__vnc_auth_manager = vnc_auth_manager
|
||||
self.__vncpass_enabled = vncpass_enabled
|
||||
self.__vncpass_path = vncpass_path
|
||||
|
||||
shared_params = _SharedParams()
|
||||
|
||||
@@ -487,8 +504,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout) # type: ignore
|
||||
|
||||
try:
|
||||
async with kvmd.make_session("", "") as kvmd_session:
|
||||
none_auth_only = await kvmd_session.auth.check()
|
||||
async with kvmd.make_session() as kvmd_session:
|
||||
none_auth_only = await kvmd_session.auth.check("", "")
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
|
||||
logger.error("%s [entry]: Can't check KVMD auth mode: %s", remote, tools.efmt(ex))
|
||||
return
|
||||
@@ -504,12 +521,12 @@ class VncServer: # pylint: disable=too-many-instance-attributes
|
||||
mouse_output=mouse_output,
|
||||
keymap_name=keymap_name,
|
||||
symmap=symmap,
|
||||
allow_cut_after=allow_cut_after,
|
||||
scroll_rate=scroll_rate,
|
||||
kvmd=kvmd,
|
||||
streamers=streamers,
|
||||
vnc_credentials=(await self.__vnc_auth_manager.read_credentials())[0],
|
||||
none_auth_only=none_auth_only,
|
||||
vncpasses=(await self.__read_vncpasses()),
|
||||
vencrypt=vencrypt_enabled,
|
||||
none_auth_only=none_auth_only,
|
||||
shared_params=shared_params,
|
||||
).run()
|
||||
except Exception:
|
||||
@@ -520,9 +537,6 @@ class VncServer: # pylint: disable=too-many-instance-attributes
|
||||
self.__handle_client = handle_client
|
||||
|
||||
async def __inner_run(self) -> None:
|
||||
if not (await self.__vnc_auth_manager.read_credentials())[1]:
|
||||
raise SystemExit(1)
|
||||
|
||||
get_logger(0).info("Listening VNC on TCP [%s]:%d ...", self.__host, self.__port)
|
||||
(family, _, _, _, addr) = socket.getaddrinfo(self.__host, self.__port, type=socket.SOCK_STREAM)[0]
|
||||
with contextlib.closing(socket.socket(family, socket.SOCK_STREAM)) as sock:
|
||||
@@ -539,6 +553,21 @@ class VncServer: # pylint: disable=too-many-instance-attributes
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
@async_lru.alru_cache(maxsize=1, ttl=1)
|
||||
async def __read_vncpasses(self) -> set[str]:
|
||||
if self.__vncpass_enabled:
|
||||
try:
|
||||
vncpasses: set[str] = set()
|
||||
for (_, line) in tools.passwds_splitted(await aiotools.read_file(self.__vncpass_path)):
|
||||
if " -> " in line: # Compatibility with old ipmipasswd file format
|
||||
line = line.split(" -> ", 1)[0]
|
||||
if len(line.strip()) > 0:
|
||||
vncpasses.add(line)
|
||||
return vncpasses
|
||||
except Exception:
|
||||
get_logger(0).exception("Unhandled exception while reading VNCAuth passwd file")
|
||||
return set()
|
||||
|
||||
def run(self) -> None:
|
||||
aiotools.run(self.__inner_run())
|
||||
get_logger().info("Bye-bye")
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2020 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import dataclasses
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
from ... import aiotools
|
||||
|
||||
|
||||
# =====
|
||||
class VncAuthError(Exception):
|
||||
def __init__(self, path: str, lineno: int, msg: str) -> None:
|
||||
super().__init__(f"Syntax error at {path}:{lineno}: {msg}")
|
||||
|
||||
|
||||
# =====
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class VncAuthKvmdCredentials:
|
||||
user: str
|
||||
passwd: str
|
||||
|
||||
|
||||
class VncAuthManager:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
enabled: bool,
|
||||
) -> None:
|
||||
|
||||
self.__path = path
|
||||
self.__enabled = enabled
|
||||
|
||||
async def read_credentials(self) -> tuple[dict[str, VncAuthKvmdCredentials], bool]:
|
||||
if self.__enabled:
|
||||
try:
|
||||
return (await self.__inner_read_credentials(), True)
|
||||
except VncAuthError as ex:
|
||||
get_logger(0).error(str(ex))
|
||||
except Exception:
|
||||
get_logger(0).exception("Unhandled exception while reading VNCAuth passwd file")
|
||||
return ({}, (not self.__enabled))
|
||||
|
||||
async def __inner_read_credentials(self) -> dict[str, VncAuthKvmdCredentials]:
|
||||
lines = (await aiotools.read_file(self.__path)).split("\n")
|
||||
credentials: dict[str, VncAuthKvmdCredentials] = {}
|
||||
for (lineno, line) in enumerate(lines):
|
||||
if len(line.strip()) == 0 or line.lstrip().startswith("#"):
|
||||
continue
|
||||
|
||||
if " -> " not in line:
|
||||
raise VncAuthError(self.__path, lineno, "Missing ' -> ' operator")
|
||||
|
||||
(vnc_passwd, kvmd_userpass) = map(str.lstrip, line.split(" -> ", 1))
|
||||
if ":" not in kvmd_userpass:
|
||||
raise VncAuthError(self.__path, lineno, "Missing ':' operator in KVMD credentials (right part)")
|
||||
|
||||
(kvmd_user, kvmd_passwd) = kvmd_userpass.split(":")
|
||||
kvmd_user = kvmd_user.strip()
|
||||
if len(kvmd_user) == 0:
|
||||
raise VncAuthError(self.__path, lineno, "Empty KVMD user (right part)")
|
||||
|
||||
if vnc_passwd in credentials:
|
||||
raise VncAuthError(self.__path, lineno, "Duplicating VNC password (left part)")
|
||||
|
||||
credentials[vnc_passwd] = VncAuthKvmdCredentials(kvmd_user, kvmd_passwd)
|
||||
return credentials
|
||||
Reference in New Issue
Block a user