Merge remote-tracking branch 'upstream/master'

This commit is contained in:
mofeng-git 2024-11-20 15:18:34 +00:00
commit eec64ef57c
166 changed files with 5421 additions and 2645 deletions

View File

@ -1,7 +1,7 @@
[bumpversion] [bumpversion]
commit = True commit = True
tag = True tag = True
current_version = 4.3 current_version = 4.20
parse = (?P<major>\d+)\.(?P<minor>\d+)(\.(?P<patch>\d+)(\-(?P<release>[a-z]+))?)? parse = (?P<major>\d+)\.(?P<minor>\d+)(\.(?P<patch>\d+)(\-(?P<release>[a-z]+))?)?
serialize = serialize =
{major}.{minor} {major}.{minor}

View File

@ -39,7 +39,7 @@ for _variant in "${_variants[@]}"; do
pkgname+=(kvmd-platform-$_platform-$_board) pkgname+=(kvmd-platform-$_platform-$_board)
done done
pkgbase=kvmd pkgbase=kvmd
pkgver=4.3 pkgver=4.20
pkgrel=1 pkgrel=1
pkgdesc="The main PiKVM daemon" pkgdesc="The main PiKVM daemon"
url="https://github.com/pikvm/kvmd" url="https://github.com/pikvm/kvmd"
@ -77,6 +77,8 @@ depends=(
python-ldap python-ldap
python-zstandard python-zstandard
python-mako python-mako
python-luma-oled
python-pyusb
"libgpiod>=2.1" "libgpiod>=2.1"
freetype2 freetype2
"v4l-utils>=1.22.1-1" "v4l-utils>=1.22.1-1"
@ -91,7 +93,7 @@ depends=(
certbot certbot
platform-io-access platform-io-access
raspberrypi-utils raspberrypi-utils
"ustreamer>=6.11" "ustreamer>=6.16"
# Systemd UDEV bug # Systemd UDEV bug
"systemd>=248.3-2" "systemd>=248.3-2"
@ -131,6 +133,7 @@ conflicts=(
python-aiohttp-pikvm python-aiohttp-pikvm
platformio platformio
avrdude-pikvm avrdude-pikvm
kvmd-oled
) )
makedepends=( makedepends=(
python-setuptools python-setuptools
@ -206,7 +209,7 @@ for _variant in "${_variants[@]}"; do
cd \"kvmd-\$pkgver\" cd \"kvmd-\$pkgver\"
pkgdesc=\"PiKVM platform configs - $_platform for $_board\" pkgdesc=\"PiKVM platform configs - $_platform for $_board\"
depends=(kvmd=$pkgver-$pkgrel \"linux-rpi-pikvm>=6.6.21-3\") depends=(kvmd=$pkgver-$pkgrel \"linux-rpi-pikvm>=6.6.45-1\" \"raspberrypi-bootloader-pikvm>=20240818-1\")
backup=( backup=(
etc/sysctl.d/99-kvmd.conf etc/sysctl.d/99-kvmd.conf

View File

@ -0,0 +1,98 @@
# Don't touch this file otherwise your device may stop working.
# Use override.yaml to modify required settings.
# You can find a working configuration in /usr/share/kvmd/configs.default/kvmd.
override: !include [override.d, override.yaml]
logging: !include logging.yaml
kvmd:
auth: !include auth.yaml
info:
hw:
ignore_past: true
fan:
unix: /run/kvmd/fan.sock
hid:
type: otg
mouse_alt:
device: /dev/kvmd-hid-mouse-alt
atx:
type: gpio
power_led_pin: 4
hdd_led_pin: 5
power_switch_pin: 23
reset_switch_pin: 27
msd:
type: otg
streamer:
h264_bitrate:
default: 5000
cmd:
- "/usr/bin/ustreamer"
- "--device=/dev/kvmd-video"
- "--persistent"
- "--dv-timings"
- "--format=uyvy"
- "--format-swap-rgb"
- "--buffers=8"
- "--encoder=m2m-image"
- "--workers=3"
- "--quality={quality}"
- "--desired-fps={desired_fps}"
- "--drop-same-frames=30"
- "--unix={unix}"
- "--unix-rm"
- "--unix-mode=0660"
- "--exit-on-parent-death"
- "--process-name-prefix={process_name_prefix}"
- "--notify-parent"
- "--no-log-colors"
- "--jpeg-sink=kvmd::ustreamer::jpeg"
- "--jpeg-sink-mode=0660"
- "--h264-sink=kvmd::ustreamer::h264"
- "--h264-sink-mode=0660"
- "--h264-bitrate={h264_bitrate}"
- "--h264-gop={h264_gop}"
gpio:
drivers:
__v4_locator__:
type: locator
scheme:
__v3_usb_breaker__:
pin: 22
mode: output
initial: true
pulse: false
__v4_locator__:
driver: __v4_locator__
pin: 12
mode: output
pulse: false
__v4_const1__:
pin: 6
mode: output
initial: false
switch: false
pulse: false
vnc:
memsink:
jpeg:
sink: "kvmd::ustreamer::jpeg"
h264:
sink: "kvmd::ustreamer::h264"
otg:
remote_wakeup: true

View File

@ -0,0 +1,12 @@
[Unit]
Description=PiKVM - Display reboot message on the OLED
DefaultDependencies=no
[Service]
Type=oneshot
ExecStart=/bin/bash -c "kill -USR1 `systemctl show -P MainPID kvmd-oled`"
ExecStop=/bin/true
RemainAfterExit=yes
[Install]
WantedBy=reboot.target

View File

@ -0,0 +1,14 @@
[Unit]
Description=PiKVM - Display shutdown message on the OLED
Conflicts=reboot.target
Before=shutdown.target poweroff.target halt.target
DefaultDependencies=no
[Service]
Type=oneshot
ExecStart=/bin/bash -c "kill -USR2 `systemctl show -P MainPID kvmd-oled`"
ExecStop=/bin/true
RemainAfterExit=yes
[Install]
WantedBy=shutdown.target

View File

@ -0,0 +1,15 @@
[Unit]
Description=PiKVM - A small OLED daemon
After=systemd-modules-load.service
ConditionPathExists=/dev/i2c-1
[Service]
Type=simple
Restart=always
RestartSec=3
ExecStartPre=/usr/bin/kvmd-oled --interval=3 --clear-on-exit --image=@hello.ppm
ExecStart=/usr/bin/kvmd-oled
TimeoutStopSec=3
[Install]
WantedBy=multi-user.target

View File

@ -1,15 +0,0 @@
[Unit]
Description=PiKVM - Video Passthrough on V4 Plus
Wants=dev-kvmd\x2dvideo.device
After=dev-kvmd\x2dvideo.device systemd-modules-load.service
[Service]
Type=simple
Restart=always
RestartSec=3
ExecStart=/usr/bin/ustreamer-v4p --unix-follow /run/kvmd/ustreamer.sock
TimeoutStopSec=10
[Install]
WantedBy=multi-user.target

View File

@ -2,11 +2,11 @@
Description=PiKVM - EDID loader for TC358743 Description=PiKVM - EDID loader for TC358743
Wants=dev-kvmd\x2dvideo.device Wants=dev-kvmd\x2dvideo.device
After=dev-kvmd\x2dvideo.device systemd-modules-load.service After=dev-kvmd\x2dvideo.device systemd-modules-load.service
Before=kvmd.service kvmd-pass.service Before=kvmd.service
[Service] [Service]
Type=oneshot Type=oneshot
ExecStart=/usr/bin/v4l2-ctl --device=/dev/kvmd-video --set-edid=file=/etc/kvmd/tc358743-edid.hex --fix-edid-checksums --info-edid ExecStart=/usr/bin/v4l2-ctl --device=/dev/kvmd-video --set-edid=file=/etc/kvmd/tc358743-edid.hex --info-edid
ExecStop=/usr/bin/v4l2-ctl --device=/dev/kvmd-video --clear-edid ExecStop=/usr/bin/v4l2-ctl --device=/dev/kvmd-video --clear-edid
RemainAfterExit=true RemainAfterExit=true

View File

@ -19,6 +19,7 @@ m kvmd gpio
m kvmd uucp m kvmd uucp
m kvmd spi m kvmd spi
m kvmd systemd-journal m kvmd systemd-journal
m kvmd kvmd-pst
m kvmd-pst kvmd m kvmd-pst kvmd

View File

@ -27,7 +27,8 @@ post_upgrade() {
done done
chown kvmd /var/lib/kvmd/msd 2>/dev/null || true chown kvmd /var/lib/kvmd/msd 2>/dev/null || true
chown kvmd-pst /var/lib/kvmd/pst 2>/dev/null || true chown kvmd-pst:kvmd-pst /var/lib/kvmd/pst 2>/dev/null || true
chmod 1775 /var/lib/kvmd/pst 2>/dev/null || true
if [ ! -e /etc/kvmd/nginx/ssl/server.crt ]; then if [ ! -e /etc/kvmd/nginx/ssl/server.crt ]; then
echo "==> Generating KVMD-Nginx certificate ..." echo "==> Generating KVMD-Nginx certificate ..."
@ -92,6 +93,15 @@ disable_overscan=1
EOF EOF
fi fi
if [[ "$(vercmp "$2" 4.4)" -lt 0 ]]; then
systemctl disable kvmd-pass || true
fi
if [[ "$(vercmp "$2" 4.5)" -lt 0 ]]; then
sed -i 's/X-kvmd\.pst-user=kvmd-pst/X-kvmd.pst-user=kvmd-pst,X-kvmd.pst-group=kvmd-pst/g' /etc/fstab
touch -t 200701011000 /etc/fstab
fi
# Some update deletes /etc/motd, WTF # Some update deletes /etc/motd, WTF
# shellcheck disable=SC2015,SC2166 # shellcheck disable=SC2015,SC2166
[ ! -f /etc/motd -a -f /etc/motd.pacsave ] && mv /etc/motd.pacsave /etc/motd || true [ ! -f /etc/motd -a -f /etc/motd.pacsave ] && mv /etc/motd.pacsave /etc/motd || true

View File

@ -20,4 +20,4 @@
# ========================================================================== # # ========================================================================== #
__version__ = "4.3" __version__ = "4.20"

View File

@ -83,9 +83,9 @@ class AioReader: # pylint: disable=too-many-instance-attributes
self.__path, self.__path,
consumer=self.__consumer, consumer=self.__consumer,
config={tuple(pins): gpiod.LineSettings(edge_detection=gpiod.line.Edge.BOTH)}, config={tuple(pins): gpiod.LineSettings(edge_detection=gpiod.line.Edge.BOTH)},
) as line_request: ) as line_req:
line_request.wait_edge_events(0.1) line_req.wait_edge_events(0.1)
self.__values = { self.__values = {
pin: _DebouncedValue( pin: _DebouncedValue(
initial=bool(value.value), initial=bool(value.value),
@ -93,14 +93,14 @@ class AioReader: # pylint: disable=too-many-instance-attributes
notifier=self.__notifier, notifier=self.__notifier,
loop=self.__loop, loop=self.__loop,
) )
for (pin, value) in zip(pins, line_request.get_values(pins)) for (pin, value) in zip(pins, line_req.get_values(pins))
} }
self.__loop.call_soon_threadsafe(self.__notifier.notify) self.__loop.call_soon_threadsafe(self.__notifier.notify)
while not self.__stop_event.is_set(): while not self.__stop_event.is_set():
if line_request.wait_edge_events(1): if line_req.wait_edge_events(1):
new: dict[int, bool] = {} new: dict[int, bool] = {}
for event in line_request.read_edge_events(): for event in line_req.read_edge_events():
(pin, value) = self.__parse_event(event) (pin, value) = self.__parse_event(event)
new[pin] = value new[pin] = value
for (pin, value) in new.items(): for (pin, value) in new.items():
@ -110,7 +110,7 @@ class AioReader: # pylint: disable=too-many-instance-attributes
# Размер буфера ядра - 16 эвентов на линии. При превышении этого числа, # Размер буфера ядра - 16 эвентов на линии. При превышении этого числа,
# новые эвенты потеряются. Это не баг, это фича, как мне объяснили в LKML. # новые эвенты потеряются. Это не баг, это фича, как мне объяснили в LKML.
# Штош. Будем с этим жить и синхронизировать состояния при таймауте. # Штош. Будем с этим жить и синхронизировать состояния при таймауте.
for (pin, value) in zip(pins, line_request.get_values(pins)): for (pin, value) in zip(pins, line_req.get_values(pins)):
self.__values[pin].set(bool(value.value)) # type: ignore self.__values[pin].set(bool(value.value)) # type: ignore
def __parse_event(self, event: gpiod.EdgeEvent) -> tuple[int, bool]: def __parse_event(self, event: gpiod.EdgeEvent) -> tuple[int, bool]:

View File

@ -42,7 +42,7 @@ async def remount(name: str, base_cmd: list[str], rw: bool) -> bool:
if proc.returncode != 0: if proc.returncode != 0:
assert proc.returncode is not None assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, cmd) raise subprocess.CalledProcessError(proc.returncode, cmd)
except Exception as err: except Exception as ex:
logger.error("Can't remount %s storage: %s", name, tools.efmt(err)) logger.error("Can't remount %s storage: %s", name, tools.efmt(ex))
return False return False
return True return True

View File

@ -59,14 +59,25 @@ def queue_get_last_sync( # pylint: disable=invalid-name
# ===== # =====
class AioProcessNotifier: class AioProcessNotifier:
def __init__(self) -> None: def __init__(self) -> None:
self.__queue: "multiprocessing.Queue[None]" = multiprocessing.Queue() self.__queue: "multiprocessing.Queue[int]" = multiprocessing.Queue()
def notify(self) -> None: def notify(self, mask: int=0) -> None:
self.__queue.put_nowait(None) self.__queue.put_nowait(mask)
async def wait(self) -> None: async def wait(self) -> int:
while not (await queue_get_last(self.__queue, 0.1))[0]: while True:
pass mask = await aiotools.run_async(self.__get)
if mask >= 0:
return mask
def __get(self) -> int:
try:
mask = self.__queue.get(timeout=0.1)
while not self.__queue.empty():
mask |= self.__queue.get()
return mask
except queue.Empty:
return -1
# ===== # =====

View File

@ -112,9 +112,9 @@ def shield_fg(aw: Awaitable): # type: ignore
if inner.cancelled(): if inner.cancelled():
outer.forced_cancel() outer.forced_cancel()
else: else:
err = inner.exception() ex = inner.exception()
if err is not None: if ex is not None:
outer.set_exception(err) outer.set_exception(ex)
else: else:
outer.set_result(inner.result()) outer.set_result(inner.result())
@ -232,25 +232,26 @@ async def close_writer(writer: asyncio.StreamWriter) -> bool:
# ===== # =====
class AioNotifier: class AioNotifier:
def __init__(self) -> None: def __init__(self) -> None:
self.__queue: "asyncio.Queue[None]" = asyncio.Queue() self.__queue: "asyncio.Queue[int]" = asyncio.Queue()
def notify(self) -> None: def notify(self, mask: int=0) -> None:
self.__queue.put_nowait(None) self.__queue.put_nowait(mask)
async def wait(self, timeout: (float | None)=None) -> None: async def wait(self, timeout: (float | None)=None) -> int:
mask = 0
if timeout is None: if timeout is None:
await self.__queue.get() mask = await self.__queue.get()
else: else:
try: try:
await asyncio.wait_for( mask = await asyncio.wait_for(
asyncio.ensure_future(self.__queue.get()), asyncio.ensure_future(self.__queue.get()),
timeout=timeout, timeout=timeout,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
return # False return -1
while not self.__queue.empty(): while not self.__queue.empty():
await self.__queue.get() mask |= await self.__queue.get()
# return True return mask
# ===== # =====
@ -296,7 +297,7 @@ class AioExclusiveRegion:
def is_busy(self) -> bool: def is_busy(self) -> bool:
return self.__busy return self.__busy
async def enter(self) -> None: def enter(self) -> None:
if not self.__busy: if not self.__busy:
self.__busy = True self.__busy = True
try: try:
@ -308,22 +309,22 @@ class AioExclusiveRegion:
return return
raise self.__exc_type() raise self.__exc_type()
async def exit(self) -> None: def exit(self) -> None:
self.__busy = False self.__busy = False
if self.__notifier: if self.__notifier:
self.__notifier.notify() self.__notifier.notify()
async def __aenter__(self) -> None: def __enter__(self) -> None:
await self.enter() self.enter()
async def __aexit__( def __exit__(
self, self,
_exc_type: type[BaseException], _exc_type: type[BaseException],
_exc: BaseException, _exc: BaseException,
_tb: types.TracebackType, _tb: types.TracebackType,
) -> None: ) -> None:
await self.exit() self.exit()
async def run_region_task( async def run_region_task(
@ -338,7 +339,7 @@ async def run_region_task(
async def wrapper() -> None: async def wrapper() -> None:
try: try:
async with region: with region:
entered.set_result(None) entered.set_result(None)
await func(*args, **kwargs) await func(*args, **kwargs)
except region.get_exc_type(): except region.get_exc_type():

View File

@ -33,8 +33,6 @@ import pygments.formatters
from .. import tools from .. import tools
from ..mouse import MouseRange
from ..plugins import UnknownPluginError from ..plugins import UnknownPluginError
from ..plugins.auth import get_auth_service_class from ..plugins.auth import get_auth_service_class
from ..plugins.hid import get_hid_class from ..plugins.hid import get_hid_class
@ -171,8 +169,8 @@ def _init_config(config_path: str, override_options: list[str], **load_flags: bo
config_path = os.path.expanduser(config_path) config_path = os.path.expanduser(config_path)
try: try:
raw_config: dict = load_yaml_file(config_path) raw_config: dict = load_yaml_file(config_path)
except Exception as err: except Exception as ex:
raise SystemExit(f"ConfigError: Can't read config file {config_path!r}:\n{tools.efmt(err)}") raise SystemExit(f"ConfigError: Can't read config file {config_path!r}:\n{tools.efmt(ex)}")
if not isinstance(raw_config, dict): if not isinstance(raw_config, dict):
raise SystemExit(f"ConfigError: Top-level of the file {config_path!r} must be a dictionary") raise SystemExit(f"ConfigError: Top-level of the file {config_path!r} must be a dictionary")
@ -187,8 +185,8 @@ def _init_config(config_path: str, override_options: list[str], **load_flags: bo
config = make_config(raw_config, scheme) config = make_config(raw_config, scheme)
return config return config
except (ConfigError, UnknownPluginError) as err: except (ConfigError, UnknownPluginError) as ex:
raise SystemExit(f"ConfigError: {err}") raise SystemExit(f"ConfigError: {ex}")
def _patch_raw(raw_config: dict) -> None: # pylint: disable=too-many-branches def _patch_raw(raw_config: dict) -> None: # pylint: disable=too-many-branches
@ -407,19 +405,7 @@ def _get_config_scheme() -> dict:
"hid": { "hid": {
"type": Option("", type=valid_stripped_string_not_empty), "type": Option("", type=valid_stripped_string_not_empty),
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file), "keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
"ignore_keys": Option([], type=functools.partial(valid_string_list, subval=valid_hid_key)),
"mouse_x_range": {
"min": Option(MouseRange.MIN, type=valid_hid_mouse_move),
"max": Option(MouseRange.MAX, type=valid_hid_mouse_move),
},
"mouse_y_range": {
"min": Option(MouseRange.MIN, type=valid_hid_mouse_move),
"max": Option(MouseRange.MAX, type=valid_hid_mouse_move),
},
# Dynamic content # Dynamic content
}, },
@ -684,6 +670,7 @@ def _get_config_scheme() -> dict:
"desired_fps": Option(30, type=valid_stream_fps), "desired_fps": Option(30, type=valid_stream_fps),
"mouse_output": Option("usb", type=valid_hid_mouse_output), "mouse_output": Option("usb", type=valid_hid_mouse_output),
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file), "keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
"allow_cut_after": Option(3.0, type=valid_float_f0),
"server": { "server": {
"host": Option("", type=valid_ip_or_host, if_empty=""), "host": Option("", type=valid_ip_or_host, if_empty=""),

View File

@ -22,259 +22,22 @@
import sys import sys
import os import os
import re
import dataclasses
import contextlib
import subprocess import subprocess
import argparse import argparse
import time import time
from typing import IO
from typing import Generator
from typing import Callable from typing import Callable
from ...validators.basic import valid_bool from ...validators.basic import valid_bool
from ...validators.basic import valid_int_f0 from ...validators.basic import valid_int_f0
from ...edid import EdidNoBlockError
from ...edid import Edid
# from .. import init # from .. import init
# ===== # =====
class NoBlockError(Exception):
pass
@contextlib.contextmanager
def _smart_open(path: str, mode: str) -> Generator[IO, None, None]:
fd = (0 if "r" in mode else 1)
with (os.fdopen(fd, mode, closefd=False) if path == "-" else open(path, mode)) as file:
yield file
if "w" in mode:
file.flush()
@dataclasses.dataclass(frozen=True)
class _CeaBlock:
tag: int
data: bytes
def __post_init__(self) -> None:
assert 0 < self.tag <= 0b111
assert 0 < len(self.data) <= 0b11111
@property
def size(self) -> int:
return len(self.data) + 1
def pack(self) -> bytes:
header = (self.tag << 5) | len(self.data)
return header.to_bytes() + self.data
@classmethod
def first_from_raw(cls, raw: (bytes | list[int])) -> "_CeaBlock":
assert 0 < raw[0] <= 0xFF
tag = (raw[0] & 0b11100000) >> 5
data_size = (raw[0] & 0b00011111)
data = bytes(raw[1:data_size + 1])
return _CeaBlock(tag, data)
_CEA = 128
_CEA_AUDIO = 1
_CEA_SPEAKERS = 4
class _Edid:
# https://en.wikipedia.org/wiki/Extended_Display_Identification_Data
def __init__(self, path: str) -> None:
with _smart_open(path, "rb") as file:
data = file.read()
if data.startswith(b"\x00\xFF\xFF\xFF\xFF\xFF\xFF\x00"):
self.__data = list(data)
else:
text = re.sub(r"\s", "", data.decode())
self.__data = [
int(text[index:index + 2], 16)
for index in range(0, len(text), 2)
]
assert len(self.__data) == 256, f"Invalid EDID length: {len(self.__data)}, should be 256 bytes"
assert self.__data[126] == 1, "Zero extensions number"
assert (self.__data[_CEA + 0], self.__data[_CEA + 1]) == (0x02, 0x03), "Can't find CEA extension"
def write_hex(self, path: str) -> None:
self.__update_checksums()
text = "\n".join(
"".join(
f"{item:0{2}X}"
for item in self.__data[index:index + 16]
)
for index in range(0, len(self.__data), 16)
) + "\n"
with _smart_open(path, "w") as file:
file.write(text)
def write_bin(self, path: str) -> None:
self.__update_checksums()
with _smart_open(path, "wb") as file:
file.write(bytes(self.__data))
def __update_checksums(self) -> None:
self.__data[127] = 256 - (sum(self.__data[:127]) % 256)
self.__data[255] = 256 - (sum(self.__data[128:255]) % 256)
# =====
def get_mfc_id(self) -> str:
raw = self.__data[8] << 8 | self.__data[9]
return bytes([
((raw >> 10) & 0b11111) + 0x40,
((raw >> 5) & 0b11111) + 0x40,
(raw & 0b11111) + 0x40,
]).decode("ascii")
def set_mfc_id(self, mfc_id: str) -> None:
assert len(mfc_id) == 3, "Mfc ID must be 3 characters long"
data = mfc_id.upper().encode("ascii")
for ch in data:
assert 0x41 <= ch <= 0x5A, "Mfc ID must contain only A-Z characters"
raw = (
(data[2] - 0x40)
| ((data[1] - 0x40) << 5)
| ((data[0] - 0x40) << 10)
)
self.__data[8] = (raw >> 8) & 0xFF
self.__data[9] = raw & 0xFF
# =====
def get_product_id(self) -> int:
return (self.__data[10] | self.__data[11] << 8)
def set_product_id(self, product_id: int) -> None:
assert 0 <= product_id <= 0xFFFF, f"Product ID should be from 0 to {0xFFFF}"
self.__data[10] = product_id & 0xFF
self.__data[11] = (product_id >> 8) & 0xFF
# =====
def get_serial(self) -> int:
return (
self.__data[12]
| self.__data[13] << 8
| self.__data[14] << 16
| self.__data[15] << 24
)
def set_serial(self, serial: int) -> None:
assert 0 <= serial <= 0xFFFFFFFF, f"Serial should be from 0 to {0xFFFFFFFF}"
self.__data[12] = serial & 0xFF
self.__data[13] = (serial >> 8) & 0xFF
self.__data[14] = (serial >> 16) & 0xFF
self.__data[15] = (serial >> 24) & 0xFF
# =====
def get_monitor_name(self) -> str:
return self.__get_dtd_text(0xFC, "Monitor Name")
def set_monitor_name(self, text: str) -> None:
self.__set_dtd_text(0xFC, "Monitor Name", text)
def get_monitor_serial(self) -> str:
return self.__get_dtd_text(0xFF, "Monitor Serial")
def set_monitor_serial(self, text: str) -> None:
self.__set_dtd_text(0xFF, "Monitor Serial", text)
def __get_dtd_text(self, d_type: int, name: str) -> str:
index = self.__find_dtd_text(d_type, name)
return bytes(self.__data[index:index + 13]).decode("cp437").strip()
def __set_dtd_text(self, d_type: int, name: str, text: str) -> None:
index = self.__find_dtd_text(d_type, name)
encoded = (text[:13] + "\n" + " " * 12)[:13].encode("cp437")
for (offset, ch) in enumerate(encoded):
self.__data[index + offset] = ch
def __find_dtd_text(self, d_type: int, name: str) -> int:
for index in [54, 72, 90, 108]:
if self.__data[index + 3] == d_type:
return index + 5
raise NoBlockError(f"Can't find DTD {name}")
# ===== CEA =====
def get_audio(self) -> bool:
(cbs, _) = self.__parse_cea()
audio = False
speakers = False
for cb in cbs:
if cb.tag == _CEA_AUDIO:
audio = True
elif cb.tag == _CEA_SPEAKERS:
speakers = True
return (audio and speakers and self.__get_basic_audio())
def set_audio(self, enabled: bool) -> None:
(cbs, dtds) = self.__parse_cea()
cbs = [cb for cb in cbs if cb.tag not in [_CEA_AUDIO, _CEA_SPEAKERS]]
if enabled:
cbs.append(_CeaBlock(_CEA_AUDIO, b"\x09\x7f\x07"))
cbs.append(_CeaBlock(_CEA_SPEAKERS, b"\x01\x00\x00"))
self.__replace_cea(cbs, dtds)
self.__set_basic_audio(enabled)
def __get_basic_audio(self) -> bool:
return bool(self.__data[_CEA + 3] & 0b01000000)
def __set_basic_audio(self, enabled: bool) -> None:
if enabled:
self.__data[_CEA + 3] |= 0b01000000
else:
self.__data[_CEA + 3] &= (0xFF - 0b01000000) # ~X
def __parse_cea(self) -> tuple[list[_CeaBlock], bytes]:
cea = self.__data[_CEA:]
dtd_begin = cea[2]
if dtd_begin == 0:
return ([], b"")
cbs: list[_CeaBlock] = []
if dtd_begin > 4:
raw = cea[4:dtd_begin]
while len(raw) != 0:
cb = _CeaBlock.first_from_raw(raw)
cbs.append(cb)
raw = raw[cb.size:]
dtds = b""
assert dtd_begin >= 4
raw = cea[dtd_begin:]
while len(raw) > (18 + 1) and raw[0] != 0:
dtds += bytes(raw[:18])
raw = raw[18:]
return (cbs, dtds)
def __replace_cea(self, cbs: list[_CeaBlock], dtds: bytes) -> None:
cbs_packed = b""
for cb in cbs:
cbs_packed += cb.pack()
raw = cbs_packed + dtds
assert len(raw) <= (128 - 4 - 1), "Too many CEA blocks or DTDs"
self.__data[_CEA + 2] = (0 if len(raw) == 0 else (len(cbs_packed) + 4))
for index in range(4, 127):
try:
ch = raw[index - 4]
except IndexError:
ch = 0
self.__data[_CEA + index] = ch
def _format_bool(value: bool) -> str: def _format_bool(value: bool) -> str:
return ("yes" if value else "no") return ("yes" if value else "no")
@ -283,7 +46,7 @@ def _make_format_hex(size: int) -> Callable[[int], str]:
return (lambda value: ("0x{:0%dX} ({})" % (size * 2)).format(value, value)) return (lambda value: ("0x{:0%dX} ({})" % (size * 2)).format(value, value))
def _print_edid(edid: _Edid) -> None: def _print_edid(edid: Edid) -> None:
for (key, get, fmt) in [ for (key, get, fmt) in [
("Manufacturer ID:", edid.get_mfc_id, str), ("Manufacturer ID:", edid.get_mfc_id, str),
("Product ID: ", edid.get_product_id, _make_format_hex(2)), ("Product ID: ", edid.get_product_id, _make_format_hex(2)),
@ -294,7 +57,7 @@ def _print_edid(edid: _Edid) -> None:
]: ]:
try: try:
print(key, fmt(get()), file=sys.stderr) # type: ignore print(key, fmt(get()), file=sys.stderr) # type: ignore
except NoBlockError: except EdidNoBlockError:
pass pass
@ -348,12 +111,12 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
help="Presets directory", metavar="<dir>") help="Presets directory", metavar="<dir>")
options = parser.parse_args(argv[1:]) options = parser.parse_args(argv[1:])
base: (_Edid | None) = None base: (Edid | None) = None
if options.import_preset: if options.import_preset:
imp = options.import_preset imp = options.import_preset
if "." in imp: if "." in imp:
(base_name, imp) = imp.split(".", 1) # v3.1080p-by-default (base_name, imp) = imp.split(".", 1) # v3.1080p-by-default
base = _Edid(os.path.join(options.presets_path, f"{base_name}.hex")) base = Edid.from_file(os.path.join(options.presets_path, f"{base_name}.hex"))
imp = f"_{imp}" imp = f"_{imp}"
options.imp = os.path.join(options.presets_path, f"{imp}.hex") options.imp = os.path.join(options.presets_path, f"{imp}.hex")
@ -362,16 +125,16 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
options.export_hex = options.edid_path options.export_hex = options.edid_path
options.edid_path = options.imp options.edid_path = options.imp
edid = _Edid(options.edid_path) edid = Edid.from_file(options.edid_path)
changed = False changed = False
for cmd in dir(_Edid): for cmd in dir(Edid):
if cmd.startswith("set_"): if cmd.startswith("set_"):
value = getattr(options, cmd) value = getattr(options, cmd)
if value is None and base is not None: if value is None and base is not None:
try: try:
value = getattr(base, cmd.replace("set_", "get_"))() value = getattr(base, cmd.replace("set_", "get_"))()
except NoBlockError: except EdidNoBlockError:
pass pass
if value is not None: if value is not None:
getattr(edid, cmd)(value) getattr(edid, cmd)(value)
@ -400,8 +163,7 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
"/usr/bin/v4l2-ctl", "/usr/bin/v4l2-ctl",
f"--device={options.device_path}", f"--device={options.device_path}",
f"--set-edid=file={orig_edid_path}", f"--set-edid=file={orig_edid_path}",
"--fix-edid-checksums",
"--info-edid", "--info-edid",
], stdout=sys.stderr, check=True) ], stdout=sys.stderr, check=True)
except subprocess.CalledProcessError as err: except subprocess.CalledProcessError as ex:
raise SystemExit(str(err)) raise SystemExit(str(ex))

View File

@ -155,5 +155,5 @@ def main(argv: (list[str] | None)=None) -> None:
options = parser.parse_args(argv[1:]) options = parser.parse_args(argv[1:])
try: try:
options.cmd(config, options) options.cmd(config, options)
except ValidatorError as err: except ValidatorError as ex:
raise SystemExit(str(err)) raise SystemExit(str(ex))

View File

@ -101,6 +101,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
# ===== # =====
def handle_raw_request(self, request: dict, session: IpmiServerSession) -> None: def handle_raw_request(self, request: dict, session: IpmiServerSession) -> None:
# Parameter 'request' has been renamed to 'req' in overriding method
handler = { handler = {
(6, 1): (lambda _, session: self.send_device_id(session)), # Get device ID (6, 1): (lambda _, session: self.send_device_id(session)), # Get device ID
(6, 7): self.__get_power_state_handler, # Power state (6, 7): self.__get_power_state_handler, # Power state
@ -145,13 +146,13 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
data = [int(result["leds"]["power"]), 0, 0] data = [int(result["leds"]["power"]), 0, 0]
session.send_ipmi_response(data=data) session.send_ipmi_response(data=data)
def __chassis_control_handler(self, request: dict, session: IpmiServerSession) -> None: def __chassis_control_handler(self, req: dict, session: IpmiServerSession) -> None:
action = { action = {
0: "off_hard", 0: "off_hard",
1: "on", 1: "on",
3: "reset_hard", 3: "reset_hard",
5: "off", 5: "off",
}.get(request["data"][0], "") }.get(req["data"][0], "")
if action: if action:
if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action): if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action):
code = 0xC0 # Try again later code = 0xC0 # Try again later
@ -171,8 +172,8 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session: async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session:
func = functools.reduce(getattr, func_path.split("."), kvmd_session) func = functools.reduce(getattr, func_path.split("."), kvmd_session)
return (await func(**kwargs)) return (await func(**kwargs))
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
logger.error("[%s]: Can't perform request %s: %s", session.sockaddr[0], name, err) logger.error("[%s]: Can't perform request %s: %s", session.sockaddr[0], name, ex)
raise raise
return aiotools.run_sync(runner()) return aiotools.run_sync(runner())

View File

@ -11,16 +11,17 @@ from ... import aioproc
from ...logging import get_logger from ...logging import get_logger
from .stun import StunNatType
from .stun import Stun from .stun import Stun
# ===== # =====
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class _Netcfg: class _Netcfg:
nat_type: str = dataclasses.field(default="") nat_type: StunNatType = dataclasses.field(default=StunNatType.ERROR)
src_ip: str = dataclasses.field(default="") src_ip: str = dataclasses.field(default="")
ext_ip: str = dataclasses.field(default="") ext_ip: str = dataclasses.field(default="")
stun_host: str = dataclasses.field(default="") stun_ip: str = dataclasses.field(default="")
stun_port: int = dataclasses.field(default=0) stun_port: int = dataclasses.field(default=0)
@ -92,8 +93,9 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
async def __get_netcfg(self) -> _Netcfg: async def __get_netcfg(self) -> _Netcfg:
src_ip = (self.__get_default_ip() or "0.0.0.0") src_ip = (self.__get_default_ip() or "0.0.0.0")
(stun, (nat_type, ext_ip)) = await self.__get_stun_info(src_ip) info = await self.__stun.get_info(src_ip, 0)
return _Netcfg(nat_type, src_ip, ext_ip, stun.host, stun.port) # В текущей реализации _Netcfg() это копия StunInfo()
return _Netcfg(**dataclasses.asdict(info))
def __get_default_ip(self) -> str: def __get_default_ip(self) -> str:
try: try:
@ -111,17 +113,10 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
for proto in [socket.AF_INET, socket.AF_INET6]: for proto in [socket.AF_INET, socket.AF_INET6]:
if proto in addrs: if proto in addrs:
return addrs[proto][0]["addr"] return addrs[proto][0]["addr"]
except Exception as err: except Exception as ex:
get_logger().error("Can't get default IP: %s", tools.efmt(err)) get_logger().error("Can't get default IP: %s", tools.efmt(ex))
return "" return ""
async def __get_stun_info(self, src_ip: str) -> tuple[Stun, tuple[str, str]]:
try:
return (self.__stun, (await self.__stun.get_info(src_ip, 0)))
except Exception as err:
get_logger().error("Can't get STUN info: %s", tools.efmt(err))
return (self.__stun, ("", ""))
# ===== # =====
@aiotools.atomic_fg @aiotools.atomic_fg
@ -162,7 +157,7 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
async def __start_janus_proc(self, netcfg: _Netcfg) -> None: async def __start_janus_proc(self, netcfg: _Netcfg) -> None:
assert self.__janus_proc is None assert self.__janus_proc is None
placeholders = { placeholders = {
"o_stun_server": f"--stun-server={netcfg.stun_host}:{netcfg.stun_port}", "o_stun_server": f"--stun-server={netcfg.stun_ip}:{netcfg.stun_port}",
**{ **{
key: str(value) key: str(value)
for (key, value) in dataclasses.asdict(netcfg).items() for (key, value) in dataclasses.asdict(netcfg).items()

View File

@ -4,6 +4,7 @@ import ipaddress
import struct import struct
import secrets import secrets
import dataclasses import dataclasses
import enum
from ... import tools from ... import tools
from ... import aiotools from ... import aiotools
@ -12,21 +13,8 @@ from ...logging import get_logger
# ===== # =====
@dataclasses.dataclass(frozen=True) class StunNatType(enum.Enum):
class StunAddress: ERROR = ""
ip: str
port: int
@dataclasses.dataclass(frozen=True)
class StunResponse:
ok: bool
ext: (StunAddress | None) = dataclasses.field(default=None)
src: (StunAddress | None) = dataclasses.field(default=None)
changed: (StunAddress | None) = dataclasses.field(default=None)
class StunNatType:
BLOCKED = "Blocked" BLOCKED = "Blocked"
OPEN_INTERNET = "Open Internet" OPEN_INTERNET = "Open Internet"
SYMMETRIC_UDP_FW = "Symmetric UDP Firewall" SYMMETRIC_UDP_FW = "Symmetric UDP Firewall"
@ -37,6 +25,29 @@ class StunNatType:
CHANGED_ADDR_ERROR = "Error when testing on Changed-IP and Port" CHANGED_ADDR_ERROR = "Error when testing on Changed-IP and Port"
@dataclasses.dataclass(frozen=True)
class StunInfo:
nat_type: StunNatType
src_ip: str
ext_ip: str
stun_ip: str
stun_port: int
@dataclasses.dataclass(frozen=True)
class _StunAddress:
ip: str
port: int
@dataclasses.dataclass(frozen=True)
class _StunResponse:
ok: bool
ext: (_StunAddress | None) = dataclasses.field(default=None)
src: (_StunAddress | None) = dataclasses.field(default=None)
changed: (_StunAddress | None) = dataclasses.field(default=None)
# ===== # =====
class Stun: class Stun:
# Partially based on https://github.com/JohnVillalovos/pystun # Partially based on https://github.com/JohnVillalovos/pystun
@ -50,58 +61,94 @@ class Stun:
retries_delay: float, retries_delay: float,
) -> None: ) -> None:
self.host = host self.__host = host
self.port = port self.__port = port
self.__timeout = timeout self.__timeout = timeout
self.__retries = retries self.__retries = retries
self.__retries_delay = retries_delay self.__retries_delay = retries_delay
self.__stun_ip = ""
self.__sock: (socket.socket | None) = None self.__sock: (socket.socket | None) = None
async def get_info(self, src_ip: str, src_port: int) -> tuple[str, str]: async def get_info(self, src_ip: str, src_port: int) -> StunInfo:
(family, _, _, _, addr) = socket.getaddrinfo(src_ip, src_port, type=socket.SOCK_DGRAM)[0] nat_type = StunNatType.ERROR
ext_ip = ""
try: try:
with socket.socket(family, socket.SOCK_DGRAM) as self.__sock: (src_fam, _, _, _, src_addr) = (await self.__retried_getaddrinfo_udp(src_ip, src_port))[0]
stun_ips = [
stun_addr[0]
for (stun_fam, _, _, _, stun_addr) in (await self.__retried_getaddrinfo_udp(self.__host, self.__port))
if stun_fam == src_fam
]
if not stun_ips:
raise RuntimeError(f"Can't resolve {src_fam.name} address for STUN")
if not self.__stun_ip or self.__stun_ip not in stun_ips:
# On new IP, changed family, etc.
self.__stun_ip = stun_ips[0]
with socket.socket(src_fam, socket.SOCK_DGRAM) as self.__sock:
self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.__sock.settimeout(self.__timeout) self.__sock.settimeout(self.__timeout)
self.__sock.bind(addr) self.__sock.bind(src_addr)
(nat_type, response) = await self.__get_nat_type(src_ip) (nat_type, resp) = await self.__get_nat_type(src_ip)
return (nat_type, (response.ext.ip if response.ext is not None else "")) ext_ip = (resp.ext.ip if resp.ext is not None else "")
except Exception as ex:
get_logger(0).error("Can't get STUN info: %s", tools.efmt(ex))
finally: finally:
self.__sock = None self.__sock = None
async def __get_nat_type(self, src_ip: str) -> tuple[str, StunResponse]: # pylint: disable=too-many-return-statements return StunInfo(
first = await self.__make_request("First probe") nat_type=nat_type,
src_ip=src_ip,
ext_ip=ext_ip,
stun_ip=self.__stun_ip,
stun_port=self.__port,
)
async def __retried_getaddrinfo_udp(self, host: str, port: int) -> list:
retries = self.__retries
while True:
try:
return socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
except Exception:
retries -= 1
if retries == 0:
raise
await asyncio.sleep(self.__retries_delay)
async def __get_nat_type(self, src_ip: str) -> tuple[StunNatType, _StunResponse]: # pylint: disable=too-many-return-statements
first = await self.__make_request("First probe", self.__stun_ip, b"")
if not first.ok: if not first.ok:
return (StunNatType.BLOCKED, first) return (StunNatType.BLOCKED, first)
request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request req = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request
response = await self.__make_request("Change request [ext_ip == src_ip]", request) resp = await self.__make_request("Change request [ext_ip == src_ip]", self.__stun_ip, req)
if first.ext is not None and first.ext.ip == src_ip: if first.ext is not None and first.ext.ip == src_ip:
if response.ok: if resp.ok:
return (StunNatType.OPEN_INTERNET, response) return (StunNatType.OPEN_INTERNET, resp)
return (StunNatType.SYMMETRIC_UDP_FW, response) return (StunNatType.SYMMETRIC_UDP_FW, resp)
if response.ok: if resp.ok:
return (StunNatType.FULL_CONE_NAT, response) return (StunNatType.FULL_CONE_NAT, resp)
if first.changed is None: if first.changed is None:
raise RuntimeError(f"Changed addr is None: {first}") raise RuntimeError(f"Changed addr is None: {first}")
response = await self.__make_request("Change request [ext_ip != src_ip]", addr=first.changed) resp = await self.__make_request("Change request [ext_ip != src_ip]", first.changed, b"")
if not response.ok: if not resp.ok:
return (StunNatType.CHANGED_ADDR_ERROR, response) return (StunNatType.CHANGED_ADDR_ERROR, resp)
if response.ext == first.ext: if resp.ext == first.ext:
request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002) req = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002)
response = await self.__make_request("Change port", request, addr=first.changed.ip) resp = await self.__make_request("Change port", first.changed.ip, req)
if response.ok: if resp.ok:
return (StunNatType.RESTRICTED_NAT, response) return (StunNatType.RESTRICTED_NAT, resp)
return (StunNatType.RESTRICTED_PORT_NAT, response) return (StunNatType.RESTRICTED_PORT_NAT, resp)
return (StunNatType.SYMMETRIC_NAT, response) return (StunNatType.SYMMETRIC_NAT, resp)
async def __make_request(self, ctx: str, request: bytes=b"", addr: (StunAddress | str | None)=None) -> StunResponse: async def __make_request(self, ctx: str, addr: (_StunAddress | str), req: bytes) -> _StunResponse:
# TODO: Support IPv6 and RFC 5389 # TODO: Support IPv6 and RFC 5389
# The first 4 bytes of the response are the Type (2) and Length (2) # The first 4 bytes of the response are the Type (2) and Length (2)
# The 5th byte is Reserved # The 5th byte is Reserved
@ -111,32 +158,29 @@ class Stun:
# More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1 # More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1
# And at: https://tools.ietf.org/html/rfc5389#section-15.1 # And at: https://tools.ietf.org/html/rfc5389#section-15.1
if isinstance(addr, StunAddress): if isinstance(addr, _StunAddress):
addr_t = (addr.ip, addr.port) addr_t = (addr.ip, addr.port)
elif isinstance(addr, str): else: # str
addr_t = (addr, self.port) addr_t = (addr, self.__port)
else:
assert addr is None
addr_t = (self.host, self.port)
# https://datatracker.ietf.org/doc/html/rfc5389#section-6 # https://datatracker.ietf.org/doc/html/rfc5389#section-6
trans_id = b"\x21\x12\xA4\x42" + secrets.token_bytes(12) trans_id = b"\x21\x12\xA4\x42" + secrets.token_bytes(12)
(response, error) = (b"", "") (resp, error) = (b"", "")
for _ in range(self.__retries): for _ in range(self.__retries):
(response, error) = await self.__inner_make_request(trans_id, request, addr_t) (resp, error) = await self.__inner_make_request(trans_id, req, addr_t)
if not error: if not error:
break break
await asyncio.sleep(self.__retries_delay) await asyncio.sleep(self.__retries_delay)
if error: if error:
get_logger(0).error("%s: Can't perform STUN request after %d retries; last error: %s", get_logger(0).error("%s: Can't perform STUN request after %d retries; last error: %s",
ctx, self.__retries, error) ctx, self.__retries, error)
return StunResponse(ok=False) return _StunResponse(ok=False)
parsed: dict[str, StunAddress] = {} parsed: dict[str, _StunAddress] = {}
offset = 0 offset = 0
remaining = len(response) remaining = len(resp)
while remaining > 0: while remaining > 0:
(attr_type, attr_len) = struct.unpack(">HH", response[offset : offset + 4]) # noqa: E203 (attr_type, attr_len) = struct.unpack(">HH", resp[offset : offset + 4]) # noqa: E203
offset += 4 offset += 4
field = { field = {
0x0001: "ext", # MAPPED-ADDRESS 0x0001: "ext", # MAPPED-ADDRESS
@ -145,40 +189,40 @@ class Stun:
0x0005: "changed", # CHANGED-ADDRESS 0x0005: "changed", # CHANGED-ADDRESS
}.get(attr_type) }.get(attr_type)
if field is not None: if field is not None:
parsed[field] = self.__parse_address(response[offset:], (trans_id if attr_type == 0x0020 else b"")) parsed[field] = self.__parse_address(resp[offset:], (trans_id if attr_type == 0x0020 else b""))
offset += attr_len offset += attr_len
remaining -= (4 + attr_len) remaining -= (4 + attr_len)
return StunResponse(ok=True, **parsed) return _StunResponse(ok=True, **parsed)
async def __inner_make_request(self, trans_id: bytes, request: bytes, addr: tuple[str, int]) -> tuple[bytes, str]: async def __inner_make_request(self, trans_id: bytes, req: bytes, addr: tuple[str, int]) -> tuple[bytes, str]:
assert self.__sock is not None assert self.__sock is not None
request = struct.pack(">HH", 0x0001, len(request)) + trans_id + request # Bind Request req = struct.pack(">HH", 0x0001, len(req)) + trans_id + req # Bind Request
try: try:
await aiotools.run_async(self.__sock.sendto, request, addr) await aiotools.run_async(self.__sock.sendto, req, addr)
except Exception as err: except Exception as ex:
return (b"", f"Send error: {tools.efmt(err)}") return (b"", f"Send error: {tools.efmt(ex)}")
try: try:
response = (await aiotools.run_async(self.__sock.recvfrom, 2048))[0] resp = (await aiotools.run_async(self.__sock.recvfrom, 2048))[0]
except Exception as err: except Exception as ex:
return (b"", f"Recv error: {tools.efmt(err)}") return (b"", f"Recv error: {tools.efmt(ex)}")
(response_type, payload_len) = struct.unpack(">HH", response[:4]) (resp_type, payload_len) = struct.unpack(">HH", resp[:4])
if response_type != 0x0101: if resp_type != 0x0101:
return (b"", f"Invalid response type: {response_type:#06x}") return (b"", f"Invalid response type: {resp_type:#06x}")
if trans_id != response[4:20]: if trans_id != resp[4:20]:
return (b"", "Transaction ID mismatch") return (b"", "Transaction ID mismatch")
return (response[20 : 20 + payload_len], "") # noqa: E203 return (resp[20 : 20 + payload_len], "") # noqa: E203
def __parse_address(self, data: bytes, trans_id: bytes) -> StunAddress: def __parse_address(self, data: bytes, trans_id: bytes) -> _StunAddress:
family = data[1] family = data[1]
port = struct.unpack(">H", self.__trans_xor(data[2:4], trans_id))[0] port = struct.unpack(">H", self.__trans_xor(data[2:4], trans_id))[0]
if family == 0x01: if family == 0x01:
return StunAddress(str(ipaddress.IPv4Address(self.__trans_xor(data[4:8], trans_id))), port) return _StunAddress(str(ipaddress.IPv4Address(self.__trans_xor(data[4:8], trans_id))), port)
elif family == 0x02: elif family == 0x02:
return StunAddress(str(ipaddress.IPv6Address(self.__trans_xor(data[4:20], trans_id))), port) return _StunAddress(str(ipaddress.IPv6Address(self.__trans_xor(data[4:20], trans_id))), port)
raise RuntimeError(f"Unknown family; received: {family}") raise RuntimeError(f"Unknown family; received: {family}")
def __trans_xor(self, data: bytes, trans_id: bytes) -> bytes: def __trans_xor(self, data: bytes, trans_id: bytes) -> bytes:

View File

@ -56,7 +56,7 @@ def main(argv: (list[str] | None)=None) -> None:
if config.kvmd.msd.type == "otg": if config.kvmd.msd.type == "otg":
msd_kwargs["gadget"] = config.otg.gadget # XXX: Small crutch to pass gadget name to the plugin msd_kwargs["gadget"] = config.otg.gadget # XXX: Small crutch to pass gadget name to the plugin
hid_kwargs = config.kvmd.hid._unpack(ignore=["type", "keymap", "ignore_keys", "mouse_x_range", "mouse_y_range"]) hid_kwargs = config.kvmd.hid._unpack(ignore=["type", "keymap"])
if config.kvmd.hid.type == "otg": if config.kvmd.hid.type == "otg":
hid_kwargs["udc"] = config.otg.udc # XXX: Small crutch to pass UDC to the plugin hid_kwargs["udc"] = config.otg.udc # XXX: Small crutch to pass UDC to the plugin
@ -103,9 +103,6 @@ def main(argv: (list[str] | None)=None) -> None:
), ),
keymap_path=config.hid.keymap, keymap_path=config.hid.keymap,
ignore_keys=config.hid.ignore_keys,
mouse_x_range=(config.hid.mouse_x_range.min, config.hid.mouse_x_range.max),
mouse_y_range=(config.hid.mouse_y_range.min, config.hid.mouse_y_range.max),
stream_forever=config.streamer.forever, stream_forever=config.streamer.forever,
).run(**config.server._unpack()) ).run(**config.server._unpack())

View File

@ -45,9 +45,9 @@ class AtxApi:
return make_json_response(await self.__atx.get_state()) return make_json_response(await self.__atx.get_state())
@exposed_http("POST", "/atx/power") @exposed_http("POST", "/atx/power")
async def __power_handler(self, request: Request) -> Response: async def __power_handler(self, req: Request) -> Response:
action = valid_atx_power_action(request.query.get("action")) action = valid_atx_power_action(req.query.get("action"))
wait = valid_bool(request.query.get("wait", False)) wait = valid_bool(req.query.get("wait", False))
await ({ await ({
"on": self.__atx.power_on, "on": self.__atx.power_on,
"off": self.__atx.power_off, "off": self.__atx.power_off,
@ -57,9 +57,9 @@ class AtxApi:
return make_json_response() return make_json_response()
@exposed_http("POST", "/atx/click") @exposed_http("POST", "/atx/click")
async def __click_handler(self, request: Request) -> Response: async def __click_handler(self, req: Request) -> Response:
button = valid_atx_button(request.query.get("button")) button = valid_atx_button(req.query.get("button"))
wait = valid_bool(request.query.get("wait", False)) wait = valid_bool(req.query.get("wait", False))
await ({ await ({
"power": self.__atx.click_power, "power": self.__atx.click_power,
"power_long": self.__atx.click_power_long, "power_long": self.__atx.click_power_long,

View File

@ -43,34 +43,34 @@ from ..auth import AuthManager
_COOKIE_AUTH_TOKEN = "auth_token" _COOKIE_AUTH_TOKEN = "auth_token"
async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, request: Request) -> None: async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, req: Request) -> None:
if auth_manager.is_auth_required(exposed): if auth_manager.is_auth_required(exposed):
user = request.headers.get("X-KVMD-User", "") user = req.headers.get("X-KVMD-User", "")
if user: if user:
user = valid_user(user) user = valid_user(user)
passwd = request.headers.get("X-KVMD-Passwd", "") passwd = req.headers.get("X-KVMD-Passwd", "")
set_request_auth_info(request, f"{user} (xhdr)") set_request_auth_info(req, f"{user} (xhdr)")
if not (await auth_manager.authorize(user, valid_passwd(passwd))): if not (await auth_manager.authorize(user, valid_passwd(passwd))):
raise ForbiddenError() raise ForbiddenError()
return return
token = request.cookies.get(_COOKIE_AUTH_TOKEN, "") token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
if token: if token:
user = auth_manager.check(valid_auth_token(token)) # type: ignore user = auth_manager.check(valid_auth_token(token)) # type: ignore
if not user: if not user:
set_request_auth_info(request, "- (token)") set_request_auth_info(req, "- (token)")
raise ForbiddenError() raise ForbiddenError()
set_request_auth_info(request, f"{user} (token)") set_request_auth_info(req, f"{user} (token)")
return return
basic_auth = request.headers.get("Authorization", "") basic_auth = req.headers.get("Authorization", "")
if basic_auth and basic_auth[:6].lower() == "basic ": if basic_auth and basic_auth[:6].lower() == "basic ":
try: try:
(user, passwd) = base64.b64decode(basic_auth[6:]).decode("utf-8").split(":") (user, passwd) = base64.b64decode(basic_auth[6:]).decode("utf-8").split(":")
except Exception: except Exception:
raise UnauthorizedError() raise UnauthorizedError()
user = valid_user(user) user = valid_user(user)
set_request_auth_info(request, f"{user} (basic)") set_request_auth_info(req, f"{user} (basic)")
if not (await auth_manager.authorize(user, valid_passwd(passwd))): if not (await auth_manager.authorize(user, valid_passwd(passwd))):
raise ForbiddenError() raise ForbiddenError()
return return
@ -85,9 +85,9 @@ class AuthApi:
# ===== # =====
@exposed_http("POST", "/auth/login", auth_required=False) @exposed_http("POST", "/auth/login", auth_required=False)
async def __login_handler(self, request: Request) -> Response: async def __login_handler(self, req: Request) -> Response:
if self.__auth_manager.is_auth_enabled(): if self.__auth_manager.is_auth_enabled():
credentials = await request.post() credentials = await req.post()
token = await self.__auth_manager.login( token = await self.__auth_manager.login(
user=valid_user(credentials.get("user", "")), user=valid_user(credentials.get("user", "")),
passwd=valid_passwd(credentials.get("passwd", "")), passwd=valid_passwd(credentials.get("passwd", "")),
@ -98,9 +98,9 @@ class AuthApi:
return make_json_response() return make_json_response()
@exposed_http("POST", "/auth/logout") @exposed_http("POST", "/auth/logout")
async def __logout_handler(self, request: Request) -> Response: async def __logout_handler(self, req: Request) -> Response:
if self.__auth_manager.is_auth_enabled(): if self.__auth_manager.is_auth_enabled():
token = valid_auth_token(request.cookies.get(_COOKIE_AUTH_TOKEN, "")) token = valid_auth_token(req.cookies.get(_COOKIE_AUTH_TOKEN, ""))
self.__auth_manager.logout(token) self.__auth_manager.logout(token)
return make_json_response() return make_json_response()

View File

@ -55,10 +55,9 @@ class ExportApi:
@async_lru.alru_cache(maxsize=1, ttl=5) @async_lru.alru_cache(maxsize=1, ttl=5)
async def __get_prometheus_metrics(self) -> str: async def __get_prometheus_metrics(self) -> str:
(atx_state, hw_state, fan_state, gpio_state) = await asyncio.gather(*[ (atx_state, info_state, gpio_state) = await asyncio.gather(*[
self.__atx.get_state(), self.__atx.get_state(),
self.__info_manager.get_submanager("hw").get_state(), self.__info_manager.get_state(["hw", "fan"]),
self.__info_manager.get_submanager("fan").get_state(),
self.__user_gpio.get_state(), self.__user_gpio.get_state(),
]) ])
rows: list[str] = [] rows: list[str] = []
@ -72,8 +71,8 @@ class ExportApi:
for key in ["online", "state"]: for key in ["online", "state"]:
self.__append_prometheus_rows(rows, ch_state["state"], f"pikvm_gpio_{mode}_{key}_{channel}") self.__append_prometheus_rows(rows, ch_state["state"], f"pikvm_gpio_{mode}_{key}_{channel}")
self.__append_prometheus_rows(rows, hw_state["health"], "pikvm_hw") # type: ignore self.__append_prometheus_rows(rows, info_state["hw"]["health"], "pikvm_hw") # type: ignore
self.__append_prometheus_rows(rows, fan_state, "pikvm_fan") self.__append_prometheus_rows(rows, info_state["fan"], "pikvm_fan")
return "\n".join(rows) return "\n".join(rows)

View File

@ -25,13 +25,12 @@ import stat
import functools import functools
import struct import struct
from typing import Iterable
from typing import Callable from typing import Callable
from aiohttp.web import Request from aiohttp.web import Request
from aiohttp.web import Response from aiohttp.web import Response
from ....mouse import MouseRange
from ....keyboard.keysym import build_symmap from ....keyboard.keysym import build_symmap
from ....keyboard.printer import text_to_web_keys from ....keyboard.printer import text_to_web_keys
@ -59,12 +58,7 @@ class HidApi:
def __init__( def __init__(
self, self,
hid: BaseHid, hid: BaseHid,
keymap_path: str, keymap_path: str,
ignore_keys: list[str],
mouse_x_range: tuple[int, int],
mouse_y_range: tuple[int, int],
) -> None: ) -> None:
self.__hid = hid self.__hid = hid
@ -73,11 +67,6 @@ class HidApi:
self.__default_keymap_name = os.path.basename(keymap_path) self.__default_keymap_name = os.path.basename(keymap_path)
self.__ensure_symmap(self.__default_keymap_name) self.__ensure_symmap(self.__default_keymap_name)
self.__ignore_keys = ignore_keys
self.__mouse_x_range = mouse_x_range
self.__mouse_y_range = mouse_y_range
# ===== # =====
@exposed_http("GET", "/hid") @exposed_http("GET", "/hid")
@ -85,22 +74,22 @@ class HidApi:
return make_json_response(await self.__hid.get_state()) return make_json_response(await self.__hid.get_state())
@exposed_http("POST", "/hid/set_params") @exposed_http("POST", "/hid/set_params")
async def __set_params_handler(self, request: Request) -> Response: async def __set_params_handler(self, req: Request) -> Response:
params = { params = {
key: validator(request.query.get(key)) key: validator(req.query.get(key))
for (key, validator) in [ for (key, validator) in [
("keyboard_output", valid_hid_keyboard_output), ("keyboard_output", valid_hid_keyboard_output),
("mouse_output", valid_hid_mouse_output), ("mouse_output", valid_hid_mouse_output),
("jiggler", valid_bool), ("jiggler", valid_bool),
] ]
if request.query.get(key) is not None if req.query.get(key) is not None
} }
self.__hid.set_params(**params) # type: ignore self.__hid.set_params(**params) # type: ignore
return make_json_response() return make_json_response()
@exposed_http("POST", "/hid/set_connected") @exposed_http("POST", "/hid/set_connected")
async def __set_connected_handler(self, request: Request) -> Response: async def __set_connected_handler(self, req: Request) -> Response:
self.__hid.set_connected(valid_bool(request.query.get("connected"))) self.__hid.set_connected(valid_bool(req.query.get("connected")))
return make_json_response() return make_json_response()
@exposed_http("POST", "/hid/reset") @exposed_http("POST", "/hid/reset")
@ -128,13 +117,13 @@ class HidApi:
return make_json_response(await self.get_keymaps()) return make_json_response(await self.get_keymaps())
@exposed_http("POST", "/hid/print") @exposed_http("POST", "/hid/print")
async def __print_handler(self, request: Request) -> Response: async def __print_handler(self, req: Request) -> Response:
text = await request.text() text = await req.text()
limit = int(valid_int_f0(request.query.get("limit", 1024))) limit = int(valid_int_f0(req.query.get("limit", 1024)))
if limit > 0: if limit > 0:
text = text[:limit] text = text[:limit]
symmap = self.__ensure_symmap(request.query.get("keymap", self.__default_keymap_name)) symmap = self.__ensure_symmap(req.query.get("keymap", self.__default_keymap_name))
self.__hid.send_key_events(text_to_web_keys(text, symmap)) self.__hid.send_key_events(text_to_web_keys(text, symmap), no_ignore_keys=True)
return make_json_response() return make_json_response()
def __ensure_symmap(self, keymap_name: str) -> dict[int, dict[int, str]]: def __ensure_symmap(self, keymap_name: str) -> dict[int, dict[int, str]]:
@ -162,8 +151,7 @@ class HidApi:
state = valid_bool(data[0]) state = valid_bool(data[0])
except Exception: except Exception:
return return
if key not in self.__ignore_keys: self.__hid.send_key_event(key, state)
self.__hid.send_key_events([(key, state)])
@exposed_ws(2) @exposed_ws(2)
async def __ws_bin_mouse_button_handler(self, _: WsSession, data: bytes) -> None: async def __ws_bin_mouse_button_handler(self, _: WsSession, data: bytes) -> None:
@ -182,17 +170,17 @@ class HidApi:
to_y = valid_hid_mouse_move(to_y) to_y = valid_hid_mouse_move(to_y)
except Exception: except Exception:
return return
self.__send_mouse_move_event(to_x, to_y) self.__hid.send_mouse_move_event(to_x, to_y)
@exposed_ws(4) @exposed_ws(4)
async def __ws_bin_mouse_relative_handler(self, _: WsSession, data: bytes) -> None: async def __ws_bin_mouse_relative_handler(self, _: WsSession, data: bytes) -> None:
self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_relative_event) self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_relative_events)
@exposed_ws(5) @exposed_ws(5)
async def __ws_bin_mouse_wheel_handler(self, _: WsSession, data: bytes) -> None: async def __ws_bin_mouse_wheel_handler(self, _: WsSession, data: bytes) -> None:
self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_wheel_event) self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_wheel_events)
def __process_ws_bin_delta_request(self, data: bytes, handler: Callable[[int, int], None]) -> None: def __process_ws_bin_delta_request(self, data: bytes, handler: Callable[[Iterable[tuple[int, int]], bool], None]) -> None:
try: try:
squash = valid_bool(data[0]) squash = valid_bool(data[0])
data = data[1:] data = data[1:]
@ -202,7 +190,7 @@ class HidApi:
deltas.append((valid_hid_mouse_delta(delta_x), valid_hid_mouse_delta(delta_y))) deltas.append((valid_hid_mouse_delta(delta_x), valid_hid_mouse_delta(delta_y)))
except Exception: except Exception:
return return
self.__send_mouse_delta_event(deltas, squash, handler) handler(deltas, squash)
# ===== # =====
@ -213,8 +201,7 @@ class HidApi:
state = valid_bool(event["state"]) state = valid_bool(event["state"])
except Exception: except Exception:
return return
if key not in self.__ignore_keys: self.__hid.send_key_event(key, state)
self.__hid.send_key_events([(key, state)])
@exposed_ws("mouse_button") @exposed_ws("mouse_button")
async def __ws_mouse_button_handler(self, _: WsSession, event: dict) -> None: async def __ws_mouse_button_handler(self, _: WsSession, event: dict) -> None:
@ -232,17 +219,17 @@ class HidApi:
to_y = valid_hid_mouse_move(event["to"]["y"]) to_y = valid_hid_mouse_move(event["to"]["y"])
except Exception: except Exception:
return return
self.__send_mouse_move_event(to_x, to_y) self.__hid.send_mouse_move_event(to_x, to_y)
@exposed_ws("mouse_relative") @exposed_ws("mouse_relative")
async def __ws_mouse_relative_handler(self, _: WsSession, event: dict) -> None: async def __ws_mouse_relative_handler(self, _: WsSession, event: dict) -> None:
self.__process_ws_delta_event(event, self.__hid.send_mouse_relative_event) self.__process_ws_delta_event(event, self.__hid.send_mouse_relative_events)
@exposed_ws("mouse_wheel") @exposed_ws("mouse_wheel")
async def __ws_mouse_wheel_handler(self, _: WsSession, event: dict) -> None: async def __ws_mouse_wheel_handler(self, _: WsSession, event: dict) -> None:
self.__process_ws_delta_event(event, self.__hid.send_mouse_wheel_event) self.__process_ws_delta_event(event, self.__hid.send_mouse_wheel_events)
def __process_ws_delta_event(self, event: dict, handler: Callable[[int, int], None]) -> None: def __process_ws_delta_event(self, event: dict, handler: Callable[[Iterable[tuple[int, int]], bool], None]) -> None:
try: try:
raw_delta = event["delta"] raw_delta = event["delta"]
deltas = [ deltas = [
@ -252,26 +239,25 @@ class HidApi:
squash = valid_bool(event.get("squash", False)) squash = valid_bool(event.get("squash", False))
except Exception: except Exception:
return return
self.__send_mouse_delta_event(deltas, squash, handler) handler(deltas, squash)
# ===== # =====
@exposed_http("POST", "/hid/events/send_key") @exposed_http("POST", "/hid/events/send_key")
async def __events_send_key_handler(self, request: Request) -> Response: async def __events_send_key_handler(self, req: Request) -> Response:
key = valid_hid_key(request.query.get("key")) key = valid_hid_key(req.query.get("key"))
if key not in self.__ignore_keys: if "state" in req.query:
if "state" in request.query: state = valid_bool(req.query["state"])
state = valid_bool(request.query["state"]) self.__hid.send_key_event(key, state)
self.__hid.send_key_events([(key, state)])
else: else:
self.__hid.send_key_events([(key, True), (key, False)]) self.__hid.send_key_events([(key, True), (key, False)])
return make_json_response() return make_json_response()
@exposed_http("POST", "/hid/events/send_mouse_button") @exposed_http("POST", "/hid/events/send_mouse_button")
async def __events_send_mouse_button_handler(self, request: Request) -> Response: async def __events_send_mouse_button_handler(self, req: Request) -> Response:
button = valid_hid_mouse_button(request.query.get("button")) button = valid_hid_mouse_button(req.query.get("button"))
if "state" in request.query: if "state" in req.query:
state = valid_bool(request.query["state"]) state = valid_bool(req.query["state"])
self.__hid.send_mouse_button_event(button, state) self.__hid.send_mouse_button_event(button, state)
else: else:
self.__hid.send_mouse_button_event(button, True) self.__hid.send_mouse_button_event(button, True)
@ -279,52 +265,22 @@ class HidApi:
return make_json_response() return make_json_response()
@exposed_http("POST", "/hid/events/send_mouse_move") @exposed_http("POST", "/hid/events/send_mouse_move")
async def __events_send_mouse_move_handler(self, request: Request) -> Response: async def __events_send_mouse_move_handler(self, req: Request) -> Response:
to_x = valid_hid_mouse_move(request.query.get("to_x")) to_x = valid_hid_mouse_move(req.query.get("to_x"))
to_y = valid_hid_mouse_move(request.query.get("to_y")) to_y = valid_hid_mouse_move(req.query.get("to_y"))
self.__send_mouse_move_event(to_x, to_y) self.__hid.send_mouse_move_event(to_x, to_y)
return make_json_response() return make_json_response()
@exposed_http("POST", "/hid/events/send_mouse_relative") @exposed_http("POST", "/hid/events/send_mouse_relative")
async def __events_send_mouse_relative_handler(self, request: Request) -> Response: async def __events_send_mouse_relative_handler(self, req: Request) -> Response:
return self.__process_http_delta_event(request, self.__hid.send_mouse_relative_event) return self.__process_http_delta_event(req, self.__hid.send_mouse_relative_event)
@exposed_http("POST", "/hid/events/send_mouse_wheel") @exposed_http("POST", "/hid/events/send_mouse_wheel")
async def __events_send_mouse_wheel_handler(self, request: Request) -> Response: async def __events_send_mouse_wheel_handler(self, req: Request) -> Response:
return self.__process_http_delta_event(request, self.__hid.send_mouse_wheel_event) return self.__process_http_delta_event(req, self.__hid.send_mouse_wheel_event)
def __process_http_delta_event(self, request: Request, handler: Callable[[int, int], None]) -> Response: def __process_http_delta_event(self, req: Request, handler: Callable[[int, int], None]) -> Response:
delta_x = valid_hid_mouse_delta(request.query.get("delta_x")) delta_x = valid_hid_mouse_delta(req.query.get("delta_x"))
delta_y = valid_hid_mouse_delta(request.query.get("delta_y")) delta_y = valid_hid_mouse_delta(req.query.get("delta_y"))
handler(delta_x, delta_y) handler(delta_x, delta_y)
return make_json_response() return make_json_response()
# =====
def __send_mouse_move_event(self, to_x: int, to_y: int) -> None:
if self.__mouse_x_range != MouseRange.RANGE:
to_x = MouseRange.remap(to_x, *self.__mouse_x_range)
if self.__mouse_y_range != MouseRange.RANGE:
to_y = MouseRange.remap(to_y, *self.__mouse_y_range)
self.__hid.send_mouse_move_event(to_x, to_y)
def __send_mouse_delta_event(
self,
deltas: list[tuple[int, int]],
squash: bool,
handler: Callable[[int, int], None],
) -> None:
if squash:
prev = (0, 0)
for cur in deltas:
if abs(prev[0] + cur[0]) > 127 or abs(prev[1] + cur[1]) > 127:
handler(*prev)
prev = cur
else:
prev = (prev[0] + cur[0], prev[1] + cur[1])
if prev[0] or prev[1]:
handler(*prev)
else:
for xy in deltas:
handler(*xy)

View File

@ -20,8 +20,6 @@
# ========================================================================== # # ========================================================================== #
import asyncio
from aiohttp.web import Request from aiohttp.web import Request
from aiohttp.web import Response from aiohttp.web import Response
@ -41,17 +39,13 @@ class InfoApi:
# ===== # =====
@exposed_http("GET", "/info") @exposed_http("GET", "/info")
async def __common_state_handler(self, request: Request) -> Response: async def __common_state_handler(self, req: Request) -> Response:
fields = self.__valid_info_fields(request) fields = self.__valid_info_fields(req)
results = dict(zip(fields, await asyncio.gather(*[ return make_json_response(await self.__info_manager.get_state(fields))
self.__info_manager.get_submanager(field).get_state()
for field in fields
])))
return make_json_response(results)
def __valid_info_fields(self, request: Request) -> list[str]: def __valid_info_fields(self, req: Request) -> list[str]:
subs = self.__info_manager.get_subs() available = self.__info_manager.get_subs()
return sorted(valid_info_fields( return sorted(valid_info_fields(
arg=request.query.get("fields", ",".join(subs)), arg=req.query.get("fields", ",".join(available)),
variants=subs, variants=available,
) or subs) ) or available)

View File

@ -47,12 +47,12 @@ class LogApi:
# ===== # =====
@exposed_http("GET", "/log") @exposed_http("GET", "/log")
async def __log_handler(self, request: Request) -> StreamResponse: async def __log_handler(self, req: Request) -> StreamResponse:
if self.__log_reader is None: if self.__log_reader is None:
raise LogReaderDisabledError() raise LogReaderDisabledError()
seek = valid_log_seek(request.query.get("seek", 0)) seek = valid_log_seek(req.query.get("seek", 0))
follow = valid_bool(request.query.get("follow", False)) follow = valid_bool(req.query.get("follow", False))
response = await start_streaming(request, "text/plain") response = await start_streaming(req, "text/plain")
try: try:
async for record in self.__log_reader.poll_log(seek, follow): async for record in self.__log_reader.poll_log(seek, follow):
await response.write(("[%s %s] --- %s" % ( await response.write(("[%s %s] --- %s" % (

View File

@ -63,32 +63,36 @@ class MsdApi:
@exposed_http("GET", "/msd") @exposed_http("GET", "/msd")
async def __state_handler(self, _: Request) -> Response: async def __state_handler(self, _: Request) -> Response:
return make_json_response(await self.__msd.get_state()) state = await self.__msd.get_state()
if state["storage"] and state["storage"]["parts"]:
state["storage"]["size"] = state["storage"]["parts"][""]["size"] # Legacy API
state["storage"]["free"] = state["storage"]["parts"][""]["free"] # Legacy API
return make_json_response(state)
@exposed_http("POST", "/msd/set_params") @exposed_http("POST", "/msd/set_params")
async def __set_params_handler(self, request: Request) -> Response: async def __set_params_handler(self, req: Request) -> Response:
params = { params = {
key: validator(request.query.get(param)) key: validator(req.query.get(param))
for (param, key, validator) in [ for (param, key, validator) in [
("image", "name", (lambda arg: str(arg).strip() and valid_msd_image_name(arg))), ("image", "name", (lambda arg: str(arg).strip() and valid_msd_image_name(arg))),
("cdrom", "cdrom", valid_bool), ("cdrom", "cdrom", valid_bool),
("rw", "rw", valid_bool), ("rw", "rw", valid_bool),
] ]
if request.query.get(param) is not None if req.query.get(param) is not None
} }
await self.__msd.set_params(**params) # type: ignore await self.__msd.set_params(**params) # type: ignore
return make_json_response() return make_json_response()
@exposed_http("POST", "/msd/set_connected") @exposed_http("POST", "/msd/set_connected")
async def __set_connected_handler(self, request: Request) -> Response: async def __set_connected_handler(self, req: Request) -> Response:
await self.__msd.set_connected(valid_bool(request.query.get("connected"))) await self.__msd.set_connected(valid_bool(req.query.get("connected")))
return make_json_response() return make_json_response()
# ===== # =====
@exposed_http("GET", "/msd/read") @exposed_http("GET", "/msd/read")
async def __read_handler(self, request: Request) -> StreamResponse: async def __read_handler(self, req: Request) -> StreamResponse:
name = valid_msd_image_name(request.query.get("image")) name = valid_msd_image_name(req.query.get("image"))
compressors = { compressors = {
"": ("", None), "": ("", None),
"none": ("", None), "none": ("", None),
@ -96,7 +100,7 @@ class MsdApi:
"zstd": (".zst", (lambda: zstandard.ZstdCompressor().compressobj())), # pylint: disable=unnecessary-lambda "zstd": (".zst", (lambda: zstandard.ZstdCompressor().compressobj())), # pylint: disable=unnecessary-lambda
} }
(suffix, make_compressor) = compressors[check_string_in_list( (suffix, make_compressor) = compressors[check_string_in_list(
arg=request.query.get("compress", ""), arg=req.query.get("compress", ""),
name="Compression mode", name="Compression mode",
variants=set(compressors), variants=set(compressors),
)] )]
@ -127,7 +131,7 @@ class MsdApi:
src = compressed() src = compressed()
size = -1 size = -1
response = await start_streaming(request, "application/octet-stream", size, name + suffix) response = await start_streaming(req, "application/octet-stream", size, name + suffix)
async for chunk in src: async for chunk in src:
await response.write(chunk) await response.write(chunk)
return response return response
@ -135,28 +139,28 @@ class MsdApi:
# ===== # =====
@exposed_http("POST", "/msd/write") @exposed_http("POST", "/msd/write")
async def __write_handler(self, request: Request) -> Response: async def __write_handler(self, req: Request) -> Response:
unsafe_prefix = request.query.get("prefix", "") + "/" unsafe_prefix = req.query.get("prefix", "") + "/"
name = valid_msd_image_name(unsafe_prefix + request.query.get("image", "")) name = valid_msd_image_name(unsafe_prefix + req.query.get("image", ""))
size = valid_int_f0(request.content_length) size = valid_int_f0(req.content_length)
remove_incomplete = self.__get_remove_incomplete(request) remove_incomplete = self.__get_remove_incomplete(req)
written = 0 written = 0
async with self.__msd.write_image(name, size, remove_incomplete) as writer: async with self.__msd.write_image(name, size, remove_incomplete) as writer:
chunk_size = writer.get_chunk_size() chunk_size = writer.get_chunk_size()
while True: while True:
chunk = await request.content.read(chunk_size) chunk = await req.content.read(chunk_size)
if not chunk: if not chunk:
break break
written = await writer.write_chunk(chunk) written = await writer.write_chunk(chunk)
return make_json_response(self.__make_write_info(name, size, written)) return make_json_response(self.__make_write_info(name, size, written))
@exposed_http("POST", "/msd/write_remote") @exposed_http("POST", "/msd/write_remote")
async def __write_remote_handler(self, request: Request) -> (Response | StreamResponse): # pylint: disable=too-many-locals async def __write_remote_handler(self, req: Request) -> (Response | StreamResponse): # pylint: disable=too-many-locals
unsafe_prefix = request.query.get("prefix", "") + "/" unsafe_prefix = req.query.get("prefix", "") + "/"
url = valid_url(request.query.get("url")) url = valid_url(req.query.get("url"))
insecure = valid_bool(request.query.get("insecure", False)) insecure = valid_bool(req.query.get("insecure", False))
timeout = valid_float_f01(request.query.get("timeout", 10.0)) timeout = valid_float_f01(req.query.get("timeout", 10.0))
remove_incomplete = self.__get_remove_incomplete(request) remove_incomplete = self.__get_remove_incomplete(req)
name = "" name = ""
size = written = 0 size = written = 0
@ -174,7 +178,7 @@ class MsdApi:
read_timeout=(7 * 24 * 3600), read_timeout=(7 * 24 * 3600),
) as remote: ) as remote:
name = str(request.query.get("image", "")).strip() name = str(req.query.get("image", "")).strip()
if len(name) == 0: if len(name) == 0:
name = htclient.get_filename(remote) name = htclient.get_filename(remote)
name = valid_msd_image_name(unsafe_prefix + name) name = valid_msd_image_name(unsafe_prefix + name)
@ -184,7 +188,7 @@ class MsdApi:
get_logger(0).info("Downloading image %r as %r to MSD ...", url, name) get_logger(0).info("Downloading image %r as %r to MSD ...", url, name)
async with self.__msd.write_image(name, size, remove_incomplete) as writer: async with self.__msd.write_image(name, size, remove_incomplete) as writer:
chunk_size = writer.get_chunk_size() chunk_size = writer.get_chunk_size()
response = await start_streaming(request, "application/x-ndjson") response = await start_streaming(req, "application/x-ndjson")
await stream_write_info() await stream_write_info()
last_report_ts = 0 last_report_ts = 0
async for chunk in remote.content.iter_chunked(chunk_size): async for chunk in remote.content.iter_chunked(chunk_size):
@ -197,16 +201,16 @@ class MsdApi:
await stream_write_info() await stream_write_info()
return response return response
except Exception as err: except Exception as ex:
if response is not None: if response is not None:
await stream_write_info() await stream_write_info()
await stream_json_exception(response, err) await stream_json_exception(response, ex)
elif isinstance(err, aiohttp.ClientError): elif isinstance(ex, aiohttp.ClientError):
return make_json_exception(err, 400) return make_json_exception(ex, 400)
raise raise
def __get_remove_incomplete(self, request: Request) -> (bool | None): def __get_remove_incomplete(self, req: Request) -> (bool | None):
flag: (str | None) = request.query.get("remove_incomplete") flag: (str | None) = req.query.get("remove_incomplete")
return (valid_bool(flag) if flag is not None else None) return (valid_bool(flag) if flag is not None else None)
def __make_write_info(self, name: str, size: int, written: int) -> dict: def __make_write_info(self, name: str, size: int, written: int) -> dict:
@ -215,8 +219,8 @@ class MsdApi:
# ===== # =====
@exposed_http("POST", "/msd/remove") @exposed_http("POST", "/msd/remove")
async def __remove_handler(self, request: Request) -> Response: async def __remove_handler(self, req: Request) -> Response:
await self.__msd.remove(valid_msd_image_name(request.query.get("image"))) await self.__msd.remove(valid_msd_image_name(req.query.get("image")))
return make_json_response() return make_json_response()
@exposed_http("POST", "/msd/reset") @exposed_http("POST", "/msd/reset")

View File

@ -88,12 +88,12 @@ class RedfishApi:
@exposed_http("GET", "/redfish/v1/Systems/0") @exposed_http("GET", "/redfish/v1/Systems/0")
async def __server_handler(self, _: Request) -> Response: async def __server_handler(self, _: Request) -> Response:
(atx_state, meta_state) = await asyncio.gather(*[ (atx_state, info_state) = await asyncio.gather(*[
self.__atx.get_state(), self.__atx.get_state(),
self.__info_manager.get_submanager("meta").get_state(), self.__info_manager.get_state(["meta"]),
]) ])
try: try:
host = str(meta_state.get("server", {})["host"]) # type: ignore host = str(info_state["meta"].get("server", {})["host"]) # type: ignore
except Exception: except Exception:
host = "" host = ""
return make_json_response({ return make_json_response({
@ -111,10 +111,10 @@ class RedfishApi:
}, wrap_result=False) }, wrap_result=False)
@exposed_http("POST", "/redfish/v1/Systems/0/Actions/ComputerSystem.Reset") @exposed_http("POST", "/redfish/v1/Systems/0/Actions/ComputerSystem.Reset")
async def __power_handler(self, request: Request) -> Response: async def __power_handler(self, req: Request) -> Response:
try: try:
action = check_string_in_list( action = check_string_in_list(
arg=(await request.json())["ResetType"], arg=(await req.json()).get("ResetType"),
name="Redfish ResetType", name="Redfish ResetType",
variants=set(self.__actions), variants=set(self.__actions),
lower=False, lower=False,

View File

@ -52,36 +52,36 @@ class StreamerApi:
return make_json_response(await self.__streamer.get_state()) return make_json_response(await self.__streamer.get_state())
@exposed_http("GET", "/streamer/snapshot") @exposed_http("GET", "/streamer/snapshot")
async def __take_snapshot_handler(self, request: Request) -> Response: async def __take_snapshot_handler(self, req: Request) -> Response:
snapshot = await self.__streamer.take_snapshot( snapshot = await self.__streamer.take_snapshot(
save=valid_bool(request.query.get("save", False)), save=valid_bool(req.query.get("save", False)),
load=valid_bool(request.query.get("load", False)), load=valid_bool(req.query.get("load", False)),
allow_offline=valid_bool(request.query.get("allow_offline", False)), allow_offline=valid_bool(req.query.get("allow_offline", False)),
) )
if snapshot: if snapshot:
if valid_bool(request.query.get("ocr", False)): if valid_bool(req.query.get("ocr", False)):
langs = self.__ocr.get_available_langs() langs = self.__ocr.get_available_langs()
return Response( return Response(
body=(await self.__ocr.recognize( body=(await self.__ocr.recognize(
data=snapshot.data, data=snapshot.data,
langs=valid_string_list( langs=valid_string_list(
arg=str(request.query.get("ocr_langs", "")).strip(), arg=str(req.query.get("ocr_langs", "")).strip(),
subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)), subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)),
name="OCR langs list", name="OCR langs list",
), ),
left=int(valid_number(request.query.get("ocr_left", -1))), left=int(valid_number(req.query.get("ocr_left", -1))),
top=int(valid_number(request.query.get("ocr_top", -1))), top=int(valid_number(req.query.get("ocr_top", -1))),
right=int(valid_number(request.query.get("ocr_right", -1))), right=int(valid_number(req.query.get("ocr_right", -1))),
bottom=int(valid_number(request.query.get("ocr_bottom", -1))), bottom=int(valid_number(req.query.get("ocr_bottom", -1))),
)), )),
headers=dict(snapshot.headers), headers=dict(snapshot.headers),
content_type="text/plain", content_type="text/plain",
) )
elif valid_bool(request.query.get("preview", False)): elif valid_bool(req.query.get("preview", False)):
data = await snapshot.make_preview( data = await snapshot.make_preview(
max_width=valid_int_f0(request.query.get("preview_max_width", 0)), max_width=valid_int_f0(req.query.get("preview_max_width", 0)),
max_height=valid_int_f0(request.query.get("preview_max_height", 0)), max_height=valid_int_f0(req.query.get("preview_max_height", 0)),
quality=valid_stream_quality(request.query.get("preview_quality", 80)), quality=valid_stream_quality(req.query.get("preview_quality", 80)),
) )
else: else:
data = snapshot.data data = snapshot.data
@ -97,25 +97,6 @@ class StreamerApi:
self.__streamer.remove_snapshot() self.__streamer.remove_snapshot()
return make_json_response() return make_json_response()
# =====
async def get_ocr(self) -> dict: # XXX: Ugly hack
enabled = self.__ocr.is_available()
default: list[str] = []
available: list[str] = []
if enabled:
default = self.__ocr.get_default_langs()
available = self.__ocr.get_available_langs()
return {
"ocr": {
"enabled": enabled,
"langs": {
"default": default,
"available": available,
},
},
}
@exposed_http("GET", "/streamer/ocr") @exposed_http("GET", "/streamer/ocr")
async def __ocr_handler(self, _: Request) -> Response: async def __ocr_handler(self, _: Request) -> Response:
return make_json_response(await self.get_ocr()) return make_json_response({"ocr": (await self.__ocr.get_state())})

View File

@ -42,23 +42,20 @@ class UserGpioApi:
@exposed_http("GET", "/gpio") @exposed_http("GET", "/gpio")
async def __state_handler(self, _: Request) -> Response: async def __state_handler(self, _: Request) -> Response:
return make_json_response({ return make_json_response(await self.__user_gpio.get_state())
"model": (await self.__user_gpio.get_model()),
"state": (await self.__user_gpio.get_state()),
})
@exposed_http("POST", "/gpio/switch") @exposed_http("POST", "/gpio/switch")
async def __switch_handler(self, request: Request) -> Response: async def __switch_handler(self, req: Request) -> Response:
channel = valid_ugpio_channel(request.query.get("channel")) channel = valid_ugpio_channel(req.query.get("channel"))
state = valid_bool(request.query.get("state")) state = valid_bool(req.query.get("state"))
wait = valid_bool(request.query.get("wait", False)) wait = valid_bool(req.query.get("wait", False))
await self.__user_gpio.switch(channel, state, wait) await self.__user_gpio.switch(channel, state, wait)
return make_json_response() return make_json_response()
@exposed_http("POST", "/gpio/pulse") @exposed_http("POST", "/gpio/pulse")
async def __pulse_handler(self, request: Request) -> Response: async def __pulse_handler(self, req: Request) -> Response:
channel = valid_ugpio_channel(request.query.get("channel")) channel = valid_ugpio_channel(req.query.get("channel"))
delay = valid_float_f0(request.query.get("delay", 0.0)) delay = valid_float_f0(req.query.get("delay", 0.0))
wait = valid_bool(request.query.get("wait", False)) wait = valid_bool(req.query.get("wait", False))
await self.__user_gpio.pulse(channel, delay, wait) await self.__user_gpio.pulse(channel, delay, wait)
return make_json_response() return make_json_response()

View File

@ -20,6 +20,10 @@
# ========================================================================== # # ========================================================================== #
import asyncio
from typing import AsyncGenerator
from ....yamlconf import Section from ....yamlconf import Section
from .base import BaseInfoSubmanager from .base import BaseInfoSubmanager
@ -34,7 +38,7 @@ from .fan import FanInfoSubmanager
# ===== # =====
class InfoManager: class InfoManager:
def __init__(self, config: Section) -> None: def __init__(self, config: Section) -> None:
self.__subs = { self.__subs: dict[str, BaseInfoSubmanager] = {
"system": SystemInfoSubmanager(config.kvmd.streamer.cmd), "system": SystemInfoSubmanager(config.kvmd.streamer.cmd),
"auth": AuthInfoSubmanager(config.kvmd.auth.enabled), "auth": AuthInfoSubmanager(config.kvmd.auth.enabled),
"meta": MetaInfoSubmanager(config.kvmd.info.meta), "meta": MetaInfoSubmanager(config.kvmd.info.meta),
@ -42,9 +46,51 @@ class InfoManager:
"hw": HwInfoSubmanager(**config.kvmd.info.hw._unpack()), "hw": HwInfoSubmanager(**config.kvmd.info.hw._unpack()),
"fan": FanInfoSubmanager(**config.kvmd.info.fan._unpack()), "fan": FanInfoSubmanager(**config.kvmd.info.fan._unpack()),
} }
self.__queue: "asyncio.Queue[tuple[str, (dict | None)]]" = asyncio.Queue()
def get_subs(self) -> set[str]: def get_subs(self) -> set[str]:
return set(self.__subs) return set(self.__subs)
def get_submanager(self, name: str) -> BaseInfoSubmanager: async def get_state(self, fields: (list[str] | None)=None) -> dict:
return self.__subs[name] fields = (fields or list(self.__subs))
return dict(zip(fields, await asyncio.gather(*[
self.__subs[field].get_state()
for field in fields
])))
async def trigger_state(self) -> None:
await asyncio.gather(*[
sub.trigger_state()
for sub in self.__subs.values()
])
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - system -- Partial
# - auth -- Partial
# - meta -- Partial, nullable
# - extras -- Partial, nullable
# - hw -- Partial
# - fan -- Partial
# ===========================
while True:
(field, value) = await self.__queue.get()
yield {field: value}
async def systask(self) -> None:
tasks = [
asyncio.create_task(self.__poller(field))
for field in self.__subs
]
try:
await asyncio.gather(*tasks)
except Exception:
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
async def __poller(self, field: str) -> None:
async for state in self.__subs[field].poll_state():
self.__queue.put_nowait((field, state))

View File

@ -20,6 +20,10 @@
# ========================================================================== # # ========================================================================== #
from typing import AsyncGenerator
from .... import aiotools
from .base import BaseInfoSubmanager from .base import BaseInfoSubmanager
@ -27,6 +31,15 @@ from .base import BaseInfoSubmanager
class AuthInfoSubmanager(BaseInfoSubmanager): class AuthInfoSubmanager(BaseInfoSubmanager):
def __init__(self, enabled: bool) -> None: def __init__(self, enabled: bool) -> None:
self.__enabled = enabled self.__enabled = enabled
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict: async def get_state(self) -> dict:
return {"enabled": self.__enabled} return {"enabled": self.__enabled}
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())

View File

@ -20,7 +20,17 @@
# ========================================================================== # # ========================================================================== #
from typing import AsyncGenerator
# ===== # =====
class BaseInfoSubmanager: class BaseInfoSubmanager:
async def get_state(self) -> (dict | None): async def get_state(self) -> (dict | None):
raise NotImplementedError raise NotImplementedError
async def trigger_state(self) -> None:
raise NotImplementedError
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
yield None
raise NotImplementedError

View File

@ -24,6 +24,8 @@ import os
import re import re
import asyncio import asyncio
from typing import AsyncGenerator
from ....logging import get_logger from ....logging import get_logger
from ....yamlconf import Section from ....yamlconf import Section
@ -42,13 +44,15 @@ from .base import BaseInfoSubmanager
class ExtrasInfoSubmanager(BaseInfoSubmanager): class ExtrasInfoSubmanager(BaseInfoSubmanager):
def __init__(self, global_config: Section) -> None: def __init__(self, global_config: Section) -> None:
self.__global_config = global_config self.__global_config = global_config
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> (dict | None): async def get_state(self) -> (dict | None):
try: try:
sui = sysunit.SystemdUnitInfo() sui = sysunit.SystemdUnitInfo()
await sui.open() await sui.open()
except Exception as err: except Exception as ex:
get_logger(0).error("Can't open systemd bus to get extras state: %s", tools.efmt(err)) if not os.path.exists("/etc/kvmd/.docker_flag"):
get_logger(0).error("Can't open systemd bus to get extras state: %s", tools.efmt(ex))
sui = None sui = None
try: try:
extras: dict[str, dict] = {} extras: dict[str, dict] = {}
@ -66,6 +70,14 @@ class ExtrasInfoSubmanager(BaseInfoSubmanager):
if sui is not None: if sui is not None:
await aiotools.shield_fg(sui.close()) await aiotools.shield_fg(sui.close())
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())
def __get_extras_path(self, *parts: str) -> str: def __get_extras_path(self, *parts: str) -> str:
return os.path.join(self.__global_config.kvmd.info.extras, *parts) return os.path.join(self.__global_config.kvmd.info.extras, *parts)

View File

@ -21,7 +21,6 @@
import copy import copy
import asyncio
from typing import AsyncGenerator from typing import AsyncGenerator
@ -53,6 +52,8 @@ class FanInfoSubmanager(BaseInfoSubmanager):
self.__timeout = timeout self.__timeout = timeout
self.__state_poll = state_poll self.__state_poll = state_poll
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict: async def get_state(self) -> dict:
monitored = await self.__get_monitored() monitored = await self.__get_monitored()
return { return {
@ -60,24 +61,28 @@ class FanInfoSubmanager(BaseInfoSubmanager):
"state": ((await self.__get_fan_state() if monitored else None)), "state": ((await self.__get_fan_state() if monitored else None)),
} }
async def poll_state(self) -> AsyncGenerator[dict, None]: async def trigger_state(self) -> None:
prev_state: dict = {} self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
prev: dict = {}
while True: while True:
if self.__unix_path: if self.__unix_path:
pure = state = await self.get_state() if (await self.__notifier.wait(timeout=self.__state_poll)) > 0:
prev = {}
new = await self.get_state()
pure = copy.deepcopy(new)
if pure["state"] is not None: if pure["state"] is not None:
try: try:
pure = copy.deepcopy(state)
pure["state"]["service"]["now_ts"] = 0 pure["state"]["service"]["now_ts"] = 0
except Exception: except Exception:
pass pass
if pure != prev_state: if pure != prev:
yield state prev = pure
prev_state = pure yield new
await asyncio.sleep(self.__state_poll)
else: else:
await self.__notifier.wait()
yield (await self.get_state()) yield (await self.get_state())
await aiotools.wait_infinite()
# ===== # =====
@ -87,8 +92,8 @@ class FanInfoSubmanager(BaseInfoSubmanager):
async with sysunit.SystemdUnitInfo() as sui: async with sysunit.SystemdUnitInfo() as sui:
status = await sui.get_status(self.__daemon) status = await sui.get_status(self.__daemon)
return (status[0] or status[1]) return (status[0] or status[1])
except Exception as err: except Exception as ex:
get_logger(0).error("Can't get info about the service %r: %s", self.__daemon, tools.efmt(err)) get_logger(0).error("Can't get info about the service %r: %s", self.__daemon, tools.efmt(ex))
return False return False
async def __get_fan_state(self) -> (dict | None): async def __get_fan_state(self) -> (dict | None):
@ -97,8 +102,8 @@ class FanInfoSubmanager(BaseInfoSubmanager):
async with session.get("http://localhost/state") as response: async with session.get("http://localhost/state") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return (await response.json())["result"] return (await response.json())["result"]
except Exception as err: except Exception as ex:
get_logger(0).error("Can't read fan state: %s", err) get_logger(0).error("Can't read fan state: %s", ex)
return None return None
def __make_http_session(self) -> aiohttp.ClientSession: def __make_http_session(self) -> aiohttp.ClientSession:

View File

@ -22,6 +22,7 @@
import os import os
import asyncio import asyncio
import copy
from typing import Callable from typing import Callable
from typing import AsyncGenerator from typing import AsyncGenerator
@ -60,6 +61,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
self.__dt_cache: dict[str, str] = {} self.__dt_cache: dict[str, str] = {}
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict: async def get_state(self) -> dict:
( (
base, base,
@ -70,8 +73,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
cpu_temp, cpu_temp,
mem, mem,
) = await asyncio.gather( ) = await asyncio.gather(
self.__read_dt_file("model"), self.__read_dt_file("model", upper=False),
self.__read_dt_file("serial-number"), self.__read_dt_file("serial-number", upper=True),
self.__read_platform_file(), self.__read_platform_file(),
self.__get_throttling(), self.__get_throttling(),
self.__get_cpu_percent(), self.__get_cpu_percent(),
@ -97,18 +100,22 @@ class HwInfoSubmanager(BaseInfoSubmanager):
}, },
} }
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {} prev: dict = {}
while True: while True:
state = await self.get_state() if (await self.__notifier.wait(timeout=self.__state_poll)) > 0:
if state != prev_state: prev = {}
yield state new = await self.get_state()
prev_state = state if new != prev:
await asyncio.sleep(self.__state_poll) prev = copy.deepcopy(new)
yield new
# ===== # =====
async def __read_dt_file(self, name: str) -> (str | None): async def __read_dt_file(self, name: str, upper: bool) -> (str | None):
if name not in self.__dt_cache: if name not in self.__dt_cache:
path = os.path.join(f"{env.PROCFS_PREFIX}/proc/device-tree", name) path = os.path.join(f"{env.PROCFS_PREFIX}/proc/device-tree", name)
if not os.path.exists(path): if not os.path.exists(path):
@ -161,8 +168,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
+ system_all / total * 100 + system_all / total * 100
+ (st.steal + st.guest) / total * 100 + (st.steal + st.guest) / total * 100
) )
except Exception as err: except Exception as ex:
get_logger(0).error("Can't get CPU percent: %s", err) get_logger(0).error("Can't get CPU percent: %s", ex)
return None return None
async def __get_mem(self) -> dict: async def __get_mem(self) -> dict:
@ -173,8 +180,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
"total": st.total, "total": st.total,
"available": st.available, "available": st.available,
} }
except Exception as err: except Exception as ex:
get_logger(0).error("Can't get memory info: %s", err) get_logger(0).error("Can't get memory info: %s", ex)
return { return {
"percent": None, "percent": None,
"total": None, "total": None,
@ -217,6 +224,6 @@ class HwInfoSubmanager(BaseInfoSubmanager):
return None return None
try: try:
return parser(text) return parser(text)
except Exception as err: except Exception as ex:
get_logger(0).error("Can't parse [ %s ] output: %r: %s", tools.cmdfmt(cmd), text, tools.efmt(err)) get_logger(0).error("Can't parse [ %s ] output: %r: %s", tools.cmdfmt(cmd), text, tools.efmt(ex))
return None return None

View File

@ -20,6 +20,8 @@
# ========================================================================== # # ========================================================================== #
from typing import AsyncGenerator
from ....logging import get_logger from ....logging import get_logger
from ....yamlconf.loader import load_yaml_file from ....yamlconf.loader import load_yaml_file
@ -33,6 +35,7 @@ from .base import BaseInfoSubmanager
class MetaInfoSubmanager(BaseInfoSubmanager): class MetaInfoSubmanager(BaseInfoSubmanager):
def __init__(self, meta_path: str) -> None: def __init__(self, meta_path: str) -> None:
self.__meta_path = meta_path self.__meta_path = meta_path
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> (dict | None): async def get_state(self) -> (dict | None):
try: try:
@ -40,3 +43,11 @@ class MetaInfoSubmanager(BaseInfoSubmanager):
except Exception: except Exception:
get_logger(0).exception("Can't parse meta") get_logger(0).exception("Can't parse meta")
return None return None
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())

View File

@ -24,8 +24,11 @@ import os
import asyncio import asyncio
import platform import platform
from typing import AsyncGenerator
from ....logging import get_logger from ....logging import get_logger
from .... import aiotools
from .... import aioproc from .... import aioproc
from .... import __version__ from .... import __version__
@ -37,6 +40,7 @@ from .base import BaseInfoSubmanager
class SystemInfoSubmanager(BaseInfoSubmanager): class SystemInfoSubmanager(BaseInfoSubmanager):
def __init__(self, streamer_cmd: list[str]) -> None: def __init__(self, streamer_cmd: list[str]) -> None:
self.__streamer_cmd = streamer_cmd self.__streamer_cmd = streamer_cmd
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict: async def get_state(self) -> dict:
streamer_info = await self.__get_streamer_info() streamer_info = await self.__get_streamer_info()
@ -50,6 +54,14 @@ class SystemInfoSubmanager(BaseInfoSubmanager):
}, },
} }
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())
# ===== # =====
async def __get_streamer_info(self) -> dict: async def __get_streamer_info(self) -> dict:

View File

@ -37,6 +37,7 @@ from ctypes import c_void_p
from ctypes import c_char from ctypes import c_char
from typing import Generator from typing import Generator
from typing import AsyncGenerator
from PIL import ImageOps from PIL import ImageOps
from PIL import Image as PilImage from PIL import Image as PilImage
@ -76,8 +77,8 @@ def _load_libtesseract() -> (ctypes.CDLL | None):
setattr(func, "restype", restype) setattr(func, "restype", restype)
setattr(func, "argtypes", argtypes) setattr(func, "argtypes", argtypes)
return lib return lib
except Exception as err: except Exception as ex:
warnings.warn(f"Can't load libtesseract: {err}", RuntimeWarning) warnings.warn(f"Can't load libtesseract: {ex}", RuntimeWarning)
return None return None
@ -107,9 +108,37 @@ class Ocr:
def __init__(self, data_dir_path: str, default_langs: list[str]) -> None: def __init__(self, data_dir_path: str, default_langs: list[str]) -> None:
self.__data_dir_path = data_dir_path self.__data_dir_path = data_dir_path
self.__default_langs = default_langs self.__default_langs = default_langs
self.__notifier = aiotools.AioNotifier()
def is_available(self) -> bool: async def get_state(self) -> dict:
return bool(_libtess) enabled = bool(_libtess)
default: list[str] = []
available: list[str] = []
if enabled:
default = self.get_default_langs()
available = self.get_available_langs()
return {
"enabled": enabled,
"langs": {
"default": default,
"available": available,
},
}
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ===== Granularity table =====
# - enabled -- Full
# - langs -- Partial
# =============================
while True:
await self.__notifier.wait()
yield (await self.get_state())
# =====
def get_default_langs(self) -> list[str]: def get_default_langs(self) -> list[str]:
return list(self.__default_langs) return list(self.__default_langs)

View File

@ -20,8 +20,6 @@
# ========================================================================== # # ========================================================================== #
import asyncio
import operator
import dataclasses import dataclasses
from typing import Callable from typing import Callable
@ -33,6 +31,8 @@ from aiohttp.web import Request
from aiohttp.web import Response from aiohttp.web import Response
from aiohttp.web import WebSocketResponse from aiohttp.web import WebSocketResponse
from ... import __version__
from ...logging import get_logger from ...logging import get_logger
from ...errors import OperationError from ...errors import OperationError
@ -98,54 +98,46 @@ class StreamerH264NotSupported(OperationError):
# ===== # =====
@dataclasses.dataclass(frozen=True)
class _SubsystemEventSource:
get_state: (Callable[[], Coroutine[Any, Any, dict]] | None) = None
poll_state: (Callable[[], AsyncGenerator[dict, None]] | None) = None
@dataclasses.dataclass @dataclasses.dataclass
class _Subsystem: class _Subsystem:
name: str name: str
event_type: str
sysprep: (Callable[[], None] | None) sysprep: (Callable[[], None] | None)
systask: (Callable[[], Coroutine[Any, Any, None]] | None) systask: (Callable[[], Coroutine[Any, Any, None]] | None)
cleanup: (Callable[[], Coroutine[Any, Any, dict]] | None) cleanup: (Callable[[], Coroutine[Any, Any, dict]] | None)
sources: dict[str, _SubsystemEventSource] trigger_state: (Callable[[], Coroutine[Any, Any, None]] | None) = None
poll_state: (Callable[[], AsyncGenerator[dict, None]] | None) = None
def __post_init__(self) -> None:
if self.event_type:
assert self.trigger_state
assert self.poll_state
@classmethod @classmethod
def make(cls, obj: object, name: str, event_type: str="") -> "_Subsystem": def make(cls, obj: object, name: str, event_type: str="") -> "_Subsystem":
if isinstance(obj, BasePlugin): if isinstance(obj, BasePlugin):
name = f"{name} ({obj.get_plugin_name()})" name = f"{name} ({obj.get_plugin_name()})"
sub = _Subsystem( return _Subsystem(
name=name, name=name,
event_type=event_type,
sysprep=getattr(obj, "sysprep", None), sysprep=getattr(obj, "sysprep", None),
systask=getattr(obj, "systask", None), systask=getattr(obj, "systask", None),
cleanup=getattr(obj, "cleanup", None), cleanup=getattr(obj, "cleanup", None),
sources={}, trigger_state=getattr(obj, "trigger_state", None),
)
if event_type:
sub.add_source(
event_type=event_type,
get_state=getattr(obj, "get_state", None),
poll_state=getattr(obj, "poll_state", None), poll_state=getattr(obj, "poll_state", None),
) )
return sub
def add_source(
self,
event_type: str,
get_state: (Callable[[], Coroutine[Any, Any, dict]] | None),
poll_state: (Callable[[], AsyncGenerator[dict, None]] | None),
) -> "_Subsystem":
assert event_type
assert event_type not in self.sources, (self, event_type)
assert get_state or poll_state, (self, event_type)
self.sources[event_type] = _SubsystemEventSource(get_state, poll_state)
return self
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
__EV_GPIO_STATE = "gpio_state"
__EV_HID_STATE = "hid_state"
__EV_ATX_STATE = "atx_state"
__EV_MSD_STATE = "msd_state"
__EV_STREAMER_STATE = "streamer_state"
__EV_OCR_STATE = "ocr_state"
__EV_INFO_STATE = "info_state"
def __init__( # pylint: disable=too-many-arguments,too-many-locals def __init__( # pylint: disable=too-many-arguments,too-many-locals
self, self,
auth_manager: AuthManager, auth_manager: AuthManager,
@ -161,9 +153,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
snapshoter: Snapshoter, snapshoter: Snapshoter,
keymap_path: str, keymap_path: str,
ignore_keys: list[str],
mouse_x_range: tuple[int, int],
mouse_y_range: tuple[int, int],
stream_forever: bool, stream_forever: bool,
) -> None: ) -> None:
@ -177,8 +166,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
self.__stream_forever = stream_forever self.__stream_forever = stream_forever
self.__hid_api = HidApi(hid, keymap_path, ignore_keys, mouse_x_range, mouse_y_range) # Ugly hack to get keymaps state self.__hid_api = HidApi(hid, keymap_path) # Ugly hack to get keymaps state
self.__streamer_api = StreamerApi(streamer, ocr) # Same hack to get ocr langs state
self.__apis: list[object] = [ self.__apis: list[object] = [
self, self,
AuthApi(auth_manager), AuthApi(auth_manager),
@ -188,22 +176,19 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
self.__hid_api, self.__hid_api,
AtxApi(atx), AtxApi(atx),
MsdApi(msd), MsdApi(msd),
self.__streamer_api, StreamerApi(streamer, ocr),
ExportApi(info_manager, atx, user_gpio), ExportApi(info_manager, atx, user_gpio),
RedfishApi(info_manager, atx), RedfishApi(info_manager, atx),
] ]
self.__subsystems = [ self.__subsystems = [
_Subsystem.make(auth_manager, "Auth manager"), _Subsystem.make(auth_manager, "Auth manager"),
_Subsystem.make(user_gpio, "User-GPIO", "gpio_state").add_source("gpio_model_state", user_gpio.get_model, None), _Subsystem.make(user_gpio, "User-GPIO", self.__EV_GPIO_STATE),
_Subsystem.make(hid, "HID", "hid_state").add_source("hid_keymaps_state", self.__hid_api.get_keymaps, None), _Subsystem.make(hid, "HID", self.__EV_HID_STATE),
_Subsystem.make(atx, "ATX", "atx_state"), _Subsystem.make(atx, "ATX", self.__EV_ATX_STATE),
_Subsystem.make(msd, "MSD", "msd_state"), _Subsystem.make(msd, "MSD", self.__EV_MSD_STATE),
_Subsystem.make(streamer, "Streamer", "streamer_state").add_source("streamer_ocr_state", self.__streamer_api.get_ocr, None), _Subsystem.make(streamer, "Streamer", self.__EV_STREAMER_STATE),
*[ _Subsystem.make(ocr, "OCR", self.__EV_OCR_STATE),
_Subsystem.make(info_manager.get_submanager(sub), f"Info manager ({sub})", f"info_{sub}_state",) _Subsystem.make(info_manager, "Info manager", self.__EV_INFO_STATE),
for sub in sorted(info_manager.get_subs())
],
] ]
self.__streamer_notifier = aiotools.AioNotifier() self.__streamer_notifier = aiotools.AioNotifier()
@ -213,7 +198,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
# ===== STREAMER CONTROLLER # ===== STREAMER CONTROLLER
@exposed_http("POST", "/streamer/set_params") @exposed_http("POST", "/streamer/set_params")
async def __streamer_set_params_handler(self, request: Request) -> Response: async def __streamer_set_params_handler(self, req: Request) -> Response:
current_params = self.__streamer.get_params() current_params = self.__streamer.get_params()
for (name, validator, exc_cls) in [ for (name, validator, exc_cls) in [
("quality", valid_stream_quality, StreamerQualityNotSupported), ("quality", valid_stream_quality, StreamerQualityNotSupported),
@ -222,7 +207,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
("h264_bitrate", valid_stream_h264_bitrate, StreamerH264NotSupported), ("h264_bitrate", valid_stream_h264_bitrate, StreamerH264NotSupported),
("h264_gop", valid_stream_h264_gop, StreamerH264NotSupported), ("h264_gop", valid_stream_h264_gop, StreamerH264NotSupported),
]: ]:
value = request.query.get(name) value = req.query.get(name)
if value: if value:
if name not in current_params: if name not in current_params:
assert exc_cls is not None, name assert exc_cls is not None, name
@ -242,24 +227,22 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
# ===== WEBSOCKET # ===== WEBSOCKET
@exposed_http("GET", "/ws") @exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse: async def __ws_handler(self, req: Request) -> WebSocketResponse:
stream = valid_bool(request.query.get("stream", True)) stream = valid_bool(req.query.get("stream", True))
async with self._ws_session(request, stream=stream) as ws: legacy = valid_bool(req.query.get("legacy", True))
states = [ async with self._ws_session(req, stream=stream, legacy=legacy) as ws:
(event_type, src.get_state()) (major, minor) = __version__.split(".")
for sub in self.__subsystems await ws.send_event("loop", {
for (event_type, src) in sub.sources.items() "version": {
if src.get_state "major": int(major),
] "minor": int(minor),
events = dict(zip( },
map(operator.itemgetter(0), states), })
await asyncio.gather(*map(operator.itemgetter(1), states)), for sub in self.__subsystems:
)) if sub.event_type:
await asyncio.gather(*[ assert sub.trigger_state
ws.send_event(event_type, events.pop(event_type)) await sub.trigger_state()
for (event_type, _) in states await self._broadcast_ws_event("hid_keymaps_state", await self.__hid_api.get_keymaps()) # FIXME
])
await ws.send_event("loop", {})
return (await self._ws_loop(ws)) return (await self._ws_loop(ws))
@exposed_ws("ping") @exposed_ws("ping")
@ -275,17 +258,17 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
aioproc.rename_process("main") aioproc.rename_process("main")
super().run(**kwargs) super().run(**kwargs)
async def _check_request_auth(self, exposed: HttpExposed, request: Request) -> None: async def _check_request_auth(self, exposed: HttpExposed, req: Request) -> None:
await check_request_auth(self.__auth_manager, exposed, request) await check_request_auth(self.__auth_manager, exposed, req)
async def _init_app(self) -> None: async def _init_app(self) -> None:
aiotools.create_deadly_task("Stream controller", self.__stream_controller()) aiotools.create_deadly_task("Stream controller", self.__stream_controller())
for sub in self.__subsystems: for sub in self.__subsystems:
if sub.systask: if sub.systask:
aiotools.create_deadly_task(sub.name, sub.systask()) aiotools.create_deadly_task(sub.name, sub.systask())
for (event_type, src) in sub.sources.items(): if sub.event_type:
if src.poll_state: assert sub.poll_state
aiotools.create_deadly_task(f"{sub.name} [poller]", self.__poll_state(event_type, src.poll_state())) aiotools.create_deadly_task(f"{sub.name} [poller]", self.__poll_state(sub.event_type, sub.poll_state()))
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter()) aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
self._add_exposed(*self.__apis) self._add_exposed(*self.__apis)
@ -347,12 +330,67 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
prev = cur prev = cur
await self.__streamer_notifier.wait() await self.__streamer_notifier.wait()
async def __poll_state(self, event_type: str, poller: AsyncGenerator[dict, None]) -> None:
async for state in poller:
await self._broadcast_ws_event(event_type, state)
async def __stream_snapshoter(self) -> None: async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run( await self.__snapshoter.run(
is_live=self.__has_stream_clients, is_live=self.__has_stream_clients,
notifier=self.__streamer_notifier, notifier=self.__streamer_notifier,
) )
async def __poll_state(self, event_type: str, poller: AsyncGenerator[dict, None]) -> None:
match event_type:
case self.__EV_GPIO_STATE:
await self.__poll_gpio_state(poller)
case self.__EV_INFO_STATE:
await self.__poll_info_state(poller)
case self.__EV_MSD_STATE:
await self.__poll_msd_state(poller)
case self.__EV_STREAMER_STATE:
await self.__poll_streamer_state(poller)
case self.__EV_OCR_STATE:
await self.__poll_ocr_state(poller)
case _:
async for state in poller:
await self._broadcast_ws_event(event_type, state)
async def __poll_gpio_state(self, poller: AsyncGenerator[dict, None]) -> None:
prev: dict = {"state": {"inputs": {}, "outputs": {}}}
async for state in poller:
await self._broadcast_ws_event(self.__EV_GPIO_STATE, state, legacy=False)
if "model" in state: # We have only "model"+"state" or "model" event
prev = state
await self._broadcast_ws_event("gpio_model_state", prev["model"], legacy=True)
else:
prev["state"]["inputs"].update(state["state"].get("inputs", {}))
prev["state"]["outputs"].update(state["state"].get("outputs", {}))
await self._broadcast_ws_event(self.__EV_GPIO_STATE, prev["state"], legacy=True)
async def __poll_info_state(self, poller: AsyncGenerator[dict, None]) -> None:
async for state in poller:
await self._broadcast_ws_event(self.__EV_INFO_STATE, state, legacy=False)
for (key, value) in state.items():
await self._broadcast_ws_event(f"info_{key}_state", value, legacy=True)
async def __poll_msd_state(self, poller: AsyncGenerator[dict, None]) -> None:
prev: dict = {"storage": None}
async for state in poller:
await self._broadcast_ws_event(self.__EV_MSD_STATE, state, legacy=False)
prev_storage = prev["storage"]
prev.update(state)
if prev["storage"] is not None and prev_storage is not None:
prev_storage.update(prev["storage"])
prev["storage"] = prev_storage
if "online" in prev: # Complete/Full
await self._broadcast_ws_event(self.__EV_MSD_STATE, prev, legacy=True)
async def __poll_streamer_state(self, poller: AsyncGenerator[dict, None]) -> None:
prev: dict = {}
async for state in poller:
await self._broadcast_ws_event(self.__EV_STREAMER_STATE, state, legacy=False)
prev.update(state)
if "features" in prev: # Complete/Full
await self._broadcast_ws_event(self.__EV_STREAMER_STATE, prev, legacy=True)
async def __poll_ocr_state(self, poller: AsyncGenerator[dict, None]) -> None:
async for state in poller:
await self._broadcast_ws_event(self.__EV_OCR_STATE, state, legacy=False)
await self._broadcast_ws_event("streamer_ocr_state", {"ocr": state}, legacy=True)

View File

@ -20,22 +20,23 @@
# ========================================================================== # # ========================================================================== #
import io
import signal import signal
import asyncio import asyncio
import asyncio.subprocess import asyncio.subprocess
import dataclasses import dataclasses
import functools import copy
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Any from typing import Any
import aiohttp import aiohttp
from PIL import Image as PilImage
from ...logging import get_logger from ...logging import get_logger
from ...clients.streamer import StreamerSnapshot
from ...clients.streamer import HttpStreamerClient
from ...clients.streamer import HttpStreamerClientSession
from ... import tools from ... import tools
from ... import aiotools from ... import aiotools
from ... import aioproc from ... import aioproc
@ -43,40 +44,6 @@ from ... import htclient
# ===== # =====
@dataclasses.dataclass(frozen=True)
class StreamerSnapshot:
online: bool
width: int
height: int
headers: tuple[tuple[str, str], ...]
data: bytes
async def make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
assert max_width >= 0
assert max_height >= 0
assert quality > 0
if max_width == 0 and max_height == 0:
max_width = self.width // 5
max_height = self.height // 5
else:
max_width = min((max_width or self.width), self.width)
max_height = min((max_height or self.height), self.height)
if (max_width, max_height) == (self.width, self.height):
return self.data
return (await aiotools.run_async(self.__inner_make_preview, max_width, max_height, quality))
@functools.lru_cache(maxsize=1)
def __inner_make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
with io.BytesIO(self.data) as snapshot_bio:
with io.BytesIO() as preview_bio:
with PilImage.open(snapshot_bio) as image:
image.thumbnail((max_width, max_height), PilImage.Resampling.LANCZOS)
image.save(preview_bio, format="jpeg", quality=quality)
return preview_bio.getvalue()
class _StreamerParams: class _StreamerParams:
__DESIRED_FPS = "desired_fps" __DESIRED_FPS = "desired_fps"
@ -136,7 +103,7 @@ class _StreamerParams:
} }
def get_limits(self) -> dict: def get_limits(self) -> dict:
limits = dict(self.__limits) limits = copy.deepcopy(self.__limits)
if self.__has_resolution: if self.__has_resolution:
limits[self.__AVAILABLE_RESOLUTIONS] = list(limits[self.__AVAILABLE_RESOLUTIONS]) limits[self.__AVAILABLE_RESOLUTIONS] = list(limits[self.__AVAILABLE_RESOLUTIONS])
return limits return limits
@ -170,6 +137,11 @@ class _StreamerParams:
class Streamer: # pylint: disable=too-many-instance-attributes class Streamer: # pylint: disable=too-many-instance-attributes
__ST_FULL = 0xFF
__ST_PARAMS = 0x01
__ST_STREAMER = 0x02
__ST_SNAPSHOT = 0x04
def __init__( # pylint: disable=too-many-arguments,too-many-locals def __init__( # pylint: disable=too-many-arguments,too-many-locals
self, self,
@ -203,7 +175,6 @@ class Streamer: # pylint: disable=too-many-instance-attributes
self.__state_poll = state_poll self.__state_poll = state_poll
self.__unix_path = unix_path self.__unix_path = unix_path
self.__timeout = timeout
self.__snapshot_timeout = snapshot_timeout self.__snapshot_timeout = snapshot_timeout
self.__process_name_prefix = process_name_prefix self.__process_name_prefix = process_name_prefix
@ -220,7 +191,13 @@ class Streamer: # pylint: disable=too-many-instance-attributes
self.__streamer_task: (asyncio.Task | None) = None self.__streamer_task: (asyncio.Task | None) = None
self.__streamer_proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member self.__streamer_proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member
self.__http_session: (aiohttp.ClientSession | None) = None self.__client = HttpStreamerClient(
name="jpeg",
unix_path=self.__unix_path,
timeout=timeout,
user_agent=htclient.make_user_agent("KVMD"),
)
self.__client_session: (HttpStreamerClientSession | None) = None
self.__snapshot: (StreamerSnapshot | None) = None self.__snapshot: (StreamerSnapshot | None) = None
@ -289,6 +266,7 @@ class Streamer: # pylint: disable=too-many-instance-attributes
def set_params(self, params: dict) -> None: def set_params(self, params: dict) -> None:
assert not self.__streamer_task assert not self.__streamer_task
self.__notifier.notify(self.__ST_PARAMS)
return self.__params.set_params(params) return self.__params.set_params(params)
def get_params(self) -> dict: def get_params(self) -> dict:
@ -297,55 +275,80 @@ class Streamer: # pylint: disable=too-many-instance-attributes
# ===== # =====
async def get_state(self) -> dict: async def get_state(self) -> dict:
streamer_state = None
if self.__streamer_task:
session = self.__ensure_http_session()
try:
async with session.get(self.__make_url("state")) as response:
htclient.raise_not_200(response)
streamer_state = (await response.json())["result"]
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
pass
except Exception:
get_logger().exception("Invalid streamer response from /state")
snapshot: (dict | None) = None
if self.__snapshot:
snapshot = dataclasses.asdict(self.__snapshot)
del snapshot["headers"]
del snapshot["data"]
return { return {
"features": self.__params.get_features(),
"limits": self.__params.get_limits(), "limits": self.__params.get_limits(),
"params": self.__params.get_params(), "params": self.__params.get_params(),
"snapshot": {"saved": snapshot}, "streamer": (await self.__get_streamer_state()),
"streamer": streamer_state, "snapshot": self.__get_snapshot_state(),
"features": self.__params.get_features(),
} }
async def trigger_state(self) -> None:
self.__notifier.notify(self.__ST_FULL)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - features -- Full
# - limits -- Partial, paired with params
# - params -- Partial, paired with limits
# - streamer -- Partial, nullable
# - snapshot -- Partial
# ===========================
def signal_handler(*_: Any) -> None: def signal_handler(*_: Any) -> None:
get_logger(0).info("Got SIGUSR2, checking the stream state ...") get_logger(0).info("Got SIGUSR2, checking the stream state ...")
self.__notifier.notify() self.__notifier.notify(self.__ST_STREAMER)
get_logger(0).info("Installing SIGUSR2 streamer handler ...") get_logger(0).info("Installing SIGUSR2 streamer handler ...")
asyncio.get_event_loop().add_signal_handler(signal.SIGUSR2, signal_handler) asyncio.get_event_loop().add_signal_handler(signal.SIGUSR2, signal_handler)
waiter_task: (asyncio.Task | None) = None prev: dict = {}
prev_state: dict = {}
while True: while True:
state = await self.get_state() new: dict = {}
if state != prev_state:
yield state
prev_state = state
if waiter_task is None: mask = await self.__notifier.wait(timeout=self.__state_poll)
waiter_task = asyncio.create_task(self.__notifier.wait()) if mask == self.__ST_FULL:
if waiter_task in (await aiotools.wait_first( new = await self.get_state()
asyncio.ensure_future(asyncio.sleep(self.__state_poll)), prev = copy.deepcopy(new)
waiter_task, yield new
))[0]: continue
waiter_task = None
if mask < 0:
mask = self.__ST_STREAMER
def check_update(key: str, value: (dict | None)) -> None:
if prev.get(key) != value:
new[key] = value
if mask & self.__ST_PARAMS:
check_update("params", self.__params.get_params())
if mask & self.__ST_STREAMER:
check_update("streamer", await self.__get_streamer_state())
if mask & self.__ST_SNAPSHOT:
check_update("snapshot", self.__get_snapshot_state())
if new and prev != new:
prev.update(copy.deepcopy(new))
yield new
async def __get_streamer_state(self) -> (dict | None):
if self.__streamer_task:
session = self.__ensure_client_session()
try:
return (await session.get_state())
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
pass
except Exception:
get_logger().exception("Invalid streamer response from /state")
return None
def __get_snapshot_state(self) -> dict:
if self.__snapshot:
snapshot = dataclasses.asdict(self.__snapshot)
del snapshot["headers"]
del snapshot["data"]
return {"saved": snapshot}
return {"saved": None}
# ===== # =====
@ -353,41 +356,17 @@ class Streamer: # pylint: disable=too-many-instance-attributes
if load: if load:
return self.__snapshot return self.__snapshot
logger = get_logger() logger = get_logger()
session = self.__ensure_http_session() session = self.__ensure_client_session()
try: try:
async with session.get( snapshot = await session.take_snapshot(self.__snapshot_timeout)
self.__make_url("snapshot"), if snapshot.online or allow_offline:
timeout=self.__snapshot_timeout,
) as response:
htclient.raise_not_200(response)
online = (response.headers["X-UStreamer-Online"] == "true")
if online or allow_offline:
snapshot = StreamerSnapshot(
online=online,
width=int(response.headers["X-UStreamer-Width"]),
height=int(response.headers["X-UStreamer-Height"]),
headers=tuple(
(key, value)
for (key, value) in tools.sorted_kvs(dict(response.headers))
if key.lower().startswith("x-ustreamer-") or key.lower() in [
"x-timestamp",
"access-control-allow-origin",
"cache-control",
"pragma",
"expires",
]
),
data=bytes(await response.read()),
)
if save: if save:
self.__snapshot = snapshot self.__snapshot = snapshot
self.__notifier.notify() self.__notifier.notify(self.__ST_SNAPSHOT)
return snapshot return snapshot
logger.error("Stream is offline, no signal or so") logger.error("Stream is offline, no signal or so")
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as ex:
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as err: logger.error("Can't connect to streamer: %s", tools.efmt(ex))
logger.error("Can't connect to streamer: %s", tools.efmt(err))
except Exception: except Exception:
logger.exception("Invalid streamer response from /snapshot") logger.exception("Invalid streamer response from /snapshot")
return None return None
@ -400,25 +379,14 @@ class Streamer: # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg @aiotools.atomic_fg
async def cleanup(self) -> None: async def cleanup(self) -> None:
await self.ensure_stop(immediately=True) await self.ensure_stop(immediately=True)
if self.__http_session: if self.__client_session:
await self.__http_session.close() await self.__client_session.close()
self.__http_session = None self.__client_session = None
# ===== def __ensure_client_session(self) -> HttpStreamerClientSession:
if not self.__client_session:
def __ensure_http_session(self) -> aiohttp.ClientSession: self.__client_session = self.__client.make_session()
if not self.__http_session: return self.__client_session
kwargs: dict = {
"headers": {"User-Agent": htclient.make_user_agent("KVMD")},
"connector": aiohttp.UnixConnector(path=self.__unix_path),
"timeout": aiohttp.ClientTimeout(total=self.__timeout),
}
self.__http_session = aiohttp.ClientSession(**kwargs)
return self.__http_session
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"
# ===== # =====
@ -473,8 +441,8 @@ class Streamer: # pylint: disable=too-many-instance-attributes
logger.info("%s: %s", name, tools.cmdfmt(cmd)) logger.info("%s: %s", name, tools.cmdfmt(cmd))
try: try:
await aioproc.log_process(cmd, logger, prefix=name) await aioproc.log_process(cmd, logger, prefix=name)
except Exception as err: except Exception as ex:
logger.exception("Can't execute command: %s", err) logger.exception("Can't execute command: %s", ex)
async def __start_streamer_proc(self) -> None: async def __start_streamer_proc(self) -> None:
assert self.__streamer_proc is None assert self.__streamer_proc is None

View File

@ -35,6 +35,7 @@ class SystemdUnitInfo:
self.__bus: (dbus_next.aio.MessageBus | None) = None self.__bus: (dbus_next.aio.MessageBus | None) = None
self.__intr: (dbus_next.introspection.Node | None) = None self.__intr: (dbus_next.introspection.Node | None) = None
self.__manager: (dbus_next.aio.proxy_object.ProxyInterface | None) = None self.__manager: (dbus_next.aio.proxy_object.ProxyInterface | None) = None
self.__requested = False
async def get_status(self, name: str) -> tuple[bool, bool]: async def get_status(self, name: str) -> tuple[bool, bool]:
assert self.__bus is not None assert self.__bus is not None
@ -49,8 +50,9 @@ class SystemdUnitInfo:
unit = self.__bus.get_proxy_object("org.freedesktop.systemd1", unit_p, self.__intr) unit = self.__bus.get_proxy_object("org.freedesktop.systemd1", unit_p, self.__intr)
unit_props = unit.get_interface("org.freedesktop.DBus.Properties") unit_props = unit.get_interface("org.freedesktop.DBus.Properties")
started = ((await unit_props.call_get("org.freedesktop.systemd1.Unit", "ActiveState")).value == "active") # type: ignore started = ((await unit_props.call_get("org.freedesktop.systemd1.Unit", "ActiveState")).value == "active") # type: ignore
except dbus_next.errors.DBusError as err: self.__requested = True
if err.type != "org.freedesktop.systemd1.NoSuchUnit": except dbus_next.errors.DBusError as ex:
if ex.type != "org.freedesktop.systemd1.NoSuchUnit":
raise raise
started = False started = False
enabled = ((await self.__manager.call_get_unit_file_state(name)) in [ # type: ignore enabled = ((await self.__manager.call_get_unit_file_state(name)) in [ # type: ignore
@ -75,6 +77,11 @@ class SystemdUnitInfo:
async def close(self) -> None: async def close(self) -> None:
try: try:
if self.__bus is not None: if self.__bus is not None:
try:
# XXX: Workaround for dbus_next bug: https://github.com/pikvm/kvmd/pull/182
if not self.__requested:
await self.__manager.call_get_default_target() # type: ignore
finally:
self.__bus.disconnect() self.__bus.disconnect()
await self.__bus.wait_for_disconnect() await self.__bus.wait_for_disconnect()
except Exception: except Exception:

View File

@ -21,6 +21,7 @@
import asyncio import asyncio
import copy
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Callable from typing import Callable
@ -68,12 +69,12 @@ class GpioChannelIsBusyError(IsBusyError, GpioError):
class _GpioInput: class _GpioInput:
def __init__( def __init__(
self, self,
channel: str, ch: str,
config: Section, config: Section,
driver: BaseUserGpioDriver, driver: BaseUserGpioDriver,
) -> None: ) -> None:
self.__channel = channel self.__ch = ch
self.__pin: str = str(config.pin) self.__pin: str = str(config.pin)
self.__inverted: bool = config.inverted self.__inverted: bool = config.inverted
@ -100,7 +101,7 @@ class _GpioInput:
} }
def __str__(self) -> str: def __str__(self) -> str:
return f"Input({self.__channel}, driver={self.__driver}, pin={self.__pin})" return f"Input({self.__ch}, driver={self.__driver}, pin={self.__pin})"
__repr__ = __str__ __repr__ = __str__
@ -108,13 +109,13 @@ class _GpioInput:
class _GpioOutput: # pylint: disable=too-many-instance-attributes class _GpioOutput: # pylint: disable=too-many-instance-attributes
def __init__( def __init__(
self, self,
channel: str, ch: str,
config: Section, config: Section,
driver: BaseUserGpioDriver, driver: BaseUserGpioDriver,
notifier: aiotools.AioNotifier, notifier: aiotools.AioNotifier,
) -> None: ) -> None:
self.__channel = channel self.__ch = ch
self.__pin: str = str(config.pin) self.__pin: str = str(config.pin)
self.__inverted: bool = config.inverted self.__inverted: bool = config.inverted
@ -184,7 +185,7 @@ class _GpioOutput: # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg @aiotools.atomic_fg
async def __run_action(self, wait: bool, name: str, func: Callable, *args: Any) -> None: async def __run_action(self, wait: bool, name: str, func: Callable, *args: Any) -> None:
if wait: if wait:
async with self.__region: with self.__region:
await func(*args) await func(*args)
else: else:
await aiotools.run_region_task( await aiotools.run_region_task(
@ -224,7 +225,7 @@ class _GpioOutput: # pylint: disable=too-many-instance-attributes
await self.__driver.write(self.__pin, (state ^ self.__inverted)) await self.__driver.write(self.__pin, (state ^ self.__inverted))
def __str__(self) -> str: def __str__(self) -> str:
return f"Output({self.__channel}, driver={self.__driver}, pin={self.__pin})" return f"Output({self.__ch}, driver={self.__driver}, pin={self.__pin})"
__repr__ = __str__ __repr__ = __str__
@ -232,8 +233,6 @@ class _GpioOutput: # pylint: disable=too-many-instance-attributes
# ===== # =====
class UserGpio: class UserGpio:
def __init__(self, config: Section, otg_config: Section) -> None: def __init__(self, config: Section, otg_config: Section) -> None:
self.__view = config.view
self.__notifier = aiotools.AioNotifier() self.__notifier = aiotools.AioNotifier()
self.__drivers = { self.__drivers = {
@ -249,45 +248,67 @@ class UserGpio:
self.__inputs: dict[str, _GpioInput] = {} self.__inputs: dict[str, _GpioInput] = {}
self.__outputs: dict[str, _GpioOutput] = {} self.__outputs: dict[str, _GpioOutput] = {}
for (channel, ch_config) in tools.sorted_kvs(config.scheme): for (ch, ch_config) in tools.sorted_kvs(config.scheme):
driver = self.__drivers[ch_config.driver] driver = self.__drivers[ch_config.driver]
if ch_config.mode == UserGpioModes.INPUT: if ch_config.mode == UserGpioModes.INPUT:
self.__inputs[channel] = _GpioInput(channel, ch_config, driver) self.__inputs[ch] = _GpioInput(ch, ch_config, driver)
else: # output: else: # output:
self.__outputs[channel] = _GpioOutput(channel, ch_config, driver, self.__notifier) self.__outputs[ch] = _GpioOutput(ch, ch_config, driver, self.__notifier)
async def get_model(self) -> dict: self.__scheme = self.__make_scheme()
return { self.__view = self.__make_view(config.view)
"scheme": {
"inputs": {channel: gin.get_scheme() for (channel, gin) in self.__inputs.items()},
"outputs": {
channel: gout.get_scheme()
for (channel, gout) in self.__outputs.items()
if not gout.is_const()
},
},
"view": self.__make_view(),
}
async def get_state(self) -> dict: async def get_state(self) -> dict:
return { return {
"inputs": {channel: await gin.get_state() for (channel, gin) in self.__inputs.items()}, "model": {
"scheme": copy.deepcopy(self.__scheme),
"view": copy.deepcopy(self.__view),
},
"state": (await self.__get_io_state()),
}
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - model -- Full
# - state.inputs -- Partial
# - state.outputs -- Partial
# ===========================
prev: dict = {"inputs": {}, "outputs": {}}
while True: # pylint: disable=too-many-nested-blocks
if (await self.__notifier.wait()) > 0:
full = await self.get_state()
prev = copy.deepcopy(full["state"])
yield full
else:
new = await self.__get_io_state()
diff: dict = {}
for sub in ["inputs", "outputs"]:
for ch in new[sub]:
if new[sub][ch] != prev[sub].get(ch):
if sub not in diff:
diff[sub] = {}
diff[sub][ch] = new[sub][ch]
if diff:
prev = copy.deepcopy(new)
yield {"state": diff}
async def __get_io_state(self) -> dict:
return {
"inputs": {
ch: (await gin.get_state())
for (ch, gin) in self.__inputs.items()
},
"outputs": { "outputs": {
channel: await gout.get_state() ch: (await gout.get_state())
for (channel, gout) in self.__outputs.items() for (ch, gout) in self.__outputs.items()
if not gout.is_const() if not gout.is_const()
}, },
} }
async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {}
while True:
state = await self.get_state()
if state != prev_state:
yield state
prev_state = state
await self.__notifier.wait()
def sysprep(self) -> None: def sysprep(self) -> None:
get_logger(0).info("Preparing User-GPIO drivers ...") get_logger(0).info("Preparing User-GPIO drivers ...")
for (_, driver) in tools.sorted_kvs(self.__drivers): for (_, driver) in tools.sorted_kvs(self.__drivers):
@ -307,28 +328,43 @@ class UserGpio:
except Exception: except Exception:
get_logger().exception("Can't cleanup driver %s", driver) get_logger().exception("Can't cleanup driver %s", driver)
async def switch(self, channel: str, state: bool, wait: bool) -> None: async def switch(self, ch: str, state: bool, wait: bool) -> None:
gout = self.__outputs.get(channel) gout = self.__outputs.get(ch)
if gout is None: if gout is None:
raise GpioChannelNotFoundError() raise GpioChannelNotFoundError()
await gout.switch(state, wait) await gout.switch(state, wait)
async def pulse(self, channel: str, delay: float, wait: bool) -> None: async def pulse(self, ch: str, delay: float, wait: bool) -> None:
gout = self.__outputs.get(channel) gout = self.__outputs.get(ch)
if gout is None: if gout is None:
raise GpioChannelNotFoundError() raise GpioChannelNotFoundError()
await gout.pulse(delay, wait) await gout.pulse(delay, wait)
# ===== # =====
def __make_view(self) -> dict: def __make_scheme(self) -> dict:
return { return {
"header": {"title": self.__make_view_title()}, "inputs": {
"table": self.__make_view_table(), ch: gin.get_scheme()
for (ch, gin) in self.__inputs.items()
},
"outputs": {
ch: gout.get_scheme()
for (ch, gout) in self.__outputs.items()
if not gout.is_const()
},
} }
def __make_view_title(self) -> list[dict]: # =====
raw_title = self.__view["header"]["title"]
def __make_view(self, view: dict) -> dict:
return {
"header": {"title": self.__make_view_title(view)},
"table": self.__make_view_table(view),
}
def __make_view_title(self, view: dict) -> list[dict]:
raw_title = view["header"]["title"]
title: list[dict] = [] title: list[dict] = []
if isinstance(raw_title, list): if isinstance(raw_title, list):
for item in raw_title: for item in raw_title:
@ -342,9 +378,9 @@ class UserGpio:
title.append(self.__make_item_label(f"#{raw_title}")) title.append(self.__make_item_label(f"#{raw_title}"))
return title return title
def __make_view_table(self) -> list[list[dict] | None]: def __make_view_table(self, view: dict) -> list[list[dict] | None]:
table: list[list[dict] | None] = [] table: list[list[dict] | None] = []
for row in self.__view["table"]: for row in view["table"]:
if len(row) == 0: if len(row) == 0:
table.append(None) table.append(None)
continue continue

184
kvmd/apps/oled/__init__.py Normal file
View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
# ========================================================================== #
# #
# KVMD-OLED - A small OLED daemon for PiKVM. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import sys
import os
import signal
import itertools
import logging
import time
import usb.core
from luma.core import cmdline as luma_cmdline
from PIL import ImageFont
from .screen import Screen
from .sensors import Sensors
# =====
_logger = logging.getLogger("oled")
# =====
def _detect_geometry() -> dict:
with open("/proc/device-tree/model") as file:
is_cm4 = ("Compute Module 4" in file.read())
has_usb = bool(list(usb.core.find(find_all=True)))
if is_cm4 and has_usb:
return {"height": 64, "rotate": 2}
return {"height": 32, "rotate": 0}
def _get_data_path(subdir: str, name: str) -> str:
if not name.startswith("@"):
return name # Just a regular system path
name = name[1:]
module_path = sys.modules[__name__].__file__
assert module_path is not None
return os.path.join(os.path.dirname(module_path), subdir, name)
# =====
def main() -> None: # pylint: disable=too-many-locals,too-many-branches,too-many-statements
logging.basicConfig(level=logging.INFO, format="%(message)s")
logging.getLogger("PIL").setLevel(logging.ERROR)
parser = luma_cmdline.create_parser(description="Display FQDN and IP on the OLED")
parser.set_defaults(**_detect_geometry())
parser.add_argument("--font", default="@ProggySquare.ttf", type=(lambda arg: _get_data_path("fonts", arg)), help="Font path")
parser.add_argument("--font-size", default=16, type=int, help="Font size")
parser.add_argument("--font-spacing", default=2, type=int, help="Font line spacing")
parser.add_argument("--offset-x", default=0, type=int, help="Horizontal offset")
parser.add_argument("--offset-y", default=0, type=int, help="Vertical offset")
parser.add_argument("--interval", default=5, type=int, help="Screens interval")
parser.add_argument("--image", default="", type=(lambda arg: _get_data_path("pics", arg)), help="Display some image, wait a single interval and exit")
parser.add_argument("--text", default="", help="Display some text, wait a single interval and exit")
parser.add_argument("--pipe", action="store_true", help="Read and display lines from stdin until EOF, wait a single interval and exit")
parser.add_argument("--clear-on-exit", action="store_true", help="Clear display on exit")
parser.add_argument("--contrast", default=64, type=int, help="Set OLED contrast, values from 0 to 255")
parser.add_argument("--fahrenheit", action="store_true", help="Display temperature in Fahrenheit instead of Celsius")
options = parser.parse_args(sys.argv[1:])
if options.config:
config = luma_cmdline.load_config(options.config)
options = parser.parse_args(config + sys.argv[1:])
device = luma_cmdline.create_device(options)
device.cleanup = (lambda _: None)
screen = Screen(
device=device,
font=ImageFont.truetype(options.font, options.font_size),
font_spacing=options.font_spacing,
offset=(options.offset_x, options.offset_y),
)
if options.display not in luma_cmdline.get_display_types()["emulator"]:
_logger.info("Iface: %s", options.interface)
_logger.info("Display: %s", options.display)
_logger.info("Size: %dx%d", device.width, device.height)
options.contrast = min(max(options.contrast, 0), 255)
_logger.info("Contrast: %d", options.contrast)
device.contrast(options.contrast)
try:
if options.image:
screen.draw_image(options.image)
time.sleep(options.interval)
elif options.text:
screen.draw_text(options.text.replace("\\n", "\n"))
time.sleep(options.interval)
elif options.pipe:
text = ""
for line in sys.stdin:
text += line
if "\0" in text:
screen.draw_text(text.replace("\0", ""))
text = ""
time.sleep(options.interval)
else:
stop_reason: (str | None) = None
def sigusr_handler(signum: int, _) -> None: # type: ignore
nonlocal stop_reason
if signum in (signal.SIGINT, signal.SIGTERM):
stop_reason = ""
elif signum == signal.SIGUSR1:
stop_reason = "Rebooting...\nPlease wait"
elif signum == signal.SIGUSR2:
stop_reason = "Halted"
for signum in [signal.SIGTERM, signal.SIGINT, signal.SIGUSR1, signal.SIGUSR2]:
signal.signal(signum, sigusr_handler)
hb = itertools.cycle(r"/-\|") # Heartbeat
swim = 0
def draw(text: str) -> None:
nonlocal swim
count = 0
while (count < max(options.interval, 1) * 2) and stop_reason is None:
screen.draw_text(
text=text.replace("__hb__", next(hb)),
offset_x=(3 if swim < 0 else 0),
)
count += 1
if swim >= 1200:
swim = -1200
else:
swim += 1
time.sleep(0.5)
sensors = Sensors(options.fahrenheit)
if device.height >= 64:
while stop_reason is None:
text = "{fqdn}\n{ip}\niface: {iface}\ntemp: {temp}\ncpu: {cpu} mem: {mem}\n(__hb__) {uptime}"
draw(sensors.render(text))
else:
summary = True
while stop_reason is None:
if summary:
text = "{fqdn}\n(__hb__) {uptime}\ntemp: {temp}"
else:
text = "{ip}\n(__hb__) iface: {iface}\ncpu: {cpu} mem: {mem}"
draw(sensors.render(text))
summary = (not summary)
if stop_reason is not None:
if len(stop_reason) > 0:
options.clear_on_exit = False
screen.draw_text(stop_reason)
while len(stop_reason) > 0:
time.sleep(0.1)
except (SystemExit, KeyboardInterrupt):
pass
if options.clear_on_exit:
screen.draw_text("")

View File

@ -0,0 +1,24 @@
# ========================================================================== #
# #
# KVMD - The main PiKVM daemon. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
from . import main
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

54
kvmd/apps/oled/screen.py Normal file
View File

@ -0,0 +1,54 @@
#!/usr/bin/env python3
# ========================================================================== #
# #
# KVMD-OLED - A small OLED daemon for PiKVM. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
from luma.core.device import device as luma_device
from luma.core.render import canvas as luma_canvas
from PIL import Image
from PIL import ImageFont
# =====
class Screen:
def __init__(
self,
device: luma_device,
font: ImageFont.FreeTypeFont,
font_spacing: int,
offset: tuple[int, int],
) -> None:
self.__device = device
self.__font = font
self.__font_spacing = font_spacing
self.__offset = offset
def draw_text(self, text: str, offset_x: int=0) -> None:
with luma_canvas(self.__device) as draw:
offset = list(self.__offset)
offset[0] += offset_x
draw.multiline_text(offset, text, font=self.__font, spacing=self.__font_spacing, fill="white")
def draw_image(self, image_path: str) -> None:
with luma_canvas(self.__device) as draw:
draw.bitmap(self.__offset, Image.open(image_path).convert("1"), fill="white")

126
kvmd/apps/oled/sensors.py Normal file
View File

@ -0,0 +1,126 @@
#!/usr/bin/env python3
# ========================================================================== #
# #
# KVMD-OLED - A small OLED daemon for PiKVM. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import socket
import functools
import datetime
import time
import netifaces
import psutil
# =====
class Sensors:
def __init__(self, fahrenheit: bool) -> None:
self.__fahrenheit = fahrenheit
self.__sensors = {
"fqdn": socket.getfqdn,
"iface": self.__get_iface,
"ip": self.__get_ip,
"uptime": self.__get_uptime,
"temp": self.__get_temp,
"cpu": self.__get_cpu,
"mem": self.__get_mem,
}
def render(self, text: str) -> str:
return text.format_map(self)
def __getitem__(self, key: str) -> str:
return self.__sensors[key]() # type: ignore
# =====
def __get_iface(self) -> str:
return self.__get_netconf(round(time.monotonic() / 0.3))[0]
def __get_ip(self) -> str:
return self.__get_netconf(round(time.monotonic() / 0.3))[1]
@functools.lru_cache(maxsize=1)
def __get_netconf(self, ts: int) -> tuple[str, str]:
_ = ts
try:
gws = netifaces.gateways()
if "default" in gws:
for proto in [socket.AF_INET, socket.AF_INET6]:
if proto in gws["default"]:
iface = gws["default"][proto][1]
addrs = netifaces.ifaddresses(iface)
return (iface, addrs[proto][0]["addr"])
for iface in netifaces.interfaces():
if not iface.startswith(("lo", "docker")):
addrs = netifaces.ifaddresses(iface)
for proto in [socket.AF_INET, socket.AF_INET6]:
if proto in addrs:
return (iface, addrs[proto][0]["addr"])
except Exception:
# _logger.exception("Can't get iface/IP")
pass
return ("<no-iface>", "<no-ip>")
# =====
def __get_uptime(self) -> str:
uptime = datetime.timedelta(seconds=int(time.time() - psutil.boot_time()))
pl = {"days": uptime.days}
(pl["hours"], rem) = divmod(uptime.seconds, 3600)
(pl["mins"], pl["secs"]) = divmod(rem, 60)
return "{days}d {hours}h {mins}m".format(**pl)
# =====
def __get_temp(self) -> str:
try:
with open("/sys/class/thermal/thermal_zone0/temp") as file:
temp = int(file.read().strip()) / 1000
if self.__fahrenheit:
temp = temp * 9 / 5 + 32
return f"{temp:.1f}\u00b0F"
return f"{temp:.1f}\u00b0C"
except Exception:
# _logger.exception("Can't read temp")
return "<no-temp>"
# =====
def __get_cpu(self) -> str:
st = psutil.cpu_times_percent()
user = st.user - st.guest
nice = st.nice - st.guest_nice
idle_all = st.idle + st.iowait
system_all = st.system + st.irq + st.softirq
virtual = st.guest + st.guest_nice
total = max(1, user + nice + system_all + idle_all + st.steal + virtual)
percent = int(
st.nice / total * 100
+ st.user / total * 100
+ system_all / total * 100
+ (st.steal + st.guest) / total * 100
)
return f"{percent}%"
def __get_mem(self) -> str:
return f"{int(psutil.virtual_memory().percent)}%"

View File

@ -350,5 +350,5 @@ def main(argv: (list[str] | None)=None) -> None:
options = parser.parse_args(argv[1:]) options = parser.parse_args(argv[1:])
try: try:
options.cmd(config) options.cmd(config)
except ValidatorError as err: except ValidatorError as ex:
raise SystemExit(str(err)) raise SystemExit(str(ex))

View File

@ -50,9 +50,9 @@ def _set_param(gadget: str, instance: int, param: str, value: str) -> None:
try: try:
with open(_get_param_path(gadget, instance, param), "w") as file: with open(_get_param_path(gadget, instance, param), "w") as file:
file.write(value + "\n") file.write(value + "\n")
except OSError as err: except OSError as ex:
if err.errno == errno.EBUSY: if ex.errno == errno.EBUSY:
raise SystemExit(f"Can't change {param!r} value because device is locked: {err}") raise SystemExit(f"Can't change {param!r} value because device is locked: {ex}")
raise raise

View File

@ -133,8 +133,8 @@ class _Service: # pylint: disable=too-many-instance-attributes
logger.info("CMD: %s", tools.cmdfmt(cmd)) logger.info("CMD: %s", tools.cmdfmt(cmd))
try: try:
return (not (await aioproc.log_process(cmd, logger)).returncode) return (not (await aioproc.log_process(cmd, logger)).returncode)
except Exception as err: except Exception as ex:
logger.exception("Can't execute command: %s", err) logger.exception("Can't execute command: %s", ex)
return False return False
# ===== # =====

View File

@ -50,7 +50,7 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
super().__init__() super().__init__()
self.__data_path = os.path.join(fstab.find_pst().root_path, "data") self.__data_path = fstab.find_pst().root_path
self.__ro_retries_delay = ro_retries_delay self.__ro_retries_delay = ro_retries_delay
self.__ro_cleanup_delay = ro_cleanup_delay self.__ro_cleanup_delay = ro_cleanup_delay
self.__remount_cmd = remount_cmd self.__remount_cmd = remount_cmd
@ -60,8 +60,8 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
# ===== WEBSOCKET # ===== WEBSOCKET
@exposed_http("GET", "/ws") @exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse: async def __ws_handler(self, req: Request) -> WebSocketResponse:
async with self._ws_session(request) as ws: async with self._ws_session(req) as ws:
await ws.send_event("loop", {}) await ws.send_event("loop", {})
return (await self._ws_loop(ws)) return (await self._ws_loop(ws))
@ -128,9 +128,9 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
def __is_write_available(self) -> bool: def __is_write_available(self) -> bool:
try: try:
return (not (os.statvfs(self.__data_path).f_flag & os.ST_RDONLY)) return (not (os.statvfs(self.__data_path).f_flag & os.ST_RDONLY))
except Exception as err: except Exception as ex:
get_logger(0).info("Can't get filesystem state of PST (%s): %s", get_logger(0).info("Can't get filesystem state of PST (%s): %s",
self.__data_path, tools.efmt(err)) self.__data_path, tools.efmt(ex))
return False return False
async def __remount_storage(self, rw: bool) -> bool: async def __remount_storage(self, rw: bool) -> bool:

View File

@ -46,8 +46,8 @@ def _preexec() -> None:
if os.isatty(0): if os.isatty(0):
try: try:
os.tcsetpgrp(0, os.getpgid(0)) os.tcsetpgrp(0, os.getpgid(0))
except Exception as err: except Exception as ex:
get_logger(0).info("Can't perform tcsetpgrp(0): %s", tools.efmt(err)) get_logger(0).info("Can't perform tcsetpgrp(0): %s", tools.efmt(ex))
async def _run_process(cmd: list[str], data_path: str) -> asyncio.subprocess.Process: # pylint: disable=no-member async def _run_process(cmd: list[str], data_path: str) -> asyncio.subprocess.Process: # pylint: disable=no-member

View File

@ -21,7 +21,7 @@
from ...clients.kvmd import KvmdClient from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamFormats from ...clients.streamer import StreamerFormats
from ...clients.streamer import BaseStreamerClient from ...clients.streamer import BaseStreamerClient
from ...clients.streamer import HttpStreamerClient from ...clients.streamer import HttpStreamerClient
from ...clients.streamer import MemsinkStreamerClient from ...clients.streamer import MemsinkStreamerClient
@ -51,8 +51,8 @@ def main(argv: (list[str] | None)=None) -> None:
return None return None
streamers: list[BaseStreamerClient] = list(filter(None, [ streamers: list[BaseStreamerClient] = list(filter(None, [
make_memsink_streamer("h264", StreamFormats.H264), make_memsink_streamer("h264", StreamerFormats.H264),
make_memsink_streamer("jpeg", StreamFormats.JPEG), make_memsink_streamer("jpeg", StreamerFormats.JPEG),
HttpStreamerClient(name="JPEG", user_agent=user_agent, **config.streamer._unpack()), HttpStreamerClient(name="JPEG", user_agent=user_agent, **config.streamer._unpack()),
])) ]))
@ -71,6 +71,7 @@ def main(argv: (list[str] | None)=None) -> None:
desired_fps=config.desired_fps, desired_fps=config.desired_fps,
mouse_output=config.mouse_output, mouse_output=config.mouse_output,
keymap_path=config.keymap, keymap_path=config.keymap,
allow_cut_after=config.allow_cut_after,
kvmd=KvmdClient(user_agent=user_agent, **config.kvmd._unpack()), kvmd=KvmdClient(user_agent=user_agent, **config.kvmd._unpack()),
streamers=streamers, streamers=streamers,

View File

@ -22,6 +22,7 @@
import asyncio import asyncio
import ssl import ssl
import time
from typing import Callable from typing import Callable
from typing import Coroutine from typing import Coroutine
@ -64,6 +65,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
width: int, width: int,
height: int, height: int,
name: str, name: str,
allow_cut_after: float,
vnc_passwds: list[str], vnc_passwds: list[str],
vencrypt: bool, vencrypt: bool,
none_auth_only: bool, none_auth_only: bool,
@ -79,6 +81,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
self._width = width self._width = width
self._height = height self._height = height
self.__name = name self.__name = name
self.__allow_cut_after = allow_cut_after
self.__vnc_passwds = vnc_passwds self.__vnc_passwds = vnc_passwds
self.__vencrypt = vencrypt self.__vencrypt = vencrypt
self.__none_auth_only = none_auth_only self.__none_auth_only = none_auth_only
@ -90,6 +93,8 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
self.__fb_cont_updates = False self.__fb_cont_updates = False
self.__fb_reset_h264 = False self.__fb_reset_h264 = False
self.__allow_cut_since_ts = 0.0
self.__lock = asyncio.Lock() self.__lock = asyncio.Lock()
# ===== # =====
@ -120,10 +125,10 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("%s [%s]: Cancelling subtask ...", self._remote, name) logger.info("%s [%s]: Cancelling subtask ...", self._remote, name)
raise raise
except RfbConnectionError as err: except RfbConnectionError as ex:
logger.info("%s [%s]: Gone: %s", self._remote, name, err) logger.info("%s [%s]: Gone: %s", self._remote, name, ex)
except (RfbError, ssl.SSLError) as err: except (RfbError, ssl.SSLError) as ex:
logger.error("%s [%s]: Error: %s", self._remote, name, err) logger.error("%s [%s]: Error: %s", self._remote, name, ex)
except Exception: except Exception:
logger.exception("%s [%s]: Unhandled exception", self._remote, name) logger.exception("%s [%s]: Unhandled exception", self._remote, name)
@ -414,6 +419,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
# ===== # =====
async def __main_loop(self) -> None: async def __main_loop(self) -> None:
self.__allow_cut_since_ts = time.monotonic() + self.__allow_cut_after
handlers = { handlers = {
0: self.__handle_set_pixel_format, 0: self.__handle_set_pixel_format,
2: self.__handle_set_encodings, 2: self.__handle_set_encodings,
@ -499,6 +505,11 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
async def __handle_client_cut_text(self) -> None: async def __handle_client_cut_text(self) -> None:
length = (await self._read_struct("cut text length", "xxx L"))[0] length = (await self._read_struct("cut text length", "xxx L"))[0]
text = await self._read_text("cut text data", length) text = await self._read_text("cut text data", length)
if self.__allow_cut_since_ts > 0 and time.monotonic() >= self.__allow_cut_since_ts:
# We should ignore cut event a few seconds after handshake
# because bVNC, AVNC and maybe some other clients perform
# it right after the connection automatically.
# - https://github.com/pikvm/pikvm/issues/1420
await self._on_cut_event(text) await self._on_cut_event(text)
async def __handle_enable_cont_updates(self) -> None: async def __handle_enable_cont_updates(self) -> None:

View File

@ -29,5 +29,5 @@ class RfbError(Exception):
class RfbConnectionError(RfbError): class RfbConnectionError(RfbError):
def __init__(self, msg: str, err: Exception) -> None: def __init__(self, msg: str, ex: Exception) -> None:
super().__init__(f"{msg}: {tools.efmt(err)}") super().__init__(f"{msg}: {tools.efmt(ex)}")

View File

@ -51,22 +51,22 @@ class RfbClientStream:
else: else:
fmt = f">{fmt}" fmt = f">{fmt}"
return struct.unpack(fmt, await self.__reader.readexactly(struct.calcsize(fmt)))[0] return struct.unpack(fmt, await self.__reader.readexactly(struct.calcsize(fmt)))[0]
except (ConnectionError, asyncio.IncompleteReadError) as err: except (ConnectionError, asyncio.IncompleteReadError) as ex:
raise RfbConnectionError(f"Can't read {msg}", err) raise RfbConnectionError(f"Can't read {msg}", ex)
async def _read_struct(self, msg: str, fmt: str) -> tuple[int, ...]: async def _read_struct(self, msg: str, fmt: str) -> tuple[int, ...]:
assert len(fmt) > 1 assert len(fmt) > 1
try: try:
fmt = f">{fmt}" fmt = f">{fmt}"
return struct.unpack(fmt, (await self.__reader.readexactly(struct.calcsize(fmt)))) return struct.unpack(fmt, (await self.__reader.readexactly(struct.calcsize(fmt))))
except (ConnectionError, asyncio.IncompleteReadError) as err: except (ConnectionError, asyncio.IncompleteReadError) as ex:
raise RfbConnectionError(f"Can't read {msg}", err) raise RfbConnectionError(f"Can't read {msg}", ex)
async def _read_text(self, msg: str, length: int) -> str: async def _read_text(self, msg: str, length: int) -> str:
try: try:
return (await self.__reader.readexactly(length)).decode("utf-8", errors="ignore") return (await self.__reader.readexactly(length)).decode("utf-8", errors="ignore")
except (ConnectionError, asyncio.IncompleteReadError) as err: except (ConnectionError, asyncio.IncompleteReadError) as ex:
raise RfbConnectionError(f"Can't read {msg}", err) raise RfbConnectionError(f"Can't read {msg}", ex)
# ===== # =====
@ -84,8 +84,8 @@ class RfbClientStream:
self.__writer.write(struct.pack(f">{fmt}", *values)) self.__writer.write(struct.pack(f">{fmt}", *values))
if drain: if drain:
await self.__writer.drain() await self.__writer.drain()
except ConnectionError as err: except ConnectionError as ex:
raise RfbConnectionError(f"Can't write {msg}", err) raise RfbConnectionError(f"Can't write {msg}", ex)
async def _write_reason(self, msg: str, text: str, drain: bool=True) -> None: async def _write_reason(self, msg: str, text: str, drain: bool=True) -> None:
encoded = text.encode("utf-8", errors="ignore") encoded = text.encode("utf-8", errors="ignore")
@ -94,8 +94,8 @@ class RfbClientStream:
self.__writer.write(encoded) self.__writer.write(encoded)
if drain: if drain:
await self.__writer.drain() await self.__writer.drain()
except ConnectionError as err: except ConnectionError as ex:
raise RfbConnectionError(f"Can't write {msg}", err) raise RfbConnectionError(f"Can't write {msg}", ex)
async def _write_fb_update(self, msg: str, width: int, height: int, encoding: int, drain: bool=True) -> None: async def _write_fb_update(self, msg: str, width: int, height: int, encoding: int, drain: bool=True) -> None:
await self._write_struct( await self._write_struct(
@ -123,8 +123,8 @@ class RfbClientStream:
server_side=True, server_side=True,
ssl_handshake_timeout=ssl_timeout, ssl_handshake_timeout=ssl_timeout,
) )
except ConnectionError as err: except ConnectionError as ex:
raise RfbConnectionError("Can't start TLS", err) raise RfbConnectionError("Can't start TLS", ex)
ssl_reader.set_transport(transport) # type: ignore ssl_reader.set_transport(transport) # type: ignore
ssl_writer = asyncio.StreamWriter( ssl_writer = asyncio.StreamWriter(

View File

@ -42,7 +42,7 @@ from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError from ...clients.streamer import StreamerError
from ...clients.streamer import StreamerPermError from ...clients.streamer import StreamerPermError
from ...clients.streamer import StreamFormats from ...clients.streamer import StreamerFormats
from ...clients.streamer import BaseStreamerClient from ...clients.streamer import BaseStreamerClient
from ... import tools from ... import tools
@ -81,6 +81,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
mouse_output: str, mouse_output: str,
keymap_name: str, keymap_name: str,
symmap: dict[int, dict[int, str]], symmap: dict[int, dict[int, str]],
allow_cut_after: float,
kvmd: KvmdClient, kvmd: KvmdClient,
streamers: list[BaseStreamerClient], streamers: list[BaseStreamerClient],
@ -100,6 +101,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
tls_timeout=tls_timeout, tls_timeout=tls_timeout,
x509_cert_path=x509_cert_path, x509_cert_path=x509_cert_path,
x509_key_path=x509_key_path, x509_key_path=x509_key_path,
allow_cut_after=allow_cut_after,
vnc_passwds=list(vnc_credentials), vnc_passwds=list(vnc_credentials),
vencrypt=vencrypt, vencrypt=vencrypt,
none_auth_only=none_auth_only, none_auth_only=none_auth_only,
@ -175,9 +177,10 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__kvmd_ws = None self.__kvmd_ws = None
async def __process_ws_event(self, event_type: str, event: dict) -> None: async def __process_ws_event(self, event_type: str, event: dict) -> None:
if event_type == "info_meta_state": if event_type == "info_state":
if "meta" in event:
try: try:
host = event["server"]["host"] host = event["meta"]["server"]["host"]
except Exception: except Exception:
host = None host = None
else: else:
@ -188,7 +191,11 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__shared_params.name = name self.__shared_params.name = name
elif event_type == "hid_state": elif event_type == "hid_state":
if self._encodings.has_leds_state: if (
self._encodings.has_leds_state
and ("keyboard" in event)
and ("leds" in event["keyboard"])
):
await self._send_leds_state(**event["keyboard"]["leds"]) await self._send_leds_state(**event["keyboard"]["leds"])
# ===== # =====
@ -210,19 +217,19 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
await self.__queue_frame(frame) await self.__queue_frame(frame)
else: else:
await self.__queue_frame("No signal") await self.__queue_frame("No signal")
except StreamerError as err: except StreamerError as ex:
if isinstance(err, StreamerPermError): if isinstance(ex, StreamerPermError):
streamer = self.__get_default_streamer() streamer = self.__get_default_streamer()
logger.info("%s [streamer]: Permanent error: %s; switching to %s ...", self._remote, err, streamer) logger.info("%s [streamer]: Permanent error: %s; switching to %s ...", self._remote, ex, streamer)
else: else:
logger.info("%s [streamer]: Waiting for stream: %s", self._remote, err) logger.info("%s [streamer]: Waiting for stream: %s", self._remote, ex)
await self.__queue_frame("Waiting for stream ...") await self.__queue_frame("Waiting for stream ...")
await asyncio.sleep(1) await asyncio.sleep(1)
def __get_preferred_streamer(self) -> BaseStreamerClient: def __get_preferred_streamer(self) -> BaseStreamerClient:
formats = { formats = {
StreamFormats.JPEG: "has_tight", StreamerFormats.JPEG: "has_tight",
StreamFormats.H264: "has_h264", StreamerFormats.H264: "has_h264",
} }
streamer: (BaseStreamerClient | None) = None streamer: (BaseStreamerClient | None) = None
for streamer in self.__streamers: for streamer in self.__streamers:
@ -248,7 +255,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
"data": (await make_text_jpeg(self._width, self._height, self._encodings.tight_jpeg_quality, text)), "data": (await make_text_jpeg(self._width, self._height, self._encodings.tight_jpeg_quality, text)),
"width": self._width, "width": self._width,
"height": self._height, "height": self._height,
"format": StreamFormats.JPEG, "format": StreamerFormats.JPEG,
} }
async def __fb_sender_task_loop(self) -> None: # pylint: disable=too-many-branches async def __fb_sender_task_loop(self) -> None: # pylint: disable=too-many-branches
@ -258,21 +265,21 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
frame = await self.__fb_queue.get() frame = await self.__fb_queue.get()
if ( if (
last is None # pylint: disable=too-many-boolean-expressions last is None # pylint: disable=too-many-boolean-expressions
or frame["format"] == StreamFormats.JPEG or frame["format"] == StreamerFormats.JPEG
or last["format"] != frame["format"] or last["format"] != frame["format"]
or (frame["format"] == StreamFormats.H264 and ( or (frame["format"] == StreamerFormats.H264 and (
frame["key"] frame["key"]
or last["width"] != frame["width"] or last["width"] != frame["width"]
or last["height"] != frame["height"] or last["height"] != frame["height"]
or len(last["data"]) + len(frame["data"]) > 4194304 or len(last["data"]) + len(frame["data"]) > 4194304
)) ))
): ):
self.__fb_has_key = (frame["format"] == StreamFormats.H264 and frame["key"]) self.__fb_has_key = (frame["format"] == StreamerFormats.H264 and frame["key"])
last = frame last = frame
if self.__fb_queue.qsize() == 0: if self.__fb_queue.qsize() == 0:
break break
continue continue
assert frame["format"] == StreamFormats.H264 assert frame["format"] == StreamerFormats.H264
last["data"] += frame["data"] last["data"] += frame["data"]
if self.__fb_queue.qsize() == 0: if self.__fb_queue.qsize() == 0:
break break
@ -294,9 +301,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
await self._send_fb_allow_again() await self._send_fb_allow_again()
continue continue
if last["format"] == StreamFormats.JPEG: if last["format"] == StreamerFormats.JPEG:
await self._send_fb_jpeg(last["data"]) await self._send_fb_jpeg(last["data"])
elif last["format"] == StreamFormats.H264: elif last["format"] == StreamerFormats.H264:
if not self._encodings.has_h264: if not self._encodings.has_h264:
raise RfbError("The client doesn't want to accept H264 anymore") raise RfbError("The client doesn't want to accept H264 anymore")
if self.__fb_has_key: if self.__fb_has_key:
@ -439,6 +446,7 @@ class VncServer: # pylint: disable=too-many-instance-attributes
desired_fps: int, desired_fps: int,
mouse_output: str, mouse_output: str,
keymap_path: str, keymap_path: str,
allow_cut_after: float,
kvmd: KvmdClient, kvmd: KvmdClient,
streamers: list[BaseStreamerClient], streamers: list[BaseStreamerClient],
@ -481,8 +489,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
try: try:
async with kvmd.make_session("", "") as kvmd_session: async with kvmd.make_session("", "") as kvmd_session:
none_auth_only = await kvmd_session.auth.check() none_auth_only = await kvmd_session.auth.check()
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
logger.error("%s [entry]: Can't check KVMD auth mode: %s", remote, tools.efmt(err)) logger.error("%s [entry]: Can't check KVMD auth mode: %s", remote, tools.efmt(ex))
return return
await _Client( await _Client(
@ -496,6 +504,7 @@ class VncServer: # pylint: disable=too-many-instance-attributes
mouse_output=mouse_output, mouse_output=mouse_output,
keymap_name=keymap_name, keymap_name=keymap_name,
symmap=symmap, symmap=symmap,
allow_cut_after=allow_cut_after,
kvmd=kvmd, kvmd=kvmd,
streamers=streamers, streamers=streamers,
vnc_credentials=(await self.__vnc_auth_manager.read_credentials())[0], vnc_credentials=(await self.__vnc_auth_manager.read_credentials())[0],

View File

@ -54,8 +54,8 @@ class VncAuthManager:
if self.__enabled: if self.__enabled:
try: try:
return (await self.__inner_read_credentials(), True) return (await self.__inner_read_credentials(), True)
except VncAuthError as err: except VncAuthError as ex:
get_logger(0).error(str(err)) get_logger(0).error(str(ex))
except Exception: except Exception:
get_logger(0).exception("Unhandled exception while reading VNCAuth passwd file") get_logger(0).exception("Unhandled exception while reading VNCAuth passwd file")
return ({}, (not self.__enabled)) return ({}, (not self.__enabled))

View File

@ -56,8 +56,8 @@ def _write_int(rtc: int, key: str, value: int) -> None:
def _reset_alarm(rtc: int, timeout: int) -> None: def _reset_alarm(rtc: int, timeout: int) -> None:
try: try:
now = _read_int(rtc, "since_epoch") now = _read_int(rtc, "since_epoch")
except OSError as err: except OSError as ex:
if err.errno != errno.EINVAL: if ex.errno != errno.EINVAL:
raise raise
raise RtcIsNotAvailableError("Can't read since_epoch right now") raise RtcIsNotAvailableError("Can't read since_epoch right now")
if now == 0: if now == 0:
@ -65,8 +65,8 @@ def _reset_alarm(rtc: int, timeout: int) -> None:
try: try:
for wake in [0, now + timeout]: for wake in [0, now + timeout]:
_write_int(rtc, "wakealarm", wake) _write_int(rtc, "wakealarm", wake)
except OSError as err: except OSError as ex:
if err.errno != errno.EIO: if ex.errno != errno.EIO:
raise raise
raise RtcIsNotAvailableError("IO error, probably the supercapacitor is not charged") raise RtcIsNotAvailableError("IO error, probably the supercapacitor is not charged")
@ -80,9 +80,9 @@ def _cmd_run(config: Section) -> None:
while True: while True:
try: try:
_reset_alarm(config.rtc, config.timeout) _reset_alarm(config.rtc, config.timeout)
except RtcIsNotAvailableError as err: except RtcIsNotAvailableError as ex:
if not fail: if not fail:
logger.error("RTC%d is not available now: %s; waiting ...", config.rtc, err) logger.error("RTC%d is not available now: %s; waiting ...", config.rtc, ex)
fail = True fail = True
else: else:
if fail: if fail:

View File

@ -18,3 +18,67 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # # along with this program. If not, see <https://www.gnu.org/licenses/>. #
# # # #
# ========================================================================== # # ========================================================================== #
import types
from typing import Callable
from typing import Self
import aiohttp
# =====
class BaseHttpClientSession:
def __init__(self, make_http_session: Callable[[], aiohttp.ClientSession]) -> None:
self._make_http_session = make_http_session
self.__http_session: (aiohttp.ClientSession | None) = None
def _ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self._make_http_session()
return self.__http_session
async def close(self) -> None:
if self.__http_session:
await self.__http_session.close()
self.__http_session = None
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
_exc_type: type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class BaseHttpClient:
def __init__(
self,
unix_path: str,
timeout: float,
user_agent: str,
) -> None:
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def make_session(self) -> BaseHttpClientSession:
raise NotImplementedError
def _make_http_session(self, headers: (dict[str, str] | None)=None) -> aiohttp.ClientSession:
return aiohttp.ClientSession(
base_url="http://localhost:0",
headers={
"User-Agent": self.__user_agent,
**(headers or {}),
},
connector=aiohttp.UnixConnector(path=self.__unix_path),
timeout=aiohttp.ClientTimeout(total=self.__timeout),
)

View File

@ -23,7 +23,6 @@
import asyncio import asyncio
import contextlib import contextlib
import struct import struct
import types
from typing import Callable from typing import Callable
from typing import AsyncGenerator from typing import AsyncGenerator
@ -34,22 +33,19 @@ from .. import aiotools
from .. import htclient from .. import htclient
from .. import htserver from .. import htserver
from . import BaseHttpClient
from . import BaseHttpClientSession
# ===== # =====
class _BaseApiPart: class _BaseApiPart:
def __init__( def __init__(self, ensure_http_session: Callable[[], aiohttp.ClientSession]) -> None:
self,
ensure_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self._ensure_http_session = ensure_http_session self._ensure_http_session = ensure_http_session
self._make_url = make_url
async def _set_params(self, handle: str, **params: (int | str | None)) -> None: async def _set_params(self, handle: str, **params: (int | str | None)) -> None:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.post( async with session.post(
url=self._make_url(handle), url=handle,
params={ params={
key: value key: value
for (key, value) in params.items() for (key, value) in params.items()
@ -63,11 +59,11 @@ class _AuthApiPart(_BaseApiPart):
async def check(self) -> bool: async def check(self) -> bool:
session = self._ensure_http_session() session = self._ensure_http_session()
try: try:
async with session.get(self._make_url("auth/check")) as response: async with session.get("/auth/check") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return True return True
except aiohttp.ClientResponseError as err: except aiohttp.ClientResponseError as ex:
if err.status in [400, 401, 403]: if ex.status in [400, 401, 403]:
return False return False
raise raise
@ -75,13 +71,13 @@ class _AuthApiPart(_BaseApiPart):
class _StreamerApiPart(_BaseApiPart): class _StreamerApiPart(_BaseApiPart):
async def get_state(self) -> dict: async def get_state(self) -> dict:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.get(self._make_url("streamer")) as response: async with session.get("/streamer") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return (await response.json())["result"] return (await response.json())["result"]
async def set_params(self, quality: (int | None)=None, desired_fps: (int | None)=None) -> None: async def set_params(self, quality: (int | None)=None, desired_fps: (int | None)=None) -> None:
await self._set_params( await self._set_params(
"streamer/set_params", "/streamer/set_params",
quality=quality, quality=quality,
desired_fps=desired_fps, desired_fps=desired_fps,
) )
@ -90,7 +86,7 @@ class _StreamerApiPart(_BaseApiPart):
class _HidApiPart(_BaseApiPart): class _HidApiPart(_BaseApiPart):
async def get_keymaps(self) -> tuple[str, set[str]]: async def get_keymaps(self) -> tuple[str, set[str]]:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.get(self._make_url("hid/keymaps")) as response: async with session.get("/hid/keymaps") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
result = (await response.json())["result"] result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"])) return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
@ -98,7 +94,7 @@ class _HidApiPart(_BaseApiPart):
async def print(self, text: str, limit: int, keymap_name: str) -> None: async def print(self, text: str, limit: int, keymap_name: str) -> None:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.post( async with session.post(
url=self._make_url("hid/print"), url="/hid/print",
params={"limit": limit, "keymap": keymap_name}, params={"limit": limit, "keymap": keymap_name},
data=text, data=text,
) as response: ) as response:
@ -106,7 +102,7 @@ class _HidApiPart(_BaseApiPart):
async def set_params(self, keyboard_output: (str | None)=None, mouse_output: (str | None)=None) -> None: async def set_params(self, keyboard_output: (str | None)=None, mouse_output: (str | None)=None) -> None:
await self._set_params( await self._set_params(
"hid/set_params", "/hid/set_params",
keyboard_output=keyboard_output, keyboard_output=keyboard_output,
mouse_output=mouse_output, mouse_output=mouse_output,
) )
@ -115,7 +111,7 @@ class _HidApiPart(_BaseApiPart):
class _AtxApiPart(_BaseApiPart): class _AtxApiPart(_BaseApiPart):
async def get_state(self) -> dict: async def get_state(self) -> dict:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.get(self._make_url("atx")) as response: async with session.get("/atx") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return (await response.json())["result"] return (await response.json())["result"]
@ -123,13 +119,13 @@ class _AtxApiPart(_BaseApiPart):
session = self._ensure_http_session() session = self._ensure_http_session()
try: try:
async with session.post( async with session.post(
url=self._make_url("atx/power"), url="/atx/power",
params={"action": action}, params={"action": action},
) as response: ) as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return True return True
except aiohttp.ClientResponseError as err: except aiohttp.ClientResponseError as ex:
if err.status == 409: if ex.status == 409:
return False return False
raise raise
@ -138,7 +134,6 @@ class _AtxApiPart(_BaseApiPart):
class KvmdClientWs: class KvmdClientWs:
def __init__(self, ws: aiohttp.ClientWebSocketResponse) -> None: def __init__(self, ws: aiohttp.ClientWebSocketResponse) -> None:
self.__ws = ws self.__ws = ws
self.__writer_queue: "asyncio.Queue[tuple[str, dict] | bytes]" = asyncio.Queue() self.__writer_queue: "asyncio.Queue[tuple[str, dict] | bytes]" = asyncio.Queue()
self.__communicated = False self.__communicated = False
@ -200,84 +195,25 @@ class KvmdClientWs:
await self.__writer_queue.put(struct.pack(">bbbb", 5, 0, delta_x, delta_y)) await self.__writer_queue.put(struct.pack(">bbbb", 5, 0, delta_x, delta_y))
class KvmdClientSession: class KvmdClientSession(BaseHttpClientSession):
def __init__( def __init__(self, make_http_session: Callable[[], aiohttp.ClientSession]) -> None:
self, super().__init__(make_http_session)
make_http_session: Callable[[], aiohttp.ClientSession], self.auth = _AuthApiPart(self._ensure_http_session)
make_url: Callable[[str], str], self.streamer = _StreamerApiPart(self._ensure_http_session)
) -> None: self.hid = _HidApiPart(self._ensure_http_session)
self.atx = _AtxApiPart(self._ensure_http_session)
self.__make_http_session = make_http_session
self.__make_url = make_url
self.__http_session: (aiohttp.ClientSession | None) = None
args = (self.__ensure_http_session, make_url)
self.auth = _AuthApiPart(*args)
self.streamer = _StreamerApiPart(*args)
self.hid = _HidApiPart(*args)
self.atx = _AtxApiPart(*args)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def ws(self) -> AsyncGenerator[KvmdClientWs, None]: async def ws(self) -> AsyncGenerator[KvmdClientWs, None]:
session = self.__ensure_http_session() session = self._ensure_http_session()
async with session.ws_connect(self.__make_url("ws")) as ws: async with session.ws_connect("/ws", params={"legacy": "0"}) as ws:
yield KvmdClientWs(ws) yield KvmdClientWs(ws)
def __ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self.__make_http_session()
return self.__http_session
async def close(self) -> None: class KvmdClient(BaseHttpClient):
if self.__http_session: def make_session(self, user: str="", passwd: str="") -> KvmdClientSession:
await self.__http_session.close() headers = {
self.__http_session = None
async def __aenter__(self) -> "KvmdClientSession":
return self
async def __aexit__(
self,
_exc_type: type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class KvmdClient:
def __init__(
self,
unix_path: str,
timeout: float,
user_agent: str,
) -> None:
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def make_session(self, user: str, passwd: str) -> KvmdClientSession:
return KvmdClientSession(
make_http_session=(lambda: self.__make_http_session(user, passwd)),
make_url=self.__make_url,
)
def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
kwargs: dict = {
"headers": {
"X-KVMD-User": user, "X-KVMD-User": user,
"X-KVMD-Passwd": passwd, "X-KVMD-Passwd": passwd,
"User-Agent": self.__user_agent,
},
"connector": aiohttp.UnixConnector(path=self.__unix_path),
"timeout": aiohttp.ClientTimeout(total=self.__timeout),
} }
return aiohttp.ClientSession(**kwargs) return KvmdClientSession(lambda: self._make_http_session(headers))
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"

View File

@ -20,7 +20,10 @@
# ========================================================================== # # ========================================================================== #
import io
import contextlib import contextlib
import dataclasses
import functools
import types import types
from typing import Callable from typing import Callable
@ -31,10 +34,15 @@ from typing import AsyncGenerator
import aiohttp import aiohttp
import ustreamer import ustreamer
from PIL import Image as PilImage
from .. import tools from .. import tools
from .. import aiotools from .. import aiotools
from .. import htclient from .. import htclient
from . import BaseHttpClient
from . import BaseHttpClientSession
# ===== # =====
class StreamerError(Exception): class StreamerError(Exception):
@ -50,7 +58,7 @@ class StreamerPermError(StreamerError):
# ===== # =====
class StreamFormats: class StreamerFormats:
JPEG = 1195724874 # V4L2_PIX_FMT_JPEG JPEG = 1195724874 # V4L2_PIX_FMT_JPEG
H264 = 875967048 # V4L2_PIX_FMT_H264 H264 = 875967048 # V4L2_PIX_FMT_H264
_MJPEG = 1196444237 # V4L2_PIX_FMT_MJPEG _MJPEG = 1196444237 # V4L2_PIX_FMT_MJPEG
@ -68,17 +76,85 @@ class BaseStreamerClient:
# ===== # =====
@dataclasses.dataclass(frozen=True)
class StreamerSnapshot:
online: bool
width: int
height: int
headers: tuple[tuple[str, str], ...]
data: bytes
async def make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
assert max_width >= 0
assert max_height >= 0
assert quality > 0
if max_width == 0 and max_height == 0:
max_width = self.width // 5
max_height = self.height // 5
else:
max_width = min((max_width or self.width), self.width)
max_height = min((max_height or self.height), self.height)
if (max_width, max_height) == (self.width, self.height):
return self.data
return (await aiotools.run_async(self.__inner_make_preview, max_width, max_height, quality))
@functools.lru_cache(maxsize=1)
def __inner_make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
with io.BytesIO(self.data) as snapshot_bio:
with io.BytesIO() as preview_bio:
with PilImage.open(snapshot_bio) as image:
image.thumbnail((max_width, max_height), PilImage.Resampling.LANCZOS)
image.save(preview_bio, format="jpeg", quality=quality)
return preview_bio.getvalue()
class HttpStreamerClientSession(BaseHttpClientSession):
async def get_state(self) -> dict:
session = self._ensure_http_session()
async with session.get("/state") as response:
htclient.raise_not_200(response)
return (await response.json())["result"]
async def take_snapshot(self, timeout: float) -> StreamerSnapshot:
session = self._ensure_http_session()
async with session.get(
url="/snapshot",
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
htclient.raise_not_200(response)
return StreamerSnapshot(
online=(response.headers["X-UStreamer-Online"] == "true"),
width=int(response.headers["X-UStreamer-Width"]),
height=int(response.headers["X-UStreamer-Height"]),
headers=tuple(
(key, value)
for (key, value) in tools.sorted_kvs(dict(response.headers))
if key.lower().startswith("x-ustreamer-") or key.lower() in [
"x-timestamp",
"access-control-allow-origin",
"cache-control",
"pragma",
"expires",
]
),
data=bytes(await response.read()),
)
@contextlib.contextmanager @contextlib.contextmanager
def _http_handle_errors() -> Generator[None, None, None]: def _http_reading_handle_errors() -> Generator[None, None, None]:
try: try:
yield yield
except Exception as err: # Тут бывают и ассерты, и KeyError, и прочая херня except Exception as ex: # Тут бывают и ассерты, и KeyError, и прочая херня
if isinstance(err, StreamerTempError): if isinstance(ex, StreamerTempError):
raise raise
raise StreamerTempError(tools.efmt(err)) raise StreamerTempError(tools.efmt(ex))
class HttpStreamerClient(BaseStreamerClient): class HttpStreamerClient(BaseHttpClient, BaseStreamerClient):
def __init__( def __init__(
self, self,
name: str, name: str,
@ -87,29 +163,35 @@ class HttpStreamerClient(BaseStreamerClient):
user_agent: str, user_agent: str,
) -> None: ) -> None:
super().__init__(unix_path, timeout, user_agent)
self.__name = name self.__name = name
self.__unix_path = unix_path
self.__timeout = timeout def make_session(self) -> HttpStreamerClientSession:
self.__user_agent = user_agent return HttpStreamerClientSession(self._make_http_session)
def get_format(self) -> int: def get_format(self) -> int:
return StreamFormats.JPEG return StreamerFormats.JPEG
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]: async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]:
with _http_handle_errors(): with _http_reading_handle_errors():
async with self.__make_http_session() as session: async with self._make_http_session() as session:
async with session.get( async with session.get(
url=self.__make_url("stream"), url="/stream",
params={"extra_headers": "1"}, params={"extra_headers": "1"},
timeout=aiohttp.ClientTimeout(
connect=session.timeout.total,
sock_read=session.timeout.total,
),
) as response: ) as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
reader = aiohttp.MultipartReader.from_response(response) reader = aiohttp.MultipartReader.from_response(response)
self.__patch_stream_reader(reader.resp.content) self.__patch_stream_reader(reader.resp.content)
async def read_frame(key_required: bool) -> dict: async def read_frame(key_required: bool) -> dict:
_ = key_required _ = key_required
with _http_handle_errors(): with _http_reading_handle_errors():
frame = await reader.next() # pylint: disable=not-callable frame = await reader.next() # pylint: disable=not-callable
if not isinstance(frame, aiohttp.BodyPartReader): if not isinstance(frame, aiohttp.BodyPartReader):
raise StreamerTempError("Expected body part") raise StreamerTempError("Expected body part")
@ -123,26 +205,11 @@ class HttpStreamerClient(BaseStreamerClient):
"width": int(frame.headers["X-UStreamer-Width"]), "width": int(frame.headers["X-UStreamer-Width"]),
"height": int(frame.headers["X-UStreamer-Height"]), "height": int(frame.headers["X-UStreamer-Height"]),
"data": data, "data": data,
"format": StreamFormats.JPEG, "format": StreamerFormats.JPEG,
} }
yield read_frame yield read_frame
def __make_http_session(self) -> aiohttp.ClientSession:
kwargs: dict = {
"headers": {"User-Agent": self.__user_agent},
"connector": aiohttp.UnixConnector(path=self.__unix_path),
"timeout": aiohttp.ClientTimeout(
connect=self.__timeout,
sock_read=self.__timeout,
),
}
return aiohttp.ClientSession(**kwargs)
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"
def __patch_stream_reader(self, reader: aiohttp.StreamReader) -> None: def __patch_stream_reader(self, reader: aiohttp.StreamReader) -> None:
# https://github.com/pikvm/pikvm/issues/92 # https://github.com/pikvm/pikvm/issues/92
# Infinite looping in BodyPartReader.read() because _at_eof flag. # Infinite looping in BodyPartReader.read() because _at_eof flag.
@ -162,15 +229,15 @@ class HttpStreamerClient(BaseStreamerClient):
# ===== # =====
@contextlib.contextmanager @contextlib.contextmanager
def _memsink_handle_errors() -> Generator[None, None, None]: def _memsink_reading_handle_errors() -> Generator[None, None, None]:
try: try:
yield yield
except StreamerPermError: except StreamerPermError:
raise raise
except FileNotFoundError as err: except FileNotFoundError as ex:
raise StreamerTempError(tools.efmt(err)) raise StreamerTempError(tools.efmt(ex))
except Exception as err: except Exception as ex:
raise StreamerPermError(tools.efmt(err)) raise StreamerPermError(tools.efmt(ex))
class MemsinkStreamerClient(BaseStreamerClient): class MemsinkStreamerClient(BaseStreamerClient):
@ -198,11 +265,11 @@ class MemsinkStreamerClient(BaseStreamerClient):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]: async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]:
with _memsink_handle_errors(): with _memsink_reading_handle_errors():
with ustreamer.Memsink(**self.__kwargs) as sink: with ustreamer.Memsink(**self.__kwargs) as sink:
async def read_frame(key_required: bool) -> dict: async def read_frame(key_required: bool) -> dict:
key_required = (key_required and self.__fmt == StreamFormats.H264) key_required = (key_required and self.__fmt == StreamerFormats.H264)
with _memsink_handle_errors(): with _memsink_reading_handle_errors():
while True: while True:
frame = await aiotools.run_async(sink.wait_frame, key_required) frame = await aiotools.run_async(sink.wait_frame, key_required)
if frame is not None: if frame is not None:
@ -211,8 +278,8 @@ class MemsinkStreamerClient(BaseStreamerClient):
yield read_frame yield read_frame
def __check_format(self, fmt: int) -> None: def __check_format(self, fmt: int) -> None:
if fmt == StreamFormats._MJPEG: # pylint: disable=protected-access if fmt == StreamerFormats._MJPEG: # pylint: disable=protected-access
fmt = StreamFormats.JPEG fmt = StreamerFormats.JPEG
if fmt != self.__fmt: if fmt != self.__fmt:
raise StreamerPermError("Invalid sink format") raise StreamerPermError("Invalid sink format")

269
kvmd/edid.py Normal file
View File

@ -0,0 +1,269 @@
# ========================================================================== #
# #
# KVMD - The main PiKVM daemon. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import os
import re
import dataclasses
import contextlib
from typing import IO
from typing import Generator
# =====
class EdidNoBlockError(Exception):
pass
@contextlib.contextmanager
def _smart_open(path: str, mode: str) -> Generator[IO, None, None]:
fd = (0 if "r" in mode else 1)
with (os.fdopen(fd, mode, closefd=False) if path == "-" else open(path, mode)) as file:
yield file
if "w" in mode:
file.flush()
@dataclasses.dataclass(frozen=True)
class _CeaBlock:
tag: int
data: bytes
def __post_init__(self) -> None:
assert 0 < self.tag <= 0b111
assert 0 < len(self.data) <= 0b11111
@property
def size(self) -> int:
return len(self.data) + 1
def pack(self) -> bytes:
header = (self.tag << 5) | len(self.data)
return header.to_bytes() + self.data
@classmethod
def first_from_raw(cls, raw: (bytes | list[int])) -> "_CeaBlock":
assert 0 < raw[0] <= 0xFF
tag = (raw[0] & 0b11100000) >> 5
data_size = (raw[0] & 0b00011111)
data = bytes(raw[1:data_size + 1])
return _CeaBlock(tag, data)
_CEA = 128
_CEA_AUDIO = 1
_CEA_SPEAKERS = 4
class Edid:
# https://en.wikipedia.org/wiki/Extended_Display_Identification_Data
def __init__(self, data: bytes) -> None:
assert len(data) == 256
self.__data = list(data)
@classmethod
def from_file(cls, path: str) -> "Edid":
with _smart_open(path, "rb") as file:
data = file.read()
if not data.startswith(b"\x00\xFF\xFF\xFF\xFF\xFF\xFF\x00"):
text = re.sub(r"\s", "", data.decode())
data = bytes([
int(text[index:index + 2], 16)
for index in range(0, len(text), 2)
])
assert len(data) == 256, f"Invalid EDID length: {len(data)}, should be 256 bytes"
assert data[126] == 1, "Zero extensions number"
assert (data[_CEA + 0], data[_CEA + 1]) == (0x02, 0x03), "Can't find CEA extension"
return Edid(data)
def write_hex(self, path: str) -> None:
self.__update_checksums()
text = "\n".join(
"".join(
f"{item:0{2}X}"
for item in self.__data[index:index + 16]
)
for index in range(0, len(self.__data), 16)
) + "\n"
with _smart_open(path, "w") as file:
file.write(text)
def write_bin(self, path: str) -> None:
self.__update_checksums()
with _smart_open(path, "wb") as file:
file.write(bytes(self.__data))
def __update_checksums(self) -> None:
self.__data[127] = 256 - (sum(self.__data[:127]) % 256)
self.__data[255] = 256 - (sum(self.__data[128:255]) % 256)
# =====
def get_mfc_id(self) -> str:
raw = self.__data[8] << 8 | self.__data[9]
return bytes([
((raw >> 10) & 0b11111) + 0x40,
((raw >> 5) & 0b11111) + 0x40,
(raw & 0b11111) + 0x40,
]).decode("ascii")
def set_mfc_id(self, mfc_id: str) -> None:
assert len(mfc_id) == 3, "Mfc ID must be 3 characters long"
data = mfc_id.upper().encode("ascii")
for ch in data:
assert 0x41 <= ch <= 0x5A, "Mfc ID must contain only A-Z characters"
raw = (
(data[2] - 0x40)
| ((data[1] - 0x40) << 5)
| ((data[0] - 0x40) << 10)
)
self.__data[8] = (raw >> 8) & 0xFF
self.__data[9] = raw & 0xFF
# =====
def get_product_id(self) -> int:
return (self.__data[10] | self.__data[11] << 8)
def set_product_id(self, product_id: int) -> None:
assert 0 <= product_id <= 0xFFFF, f"Product ID should be from 0 to {0xFFFF}"
self.__data[10] = product_id & 0xFF
self.__data[11] = (product_id >> 8) & 0xFF
# =====
def get_serial(self) -> int:
return (
self.__data[12]
| self.__data[13] << 8
| self.__data[14] << 16
| self.__data[15] << 24
)
def set_serial(self, serial: int) -> None:
assert 0 <= serial <= 0xFFFFFFFF, f"Serial should be from 0 to {0xFFFFFFFF}"
self.__data[12] = serial & 0xFF
self.__data[13] = (serial >> 8) & 0xFF
self.__data[14] = (serial >> 16) & 0xFF
self.__data[15] = (serial >> 24) & 0xFF
# =====
def get_monitor_name(self) -> str:
return self.__get_dtd_text(0xFC, "Monitor Name")
def set_monitor_name(self, text: str) -> None:
self.__set_dtd_text(0xFC, "Monitor Name", text)
def get_monitor_serial(self) -> str:
return self.__get_dtd_text(0xFF, "Monitor Serial")
def set_monitor_serial(self, text: str) -> None:
self.__set_dtd_text(0xFF, "Monitor Serial", text)
def __get_dtd_text(self, d_type: int, name: str) -> str:
index = self.__find_dtd_text(d_type, name)
return bytes(self.__data[index:index + 13]).decode("cp437").strip()
def __set_dtd_text(self, d_type: int, name: str, text: str) -> None:
index = self.__find_dtd_text(d_type, name)
encoded = (text[:13] + "\n" + " " * 12)[:13].encode("cp437")
for (offset, ch) in enumerate(encoded):
self.__data[index + offset] = ch
def __find_dtd_text(self, d_type: int, name: str) -> int:
for index in [54, 72, 90, 108]:
if self.__data[index + 3] == d_type:
return index + 5
raise EdidNoBlockError(f"Can't find DTD {name}")
# ===== CEA =====
def get_audio(self) -> bool:
(cbs, _) = self.__parse_cea()
audio = False
speakers = False
for cb in cbs:
if cb.tag == _CEA_AUDIO:
audio = True
elif cb.tag == _CEA_SPEAKERS:
speakers = True
return (audio and speakers and self.__get_basic_audio())
def set_audio(self, enabled: bool) -> None:
(cbs, dtds) = self.__parse_cea()
cbs = [cb for cb in cbs if cb.tag not in [_CEA_AUDIO, _CEA_SPEAKERS]]
if enabled:
cbs.append(_CeaBlock(_CEA_AUDIO, b"\x09\x7f\x07"))
cbs.append(_CeaBlock(_CEA_SPEAKERS, b"\x01\x00\x00"))
self.__replace_cea(cbs, dtds)
self.__set_basic_audio(enabled)
def __get_basic_audio(self) -> bool:
return bool(self.__data[_CEA + 3] & 0b01000000)
def __set_basic_audio(self, enabled: bool) -> None:
if enabled:
self.__data[_CEA + 3] |= 0b01000000
else:
self.__data[_CEA + 3] &= (0xFF - 0b01000000) # ~X
def __parse_cea(self) -> tuple[list[_CeaBlock], bytes]:
cea = self.__data[_CEA:]
dtd_begin = cea[2]
if dtd_begin == 0:
return ([], b"")
cbs: list[_CeaBlock] = []
if dtd_begin > 4:
raw = cea[4:dtd_begin]
while len(raw) != 0:
cb = _CeaBlock.first_from_raw(raw)
cbs.append(cb)
raw = raw[cb.size:]
dtds = b""
assert dtd_begin >= 4
raw = cea[dtd_begin:]
while len(raw) > (18 + 1) and raw[0] != 0:
dtds += bytes(raw[:18])
raw = raw[18:]
return (cbs, dtds)
def __replace_cea(self, cbs: list[_CeaBlock], dtds: bytes) -> None:
cbs_packed = b""
for cb in cbs:
cbs_packed += cb.pack()
raw = cbs_packed + dtds
assert len(raw) <= (128 - 4 - 1), "Too many CEA blocks or DTDs"
self.__data[_CEA + 2] = (0 if len(raw) == 0 else (len(cbs_packed) + 4))
for index in range(4, 127):
try:
ch = raw[index - 4]
except IndexError:
ch = 0
self.__data[_CEA + index] = ch

View File

@ -33,6 +33,7 @@ class Partition:
mount_path: str mount_path: str
root_path: str root_path: str
user: str user: str
group: str
# ===== # =====
@ -50,7 +51,7 @@ def _find_single(part_type: str) -> Partition:
if len(parts) == 0: if len(parts) == 0:
if os.path.exists('/var/lib/kvmd/msd'): if os.path.exists('/var/lib/kvmd/msd'):
#set default value #set default value
parts = [Partition(mount_path='/var/lib/kvmd/msd', root_path='/var/lib/kvmd/msd', user='kvmd')] parts = [Partition(mount_path='/var/lib/kvmd/msd', root_path='/var/lib/kvmd/msd',group='kvmd', user='kvmd')]
else: else:
raise RuntimeError(f"Can't find {part_type!r} mountpoint") raise RuntimeError(f"Can't find {part_type!r} mountpoint")
return parts[0] return parts[0]
@ -64,12 +65,13 @@ def _find_partitions(part_type: str, single: bool) -> list[Partition]:
if line and not line.startswith("#"): if line and not line.startswith("#"):
fields = line.split() fields = line.split()
if len(fields) == 6: if len(fields) == 6:
options = dict(re.findall(r"X-kvmd\.%s-(root|user)(?:=([^,]+))?" % (part_type), fields[3])) options = dict(re.findall(r"X-kvmd\.%s-(root|user|group)(?:=([^,]+))?" % (part_type), fields[3]))
if options: if options:
parts.append(Partition( parts.append(Partition(
mount_path=os.path.normpath(fields[1]), mount_path=os.path.normpath(fields[1]),
root_path=os.path.normpath(options.get("root", "") or fields[1]), root_path=os.path.normpath(options.get("root", "") or fields[1]),
user=options.get("user", ""), user=options.get("user", ""),
group=options.get("group", ""),
)) ))
if single: if single:
break break

View File

@ -22,7 +22,9 @@
import sys import sys
import os import os
import stat
import pwd import pwd
import grp
import shutil import shutil
import subprocess import subprocess
@ -44,8 +46,8 @@ def _remount(path: str, rw: bool) -> None:
_log(f"Remounting {path} to {mode.upper()}-mode ...") _log(f"Remounting {path} to {mode.upper()}-mode ...")
try: try:
subprocess.check_call(["/bin/mount", "--options", f"remount,{mode}", path]) subprocess.check_call(["/bin/mount", "--options", f"remount,{mode}", path])
except subprocess.CalledProcessError as err: except subprocess.CalledProcessError as ex:
raise SystemExit(f"Can't remount: {err}") raise SystemExit(f"Can't remount: {ex}")
def _mkdir(path: str) -> None: def _mkdir(path: str) -> None:
@ -53,8 +55,8 @@ def _mkdir(path: str) -> None:
_log(f"MKDIR --- {path}") _log(f"MKDIR --- {path}")
try: try:
os.mkdir(path) os.mkdir(path)
except Exception as err: except Exception as ex:
raise SystemExit(f"Can't create directory: {err}") raise SystemExit(f"Can't create directory: {ex}")
def _rmtree(path: str) -> None: def _rmtree(path: str) -> None:
@ -62,8 +64,8 @@ def _rmtree(path: str) -> None:
_log(f"RMALL --- {path}") _log(f"RMALL --- {path}")
try: try:
shutil.rmtree(path) shutil.rmtree(path)
except Exception as err: except Exception as ex:
raise SystemExit(f"Can't remove directory: {err}") raise SystemExit(f"Can't remove directory: {ex}")
def _rm(path: str) -> None: def _rm(path: str) -> None:
@ -71,25 +73,43 @@ def _rm(path: str) -> None:
_log(f"RM --- {path}") _log(f"RM --- {path}")
try: try:
os.remove(path) os.remove(path)
except Exception as err: except Exception as ex:
raise SystemExit(f"Can't remove file: {err}") raise SystemExit(f"Can't remove file: {ex}")
def _move(src: str, dest: str) -> None: def _move(src: str, dest: str) -> None:
_log(f"MOVE --- {src} --> {dest}") _log(f"MOVE --- {src} --> {dest}")
try: try:
os.rename(src, dest) os.rename(src, dest)
except Exception as err: except Exception as ex:
raise SystemExit(f"Can't move file: {err}") raise SystemExit(f"Can't move file: {ex}")
def _chown(path: str, user: str) -> None: def _chown(path: str, user: str) -> None:
if pwd.getpwuid(os.stat(path).st_uid).pw_name != user: if pwd.getpwuid(os.stat(path).st_uid).pw_name != user:
_log(f"CHOWN --- {user} - {path}") _log(f"CHOWN --- {user} - {path}")
try: try:
shutil.chown(path, user) shutil.chown(path, user=user)
except Exception as err: except Exception as ex:
raise SystemExit(f"Can't change ownership: {err}") raise SystemExit(f"Can't change ownership: {ex}")
def _chgrp(path: str, group: str) -> None:
if grp.getgrgid(os.stat(path).st_gid).gr_name != group:
_log(f"CHGRP --- {group} - {path}")
try:
shutil.chown(path, group=group)
except Exception as ex:
raise SystemExit(f"Can't change group: {ex}")
def _chmod(path: str, mode: int) -> None:
if stat.S_IMODE(os.stat(path).st_mode) != mode:
_log(f"CHMOD --- 0o{mode:o} - {path}")
try:
os.chmod(path, mode)
except Exception as ex:
raise SystemExit(f"Can't change permissions: {ex}")
# ===== # =====
@ -112,13 +132,21 @@ def _fix_msd(part: Partition) -> None:
if part.user: if part.user:
_chown(part.root_path, part.user) _chown(part.root_path, part.user)
if part.group:
_chgrp(part.root_path, part.group)
def _fix_pst(part: Partition) -> None: def _fix_pst(part: Partition) -> None:
path = os.path.join(part.root_path, "data") path = os.path.join(part.root_path, "data")
_mkdir(path) _mkdir(path)
if part.user: if part.user:
_chown(part.root_path, part.user)
_chown(path, part.user) _chown(path, part.user)
if part.group:
_chgrp(part.root_path, part.group)
_chgrp(path, part.group)
if part.user and part.group:
_chmod(part.root_path, 0o1775)
# ===== # =====

View File

@ -36,27 +36,27 @@ def make_user_agent(app: str) -> str:
return f"{app}/{__version__}" return f"{app}/{__version__}"
def raise_not_200(response: aiohttp.ClientResponse) -> None: def raise_not_200(resp: aiohttp.ClientResponse) -> None:
if response.status != 200: if resp.status != 200:
assert response.reason is not None assert resp.reason is not None
response.release() resp.release()
raise aiohttp.ClientResponseError( raise aiohttp.ClientResponseError(
response.request_info, resp.request_info,
response.history, resp.history,
status=response.status, status=resp.status,
message=response.reason, message=resp.reason,
headers=response.headers, headers=resp.headers,
) )
def get_filename(response: aiohttp.ClientResponse) -> str: def get_filename(resp: aiohttp.ClientResponse) -> str:
try: try:
disp = response.headers["Content-Disposition"] disp = resp.headers["Content-Disposition"]
parsed = aiohttp.multipart.parse_content_disposition(disp) parsed = aiohttp.multipart.parse_content_disposition(disp)
return str(parsed[1]["filename"]) return str(parsed[1]["filename"])
except Exception: except Exception:
try: try:
return os.path.basename(response.url.path) return os.path.basename(resp.url.path)
except Exception: except Exception:
raise aiohttp.ClientError("Can't determine filename") raise aiohttp.ClientError("Can't determine filename")
@ -79,6 +79,6 @@ async def download(
), ),
} }
async with aiohttp.ClientSession(**kwargs) as session: async with aiohttp.ClientSession(**kwargs) as session:
async with session.get(url, verify_ssl=verify) as response: async with session.get(url, verify_ssl=verify) as resp: # type: ignore
raise_not_200(response) raise_not_200(resp)
yield response yield resp

View File

@ -157,7 +157,7 @@ def make_json_response(
wrap_result: bool=True, wrap_result: bool=True,
) -> Response: ) -> Response:
response = Response( resp = Response(
text=json.dumps(({ text=json.dumps(({
"ok": (status == 200), "ok": (status == 200),
"result": (result or {}), "result": (result or {}),
@ -167,18 +167,18 @@ def make_json_response(
) )
if set_cookies: if set_cookies:
for (key, value) in set_cookies.items(): for (key, value) in set_cookies.items():
response.set_cookie(key, value, httponly=True, samesite="Strict") resp.set_cookie(key, value, httponly=True, samesite="Strict")
return response return resp
def make_json_exception(err: Exception, status: (int | None)=None) -> Response: def make_json_exception(ex: Exception, status: (int | None)=None) -> Response:
name = type(err).__name__ name = type(ex).__name__
msg = str(err) msg = str(ex)
if isinstance(err, HttpError): if isinstance(ex, HttpError):
status = err.status status = ex.status
else: else:
get_logger().error("API error: %s: %s", name, msg) get_logger().error("API error: %s: %s", name, msg)
assert status is not None, err assert status is not None, ex
return make_json_response({ return make_json_response({
"error": name, "error": name,
"error_msg": msg, "error_msg": msg,
@ -186,35 +186,35 @@ def make_json_exception(err: Exception, status: (int | None)=None) -> Response:
async def start_streaming( async def start_streaming(
request: Request, req: Request,
content_type: str, content_type: str,
content_length: int=-1, content_length: int=-1,
file_name: str="", file_name: str="",
) -> StreamResponse: ) -> StreamResponse:
response = StreamResponse(status=200, reason="OK") resp = StreamResponse(status=200, reason="OK")
response.content_type = content_type resp.content_type = content_type
if content_length >= 0: # pylint: disable=consider-using-min-builtin if content_length >= 0: # pylint: disable=consider-using-min-builtin
response.content_length = content_length resp.content_length = content_length
if file_name: if file_name:
file_name = urllib.parse.quote(file_name, safe="") file_name = urllib.parse.quote(file_name, safe="")
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{file_name}" resp.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{file_name}"
await response.prepare(request) await resp.prepare(req)
return response return resp
async def stream_json(response: StreamResponse, result: dict, ok: bool=True) -> None: async def stream_json(resp: StreamResponse, result: dict, ok: bool=True) -> None:
await response.write(json.dumps({ await resp.write(json.dumps({
"ok": ok, "ok": ok,
"result": result, "result": result,
}).encode("utf-8") + b"\r\n") }).encode("utf-8") + b"\r\n")
async def stream_json_exception(response: StreamResponse, err: Exception) -> None: async def stream_json_exception(resp: StreamResponse, ex: Exception) -> None:
name = type(err).__name__ name = type(ex).__name__
msg = str(err) msg = str(ex)
get_logger().error("API error: %s: %s", name, msg) get_logger().error("API error: %s: %s", name, msg)
await stream_json(response, { await stream_json(resp, {
"error": name, "error": name,
"error_msg": msg, "error_msg": msg,
}, False) }, False)
@ -249,15 +249,15 @@ def parse_ws_event(msg: str) -> tuple[str, dict]:
_REQUEST_AUTH_INFO = "_kvmd_auth_info" _REQUEST_AUTH_INFO = "_kvmd_auth_info"
def _format_P(request: BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name def _format_P(req: BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name
return (getattr(request, _REQUEST_AUTH_INFO, None) or "-") return (getattr(req, _REQUEST_AUTH_INFO, None) or "-")
AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access
def set_request_auth_info(request: BaseRequest, info: str) -> None: def set_request_auth_info(req: BaseRequest, info: str) -> None:
setattr(request, _REQUEST_AUTH_INFO, info) setattr(req, _REQUEST_AUTH_INFO, info)
# ===== # =====
@ -318,16 +318,16 @@ class HttpServer:
self.__add_exposed_ws(ws_exposed) self.__add_exposed_ws(ws_exposed)
def __add_exposed_http(self, exposed: HttpExposed) -> None: def __add_exposed_http(self, exposed: HttpExposed) -> None:
async def wrapper(request: Request) -> Response: async def wrapper(req: Request) -> Response:
try: try:
await self._check_request_auth(exposed, request) await self._check_request_auth(exposed, req)
return (await exposed.handler(request)) return (await exposed.handler(req))
except IsBusyError as err: except IsBusyError as ex:
return make_json_exception(err, 409) return make_json_exception(ex, 409)
except (ValidatorError, OperationError) as err: except (ValidatorError, OperationError) as ex:
return make_json_exception(err, 400) return make_json_exception(ex, 400)
except HttpError as err: except HttpError as ex:
return make_json_exception(err) return make_json_exception(ex)
self.__app.router.add_route(exposed.method, exposed.path, wrapper) self.__app.router.add_route(exposed.method, exposed.path, wrapper)
def __add_exposed_ws(self, exposed: WsExposed) -> None: def __add_exposed_ws(self, exposed: WsExposed) -> None:
@ -342,10 +342,10 @@ class HttpServer:
# ===== # =====
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]: async def _ws_session(self, req: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
assert self.__ws_heartbeat is not None assert self.__ws_heartbeat is not None
wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat) wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
await wsr.prepare(request) await wsr.prepare(req)
ws = WsSession(wsr, kwargs) ws = WsSession(wsr, kwargs)
async with self.__ws_sessions_lock: async with self.__ws_sessions_lock:
@ -364,8 +364,8 @@ class HttpServer:
if msg.type == WSMsgType.TEXT: if msg.type == WSMsgType.TEXT:
try: try:
(event_type, event) = parse_ws_event(msg.data) (event_type, event) = parse_ws_event(msg.data)
except Exception as err: except Exception as ex:
logger.error("Can't parse JSON event from websocket: %r", err) logger.error("Can't parse JSON event from websocket: %r", ex)
else: else:
handler = self.__ws_handlers.get(event_type) handler = self.__ws_handlers.get(event_type)
if handler: if handler:
@ -384,7 +384,7 @@ class HttpServer:
break break
return ws.wsr return ws.wsr
async def _broadcast_ws_event(self, event_type: str, event: (dict | None)) -> None: async def _broadcast_ws_event(self, event_type: str, event: (dict | None), legacy: (bool | None)=None) -> None:
if self.__ws_sessions: if self.__ws_sessions:
await asyncio.gather(*[ await asyncio.gather(*[
ws.send_event(event_type, event) ws.send_event(event_type, event)
@ -393,6 +393,7 @@ class HttpServer:
not ws.wsr.closed not ws.wsr.closed
and ws.wsr._req is not None # pylint: disable=protected-access and ws.wsr._req is not None # pylint: disable=protected-access
and ws.wsr._req.transport is not None # pylint: disable=protected-access and ws.wsr._req.transport is not None # pylint: disable=protected-access
and (legacy is None or ws.kwargs.get("legacy") == legacy)
) )
], return_exceptions=True) ], return_exceptions=True)
@ -417,7 +418,7 @@ class HttpServer:
# ===== # =====
async def _check_request_auth(self, exposed: HttpExposed, request: Request) -> None: async def _check_request_auth(self, exposed: HttpExposed, req: Request) -> None:
pass pass
async def _init_app(self) -> None: async def _init_app(self) -> None:

View File

@ -130,18 +130,25 @@ class InotifyMask:
# | OPEN # | OPEN
# ) # )
# Helper for all modify events # Helper for all changes events except MODIFY, because it fires on each write()
ALL_MODIFY_EVENTS = ( ALL_CHANGES_EVENTS = (
CLOSE_WRITE CLOSE_WRITE
| CREATE | CREATE
| DELETE | DELETE
| DELETE_SELF | DELETE_SELF
| MODIFY
| MOVE_SELF | MOVE_SELF
| MOVED_FROM | MOVED_FROM
| MOVED_TO | MOVED_TO
) )
# Helper for typicals events when we need to restart watcher
ALL_RESTART_EVENTS = (
DELETE_SELF
| MOVE_SELF
| UNMOUNT
| ISDIR
)
# Special flags for watch() # Special flags for watch()
# DONT_FOLLOW = 0x02000000 # Don't follow a symbolic link # DONT_FOLLOW = 0x02000000 # Don't follow a symbolic link
# EXCL_UNLINK = 0x04000000 # Exclude events on unlinked objects # EXCL_UNLINK = 0x04000000 # Exclude events on unlinked objects
@ -172,6 +179,10 @@ class InotifyEvent:
name: str name: str
path: str path: str
@property
def restart(self) -> bool:
return bool(self.mask & InotifyMask.ALL_RESTART_EVENTS)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<InotifyEvent: wd={self.wd}, mask={InotifyMask.to_string(self.mask)}," f"<InotifyEvent: wd={self.wd}, mask={InotifyMask.to_string(self.mask)},"
@ -190,6 +201,9 @@ class Inotify:
self.__events_queue: "asyncio.Queue[InotifyEvent]" = asyncio.Queue() self.__events_queue: "asyncio.Queue[InotifyEvent]" = asyncio.Queue()
async def watch_all_changes(self, *paths: str) -> None:
await self.watch(InotifyMask.ALL_CHANGES_EVENTS, *paths)
async def watch(self, mask: int, *paths: str) -> None: async def watch(self, mask: int, *paths: str) -> None:
for path in paths: for path in paths:
path = os.path.normpath(path) path = os.path.normpath(path)
@ -222,7 +236,7 @@ class Inotify:
except asyncio.TimeoutError: except asyncio.TimeoutError:
return None return None
async def get_series(self, timeout: float) -> list[InotifyEvent]: async def get_series(self, timeout: float, max_series: int=64) -> list[InotifyEvent]:
series: list[InotifyEvent] = [] series: list[InotifyEvent] = []
event = await self.get_event(timeout) event = await self.get_event(timeout)
if event: if event:
@ -231,6 +245,8 @@ class Inotify:
event = await self.get_event(timeout) event = await self.get_event(timeout)
if event: if event:
series.append(event) series.append(event)
if len(series) >= max_series:
break
return series return series
def __read_and_queue_events(self) -> None: def __read_and_queue_events(self) -> None:
@ -271,8 +287,8 @@ class Inotify:
while True: while True:
try: try:
return os.read(self.__fd, _EVENTS_BUFFER_LENGTH) return os.read(self.__fd, _EVENTS_BUFFER_LENGTH)
except OSError as err: except OSError as ex:
if err.errno == errno.EINTR: if ex.errno == errno.EINTR:
pass pass
def __enter__(self) -> "Inotify": def __enter__(self) -> "Inotify":

View File

@ -135,8 +135,8 @@ def _read_keyboard_layout(path: str) -> dict[int, list[At1Key]]: # Keysym to ev
try: try:
at1_code = int(parts[1], 16) at1_code = int(parts[1], 16)
except ValueError as err: except ValueError as ex:
logger.error("Syntax error at %s:%d: %s", path, lineno, err) logger.error("Syntax error at %s:%d: %s", path, lineno, ex)
continue continue
rest = parts[2:] rest = parts[2:]

View File

@ -34,10 +34,10 @@ def is_ipv6_enabled() -> bool:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
sock.bind(("::1", 0)) sock.bind(("::1", 0))
return True return True
except OSError as err: except OSError as ex:
if err.errno in [errno.EADDRNOTAVAIL, errno.EAFNOSUPPORT]: if ex.errno in [errno.EADDRNOTAVAIL, errno.EAFNOSUPPORT]:
return False return False
if err.errno == errno.EADDRINUSE: if ex.errno == errno.EADDRINUSE:
return True return True
raise raise

View File

@ -48,7 +48,16 @@ class BaseAtx(BasePlugin):
async def get_state(self) -> dict: async def get_state(self) -> dict:
raise NotImplementedError raise NotImplementedError
async def trigger_state(self) -> None:
raise NotImplementedError
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - enabled -- Full
# - busy -- Partial
# - leds -- Partial
# ===========================
yield {} yield {}
raise NotImplementedError raise NotImplementedError

View File

@ -36,6 +36,9 @@ class AtxDisabledError(AtxOperationError):
# ===== # =====
class Plugin(BaseAtx): class Plugin(BaseAtx):
def __init__(self) -> None:
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict: async def get_state(self) -> dict:
return { return {
"enabled": False, "enabled": False,
@ -46,10 +49,13 @@ class Plugin(BaseAtx):
}, },
} }
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
while True: while True:
await self.__notifier.wait()
yield (await self.get_state()) yield (await self.get_state())
await aiotools.wait_infinite()
# ===== # =====

View File

@ -21,6 +21,7 @@
import asyncio import asyncio
import copy
from typing import AsyncGenerator from typing import AsyncGenerator
@ -76,7 +77,7 @@ class Plugin(BaseAtx): # pylint: disable=too-many-instance-attributes
self.__notifier = aiotools.AioNotifier() self.__notifier = aiotools.AioNotifier()
self.__region = aiotools.AioExclusiveRegion(AtxIsBusyError, self.__notifier) self.__region = aiotools.AioExclusiveRegion(AtxIsBusyError, self.__notifier)
self.__line_request: (gpiod.LineRequest | None) = None self.__line_req: (gpiod.LineRequest | None) = None
self.__reader = aiogp.AioReader( self.__reader = aiogp.AioReader(
path=self.__device_path, path=self.__device_path,
@ -108,8 +109,8 @@ class Plugin(BaseAtx): # pylint: disable=too-many-instance-attributes
} }
def sysprep(self) -> None: def sysprep(self) -> None:
assert self.__line_request is None assert self.__line_req is None
self.__line_request = gpiod.request_lines( self.__line_req = gpiod.request_lines(
self.__device_path, self.__device_path,
consumer="kvmd::atx", consumer="kvmd::atx",
config={ config={
@ -130,22 +131,26 @@ class Plugin(BaseAtx): # pylint: disable=too-many-instance-attributes
}, },
} }
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {} prev: dict = {}
while True: while True:
state = await self.get_state() if (await self.__notifier.wait()) > 0:
if state != prev_state: prev = {}
yield state new = await self.get_state()
prev_state = state if new != prev:
await self.__notifier.wait() prev = copy.deepcopy(new)
yield new
async def systask(self) -> None: async def systask(self) -> None:
await self.__reader.poll() await self.__reader.poll()
async def cleanup(self) -> None: async def cleanup(self) -> None:
if self.__line_request: if self.__line_req:
try: try:
self.__line_request.release() self.__line_req.release()
except Exception: except Exception:
pass pass
@ -186,7 +191,7 @@ class Plugin(BaseAtx): # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg @aiotools.atomic_fg
async def __click(self, name: str, pin: int, delay: float, wait: bool) -> None: async def __click(self, name: str, pin: int, delay: float, wait: bool) -> None:
if wait: if wait:
async with self.__region: with self.__region:
await self.__inner_click(name, pin, delay) await self.__inner_click(name, pin, delay)
else: else:
await aiotools.run_region_task( await aiotools.run_region_task(
@ -196,11 +201,11 @@ class Plugin(BaseAtx): # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg @aiotools.atomic_fg
async def __inner_click(self, name: str, pin: int, delay: float) -> None: async def __inner_click(self, name: str, pin: int, delay: float) -> None:
assert self.__line_request assert self.__line_req
try: try:
self.__line_request.set_value(pin, gpiod.line.Value(True)) self.__line_req.set_value(pin, gpiod.line.Value(True))
await asyncio.sleep(delay) await asyncio.sleep(delay)
finally: finally:
self.__line_request.set_value(pin, gpiod.line.Value(False)) self.__line_req.set_value(pin, gpiod.line.Value(False))
await asyncio.sleep(1) await asyncio.sleep(1)
get_logger(0).info("Clicked ATX button %r", name) get_logger(0).info("Clicked ATX button %r", name)

View File

@ -75,7 +75,7 @@ class Plugin(BaseAuthService):
async with session.request( async with session.request(
method="POST", method="POST",
url=self.__url, url=self.__url,
timeout=self.__timeout, timeout=aiohttp.ClientTimeout(total=self.__timeout),
json={ json={
"user": user, "user": user,
"passwd": passwd, "passwd": passwd,
@ -85,8 +85,8 @@ class Plugin(BaseAuthService):
"User-Agent": htclient.make_user_agent("KVMD"), "User-Agent": htclient.make_user_agent("KVMD"),
"X-KVMD-User": user, "X-KVMD-User": user,
}, },
) as response: ) as resp:
htclient.raise_not_200(response) htclient.raise_not_200(resp)
return True return True
except Exception: except Exception:
get_logger().exception("Failed HTTP auth request for user %r", user) get_logger().exception("Failed HTTP auth request for user %r", user)

View File

@ -100,10 +100,10 @@ class Plugin(BaseAuthService):
return True return True
except ldap.INVALID_CREDENTIALS: except ldap.INVALID_CREDENTIALS:
pass pass
except ldap.SERVER_DOWN as err: except ldap.SERVER_DOWN as ex:
get_logger().error("LDAP server is down: %s", tools.efmt(err)) get_logger().error("LDAP server is down: %s", tools.efmt(ex))
except Exception as err: except Exception as ex:
get_logger().error("Unexpected LDAP error: %s", tools.efmt(err)) get_logger().error("Unexpected LDAP error: %s", tools.efmt(ex))
finally: finally:
if conn is not None: if conn is not None:
try: try:

View File

@ -435,10 +435,10 @@ class Plugin(BaseAuthService):
timeout=self.__timeout, timeout=self.__timeout,
dict=dct, dict=dct,
) )
request = client.CreateAuthPacket(code=pyrad.packet.AccessRequest, User_Name=user) req = client.CreateAuthPacket(code=pyrad.packet.AccessRequest, User_Name=user)
request["User-Password"] = request.PwCrypt(passwd) req["User-Password"] = req.PwCrypt(passwd)
response = client.SendPacket(request) resp = client.SendPacket(req)
return (response.code == pyrad.packet.AccessAccept) return (resp.code == pyrad.packet.AccessAccept)
except Exception: except Exception:
get_logger().exception("Failed RADIUS auth request for user %r", user) get_logger().exception("Failed RADIUS auth request for user %r", user)
return False return False

View File

@ -21,9 +21,11 @@
import asyncio import asyncio
import functools
import time import time
from typing import Iterable from typing import Iterable
from typing import Callable
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Any from typing import Any
@ -31,14 +33,37 @@ from ...yamlconf import Option
from ...validators.basic import valid_bool from ...validators.basic import valid_bool
from ...validators.basic import valid_int_f1 from ...validators.basic import valid_int_f1
from ...validators.basic import valid_string_list
from ...validators.hid import valid_hid_key
from ...validators.hid import valid_hid_mouse_move
from ...mouse import MouseRange
from .. import BasePlugin from .. import BasePlugin
from .. import get_plugin_class from .. import get_plugin_class
# ===== # =====
class BaseHid(BasePlugin): class BaseHid(BasePlugin): # pylint: disable=too-many-instance-attributes
def __init__(self, jiggler_enabled: bool, jiggler_active: bool, jiggler_interval: int) -> None: def __init__(
self,
ignore_keys: list[str],
mouse_x_min: int,
mouse_x_max: int,
mouse_y_min: int,
mouse_y_max: int,
jiggler_enabled: bool,
jiggler_active: bool,
jiggler_interval: int,
) -> None:
self.__ignore_keys = ignore_keys
self.__mouse_x_range = (mouse_x_min, mouse_x_max)
self.__mouse_y_range = (mouse_y_min, mouse_y_max)
self.__jiggler_enabled = jiggler_enabled self.__jiggler_enabled = jiggler_enabled
self.__jiggler_active = jiggler_active self.__jiggler_active = jiggler_active
self.__jiggler_interval = jiggler_interval self.__jiggler_interval = jiggler_interval
@ -46,8 +71,17 @@ class BaseHid(BasePlugin):
self.__activity_ts = 0 self.__activity_ts = 0
@classmethod @classmethod
def _get_jiggler_options(cls) -> dict[str, Any]: def _get_base_options(cls) -> dict[str, Any]:
return { return {
"ignore_keys": Option([], type=functools.partial(valid_string_list, subval=valid_hid_key)),
"mouse_x_range": {
"min": Option(MouseRange.MIN, type=valid_hid_mouse_move, unpack_as="mouse_x_min"),
"max": Option(MouseRange.MAX, type=valid_hid_mouse_move, unpack_as="mouse_x_max"),
},
"mouse_y_range": {
"min": Option(MouseRange.MIN, type=valid_hid_mouse_move, unpack_as="mouse_y_min"),
"max": Option(MouseRange.MAX, type=valid_hid_mouse_move, unpack_as="mouse_y_max"),
},
"jiggler": { "jiggler": {
"enabled": Option(False, type=valid_bool, unpack_as="jiggler_enabled"), "enabled": Option(False, type=valid_bool, unpack_as="jiggler_enabled"),
"active": Option(False, type=valid_bool, unpack_as="jiggler_active"), "active": Option(False, type=valid_bool, unpack_as="jiggler_active"),
@ -63,7 +97,23 @@ class BaseHid(BasePlugin):
async def get_state(self) -> dict: async def get_state(self) -> dict:
raise NotImplementedError raise NotImplementedError
async def trigger_state(self) -> None:
raise NotImplementedError
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - enabled -- Full
# - online -- Partial
# - busy -- Partial
# - connected -- Partial, nullable
# - keyboard.online -- Partial
# - keyboard.outputs -- Partial
# - keyboard.leds -- Partial
# - mouse.online -- Partial
# - mouse.outputs -- Partial, follows with absolute
# - mouse.absolute -- Partial, follows with outputs
# ===========================
yield {} yield {}
raise NotImplementedError raise NotImplementedError
@ -73,25 +123,6 @@ class BaseHid(BasePlugin):
async def cleanup(self) -> None: async def cleanup(self) -> None:
pass pass
# =====
def send_key_events(self, keys: Iterable[tuple[str, bool]]) -> None:
raise NotImplementedError
def send_mouse_button_event(self, button: str, state: bool) -> None:
raise NotImplementedError
def send_mouse_move_event(self, to_x: int, to_y: int) -> None:
_ = to_x
_ = to_y
def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
_ = delta_x
_ = delta_y
def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
raise NotImplementedError
def set_params( def set_params(
self, self,
keyboard_output: (str | None)=None, keyboard_output: (str | None)=None,
@ -104,25 +135,100 @@ class BaseHid(BasePlugin):
def set_connected(self, connected: bool) -> None: def set_connected(self, connected: bool) -> None:
_ = connected _ = connected
def clear_events(self) -> None: # =====
def send_key_events(self, keys: Iterable[tuple[str, bool]], no_ignore_keys: bool=False) -> None:
for (key, state) in keys:
if no_ignore_keys or key not in self.__ignore_keys:
self.send_key_event(key, state)
def send_key_event(self, key: str, state: bool) -> None:
self._send_key_event(key, state)
self.__bump_activity()
def _send_key_event(self, key: str, state: bool) -> None:
raise NotImplementedError raise NotImplementedError
# ===== # =====
async def systask(self) -> None: def send_mouse_button_event(self, button: str, state: bool) -> None:
factor = 1 self._send_mouse_button_event(button, state)
while True: self.__bump_activity()
if self.__jiggler_active and (self.__activity_ts + self.__jiggler_interval < int(time.monotonic())):
for _ in range(5):
if self.__jiggler_absolute:
self.send_mouse_move_event(100 * factor, 100 * factor)
else:
self.send_mouse_relative_event(10 * factor, 10 * factor)
factor *= -1
await asyncio.sleep(0.1)
await asyncio.sleep(1)
def _bump_activity(self) -> None: def _send_mouse_button_event(self, button: str, state: bool) -> None:
raise NotImplementedError
# =====
def send_mouse_move_event(self, to_x: int, to_y: int) -> None:
if self.__mouse_x_range != MouseRange.RANGE:
to_x = MouseRange.remap(to_x, *self.__mouse_x_range)
if self.__mouse_y_range != MouseRange.RANGE:
to_y = MouseRange.remap(to_y, *self.__mouse_y_range)
self._send_mouse_move_event(to_x, to_y)
self.__bump_activity()
def _send_mouse_move_event(self, to_x: int, to_y: int) -> None:
_ = to_x # XXX: NotImplementedError
_ = to_y
# =====
def send_mouse_relative_events(self, deltas: Iterable[tuple[int, int]], squash: bool) -> None:
self.__process_mouse_delta_event(deltas, squash, self.send_mouse_relative_event)
def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self._send_mouse_relative_event(delta_x, delta_y)
self.__bump_activity()
def _send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
_ = delta_x # XXX: NotImplementedError
_ = delta_y
# =====
def send_mouse_wheel_events(self, deltas: Iterable[tuple[int, int]], squash: bool) -> None:
self.__process_mouse_delta_event(deltas, squash, self.send_mouse_wheel_event)
def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self._send_mouse_wheel_event(delta_x, delta_y)
self.__bump_activity()
def _send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
raise NotImplementedError
# =====
def clear_events(self) -> None:
self._clear_events() # Don't bump activity here
def _clear_events(self) -> None:
raise NotImplementedError
# =====
def __process_mouse_delta_event(
self,
deltas: Iterable[tuple[int, int]],
squash: bool,
handler: Callable[[int, int], None],
) -> None:
if squash:
prev = (0, 0)
for cur in deltas:
if abs(prev[0] + cur[0]) > 127 or abs(prev[1] + cur[1]) > 127:
handler(*prev)
prev = cur
else:
prev = (prev[0] + cur[0], prev[1] + cur[1])
if prev[0] or prev[1]:
handler(*prev)
else:
for xy in deltas:
handler(*xy)
def __bump_activity(self) -> None:
self.__activity_ts = int(time.monotonic()) self.__activity_ts = int(time.monotonic())
def _set_jiggler_absolute(self, absolute: bool) -> None: def _set_jiggler_absolute(self, absolute: bool) -> None:
@ -141,6 +247,21 @@ class BaseHid(BasePlugin):
}, },
} }
# =====
async def systask(self) -> None:
factor = 1
while True:
if self.__jiggler_active and (self.__activity_ts + self.__jiggler_interval < int(time.monotonic())):
for _ in range(5):
if self.__jiggler_absolute:
self.send_mouse_move_event(100 * factor, 100 * factor)
else:
self.send_mouse_relative_event(10 * factor, 10 * factor)
factor *= -1
await asyncio.sleep(0.1)
await asyncio.sleep(1)
# ===== # =====
def get_hid_class(name: str) -> type[BaseHid]: def get_hid_class(name: str) -> type[BaseHid]:

View File

@ -23,9 +23,9 @@
import multiprocessing import multiprocessing
import contextlib import contextlib
import queue import queue
import copy
import time import time
from typing import Iterable
from typing import Generator from typing import Generator
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Any from typing import Any
@ -91,7 +91,7 @@ class _TempRequestError(_RequestError):
# ===== # =====
class BasePhyConnection: class BasePhyConnection:
def send(self, request: bytes) -> bytes: def send(self, req: bytes) -> bytes:
raise NotImplementedError raise NotImplementedError
@ -108,17 +108,22 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
def __init__( # pylint: disable=too-many-arguments,super-init-not-called def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, self,
phy: BasePhy, phy: BasePhy,
ignore_keys: list[str],
mouse_x_range: dict[str, Any],
mouse_y_range: dict[str, Any],
jiggler: dict[str, Any],
reset_self: bool, reset_self: bool,
read_retries: int, read_retries: int,
common_retries: int, common_retries: int,
retries_delay: float, retries_delay: float,
errors_threshold: int, errors_threshold: int,
noop: bool, noop: bool,
jiggler: dict[str, Any],
**gpio_kwargs: Any, **gpio_kwargs: Any,
) -> None: ) -> None:
BaseHid.__init__(self, **jiggler) BaseHid.__init__(self, ignore_keys=ignore_keys, **mouse_x_range, **mouse_y_range, **jiggler)
multiprocessing.Process.__init__(self, daemon=True) multiprocessing.Process.__init__(self, daemon=True)
self.__read_retries = read_retries self.__read_retries = read_retries
@ -163,7 +168,7 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
"errors_threshold": Option(5, type=valid_int_f0), "errors_threshold": Option(5, type=valid_int_f0),
"noop": Option(False, type=valid_bool), "noop": Option(False, type=valid_bool),
**cls._get_jiggler_options(), **cls._get_base_options(),
} }
def sysprep(self) -> None: def sysprep(self) -> None:
@ -212,6 +217,7 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
mouse_outputs["active"] = active_mouse mouse_outputs["active"] = active_mouse
return { return {
"enabled": True,
"online": online, "online": online,
"busy": bool(state["busy"]), "busy": bool(state["busy"]),
"connected": (bool(outputs2 & 0b01000000) if outputs2 & 0b10000000 else None), "connected": (bool(outputs2 & 0b01000000) if outputs2 & 0b10000000 else None),
@ -232,14 +238,18 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
**self._get_jiggler_state(), **self._get_jiggler_state(),
} }
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {} prev: dict = {}
while True: while True:
state = await self.get_state() if (await self.__notifier.wait()) > 0:
if state != prev_state: prev = {}
yield state new = await self.get_state()
prev_state = state if new != prev:
await self.__notifier.wait() prev = copy.deepcopy(new)
yield new
async def reset(self) -> None: async def reset(self) -> None:
self.__reset_required_event.set() self.__reset_required_event.set()
@ -254,27 +264,6 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
# ===== # =====
def send_key_events(self, keys: Iterable[tuple[str, bool]]) -> None:
for (key, state) in keys:
self.__queue_event(KeyEvent(key, state))
self._bump_activity()
def send_mouse_button_event(self, button: str, state: bool) -> None:
self.__queue_event(MouseButtonEvent(button, state))
self._bump_activity()
def send_mouse_move_event(self, to_x: int, to_y: int) -> None:
self.__queue_event(MouseMoveEvent(to_x, to_y))
self._bump_activity()
def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_event(MouseRelativeEvent(delta_x, delta_y))
self._bump_activity()
def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_event(MouseWheelEvent(delta_x, delta_y))
self._bump_activity()
def set_params( def set_params(
self, self,
keyboard_output: (str | None)=None, keyboard_output: (str | None)=None,
@ -296,9 +285,23 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
def set_connected(self, connected: bool) -> None: def set_connected(self, connected: bool) -> None:
self.__queue_event(SetConnectedEvent(connected), clear=True) self.__queue_event(SetConnectedEvent(connected), clear=True)
def clear_events(self) -> None: def _send_key_event(self, key: str, state: bool) -> None:
self.__queue_event(KeyEvent(key, state))
def _send_mouse_button_event(self, button: str, state: bool) -> None:
self.__queue_event(MouseButtonEvent(button, state))
def _send_mouse_move_event(self, to_x: int, to_y: int) -> None:
self.__queue_event(MouseMoveEvent(to_x, to_y))
def _send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_event(MouseRelativeEvent(delta_x, delta_y))
def _send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_event(MouseWheelEvent(delta_x, delta_y))
def _clear_events(self) -> None:
self.__queue_event(ClearEvent(), clear=True) self.__queue_event(ClearEvent(), clear=True)
self._bump_activity()
def __queue_event(self, event: BaseEvent, clear: bool=False) -> None: def __queue_event(self, event: BaseEvent, clear: bool=False) -> None:
if not self.__stop_event.is_set(): if not self.__stop_event.is_set():
@ -374,7 +377,7 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
self.__set_state_online(False) self.__set_state_online(False)
return False return False
def __process_request(self, conn: BasePhyConnection, request: bytes) -> bool: # pylint: disable=too-many-branches def __process_request(self, conn: BasePhyConnection, req: bytes) -> bool: # pylint: disable=too-many-branches
logger = get_logger() logger = get_logger()
error_messages: list[str] = [] error_messages: list[str] = []
live_log_errors = False live_log_errors = False
@ -384,47 +387,47 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
error_retval = False error_retval = False
while self.__gpio.is_powered() and common_retries and read_retries: while self.__gpio.is_powered() and common_retries and read_retries:
response = (RESPONSE_LEGACY_OK if self.__noop else conn.send(request)) resp = (RESPONSE_LEGACY_OK if self.__noop else conn.send(req))
try: try:
if len(response) < 4: if len(resp) < 4:
read_retries -= 1 read_retries -= 1
raise _TempRequestError(f"No response from HID: request={request!r}") raise _TempRequestError(f"No response from HID: request={req!r}")
if not check_response(response): if not check_response(resp):
request = REQUEST_REPEAT req = REQUEST_REPEAT
raise _TempRequestError("Invalid response CRC; requesting response again ...") raise _TempRequestError("Invalid response CRC; requesting response again ...")
code = response[1] code = resp[1]
if code == 0x48: # Request timeout # pylint: disable=no-else-raise if code == 0x48: # Request timeout # pylint: disable=no-else-raise
raise _TempRequestError(f"Got request timeout from HID: request={request!r}") raise _TempRequestError(f"Got request timeout from HID: request={req!r}")
elif code == 0x40: # CRC Error elif code == 0x40: # CRC Error
raise _TempRequestError(f"Got CRC error of request from HID: request={request!r}") raise _TempRequestError(f"Got CRC error of request from HID: request={req!r}")
elif code == 0x45: # Unknown command elif code == 0x45: # Unknown command
raise _PermRequestError(f"HID did not recognize the request={request!r}") raise _PermRequestError(f"HID did not recognize the request={req!r}")
elif code == 0x24: # Rebooted? elif code == 0x24: # Rebooted?
raise _PermRequestError("No previous command state inside HID, seems it was rebooted") raise _PermRequestError("No previous command state inside HID, seems it was rebooted")
elif code == 0x20: # Legacy done elif code == 0x20: # Legacy done
self.__set_state_online(True) self.__set_state_online(True)
return True return True
elif code & 0x80: # Pong/Done with state elif code & 0x80: # Pong/Done with state
self.__set_state_pong(response) self.__set_state_pong(resp)
return True return True
raise _TempRequestError(f"Invalid response from HID: request={request!r}, response=0x{response!r}") raise _TempRequestError(f"Invalid response from HID: request={req!r}, response=0x{resp!r}")
except _RequestError as err: except _RequestError as ex:
common_retries -= 1 common_retries -= 1
if live_log_errors: if live_log_errors:
logger.error(err.msg) logger.error(ex.msg)
else: else:
error_messages.append(err.msg) error_messages.append(ex.msg)
if len(error_messages) > self.__errors_threshold: if len(error_messages) > self.__errors_threshold:
for msg in error_messages: for msg in error_messages:
logger.error(msg) logger.error(msg)
error_messages = [] error_messages = []
live_log_errors = True live_log_errors = True
if isinstance(err, _PermRequestError): if isinstance(ex, _PermRequestError):
error_retval = True error_retval = True
break break
@ -440,7 +443,7 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
for msg in error_messages: for msg in error_messages:
logger.error(msg) logger.error(msg)
if not (common_retries and read_retries): if not (common_retries and read_retries):
logger.error("Can't process HID request due many errors: %r", request) logger.error("Can't process HID request due many errors: %r", req)
return error_retval return error_retval
def __set_state_online(self, online: bool) -> None: def __set_state_online(self, online: bool) -> None:
@ -449,11 +452,11 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
def __set_state_busy(self, busy: bool) -> None: def __set_state_busy(self, busy: bool) -> None:
self.__state_flags.update(busy=int(busy)) self.__state_flags.update(busy=int(busy))
def __set_state_pong(self, response: bytes) -> None: def __set_state_pong(self, resp: bytes) -> None:
status = response[1] << 16 status = resp[1] << 16
if len(response) > 4: if len(resp) > 4:
status |= (response[2] << 8) | response[3] status |= (resp[2] << 8) | resp[3]
reset_required = (1 if response[1] & 0b01000000 else 0) reset_required = (1 if resp[1] & 0b01000000 else 0)
self.__state_flags.update(online=1, busy=reset_required, status=status) self.__state_flags.update(online=1, busy=reset_required, status=status)
if reset_required: if reset_required:
if self.__reset_self: if self.__reset_self:

View File

@ -47,12 +47,12 @@ class Gpio: # pylint: disable=too-many-instance-attributes
self.__reset_inverted = reset_inverted self.__reset_inverted = reset_inverted
self.__reset_delay = reset_delay self.__reset_delay = reset_delay
self.__line_request: (gpiod.LineRequest | None) = None self.__line_req: (gpiod.LineRequest | None) = None
self.__last_power: (bool | None) = None self.__last_power: (bool | None) = None
def __enter__(self) -> None: def __enter__(self) -> None:
if self.__power_detect_pin >= 0 or self.__reset_pin >= 0: if self.__power_detect_pin >= 0 or self.__reset_pin >= 0:
assert self.__line_request is None assert self.__line_req is None
config: dict[int, gpiod.LineSettings] = {} config: dict[int, gpiod.LineSettings] = {}
if self.__power_detect_pin >= 0: if self.__power_detect_pin >= 0:
config[self.__power_detect_pin] = gpiod.LineSettings( config[self.__power_detect_pin] = gpiod.LineSettings(
@ -65,7 +65,7 @@ class Gpio: # pylint: disable=too-many-instance-attributes
output_value=gpiod.line.Value(self.__reset_inverted), output_value=gpiod.line.Value(self.__reset_inverted),
) )
assert len(config) > 0 assert len(config) > 0
self.__line_request = gpiod.request_lines( self.__line_req = gpiod.request_lines(
self.__device_path, self.__device_path,
consumer="kvmd::hid", consumer="kvmd::hid",
config=config, config=config,
@ -78,18 +78,18 @@ class Gpio: # pylint: disable=too-many-instance-attributes
_tb: types.TracebackType, _tb: types.TracebackType,
) -> None: ) -> None:
if self.__line_request: if self.__line_req:
try: try:
self.__line_request.release() self.__line_req.release()
except Exception: except Exception:
pass pass
self.__last_power = None self.__last_power = None
self.__line_request = None self.__line_req = None
def is_powered(self) -> bool: def is_powered(self) -> bool:
if self.__power_detect_pin >= 0: if self.__power_detect_pin >= 0:
assert self.__line_request assert self.__line_req
power = bool(self.__line_request.get_value(self.__power_detect_pin).value) power = bool(self.__line_req.get_value(self.__power_detect_pin).value)
if power != self.__last_power: if power != self.__last_power:
get_logger(0).info("HID power state changed: %s -> %s", self.__last_power, power) get_logger(0).info("HID power state changed: %s -> %s", self.__last_power, power)
self.__last_power = power self.__last_power = power
@ -98,11 +98,11 @@ class Gpio: # pylint: disable=too-many-instance-attributes
def reset(self) -> None: def reset(self) -> None:
if self.__reset_pin >= 0: if self.__reset_pin >= 0:
assert self.__line_request assert self.__line_req
try: try:
self.__line_request.set_value(self.__reset_pin, gpiod.line.Value(not self.__reset_inverted)) self.__line_req.set_value(self.__reset_pin, gpiod.line.Value(not self.__reset_inverted))
time.sleep(self.__reset_delay) time.sleep(self.__reset_delay)
finally: finally:
self.__line_request.set_value(self.__reset_pin, gpiod.line.Value(self.__reset_inverted)) self.__line_req.set_value(self.__reset_pin, gpiod.line.Value(self.__reset_inverted))
time.sleep(1) time.sleep(1)
get_logger(0).info("Reset HID performed") get_logger(0).info("Reset HID performed")

View File

@ -184,17 +184,17 @@ class MouseWheelEvent(BaseEvent):
# ===== # =====
def check_response(response: bytes) -> bool: def check_response(resp: bytes) -> bool:
assert len(response) in (4, 8), response assert len(resp) in (4, 8), resp
return (bitbang.make_crc16(response[:-2]) == struct.unpack(">H", response[-2:])[0]) return (bitbang.make_crc16(resp[:-2]) == struct.unpack(">H", resp[-2:])[0])
def _make_request(command: bytes) -> bytes: def _make_request(cmd: bytes) -> bytes:
assert len(command) == 5, command assert len(cmd) == 5, cmd
request = b"\x33" + command req = b"\x33" + cmd
request += struct.pack(">H", bitbang.make_crc16(request)) req += struct.pack(">H", bitbang.make_crc16(req))
assert len(request) == 8, request assert len(req) == 8, req
return request return req
# ===== # =====

View File

@ -21,9 +21,9 @@
import multiprocessing import multiprocessing
import copy
import time import time
from typing import Iterable
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Any from typing import Any
@ -63,6 +63,11 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals def __init__( # pylint: disable=too-many-arguments,too-many-locals
self, self,
ignore_keys: list[str],
mouse_x_range: dict[str, Any],
mouse_y_range: dict[str, Any],
jiggler: dict[str, Any],
manufacturer: str, manufacturer: str,
product: str, product: str,
description: str, description: str,
@ -78,11 +83,9 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
max_clients: int, max_clients: int,
socket_timeout: float, socket_timeout: float,
select_timeout: float, select_timeout: float,
jiggler: dict[str, Any],
) -> None: ) -> None:
super().__init__(**jiggler) super().__init__(ignore_keys=ignore_keys, **mouse_x_range, **mouse_y_range, **jiggler)
self._set_jiggler_absolute(False) self._set_jiggler_absolute(False)
self.__proc: (multiprocessing.Process | None) = None self.__proc: (multiprocessing.Process | None) = None
@ -126,7 +129,7 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
"socket_timeout": Option(5.0, type=valid_float_f01), "socket_timeout": Option(5.0, type=valid_float_f01),
"select_timeout": Option(1.0, type=valid_float_f01), "select_timeout": Option(1.0, type=valid_float_f01),
**cls._get_jiggler_options(), **cls._get_base_options(),
} }
def sysprep(self) -> None: def sysprep(self) -> None:
@ -138,6 +141,7 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
state = await self.__server.get_state() state = await self.__server.get_state()
outputs: dict = {"available": [], "active": ""} outputs: dict = {"available": [], "active": ""}
return { return {
"enabled": True,
"online": True, "online": True,
"busy": False, "busy": False,
"connected": None, "connected": None,
@ -158,14 +162,18 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
**self._get_jiggler_state(), **self._get_jiggler_state(),
} }
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {} prev: dict = {}
while True: while True:
state = await self.get_state() if (await self.__notifier.wait()) > 0:
if state != prev_state: prev = {}
yield state new = await self.get_state()
prev_state = state if new != prev:
await self.__notifier.wait() prev = copy.deepcopy(new)
yield new
async def reset(self) -> None: async def reset(self) -> None:
self.clear_events() self.clear_events()
@ -182,27 +190,6 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
# ===== # =====
def send_key_events(self, keys: Iterable[tuple[str, bool]]) -> None:
for (key, state) in keys:
self.__server.queue_event(make_keyboard_event(key, state))
self._bump_activity()
def send_mouse_button_event(self, button: str, state: bool) -> None:
self.__server.queue_event(MouseButtonEvent(button, state))
self._bump_activity()
def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__server.queue_event(MouseRelativeEvent(delta_x, delta_y))
self._bump_activity()
def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__server.queue_event(MouseWheelEvent(delta_x, delta_y))
self._bump_activity()
def clear_events(self) -> None:
self.__server.clear_events()
self._bump_activity()
def set_params( def set_params(
self, self,
keyboard_output: (str | None)=None, keyboard_output: (str | None)=None,
@ -216,6 +203,21 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
self._set_jiggler_active(jiggler) self._set_jiggler_active(jiggler)
self.__notifier.notify() self.__notifier.notify()
def _send_key_event(self, key: str, state: bool) -> None:
self.__server.queue_event(make_keyboard_event(key, state))
def _send_mouse_button_event(self, button: str, state: bool) -> None:
self.__server.queue_event(MouseButtonEvent(button, state))
def _send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__server.queue_event(MouseRelativeEvent(delta_x, delta_y))
def _send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__server.queue_event(MouseWheelEvent(delta_x, delta_y))
def _clear_events(self) -> None:
self.__server.clear_events()
# ===== # =====
def __server_worker(self) -> None: # pylint: disable=too-many-branches def __server_worker(self) -> None: # pylint: disable=too-many-branches

View File

@ -182,8 +182,8 @@ class BtServer: # pylint: disable=too-many-instance-attributes
self.__close_client("CTL", client, "ctl_sock") self.__close_client("CTL", client, "ctl_sock")
elif data == b"\x71": elif data == b"\x71":
sock.send(b"\x00") sock.send(b"\x00")
except Exception as err: except Exception as ex:
get_logger(0).exception("CTL socket error on %s: %s", client.addr, tools.efmt(err)) get_logger(0).exception("CTL socket error on %s: %s", client.addr, tools.efmt(ex))
self.__close_client("CTL", client, "ctl_sock") self.__close_client("CTL", client, "ctl_sock")
continue continue
@ -196,8 +196,8 @@ class BtServer: # pylint: disable=too-many-instance-attributes
self.__close_client("INT", client, "int_sock") self.__close_client("INT", client, "int_sock")
elif data[:2] == b"\xA2\x01": elif data[:2] == b"\xA2\x01":
self.__process_leds(data[2]) self.__process_leds(data[2])
except Exception as err: except Exception as ex:
get_logger(0).exception("INT socket error on %s: %s", client.addr, tools.efmt(err)) get_logger(0).exception("INT socket error on %s: %s", client.addr, tools.efmt(ex))
self.__close_client("INT", client, "ctl_sock") self.__close_client("INT", client, "ctl_sock")
if qr in ready_read: if qr in ready_read:
@ -279,8 +279,8 @@ class BtServer: # pylint: disable=too-many-instance-attributes
assert client.int_sock is not None assert client.int_sock is not None
try: try:
client.int_sock.send(report) client.int_sock.send(report)
except Exception as err: except Exception as ex:
get_logger(0).info("Can't send %s report to %s: %s", name, client.addr, tools.efmt(err)) get_logger(0).info("Can't send %s report to %s: %s", name, client.addr, tools.efmt(ex))
self.__close_client_pair(client) self.__close_client_pair(client)
def __clear_modifiers(self) -> None: def __clear_modifiers(self) -> None:
@ -371,13 +371,13 @@ class BtServer: # pylint: disable=too-many-instance-attributes
logger.info("Publishing ..." if public else "Unpublishing ...") logger.info("Publishing ..." if public else "Unpublishing ...")
try: try:
self.__iface.set_public(public) self.__iface.set_public(public)
except Exception as err: except Exception as ex:
logger.error("Can't change public mode: %s", tools.efmt(err)) logger.error("Can't change public mode: %s", tools.efmt(ex))
def __unpair_client(self, client: _BtClient) -> None: def __unpair_client(self, client: _BtClient) -> None:
logger = get_logger(0) logger = get_logger(0)
logger.info("Unpairing %s ...", client.addr) logger.info("Unpairing %s ...", client.addr)
try: try:
self.__iface.unpair(client.addr) self.__iface.unpair(client.addr)
except Exception as err: except Exception as ex:
logger.error("Can't unpair %s: %s", client.addr, tools.efmt(err)) logger.error("Can't unpair %s: %s", client.addr, tools.efmt(ex))

View File

@ -22,9 +22,9 @@
import multiprocessing import multiprocessing
import queue import queue
import copy
import time import time
from typing import Iterable
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Any from typing import Any
@ -54,13 +54,17 @@ from .keyboard import Keyboard
class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-instance-attributes class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,super-init-not-called def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, self,
ignore_keys: list[str],
mouse_x_range: dict[str, Any],
mouse_y_range: dict[str, Any],
jiggler: dict[str, Any],
device_path: str, device_path: str,
speed: int, speed: int,
read_timeout: float, read_timeout: float,
jiggler: dict[str, Any],
) -> None: ) -> None:
BaseHid.__init__(self, **jiggler) BaseHid.__init__(self, ignore_keys=ignore_keys, **mouse_x_range, **mouse_y_range, **jiggler)
multiprocessing.Process.__init__(self, daemon=True) multiprocessing.Process.__init__(self, daemon=True)
self.__device_path = device_path self.__device_path = device_path
@ -88,7 +92,7 @@ class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-inst
"device": Option("/dev/kvmd-hid", type=valid_abs_path, unpack_as="device_path"), "device": Option("/dev/kvmd-hid", type=valid_abs_path, unpack_as="device_path"),
"speed": Option(9600, type=valid_tty_speed), "speed": Option(9600, type=valid_tty_speed),
"read_timeout": Option(0.3, type=valid_float_f01), "read_timeout": Option(0.3, type=valid_float_f01),
**cls._get_jiggler_options(), **cls._get_base_options(),
} }
def sysprep(self) -> None: def sysprep(self) -> None:
@ -100,6 +104,7 @@ class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-inst
absolute = self.__mouse.is_absolute() absolute = self.__mouse.is_absolute()
leds = await self.__keyboard.get_leds() leds = await self.__keyboard.get_leds()
return { return {
"enabled": True,
"online": state["online"], "online": state["online"],
"busy": False, "busy": False,
"connected": None, "connected": None,
@ -119,14 +124,18 @@ class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-inst
**self._get_jiggler_state(), **self._get_jiggler_state(),
} }
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {} prev: dict = {}
while True: while True:
state = await self.get_state() if (await self.__notifier.wait()) > 0:
if state != prev_state: prev = {}
yield state new = await self.get_state()
prev_state = state if new != prev:
await self.__notifier.wait() prev = copy.deepcopy(new)
yield new
async def reset(self) -> None: async def reset(self) -> None:
self.__reset_required_event.set() self.__reset_required_event.set()
@ -141,27 +150,6 @@ class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-inst
# ===== # =====
def send_key_events(self, keys: Iterable[tuple[str, bool]]) -> None:
for (key, state) in keys:
self.__queue_cmd(self.__keyboard.process_key(key, state))
self._bump_activity()
def send_mouse_button_event(self, button: str, state: bool) -> None:
self.__queue_cmd(self.__mouse.process_button(button, state))
self._bump_activity()
def send_mouse_move_event(self, to_x: int, to_y: int) -> None:
self.__queue_cmd(self.__mouse.process_move(to_x, to_y))
self._bump_activity()
def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_cmd(self.__mouse.process_wheel(delta_x, delta_y))
self._bump_activity()
def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_cmd(self.__mouse.process_relative(delta_x, delta_y))
self._bump_activity()
def set_params( def set_params(
self, self,
keyboard_output: (str | None)=None, keyboard_output: (str | None)=None,
@ -180,10 +168,22 @@ class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-inst
self._set_jiggler_active(jiggler) self._set_jiggler_active(jiggler)
self.__notifier.notify() self.__notifier.notify()
def set_connected(self, connected: bool) -> None: def _send_key_event(self, key: str, state: bool) -> None:
pass self.__queue_cmd(self.__keyboard.process_key(key, state))
def clear_events(self) -> None: def _send_mouse_button_event(self, button: str, state: bool) -> None:
self.__queue_cmd(self.__mouse.process_button(button, state))
def _send_mouse_move_event(self, to_x: int, to_y: int) -> None:
self.__queue_cmd(self.__mouse.process_move(to_x, to_y))
def _send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_cmd(self.__mouse.process_wheel(delta_x, delta_y))
def _send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__queue_cmd(self.__mouse.process_relative(delta_x, delta_y))
def _clear_events(self) -> None:
tools.clear_queue(self.__cmd_queue) tools.clear_queue(self.__cmd_queue)
def __queue_cmd(self, cmd: bytes, clear: bool=False) -> None: def __queue_cmd(self, cmd: bytes, clear: bool=False) -> None:
@ -230,9 +230,9 @@ class Plugin(BaseHid, multiprocessing.Process): # pylint: disable=too-many-inst
def __process_cmd(self, conn: ChipConnection, cmd: bytes) -> bool: # pylint: disable=too-many-branches def __process_cmd(self, conn: ChipConnection, cmd: bytes) -> bool: # pylint: disable=too-many-branches
try: try:
led_byte = conn.xfer(cmd) led_byte = conn.xfer(cmd)
except ChipResponseError as err: except ChipResponseError as ex:
self.__set_state_online(False) self.__set_state_online(False)
get_logger(0).info(err) get_logger(0).error("Invalid chip response: %s", tools.efmt(ex))
time.sleep(2) time.sleep(2)
else: else:
if led_byte >= 0: if led_byte >= 0:

View File

@ -20,7 +20,8 @@
# ========================================================================== # # ========================================================================== #
from typing import Iterable import copy
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Any from typing import Any
@ -46,15 +47,20 @@ from .mouse import MouseProcess
class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
def __init__( def __init__(
self, self,
ignore_keys: list[str],
mouse_x_range: dict[str, Any],
mouse_y_range: dict[str, Any],
jiggler: dict[str, Any],
keyboard: dict[str, Any], keyboard: dict[str, Any],
mouse: dict[str, Any], mouse: dict[str, Any],
mouse_alt: dict[str, Any], mouse_alt: dict[str, Any],
jiggler: dict[str, Any],
noop: bool, noop: bool,
udc: str, # XXX: Not from options, see /kvmd/apps/kvmd/__init__.py for details udc: str, # XXX: Not from options, see /kvmd/apps/kvmd/__init__.py for details
) -> None: ) -> None:
super().__init__(**jiggler) super().__init__(ignore_keys=ignore_keys, **mouse_x_range, **mouse_y_range, **jiggler)
self.__udc = udc self.__udc = udc
@ -113,7 +119,7 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
"horizontal_wheel": Option(True, type=valid_bool), "horizontal_wheel": Option(True, type=valid_bool),
}, },
"noop": Option(False, type=valid_bool), "noop": Option(False, type=valid_bool),
**cls._get_jiggler_options(), **cls._get_base_options(),
} }
def sysprep(self) -> None: def sysprep(self) -> None:
@ -128,6 +134,7 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
keyboard_state = await self.__keyboard_proc.get_state() keyboard_state = await self.__keyboard_proc.get_state()
mouse_state = await self.__mouse_current.get_state() mouse_state = await self.__mouse_current.get_state()
return { return {
"enabled": True,
"online": True, "online": True,
"busy": False, "busy": False,
"connected": None, "connected": None,
@ -150,14 +157,18 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
**self._get_jiggler_state(), **self._get_jiggler_state(),
} }
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {} prev: dict = {}
while True: while True:
state = await self.get_state() if (await self.__notifier.wait()) > 0:
if state != prev_state: prev = {}
yield state new = await self.get_state()
prev_state = state if new != prev:
await self.__notifier.wait() prev = copy.deepcopy(new)
yield new
async def reset(self) -> None: async def reset(self) -> None:
self.__keyboard_proc.send_reset_event() self.__keyboard_proc.send_reset_event()
@ -177,26 +188,6 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
# ===== # =====
def send_key_events(self, keys: Iterable[tuple[str, bool]]) -> None:
self.__keyboard_proc.send_key_events(keys)
self._bump_activity()
def send_mouse_button_event(self, button: str, state: bool) -> None:
self.__mouse_current.send_button_event(button, state)
self._bump_activity()
def send_mouse_move_event(self, to_x: int, to_y: int) -> None:
self.__mouse_current.send_move_event(to_x, to_y)
self._bump_activity()
def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__mouse_current.send_relative_event(delta_x, delta_y)
self._bump_activity()
def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__mouse_current.send_wheel_event(delta_x, delta_y)
self._bump_activity()
def set_params( def set_params(
self, self,
keyboard_output: (str | None)=None, keyboard_output: (str | None)=None,
@ -215,12 +206,26 @@ class Plugin(BaseHid): # pylint: disable=too-many-instance-attributes
self._set_jiggler_active(jiggler) self._set_jiggler_active(jiggler)
self.__notifier.notify() self.__notifier.notify()
def clear_events(self) -> None: def _send_key_event(self, key: str, state: bool) -> None:
self.__keyboard_proc.send_key_event(key, state)
def _send_mouse_button_event(self, button: str, state: bool) -> None:
self.__mouse_current.send_button_event(button, state)
def _send_mouse_move_event(self, to_x: int, to_y: int) -> None:
self.__mouse_current.send_move_event(to_x, to_y)
def _send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None:
self.__mouse_current.send_relative_event(delta_x, delta_y)
def _send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None:
self.__mouse_current.send_wheel_event(delta_x, delta_y)
def _clear_events(self) -> None:
self.__keyboard_proc.send_clear_event() self.__keyboard_proc.send_clear_event()
self.__mouse_proc.send_clear_event() self.__mouse_proc.send_clear_event()
if self.__mouse_alt_proc: if self.__mouse_alt_proc:
self.__mouse_alt_proc.send_clear_event() self.__mouse_alt_proc.send_clear_event()
self._bump_activity()
# ===== # =====

View File

@ -192,13 +192,13 @@ class BaseDeviceProcess(multiprocessing.Process): # pylint: disable=too-many-in
else: else:
logger.error("HID-%s write() error: written (%s) != report length (%d)", logger.error("HID-%s write() error: written (%s) != report length (%d)",
self.__name, written, len(report)) self.__name, written, len(report))
except Exception as err: except Exception as ex:
if isinstance(err, OSError) and ( if isinstance(ex, OSError) and (
# https://github.com/raspberrypi/linux/commit/61b7f805dc2fd364e0df682de89227e94ce88e25 # https://github.com/raspberrypi/linux/commit/61b7f805dc2fd364e0df682de89227e94ce88e25
err.errno == errno.EAGAIN # pylint: disable=no-member ex.errno == errno.EAGAIN # pylint: disable=no-member
or err.errno == errno.ESHUTDOWN # pylint: disable=no-member or ex.errno == errno.ESHUTDOWN # pylint: disable=no-member
): ):
logger.debug("HID-%s busy/unplugged (write): %s", self.__name, tools.efmt(err)) logger.debug("HID-%s busy/unplugged (write): %s", self.__name, tools.efmt(ex))
else: else:
logger.exception("Can't write report to HID-%s", self.__name) logger.exception("Can't write report to HID-%s", self.__name)
@ -216,16 +216,16 @@ class BaseDeviceProcess(multiprocessing.Process): # pylint: disable=too-many-in
while read: while read:
try: try:
read = bool(select.select([self.__fd], [], [], 0)[0]) read = bool(select.select([self.__fd], [], [], 0)[0])
except Exception as err: except Exception as ex:
logger.error("Can't select() for read HID-%s: %s", self.__name, tools.efmt(err)) logger.error("Can't select() for read HID-%s: %s", self.__name, tools.efmt(ex))
break break
if read: if read:
try: try:
report = os.read(self.__fd, self.__read_size) report = os.read(self.__fd, self.__read_size)
except Exception as err: except Exception as ex:
if isinstance(err, OSError) and err.errno == errno.EAGAIN: # pylint: disable=no-member if isinstance(ex, OSError) and ex.errno == errno.EAGAIN: # pylint: disable=no-member
logger.debug("HID-%s busy/unplugged (read): %s", self.__name, tools.efmt(err)) logger.debug("HID-%s busy/unplugged (read): %s", self.__name, tools.efmt(ex))
else: else:
logger.exception("Can't read report from HID-%s", self.__name) logger.exception("Can't read report from HID-%s", self.__name)
else: else:
@ -255,9 +255,9 @@ class BaseDeviceProcess(multiprocessing.Process): # pylint: disable=too-many-in
flags = os.O_NONBLOCK flags = os.O_NONBLOCK
flags |= (os.O_RDWR if self.__read_size else os.O_WRONLY) flags |= (os.O_RDWR if self.__read_size else os.O_WRONLY)
self.__fd = os.open(self.__device_path, flags) self.__fd = os.open(self.__device_path, flags)
except Exception as err: except Exception as ex:
logger.error("Can't open HID-%s device %s: %s", logger.error("Can't open HID-%s device %s: %s",
self.__name, self.__device_path, tools.efmt(err)) self.__name, self.__device_path, tools.efmt(ex))
if self.__fd >= 0: if self.__fd >= 0:
try: try:
@ -268,8 +268,8 @@ class BaseDeviceProcess(multiprocessing.Process): # pylint: disable=too-many-in
else: else:
# Если запись недоступна, то скорее всего устройство отключено # Если запись недоступна, то скорее всего устройство отключено
logger.debug("HID-%s is busy/unplugged (write select)", self.__name) logger.debug("HID-%s is busy/unplugged (write select)", self.__name)
except Exception as err: except Exception as ex:
logger.error("Can't select() for write HID-%s: %s", self.__name, tools.efmt(err)) logger.error("Can't select() for write HID-%s: %s", self.__name, tools.efmt(ex))
self.__state_flags.update(online=False) self.__state_flags.update(online=False)
return False return False

View File

@ -20,7 +20,6 @@
# ========================================================================== # # ========================================================================== #
from typing import Iterable
from typing import Generator from typing import Generator
from typing import Any from typing import Any
@ -68,8 +67,7 @@ class KeyboardProcess(BaseDeviceProcess):
self._clear_queue() self._clear_queue()
self._queue_event(ResetEvent()) self._queue_event(ResetEvent())
def send_key_events(self, keys: Iterable[tuple[str, bool]]) -> None: def send_key_event(self, key: str, state: bool) -> None:
for (key, state) in keys:
self._queue_event(make_keyboard_event(key, state)) self._queue_event(make_keyboard_event(key, state))
# ===== # =====

View File

@ -44,12 +44,12 @@ class _SerialPhyConnection(BasePhyConnection):
def __init__(self, tty: serial.Serial) -> None: def __init__(self, tty: serial.Serial) -> None:
self.__tty = tty self.__tty = tty
def send(self, request: bytes) -> bytes: def send(self, req: bytes) -> bytes:
assert len(request) == 8 assert len(req) == 8
assert request[0] == 0x33 assert req[0] == 0x33
if self.__tty.in_waiting: if self.__tty.in_waiting:
self.__tty.read_all() self.__tty.read_all()
assert self.__tty.write(request) == 8 assert self.__tty.write(req) == 8
data = self.__tty.read(4) data = self.__tty.read(4)
if len(data) == 4: if len(data) == 4:
if data[0] == 0x34: # New response protocol if data[0] == 0x34: # New response protocol

View File

@ -57,9 +57,9 @@ class _SpiPhyConnection(BasePhyConnection):
self.__xfer = xfer self.__xfer = xfer
self.__read_timeout = read_timeout self.__read_timeout = read_timeout
def send(self, request: bytes) -> bytes: def send(self, req: bytes) -> bytes:
assert len(request) == 8 assert len(req) == 8
assert request[0] == 0x33 assert req[0] == 0x33
deadline_ts = time.monotonic() + self.__read_timeout deadline_ts = time.monotonic() + self.__read_timeout
dummy = b"\x00" * 10 dummy = b"\x00" * 10
@ -70,26 +70,26 @@ class _SpiPhyConnection(BasePhyConnection):
get_logger(0).error("SPI timeout reached while garbage reading") get_logger(0).error("SPI timeout reached while garbage reading")
return b"" return b""
self.__xfer(request) self.__xfer(req)
response: list[int] = [] resp: list[int] = []
deadline_ts = time.monotonic() + self.__read_timeout deadline_ts = time.monotonic() + self.__read_timeout
found = False found = False
while time.monotonic() < deadline_ts: while time.monotonic() < deadline_ts:
for byte in self.__xfer(b"\x00" * (9 - len(response))): for byte in self.__xfer(b"\x00" * (9 - len(resp))):
if not found: if not found:
if byte == 0: if byte == 0:
continue continue
found = True found = True
response.append(byte) resp.append(byte)
if len(response) == 8: if len(resp) == 8:
break break
if len(response) == 8: if len(resp) == 8:
break break
else: else:
get_logger(0).error("SPI timeout reached while responce waiting") get_logger(0).error("SPI timeout reached while responce waiting")
return b"" return b""
return bytes(response) return bytes(resp)
class _SpiPhy(BasePhy): # pylint: disable=too-many-instance-attributes class _SpiPhy(BasePhy): # pylint: disable=too-many-instance-attributes

View File

@ -117,7 +117,22 @@ class BaseMsd(BasePlugin):
async def get_state(self) -> dict: async def get_state(self) -> dict:
raise NotImplementedError() raise NotImplementedError()
async def trigger_state(self) -> None:
raise NotImplementedError()
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - enabled -- Full
# - online -- Partial
# - busy -- Partial
# - drive -- Partial, nullable
# - storage -- Partial, nullable
# - storage.parts -- Partial
# - storage.images -- Partial
# - storage.downloading -- Partial, nullable
# - storage.uploading -- Partial, nullable
# ===========================
if self is not None: # XXX: Vulture and pylint hack if self is not None: # XXX: Vulture and pylint hack
raise NotImplementedError() raise NotImplementedError()
yield yield
@ -263,16 +278,18 @@ class MsdFileWriter(BaseMsdWriter): # pylint: disable=too-many-instance-attribu
return self.__written return self.__written
def is_complete(self) -> bool:
return (self.__written >= self.__file_size)
async def open(self) -> "MsdFileWriter": async def open(self) -> "MsdFileWriter":
assert self.__file is None assert self.__file is None
get_logger(1).info("Writing %r image (%d bytes) to MSD ...", self.__name, self.__file_size) get_logger(1).info("Writing %r image (%d bytes) to MSD ...", self.__name, self.__file_size)
await aiofiles.os.makedirs(os.path.dirname(self.__path), exist_ok=True) await aiofiles.os.makedirs(os.path.dirname(self.__path), exist_ok=True)
self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore
await aiotools.run_async(os.ftruncate, self.__file.fileno(), self.__file_size) # type: ignore
return self return self
async def finish(self) -> bool:
await self.__sync()
return (self.__written >= self.__file_size)
async def close(self) -> None: async def close(self) -> None:
assert self.__file is not None assert self.__file is not None
logger = get_logger() logger = get_logger()
@ -285,9 +302,6 @@ class MsdFileWriter(BaseMsdWriter): # pylint: disable=too-many-instance-attribu
else: # written > size else: # written > size
(log, result) = (logger.warning, "OVERFLOW") (log, result) = (logger.warning, "OVERFLOW")
log("Written %d of %d bytes to MSD image %r: %s", self.__written, self.__file_size, self.__name, result) log("Written %d of %d bytes to MSD image %r: %s", self.__written, self.__file_size, self.__name, result)
try:
await self.__sync()
finally:
await self.__file.close() # type: ignore await self.__file.close() # type: ignore
except Exception: except Exception:
logger.exception("Can't close image writer") logger.exception("Can't close image writer")

View File

@ -40,6 +40,9 @@ class MsdDisabledError(MsdOperationError):
# ===== # =====
class Plugin(BaseMsd): class Plugin(BaseMsd):
def __init__(self) -> None:
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict: async def get_state(self) -> dict:
return { return {
"enabled": False, "enabled": False,
@ -49,10 +52,13 @@ class Plugin(BaseMsd):
"drive": None, "drive": None,
} }
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[dict, None]: async def poll_state(self) -> AsyncGenerator[dict, None]:
while True: while True:
await self.__notifier.wait()
yield (await self.get_state()) yield (await self.get_state())
await aiotools.wait_infinite()
async def reset(self) -> None: async def reset(self) -> None:
raise MsdDisabledError() raise MsdDisabledError()

View File

@ -26,12 +26,12 @@ import dataclasses
import functools import functools
import time import time
import os import os
import copy
from typing import AsyncGenerator from typing import AsyncGenerator
from ....logging import get_logger from ....logging import get_logger
from ....inotify import InotifyMask
from ....inotify import Inotify from ....inotify import Inotify
from ....yamlconf import Option from ....yamlconf import Option
@ -97,7 +97,8 @@ class _State:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def busy(self, check_online: bool=True) -> AsyncGenerator[None, None]: async def busy(self, check_online: bool=True) -> AsyncGenerator[None, None]:
async with self._region: try:
with self._region:
async with self._lock: async with self._lock:
self.__notifier.notify() self.__notifier.notify()
if check_online: if check_online:
@ -105,6 +106,7 @@ class _State:
raise MsdOfflineError() raise MsdOfflineError()
assert self.storage assert self.storage
yield yield
finally:
self.__notifier.notify() self.__notifier.notify()
def is_busy(self) -> bool: def is_busy(self) -> bool:
@ -141,10 +143,11 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__notifier = aiotools.AioNotifier() self.__notifier = aiotools.AioNotifier()
self.__state = _State(self.__notifier) self.__state = _State(self.__notifier)
self.__reset = False
logger = get_logger(0) logger = get_logger(0)
logger.info("Using OTG gadget %r as MSD", gadget) logger.info("Using OTG gadget %r as MSD", gadget)
aiotools.run_sync(self.__reload_state(notify=False)) aiotools.run_sync(self.__unsafe_reload_state())
@classmethod @classmethod
def get_plugin_options(cls) -> dict: def get_plugin_options(cls) -> dict:
@ -164,14 +167,13 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
}, },
} }
# =====
async def get_state(self) -> dict: async def get_state(self) -> dict:
async with self.__state._lock: # pylint: disable=protected-access async with self.__state._lock: # pylint: disable=protected-access
storage: (dict | None) = None storage: (dict | None) = None
if self.__state.storage: if self.__state.storage:
if self.__writer: assert self.__state.vd
# При загрузке файла показываем актуальную статистику вручную
await self.__storage.reload_parts_info()
storage = dataclasses.asdict(self.__state.storage) storage = dataclasses.asdict(self.__state.storage)
for name in list(storage["images"]): for name in list(storage["images"]):
del storage["images"][name]["name"] del storage["images"][name]["name"]
@ -185,34 +187,50 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
vd: (dict | None) = None vd: (dict | None) = None
if self.__state.vd: if self.__state.vd:
assert self.__state.storage
vd = dataclasses.asdict(self.__state.vd) vd = dataclasses.asdict(self.__state.vd)
if vd["image"]: if vd["image"]:
del vd["image"]["path"] del vd["image"]["path"]
return { return {
"enabled": True, "enabled": True,
"online": (bool(self.__state.vd) and self.__drive.is_enabled()), "online": (bool(vd) and self.__drive.is_enabled()),
"busy": self.__state.is_busy(), "busy": self.__state.is_busy(),
"storage": storage, "storage": storage,
"drive": vd, "drive": vd,
} }
async def poll_state(self) -> AsyncGenerator[dict, None]: async def trigger_state(self) -> None:
prev_state: dict = {} self.__notifier.notify(1)
while True:
state = await self.get_state()
if state != prev_state:
yield state
prev_state = state
await self.__notifier.wait()
async def systask(self) -> None: async def poll_state(self) -> AsyncGenerator[dict, None]:
await self.__watch_inotify() prev: dict = {}
while True:
if (await self.__notifier.wait()) > 0:
prev = {}
new = await self.get_state()
if not prev or (prev.get("online") != new["online"]):
prev = copy.deepcopy(new)
yield new
else:
diff: dict = {}
for sub in ["busy", "drive"]:
if prev.get(sub) != new[sub]:
diff[sub] = new[sub]
for sub in ["images", "parts", "downloading", "uploading"]:
if (prev.get("storage") or {}).get(sub) != (new["storage"] or {}).get(sub):
if "storage" not in diff:
diff["storage"] = {}
diff["storage"][sub] = new["storage"][sub]
if diff:
prev = copy.deepcopy(new)
yield diff
@aiotools.atomic_fg @aiotools.atomic_fg
async def reset(self) -> None: async def reset(self) -> None:
async with self.__state.busy(check_online=False): async with self.__state.busy(check_online=False):
try: try:
self.__reset = True
self.__drive.set_image_path("") self.__drive.set_image_path("")
self.__drive.set_cdrom_flag(False) self.__drive.set_cdrom_flag(False)
self.__drive.set_rw_flag(False) self.__drive.set_rw_flag(False)
@ -220,11 +238,6 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
except Exception: except Exception:
get_logger(0).exception("Can't reset MSD properly") get_logger(0).exception("Can't reset MSD properly")
@aiotools.atomic_fg
async def cleanup(self) -> None:
await self.__close_reader()
await self.__close_writer()
# ===== # =====
@aiotools.atomic_fg @aiotools.atomic_fg
@ -296,11 +309,12 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def read_image(self, name: str) -> AsyncGenerator[MsdFileReader, None]: async def read_image(self, name: str) -> AsyncGenerator[MsdFileReader, None]:
try: try:
async with self.__state._region: # pylint: disable=protected-access with self.__state._region: # pylint: disable=protected-access
try: try:
async with self.__state._lock: # pylint: disable=protected-access async with self.__state._lock: # pylint: disable=protected-access
self.__notifier.notify() self.__notifier.notify()
self.__STATE_check_disconnected() self.__STATE_check_disconnected()
image = await self.__STATE_get_storage_image(name) image = await self.__STATE_get_storage_image(name)
self.__reader = await MsdFileReader( self.__reader = await MsdFileReader(
notifier=self.__notifier, notifier=self.__notifier,
@ -308,7 +322,10 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
path=image.path, path=image.path,
chunk_size=self.__read_chunk_size, chunk_size=self.__read_chunk_size,
).open() ).open()
self.__notifier.notify()
yield self.__reader yield self.__reader
finally: finally:
await aiotools.shield_fg(self.__close_reader()) await aiotools.shield_fg(self.__close_reader())
finally: finally:
@ -316,18 +333,40 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def write_image(self, name: str, size: int, remove_incomplete: (bool | None)) -> AsyncGenerator[MsdFileWriter, None]: async def write_image(self, name: str, size: int, remove_incomplete: (bool | None)) -> AsyncGenerator[MsdFileWriter, None]:
try:
async with self.__state._region: # pylint: disable=protected-access
image: (Image | None) = None image: (Image | None) = None
complete = False
async def finish_writing() -> None:
# Делаем под блокировкой, чтобы эвент айнотифи не был обработан
# до того, как мы не закончим все процедуры.
async with self.__state._lock: # pylint: disable=protected-access
try:
self.__notifier.notify()
finally:
try:
if image:
await image.set_complete(complete)
finally:
try:
if image and remove_incomplete and not complete:
await image.remove(fatal=False)
finally:
try:
await self.__close_writer()
finally:
if image:
await image.remount_rw(False, fatal=False)
try:
with self.__state._region: # pylint: disable=protected-access
try: try:
async with self.__state._lock: # pylint: disable=protected-access async with self.__state._lock: # pylint: disable=protected-access
self.__notifier.notify() self.__notifier.notify()
self.__STATE_check_disconnected() self.__STATE_check_disconnected()
image = await self.__STORAGE_create_new_image(name)
image = await self.__STORAGE_create_new_image(name)
await image.remount_rw(True) await image.remount_rw(True)
await image.set_complete(False) await image.set_complete(False)
self.__writer = await MsdFileWriter( self.__writer = await MsdFileWriter(
notifier=self.__notifier, notifier=self.__notifier,
name=image.name, name=image.name,
@ -339,22 +378,12 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__notifier.notify() self.__notifier.notify()
yield self.__writer yield self.__writer
await image.set_complete(self.__writer.is_complete()) complete = await self.__writer.finish()
finally: finally:
try: await aiotools.shield_fg(finish_writing())
if image and remove_incomplete and self.__writer and not self.__writer.is_complete():
await image.remove(fatal=False)
finally: finally:
try: self.__notifier.notify()
await aiotools.shield_fg(self.__close_writer())
finally:
if image:
await aiotools.shield_fg(image.remount_rw(False, fatal=False))
finally:
# Между закрытием файла и эвентом айнотифи состояние может быть не обновлено,
# так что форсим обновление вручную, чтобы получить актуальное состояние.
await aiotools.shield_fg(self.__reload_state())
@aiotools.atomic_fg @aiotools.atomic_fg
async def remove(self, name: str) -> None: async def remove(self, name: str) -> None:
@ -404,17 +433,26 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
async def __close_reader(self) -> None: async def __close_reader(self) -> None:
if self.__reader: if self.__reader:
try:
await self.__reader.close() await self.__reader.close()
finally:
self.__reader = None self.__reader = None
async def __close_writer(self) -> None: async def __close_writer(self) -> None:
if self.__writer: if self.__writer:
try:
await self.__writer.close() await self.__writer.close()
finally:
self.__writer = None self.__writer = None
# ===== # =====
async def __watch_inotify(self) -> None: @aiotools.atomic_fg
async def cleanup(self) -> None:
await self.__close_reader()
await self.__close_writer()
async def systask(self) -> None:
logger = get_logger(0) logger = get_logger(0)
while True: while True:
try: try:
@ -426,19 +464,25 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
await asyncio.sleep(5) await asyncio.sleep(5)
with Inotify() as inotify: with Inotify() as inotify:
await inotify.watch(InotifyMask.ALL_MODIFY_EVENTS, *self.__storage.get_watchable_paths()) # Из-за гонки между первым релоадом и установкой вотчеров,
await inotify.watch(InotifyMask.ALL_MODIFY_EVENTS, *self.__drive.get_watchable_paths()) # мы можем потерять какие-то каталоги стораджа, но это допустимо,
# так как всегда есть ручной перезапуск.
await inotify.watch_all_changes(*self.__storage.get_watchable_paths())
await inotify.watch_all_changes(*self.__drive.get_watchable_paths())
# После установки вотчеров еще раз проверяем стейт, чтобы ничего не потерять # После установки вотчеров еще раз проверяем стейт,
# чтобы не потерять состояние привода.
await self.__reload_state() await self.__reload_state()
while self.__state.vd: # Если живы после предыдущей проверки while self.__state.vd: # Если живы после предыдущей проверки
need_restart = False need_restart = self.__reset
self.__reset = False
need_reload_state = False need_reload_state = False
for event in (await inotify.get_series(timeout=1)): for event in (await inotify.get_series(timeout=1)):
need_reload_state = True need_reload_state = True
if event.mask & (InotifyMask.DELETE_SELF | InotifyMask.MOVE_SELF | InotifyMask.UNMOUNT | InotifyMask.ISDIR): if event.restart:
# Если выгрузили OTG, изменили каталоги, что-то отмонтировали или делают еще какую-то странную фигню # Если выгрузили OTG, изменили каталоги, что-то отмонтировали или делают еще какую-то странную фигню.
# Проверяется маска InotifyMask.ALL_RESTART_EVENTS
logger.info("Got a big inotify event: %s; reinitializing MSD ...", event) logger.info("Got a big inotify event: %s; reinitializing MSD ...", event)
need_restart = True need_restart = True
break break
@ -446,13 +490,30 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
break break
if need_reload_state: if need_reload_state:
await self.__reload_state() await self.__reload_state()
elif self.__writer:
# При загрузке файла обновляем статистику раз в секунду (по таймауту).
# Это не нужно при обычном релоаде, потому что там и так проверяются все разделы.
await self.__reload_parts_info()
except Exception: except Exception:
logger.exception("Unexpected MSD watcher error") logger.exception("Unexpected MSD watcher error")
time.sleep(1) await asyncio.sleep(1)
async def __reload_state(self, notify: bool=True) -> None: async def __reload_state(self) -> None:
logger = get_logger(0)
async with self.__state._lock: # pylint: disable=protected-access async with self.__state._lock: # pylint: disable=protected-access
await self.__unsafe_reload_state()
self.__notifier.notify()
async def __reload_parts_info(self) -> None:
assert self.__writer # Использовать только при записи образа
async with self.__state._lock: # pylint: disable=protected-access
await self.__storage.reload_parts_info()
self.__notifier.notify()
# ===== Don't call this directly ====
async def __unsafe_reload_state(self) -> None:
logger = get_logger(0)
try: try:
path = self.__drive.get_image_path() path = self.__drive.get_image_path()
drive_state = _DriveState( drive_state = _DriveState(
@ -469,7 +530,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
logger.info("Probing to remount storage ...") logger.info("Probing to remount storage ...")
await self.__storage.remount_rw(True) await self.__storage.remount_rw(True)
await self.__storage.remount_rw(False) await self.__storage.remount_rw(False)
await self.__setup_initial() await self.__unsafe_setup_initial()
except Exception: except Exception:
logger.exception("Error while reloading MSD state; switching to offline") logger.exception("Error while reloading MSD state; switching to offline")
@ -492,10 +553,8 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__state.vd.image = None self.__state.vd.image = None
self.__state.vd.connected = False self.__state.vd.connected = False
if notify:
self.__notifier.notify()
async def __setup_initial(self) -> None: async def __unsafe_setup_initial(self) -> None:
if self.__initial_image: if self.__initial_image:
logger = get_logger(0) logger = get_logger(0)
image = await self.__storage.make_image_by_name(self.__initial_image) image = await self.__storage.make_image_by_name(self.__initial_image)

View File

@ -88,7 +88,7 @@ class Drive:
try: try:
with open(os.path.join(self.__lun_path, param), "w") as file: with open(os.path.join(self.__lun_path, param), "w") as file:
file.write(value + "\n") file.write(value + "\n")
except OSError as err: except OSError as ex:
if err.errno == errno.EBUSY: if ex.errno == errno.EBUSY:
raise MsdDriveLockedError() raise MsdDriveLockedError()
raise raise

View File

@ -169,8 +169,6 @@ class _Part(_PartDc):
# ===== # =====
@dataclasses.dataclass(frozen=True, eq=False) @dataclasses.dataclass(frozen=True, eq=False)
class _StorageDc: class _StorageDc:
size: int = dataclasses.field(init=False)
free: int = dataclasses.field(init=False)
images: dict[str, Image] = dataclasses.field(init=False) images: dict[str, Image] = dataclasses.field(init=False)
parts: dict[str, _Part] = dataclasses.field(init=False) parts: dict[str, _Part] = dataclasses.field(init=False)
@ -185,25 +183,15 @@ class Storage(_StorageDc):
self.__images: (dict[str, Image] | None) = None self.__images: (dict[str, Image] | None) = None
self.__parts: (dict[str, _Part] | None) = None self.__parts: (dict[str, _Part] | None) = None
@property
def size(self) -> int: # API Legacy
assert self.__parts is not None
return self.__parts[""].size
@property
def free(self) -> int: # API Legacy
assert self.__parts is not None
return self.__parts[""].free
@property @property
def images(self) -> dict[str, Image]: def images(self) -> dict[str, Image]:
assert self.__images is not None assert self.__images is not None
return self.__images return dict(self.__images)
@property @property
def parts(self) -> dict[str, _Part]: def parts(self) -> dict[str, _Part]:
assert self.__parts is not None assert self.__parts is not None
return self.__parts return dict(self.__parts)
async def reload(self) -> None: async def reload(self) -> None:
self.__watchable_paths = None self.__watchable_paths = None
@ -222,6 +210,7 @@ class Storage(_StorageDc):
part = _Part(name, root_path) part = _Part(name, root_path)
await part._reload() # pylint: disable=protected-access await part._reload() # pylint: disable=protected-access
parts[name] = part parts[name] = part
assert "" in parts, parts
self.__watchable_paths = watchable_paths self.__watchable_paths = watchable_paths
self.__images = images self.__images = images

View File

@ -113,13 +113,13 @@ class Plugin(BaseUserGpioDriver): # pylint: disable=too-many-instance-attribute
while True: while True:
session = self.__ensure_http_session() session = self.__ensure_http_session()
try: try:
async with session.get(f"{self.__url}/strg.cfg") as response: async with session.get(f"{self.__url}/strg.cfg") as resp:
htclient.raise_not_200(response) htclient.raise_not_200(resp)
parts = (await response.text()).split(";") parts = (await resp.text()).split(";")
for pin in self.__state: for pin in self.__state:
self.__state[pin] = (parts[1 + int(pin) * 5] == "1") self.__state[pin] = (parts[1 + int(pin) * 5] == "1")
except Exception as err: except Exception as ex:
get_logger().error("Failed ANELPWR bulk GET request: %s", tools.efmt(err)) get_logger().error("Failed ANELPWR bulk GET request: %s", tools.efmt(ex))
self.__state = dict.fromkeys(self.__state, None) self.__state = dict.fromkeys(self.__state, None)
if self.__state != prev_state: if self.__state != prev_state:
self._notifier.notify() self._notifier.notify()
@ -143,10 +143,10 @@ class Plugin(BaseUserGpioDriver): # pylint: disable=too-many-instance-attribute
url=f"{self.__url}/ctrl.htm", url=f"{self.__url}/ctrl.htm",
data=f"F{pin}={int(state)}", data=f"F{pin}={int(state)}",
headers={"Content-Type": "text/plain"}, headers={"Content-Type": "text/plain"},
) as response: ) as resp:
htclient.raise_not_200(response) htclient.raise_not_200(resp)
except Exception as err: except Exception as ex:
get_logger().error("Failed ANELPWR POST request to pin %s: %s", pin, tools.efmt(err)) get_logger().error("Failed ANELPWR POST request to pin %s: %s", pin, tools.efmt(ex))
raise GpioDriverOfflineError(self) raise GpioDriverOfflineError(self)
self.__update_notifier.notify() self.__update_notifier.notify()

View File

@ -78,9 +78,9 @@ class Plugin(BaseUserGpioDriver): # pylint: disable=too-many-instance-attribute
proc = await aioproc.log_process(self.__cmd, logger=get_logger(0), prefix=str(self)) proc = await aioproc.log_process(self.__cmd, logger=get_logger(0), prefix=str(self))
if proc.returncode != 0: if proc.returncode != 0:
raise RuntimeError(f"Custom command error: retcode={proc.returncode}") raise RuntimeError(f"Custom command error: retcode={proc.returncode}")
except Exception as err: except Exception as ex:
get_logger(0).error("Can't run custom command [ %s ]: %s", get_logger(0).error("Can't run custom command [ %s ]: %s",
tools.cmdfmt(self.__cmd), tools.efmt(err)) tools.cmdfmt(self.__cmd), tools.efmt(ex))
raise GpioDriverOfflineError(self) raise GpioDriverOfflineError(self)
def __str__(self) -> str: def __str__(self) -> str:

Some files were not shown because too many files have changed in this diff Show More