Merge remote-tracking branch 'upstream/master'

This commit is contained in:
mofeng-git
2024-11-20 15:18:34 +00:00
166 changed files with 5421 additions and 2645 deletions

View File

@@ -33,8 +33,6 @@ import pygments.formatters
from .. import tools
from ..mouse import MouseRange
from ..plugins import UnknownPluginError
from ..plugins.auth import get_auth_service_class
from ..plugins.hid import get_hid_class
@@ -171,8 +169,8 @@ def _init_config(config_path: str, override_options: list[str], **load_flags: bo
config_path = os.path.expanduser(config_path)
try:
raw_config: dict = load_yaml_file(config_path)
except Exception as err:
raise SystemExit(f"ConfigError: Can't read config file {config_path!r}:\n{tools.efmt(err)}")
except Exception as ex:
raise SystemExit(f"ConfigError: Can't read config file {config_path!r}:\n{tools.efmt(ex)}")
if not isinstance(raw_config, dict):
raise SystemExit(f"ConfigError: Top-level of the file {config_path!r} must be a dictionary")
@@ -187,8 +185,8 @@ def _init_config(config_path: str, override_options: list[str], **load_flags: bo
config = make_config(raw_config, scheme)
return config
except (ConfigError, UnknownPluginError) as err:
raise SystemExit(f"ConfigError: {err}")
except (ConfigError, UnknownPluginError) as ex:
raise SystemExit(f"ConfigError: {ex}")
def _patch_raw(raw_config: dict) -> None: # pylint: disable=too-many-branches
@@ -407,19 +405,7 @@ def _get_config_scheme() -> dict:
"hid": {
"type": Option("", type=valid_stripped_string_not_empty),
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
"ignore_keys": Option([], type=functools.partial(valid_string_list, subval=valid_hid_key)),
"mouse_x_range": {
"min": Option(MouseRange.MIN, type=valid_hid_mouse_move),
"max": Option(MouseRange.MAX, type=valid_hid_mouse_move),
},
"mouse_y_range": {
"min": Option(MouseRange.MIN, type=valid_hid_mouse_move),
"max": Option(MouseRange.MAX, type=valid_hid_mouse_move),
},
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
# Dynamic content
},
@@ -681,9 +667,10 @@ def _get_config_scheme() -> dict:
},
"vnc": {
"desired_fps": Option(30, type=valid_stream_fps),
"mouse_output": Option("usb", type=valid_hid_mouse_output),
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
"desired_fps": Option(30, type=valid_stream_fps),
"mouse_output": Option("usb", type=valid_hid_mouse_output),
"keymap": Option("/usr/share/kvmd/keymaps/en-us", type=valid_abs_file),
"allow_cut_after": Option(3.0, type=valid_float_f0),
"server": {
"host": Option("", type=valid_ip_or_host, if_empty=""),

View File

@@ -22,259 +22,22 @@
import sys
import os
import re
import dataclasses
import contextlib
import subprocess
import argparse
import time
from typing import IO
from typing import Generator
from typing import Callable
from ...validators.basic import valid_bool
from ...validators.basic import valid_int_f0
from ...edid import EdidNoBlockError
from ...edid import Edid
# from .. import init
# =====
class NoBlockError(Exception):
pass
@contextlib.contextmanager
def _smart_open(path: str, mode: str) -> Generator[IO, None, None]:
fd = (0 if "r" in mode else 1)
with (os.fdopen(fd, mode, closefd=False) if path == "-" else open(path, mode)) as file:
yield file
if "w" in mode:
file.flush()
@dataclasses.dataclass(frozen=True)
class _CeaBlock:
tag: int
data: bytes
def __post_init__(self) -> None:
assert 0 < self.tag <= 0b111
assert 0 < len(self.data) <= 0b11111
@property
def size(self) -> int:
return len(self.data) + 1
def pack(self) -> bytes:
header = (self.tag << 5) | len(self.data)
return header.to_bytes() + self.data
@classmethod
def first_from_raw(cls, raw: (bytes | list[int])) -> "_CeaBlock":
assert 0 < raw[0] <= 0xFF
tag = (raw[0] & 0b11100000) >> 5
data_size = (raw[0] & 0b00011111)
data = bytes(raw[1:data_size + 1])
return _CeaBlock(tag, data)
_CEA = 128
_CEA_AUDIO = 1
_CEA_SPEAKERS = 4
class _Edid:
# https://en.wikipedia.org/wiki/Extended_Display_Identification_Data
def __init__(self, path: str) -> None:
with _smart_open(path, "rb") as file:
data = file.read()
if data.startswith(b"\x00\xFF\xFF\xFF\xFF\xFF\xFF\x00"):
self.__data = list(data)
else:
text = re.sub(r"\s", "", data.decode())
self.__data = [
int(text[index:index + 2], 16)
for index in range(0, len(text), 2)
]
assert len(self.__data) == 256, f"Invalid EDID length: {len(self.__data)}, should be 256 bytes"
assert self.__data[126] == 1, "Zero extensions number"
assert (self.__data[_CEA + 0], self.__data[_CEA + 1]) == (0x02, 0x03), "Can't find CEA extension"
def write_hex(self, path: str) -> None:
self.__update_checksums()
text = "\n".join(
"".join(
f"{item:0{2}X}"
for item in self.__data[index:index + 16]
)
for index in range(0, len(self.__data), 16)
) + "\n"
with _smart_open(path, "w") as file:
file.write(text)
def write_bin(self, path: str) -> None:
self.__update_checksums()
with _smart_open(path, "wb") as file:
file.write(bytes(self.__data))
def __update_checksums(self) -> None:
self.__data[127] = 256 - (sum(self.__data[:127]) % 256)
self.__data[255] = 256 - (sum(self.__data[128:255]) % 256)
# =====
def get_mfc_id(self) -> str:
raw = self.__data[8] << 8 | self.__data[9]
return bytes([
((raw >> 10) & 0b11111) + 0x40,
((raw >> 5) & 0b11111) + 0x40,
(raw & 0b11111) + 0x40,
]).decode("ascii")
def set_mfc_id(self, mfc_id: str) -> None:
assert len(mfc_id) == 3, "Mfc ID must be 3 characters long"
data = mfc_id.upper().encode("ascii")
for ch in data:
assert 0x41 <= ch <= 0x5A, "Mfc ID must contain only A-Z characters"
raw = (
(data[2] - 0x40)
| ((data[1] - 0x40) << 5)
| ((data[0] - 0x40) << 10)
)
self.__data[8] = (raw >> 8) & 0xFF
self.__data[9] = raw & 0xFF
# =====
def get_product_id(self) -> int:
return (self.__data[10] | self.__data[11] << 8)
def set_product_id(self, product_id: int) -> None:
assert 0 <= product_id <= 0xFFFF, f"Product ID should be from 0 to {0xFFFF}"
self.__data[10] = product_id & 0xFF
self.__data[11] = (product_id >> 8) & 0xFF
# =====
def get_serial(self) -> int:
return (
self.__data[12]
| self.__data[13] << 8
| self.__data[14] << 16
| self.__data[15] << 24
)
def set_serial(self, serial: int) -> None:
assert 0 <= serial <= 0xFFFFFFFF, f"Serial should be from 0 to {0xFFFFFFFF}"
self.__data[12] = serial & 0xFF
self.__data[13] = (serial >> 8) & 0xFF
self.__data[14] = (serial >> 16) & 0xFF
self.__data[15] = (serial >> 24) & 0xFF
# =====
def get_monitor_name(self) -> str:
return self.__get_dtd_text(0xFC, "Monitor Name")
def set_monitor_name(self, text: str) -> None:
self.__set_dtd_text(0xFC, "Monitor Name", text)
def get_monitor_serial(self) -> str:
return self.__get_dtd_text(0xFF, "Monitor Serial")
def set_monitor_serial(self, text: str) -> None:
self.__set_dtd_text(0xFF, "Monitor Serial", text)
def __get_dtd_text(self, d_type: int, name: str) -> str:
index = self.__find_dtd_text(d_type, name)
return bytes(self.__data[index:index + 13]).decode("cp437").strip()
def __set_dtd_text(self, d_type: int, name: str, text: str) -> None:
index = self.__find_dtd_text(d_type, name)
encoded = (text[:13] + "\n" + " " * 12)[:13].encode("cp437")
for (offset, ch) in enumerate(encoded):
self.__data[index + offset] = ch
def __find_dtd_text(self, d_type: int, name: str) -> int:
for index in [54, 72, 90, 108]:
if self.__data[index + 3] == d_type:
return index + 5
raise NoBlockError(f"Can't find DTD {name}")
# ===== CEA =====
def get_audio(self) -> bool:
(cbs, _) = self.__parse_cea()
audio = False
speakers = False
for cb in cbs:
if cb.tag == _CEA_AUDIO:
audio = True
elif cb.tag == _CEA_SPEAKERS:
speakers = True
return (audio and speakers and self.__get_basic_audio())
def set_audio(self, enabled: bool) -> None:
(cbs, dtds) = self.__parse_cea()
cbs = [cb for cb in cbs if cb.tag not in [_CEA_AUDIO, _CEA_SPEAKERS]]
if enabled:
cbs.append(_CeaBlock(_CEA_AUDIO, b"\x09\x7f\x07"))
cbs.append(_CeaBlock(_CEA_SPEAKERS, b"\x01\x00\x00"))
self.__replace_cea(cbs, dtds)
self.__set_basic_audio(enabled)
def __get_basic_audio(self) -> bool:
return bool(self.__data[_CEA + 3] & 0b01000000)
def __set_basic_audio(self, enabled: bool) -> None:
if enabled:
self.__data[_CEA + 3] |= 0b01000000
else:
self.__data[_CEA + 3] &= (0xFF - 0b01000000) # ~X
def __parse_cea(self) -> tuple[list[_CeaBlock], bytes]:
cea = self.__data[_CEA:]
dtd_begin = cea[2]
if dtd_begin == 0:
return ([], b"")
cbs: list[_CeaBlock] = []
if dtd_begin > 4:
raw = cea[4:dtd_begin]
while len(raw) != 0:
cb = _CeaBlock.first_from_raw(raw)
cbs.append(cb)
raw = raw[cb.size:]
dtds = b""
assert dtd_begin >= 4
raw = cea[dtd_begin:]
while len(raw) > (18 + 1) and raw[0] != 0:
dtds += bytes(raw[:18])
raw = raw[18:]
return (cbs, dtds)
def __replace_cea(self, cbs: list[_CeaBlock], dtds: bytes) -> None:
cbs_packed = b""
for cb in cbs:
cbs_packed += cb.pack()
raw = cbs_packed + dtds
assert len(raw) <= (128 - 4 - 1), "Too many CEA blocks or DTDs"
self.__data[_CEA + 2] = (0 if len(raw) == 0 else (len(cbs_packed) + 4))
for index in range(4, 127):
try:
ch = raw[index - 4]
except IndexError:
ch = 0
self.__data[_CEA + index] = ch
def _format_bool(value: bool) -> str:
return ("yes" if value else "no")
@@ -283,7 +46,7 @@ def _make_format_hex(size: int) -> Callable[[int], str]:
return (lambda value: ("0x{:0%dX} ({})" % (size * 2)).format(value, value))
def _print_edid(edid: _Edid) -> None:
def _print_edid(edid: Edid) -> None:
for (key, get, fmt) in [
("Manufacturer ID:", edid.get_mfc_id, str),
("Product ID: ", edid.get_product_id, _make_format_hex(2)),
@@ -294,7 +57,7 @@ def _print_edid(edid: _Edid) -> None:
]:
try:
print(key, fmt(get()), file=sys.stderr) # type: ignore
except NoBlockError:
except EdidNoBlockError:
pass
@@ -348,12 +111,12 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
help="Presets directory", metavar="<dir>")
options = parser.parse_args(argv[1:])
base: (_Edid | None) = None
base: (Edid | None) = None
if options.import_preset:
imp = options.import_preset
if "." in imp:
(base_name, imp) = imp.split(".", 1) # v3.1080p-by-default
base = _Edid(os.path.join(options.presets_path, f"{base_name}.hex"))
base = Edid.from_file(os.path.join(options.presets_path, f"{base_name}.hex"))
imp = f"_{imp}"
options.imp = os.path.join(options.presets_path, f"{imp}.hex")
@@ -362,16 +125,16 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
options.export_hex = options.edid_path
options.edid_path = options.imp
edid = _Edid(options.edid_path)
edid = Edid.from_file(options.edid_path)
changed = False
for cmd in dir(_Edid):
for cmd in dir(Edid):
if cmd.startswith("set_"):
value = getattr(options, cmd)
if value is None and base is not None:
try:
value = getattr(base, cmd.replace("set_", "get_"))()
except NoBlockError:
except EdidNoBlockError:
pass
if value is not None:
getattr(edid, cmd)(value)
@@ -400,8 +163,7 @@ def main(argv: (list[str] | None)=None) -> None: # pylint: disable=too-many-bra
"/usr/bin/v4l2-ctl",
f"--device={options.device_path}",
f"--set-edid=file={orig_edid_path}",
"--fix-edid-checksums",
"--info-edid",
], stdout=sys.stderr, check=True)
except subprocess.CalledProcessError as err:
raise SystemExit(str(err))
except subprocess.CalledProcessError as ex:
raise SystemExit(str(ex))

View File

@@ -155,5 +155,5 @@ def main(argv: (list[str] | None)=None) -> None:
options = parser.parse_args(argv[1:])
try:
options.cmd(config, options)
except ValidatorError as err:
raise SystemExit(str(err))
except ValidatorError as ex:
raise SystemExit(str(ex))

View File

@@ -101,6 +101,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
# =====
def handle_raw_request(self, request: dict, session: IpmiServerSession) -> None:
# Parameter 'request' has been renamed to 'req' in overriding method
handler = {
(6, 1): (lambda _, session: self.send_device_id(session)), # Get device ID
(6, 7): self.__get_power_state_handler, # Power state
@@ -145,13 +146,13 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
data = [int(result["leds"]["power"]), 0, 0]
session.send_ipmi_response(data=data)
def __chassis_control_handler(self, request: dict, session: IpmiServerSession) -> None:
def __chassis_control_handler(self, req: dict, session: IpmiServerSession) -> None:
action = {
0: "off_hard",
1: "on",
3: "reset_hard",
5: "off",
}.get(request["data"][0], "")
}.get(req["data"][0], "")
if action:
if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action):
code = 0xC0 # Try again later
@@ -171,8 +172,8 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session:
func = functools.reduce(getattr, func_path.split("."), kvmd_session)
return (await func(**kwargs))
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("[%s]: Can't perform request %s: %s", session.sockaddr[0], name, err)
except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
logger.error("[%s]: Can't perform request %s: %s", session.sockaddr[0], name, ex)
raise
return aiotools.run_sync(runner())

View File

@@ -11,16 +11,17 @@ from ... import aioproc
from ...logging import get_logger
from .stun import StunNatType
from .stun import Stun
# =====
@dataclasses.dataclass(frozen=True)
class _Netcfg:
nat_type: str = dataclasses.field(default="")
src_ip: str = dataclasses.field(default="")
ext_ip: str = dataclasses.field(default="")
stun_host: str = dataclasses.field(default="")
nat_type: StunNatType = dataclasses.field(default=StunNatType.ERROR)
src_ip: str = dataclasses.field(default="")
ext_ip: str = dataclasses.field(default="")
stun_ip: str = dataclasses.field(default="")
stun_port: int = dataclasses.field(default=0)
@@ -92,8 +93,9 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
async def __get_netcfg(self) -> _Netcfg:
src_ip = (self.__get_default_ip() or "0.0.0.0")
(stun, (nat_type, ext_ip)) = await self.__get_stun_info(src_ip)
return _Netcfg(nat_type, src_ip, ext_ip, stun.host, stun.port)
info = await self.__stun.get_info(src_ip, 0)
# В текущей реализации _Netcfg() это копия StunInfo()
return _Netcfg(**dataclasses.asdict(info))
def __get_default_ip(self) -> str:
try:
@@ -111,17 +113,10 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
for proto in [socket.AF_INET, socket.AF_INET6]:
if proto in addrs:
return addrs[proto][0]["addr"]
except Exception as err:
get_logger().error("Can't get default IP: %s", tools.efmt(err))
except Exception as ex:
get_logger().error("Can't get default IP: %s", tools.efmt(ex))
return ""
async def __get_stun_info(self, src_ip: str) -> tuple[Stun, tuple[str, str]]:
try:
return (self.__stun, (await self.__stun.get_info(src_ip, 0)))
except Exception as err:
get_logger().error("Can't get STUN info: %s", tools.efmt(err))
return (self.__stun, ("", ""))
# =====
@aiotools.atomic_fg
@@ -162,7 +157,7 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
async def __start_janus_proc(self, netcfg: _Netcfg) -> None:
assert self.__janus_proc is None
placeholders = {
"o_stun_server": f"--stun-server={netcfg.stun_host}:{netcfg.stun_port}",
"o_stun_server": f"--stun-server={netcfg.stun_ip}:{netcfg.stun_port}",
**{
key: str(value)
for (key, value) in dataclasses.asdict(netcfg).items()

View File

@@ -4,6 +4,7 @@ import ipaddress
import struct
import secrets
import dataclasses
import enum
from ... import tools
from ... import aiotools
@@ -12,29 +13,39 @@ from ...logging import get_logger
# =====
class StunNatType(enum.Enum):
ERROR = ""
BLOCKED = "Blocked"
OPEN_INTERNET = "Open Internet"
SYMMETRIC_UDP_FW = "Symmetric UDP Firewall"
FULL_CONE_NAT = "Full Cone NAT"
RESTRICTED_NAT = "Restricted NAT"
RESTRICTED_PORT_NAT = "Restricted Port NAT"
SYMMETRIC_NAT = "Symmetric NAT"
CHANGED_ADDR_ERROR = "Error when testing on Changed-IP and Port"
@dataclasses.dataclass(frozen=True)
class StunAddress:
ip: str
class StunInfo:
nat_type: StunNatType
src_ip: str
ext_ip: str
stun_ip: str
stun_port: int
@dataclasses.dataclass(frozen=True)
class _StunAddress:
ip: str
port: int
@dataclasses.dataclass(frozen=True)
class StunResponse:
ok: bool
ext: (StunAddress | None) = dataclasses.field(default=None)
src: (StunAddress | None) = dataclasses.field(default=None)
changed: (StunAddress | None) = dataclasses.field(default=None)
class StunNatType:
BLOCKED = "Blocked"
OPEN_INTERNET = "Open Internet"
SYMMETRIC_UDP_FW = "Symmetric UDP Firewall"
FULL_CONE_NAT = "Full Cone NAT"
RESTRICTED_NAT = "Restricted NAT"
RESTRICTED_PORT_NAT = "Restricted Port NAT"
SYMMETRIC_NAT = "Symmetric NAT"
CHANGED_ADDR_ERROR = "Error when testing on Changed-IP and Port"
class _StunResponse:
ok: bool
ext: (_StunAddress | None) = dataclasses.field(default=None)
src: (_StunAddress | None) = dataclasses.field(default=None)
changed: (_StunAddress | None) = dataclasses.field(default=None)
# =====
@@ -50,58 +61,94 @@ class Stun:
retries_delay: float,
) -> None:
self.host = host
self.port = port
self.__host = host
self.__port = port
self.__timeout = timeout
self.__retries = retries
self.__retries_delay = retries_delay
self.__stun_ip = ""
self.__sock: (socket.socket | None) = None
async def get_info(self, src_ip: str, src_port: int) -> tuple[str, str]:
(family, _, _, _, addr) = socket.getaddrinfo(src_ip, src_port, type=socket.SOCK_DGRAM)[0]
async def get_info(self, src_ip: str, src_port: int) -> StunInfo:
nat_type = StunNatType.ERROR
ext_ip = ""
try:
with socket.socket(family, socket.SOCK_DGRAM) as self.__sock:
(src_fam, _, _, _, src_addr) = (await self.__retried_getaddrinfo_udp(src_ip, src_port))[0]
stun_ips = [
stun_addr[0]
for (stun_fam, _, _, _, stun_addr) in (await self.__retried_getaddrinfo_udp(self.__host, self.__port))
if stun_fam == src_fam
]
if not stun_ips:
raise RuntimeError(f"Can't resolve {src_fam.name} address for STUN")
if not self.__stun_ip or self.__stun_ip not in stun_ips:
# On new IP, changed family, etc.
self.__stun_ip = stun_ips[0]
with socket.socket(src_fam, socket.SOCK_DGRAM) as self.__sock:
self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.__sock.settimeout(self.__timeout)
self.__sock.bind(addr)
(nat_type, response) = await self.__get_nat_type(src_ip)
return (nat_type, (response.ext.ip if response.ext is not None else ""))
self.__sock.bind(src_addr)
(nat_type, resp) = await self.__get_nat_type(src_ip)
ext_ip = (resp.ext.ip if resp.ext is not None else "")
except Exception as ex:
get_logger(0).error("Can't get STUN info: %s", tools.efmt(ex))
finally:
self.__sock = None
async def __get_nat_type(self, src_ip: str) -> tuple[str, StunResponse]: # pylint: disable=too-many-return-statements
first = await self.__make_request("First probe")
return StunInfo(
nat_type=nat_type,
src_ip=src_ip,
ext_ip=ext_ip,
stun_ip=self.__stun_ip,
stun_port=self.__port,
)
async def __retried_getaddrinfo_udp(self, host: str, port: int) -> list:
retries = self.__retries
while True:
try:
return socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
except Exception:
retries -= 1
if retries == 0:
raise
await asyncio.sleep(self.__retries_delay)
async def __get_nat_type(self, src_ip: str) -> tuple[StunNatType, _StunResponse]: # pylint: disable=too-many-return-statements
first = await self.__make_request("First probe", self.__stun_ip, b"")
if not first.ok:
return (StunNatType.BLOCKED, first)
request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request
response = await self.__make_request("Change request [ext_ip == src_ip]", request)
req = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request
resp = await self.__make_request("Change request [ext_ip == src_ip]", self.__stun_ip, req)
if first.ext is not None and first.ext.ip == src_ip:
if response.ok:
return (StunNatType.OPEN_INTERNET, response)
return (StunNatType.SYMMETRIC_UDP_FW, response)
if resp.ok:
return (StunNatType.OPEN_INTERNET, resp)
return (StunNatType.SYMMETRIC_UDP_FW, resp)
if response.ok:
return (StunNatType.FULL_CONE_NAT, response)
if resp.ok:
return (StunNatType.FULL_CONE_NAT, resp)
if first.changed is None:
raise RuntimeError(f"Changed addr is None: {first}")
response = await self.__make_request("Change request [ext_ip != src_ip]", addr=first.changed)
if not response.ok:
return (StunNatType.CHANGED_ADDR_ERROR, response)
resp = await self.__make_request("Change request [ext_ip != src_ip]", first.changed, b"")
if not resp.ok:
return (StunNatType.CHANGED_ADDR_ERROR, resp)
if response.ext == first.ext:
request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002)
response = await self.__make_request("Change port", request, addr=first.changed.ip)
if response.ok:
return (StunNatType.RESTRICTED_NAT, response)
return (StunNatType.RESTRICTED_PORT_NAT, response)
if resp.ext == first.ext:
req = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002)
resp = await self.__make_request("Change port", first.changed.ip, req)
if resp.ok:
return (StunNatType.RESTRICTED_NAT, resp)
return (StunNatType.RESTRICTED_PORT_NAT, resp)
return (StunNatType.SYMMETRIC_NAT, response)
return (StunNatType.SYMMETRIC_NAT, resp)
async def __make_request(self, ctx: str, request: bytes=b"", addr: (StunAddress | str | None)=None) -> StunResponse:
async def __make_request(self, ctx: str, addr: (_StunAddress | str), req: bytes) -> _StunResponse:
# TODO: Support IPv6 and RFC 5389
# The first 4 bytes of the response are the Type (2) and Length (2)
# The 5th byte is Reserved
@@ -111,32 +158,29 @@ class Stun:
# More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1
# And at: https://tools.ietf.org/html/rfc5389#section-15.1
if isinstance(addr, StunAddress):
if isinstance(addr, _StunAddress):
addr_t = (addr.ip, addr.port)
elif isinstance(addr, str):
addr_t = (addr, self.port)
else:
assert addr is None
addr_t = (self.host, self.port)
else: # str
addr_t = (addr, self.__port)
# https://datatracker.ietf.org/doc/html/rfc5389#section-6
trans_id = b"\x21\x12\xA4\x42" + secrets.token_bytes(12)
(response, error) = (b"", "")
(resp, error) = (b"", "")
for _ in range(self.__retries):
(response, error) = await self.__inner_make_request(trans_id, request, addr_t)
(resp, error) = await self.__inner_make_request(trans_id, req, addr_t)
if not error:
break
await asyncio.sleep(self.__retries_delay)
if error:
get_logger(0).error("%s: Can't perform STUN request after %d retries; last error: %s",
ctx, self.__retries, error)
return StunResponse(ok=False)
return _StunResponse(ok=False)
parsed: dict[str, StunAddress] = {}
parsed: dict[str, _StunAddress] = {}
offset = 0
remaining = len(response)
remaining = len(resp)
while remaining > 0:
(attr_type, attr_len) = struct.unpack(">HH", response[offset : offset + 4]) # noqa: E203
(attr_type, attr_len) = struct.unpack(">HH", resp[offset : offset + 4]) # noqa: E203
offset += 4
field = {
0x0001: "ext", # MAPPED-ADDRESS
@@ -145,40 +189,40 @@ class Stun:
0x0005: "changed", # CHANGED-ADDRESS
}.get(attr_type)
if field is not None:
parsed[field] = self.__parse_address(response[offset:], (trans_id if attr_type == 0x0020 else b""))
parsed[field] = self.__parse_address(resp[offset:], (trans_id if attr_type == 0x0020 else b""))
offset += attr_len
remaining -= (4 + attr_len)
return StunResponse(ok=True, **parsed)
return _StunResponse(ok=True, **parsed)
async def __inner_make_request(self, trans_id: bytes, request: bytes, addr: tuple[str, int]) -> tuple[bytes, str]:
async def __inner_make_request(self, trans_id: bytes, req: bytes, addr: tuple[str, int]) -> tuple[bytes, str]:
assert self.__sock is not None
request = struct.pack(">HH", 0x0001, len(request)) + trans_id + request # Bind Request
req = struct.pack(">HH", 0x0001, len(req)) + trans_id + req # Bind Request
try:
await aiotools.run_async(self.__sock.sendto, request, addr)
except Exception as err:
return (b"", f"Send error: {tools.efmt(err)}")
await aiotools.run_async(self.__sock.sendto, req, addr)
except Exception as ex:
return (b"", f"Send error: {tools.efmt(ex)}")
try:
response = (await aiotools.run_async(self.__sock.recvfrom, 2048))[0]
except Exception as err:
return (b"", f"Recv error: {tools.efmt(err)}")
resp = (await aiotools.run_async(self.__sock.recvfrom, 2048))[0]
except Exception as ex:
return (b"", f"Recv error: {tools.efmt(ex)}")
(response_type, payload_len) = struct.unpack(">HH", response[:4])
if response_type != 0x0101:
return (b"", f"Invalid response type: {response_type:#06x}")
if trans_id != response[4:20]:
(resp_type, payload_len) = struct.unpack(">HH", resp[:4])
if resp_type != 0x0101:
return (b"", f"Invalid response type: {resp_type:#06x}")
if trans_id != resp[4:20]:
return (b"", "Transaction ID mismatch")
return (response[20 : 20 + payload_len], "") # noqa: E203
return (resp[20 : 20 + payload_len], "") # noqa: E203
def __parse_address(self, data: bytes, trans_id: bytes) -> StunAddress:
def __parse_address(self, data: bytes, trans_id: bytes) -> _StunAddress:
family = data[1]
port = struct.unpack(">H", self.__trans_xor(data[2:4], trans_id))[0]
if family == 0x01:
return StunAddress(str(ipaddress.IPv4Address(self.__trans_xor(data[4:8], trans_id))), port)
return _StunAddress(str(ipaddress.IPv4Address(self.__trans_xor(data[4:8], trans_id))), port)
elif family == 0x02:
return StunAddress(str(ipaddress.IPv6Address(self.__trans_xor(data[4:20], trans_id))), port)
return _StunAddress(str(ipaddress.IPv6Address(self.__trans_xor(data[4:20], trans_id))), port)
raise RuntimeError(f"Unknown family; received: {family}")
def __trans_xor(self, data: bytes, trans_id: bytes) -> bytes:

View File

@@ -56,7 +56,7 @@ def main(argv: (list[str] | None)=None) -> None:
if config.kvmd.msd.type == "otg":
msd_kwargs["gadget"] = config.otg.gadget # XXX: Small crutch to pass gadget name to the plugin
hid_kwargs = config.kvmd.hid._unpack(ignore=["type", "keymap", "ignore_keys", "mouse_x_range", "mouse_y_range"])
hid_kwargs = config.kvmd.hid._unpack(ignore=["type", "keymap"])
if config.kvmd.hid.type == "otg":
hid_kwargs["udc"] = config.otg.udc # XXX: Small crutch to pass UDC to the plugin
@@ -103,9 +103,6 @@ def main(argv: (list[str] | None)=None) -> None:
),
keymap_path=config.hid.keymap,
ignore_keys=config.hid.ignore_keys,
mouse_x_range=(config.hid.mouse_x_range.min, config.hid.mouse_x_range.max),
mouse_y_range=(config.hid.mouse_y_range.min, config.hid.mouse_y_range.max),
stream_forever=config.streamer.forever,
).run(**config.server._unpack())

View File

@@ -45,9 +45,9 @@ class AtxApi:
return make_json_response(await self.__atx.get_state())
@exposed_http("POST", "/atx/power")
async def __power_handler(self, request: Request) -> Response:
action = valid_atx_power_action(request.query.get("action"))
wait = valid_bool(request.query.get("wait", False))
async def __power_handler(self, req: Request) -> Response:
action = valid_atx_power_action(req.query.get("action"))
wait = valid_bool(req.query.get("wait", False))
await ({
"on": self.__atx.power_on,
"off": self.__atx.power_off,
@@ -57,9 +57,9 @@ class AtxApi:
return make_json_response()
@exposed_http("POST", "/atx/click")
async def __click_handler(self, request: Request) -> Response:
button = valid_atx_button(request.query.get("button"))
wait = valid_bool(request.query.get("wait", False))
async def __click_handler(self, req: Request) -> Response:
button = valid_atx_button(req.query.get("button"))
wait = valid_bool(req.query.get("wait", False))
await ({
"power": self.__atx.click_power,
"power_long": self.__atx.click_power_long,

View File

@@ -43,34 +43,34 @@ from ..auth import AuthManager
_COOKIE_AUTH_TOKEN = "auth_token"
async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, request: Request) -> None:
async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, req: Request) -> None:
if auth_manager.is_auth_required(exposed):
user = request.headers.get("X-KVMD-User", "")
user = req.headers.get("X-KVMD-User", "")
if user:
user = valid_user(user)
passwd = request.headers.get("X-KVMD-Passwd", "")
set_request_auth_info(request, f"{user} (xhdr)")
passwd = req.headers.get("X-KVMD-Passwd", "")
set_request_auth_info(req, f"{user} (xhdr)")
if not (await auth_manager.authorize(user, valid_passwd(passwd))):
raise ForbiddenError()
return
token = request.cookies.get(_COOKIE_AUTH_TOKEN, "")
token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
if token:
user = auth_manager.check(valid_auth_token(token)) # type: ignore
if not user:
set_request_auth_info(request, "- (token)")
set_request_auth_info(req, "- (token)")
raise ForbiddenError()
set_request_auth_info(request, f"{user} (token)")
set_request_auth_info(req, f"{user} (token)")
return
basic_auth = request.headers.get("Authorization", "")
basic_auth = req.headers.get("Authorization", "")
if basic_auth and basic_auth[:6].lower() == "basic ":
try:
(user, passwd) = base64.b64decode(basic_auth[6:]).decode("utf-8").split(":")
except Exception:
raise UnauthorizedError()
user = valid_user(user)
set_request_auth_info(request, f"{user} (basic)")
set_request_auth_info(req, f"{user} (basic)")
if not (await auth_manager.authorize(user, valid_passwd(passwd))):
raise ForbiddenError()
return
@@ -85,9 +85,9 @@ class AuthApi:
# =====
@exposed_http("POST", "/auth/login", auth_required=False)
async def __login_handler(self, request: Request) -> Response:
async def __login_handler(self, req: Request) -> Response:
if self.__auth_manager.is_auth_enabled():
credentials = await request.post()
credentials = await req.post()
token = await self.__auth_manager.login(
user=valid_user(credentials.get("user", "")),
passwd=valid_passwd(credentials.get("passwd", "")),
@@ -98,9 +98,9 @@ class AuthApi:
return make_json_response()
@exposed_http("POST", "/auth/logout")
async def __logout_handler(self, request: Request) -> Response:
async def __logout_handler(self, req: Request) -> Response:
if self.__auth_manager.is_auth_enabled():
token = valid_auth_token(request.cookies.get(_COOKIE_AUTH_TOKEN, ""))
token = valid_auth_token(req.cookies.get(_COOKIE_AUTH_TOKEN, ""))
self.__auth_manager.logout(token)
return make_json_response()

View File

@@ -55,10 +55,9 @@ class ExportApi:
@async_lru.alru_cache(maxsize=1, ttl=5)
async def __get_prometheus_metrics(self) -> str:
(atx_state, hw_state, fan_state, gpio_state) = await asyncio.gather(*[
(atx_state, info_state, gpio_state) = await asyncio.gather(*[
self.__atx.get_state(),
self.__info_manager.get_submanager("hw").get_state(),
self.__info_manager.get_submanager("fan").get_state(),
self.__info_manager.get_state(["hw", "fan"]),
self.__user_gpio.get_state(),
])
rows: list[str] = []
@@ -72,8 +71,8 @@ class ExportApi:
for key in ["online", "state"]:
self.__append_prometheus_rows(rows, ch_state["state"], f"pikvm_gpio_{mode}_{key}_{channel}")
self.__append_prometheus_rows(rows, hw_state["health"], "pikvm_hw") # type: ignore
self.__append_prometheus_rows(rows, fan_state, "pikvm_fan")
self.__append_prometheus_rows(rows, info_state["hw"]["health"], "pikvm_hw") # type: ignore
self.__append_prometheus_rows(rows, info_state["fan"], "pikvm_fan")
return "\n".join(rows)

View File

@@ -25,13 +25,12 @@ import stat
import functools
import struct
from typing import Iterable
from typing import Callable
from aiohttp.web import Request
from aiohttp.web import Response
from ....mouse import MouseRange
from ....keyboard.keysym import build_symmap
from ....keyboard.printer import text_to_web_keys
@@ -59,12 +58,7 @@ class HidApi:
def __init__(
self,
hid: BaseHid,
keymap_path: str,
ignore_keys: list[str],
mouse_x_range: tuple[int, int],
mouse_y_range: tuple[int, int],
) -> None:
self.__hid = hid
@@ -73,11 +67,6 @@ class HidApi:
self.__default_keymap_name = os.path.basename(keymap_path)
self.__ensure_symmap(self.__default_keymap_name)
self.__ignore_keys = ignore_keys
self.__mouse_x_range = mouse_x_range
self.__mouse_y_range = mouse_y_range
# =====
@exposed_http("GET", "/hid")
@@ -85,22 +74,22 @@ class HidApi:
return make_json_response(await self.__hid.get_state())
@exposed_http("POST", "/hid/set_params")
async def __set_params_handler(self, request: Request) -> Response:
async def __set_params_handler(self, req: Request) -> Response:
params = {
key: validator(request.query.get(key))
key: validator(req.query.get(key))
for (key, validator) in [
("keyboard_output", valid_hid_keyboard_output),
("mouse_output", valid_hid_mouse_output),
("jiggler", valid_bool),
]
if request.query.get(key) is not None
if req.query.get(key) is not None
}
self.__hid.set_params(**params) # type: ignore
return make_json_response()
@exposed_http("POST", "/hid/set_connected")
async def __set_connected_handler(self, request: Request) -> Response:
self.__hid.set_connected(valid_bool(request.query.get("connected")))
async def __set_connected_handler(self, req: Request) -> Response:
self.__hid.set_connected(valid_bool(req.query.get("connected")))
return make_json_response()
@exposed_http("POST", "/hid/reset")
@@ -128,13 +117,13 @@ class HidApi:
return make_json_response(await self.get_keymaps())
@exposed_http("POST", "/hid/print")
async def __print_handler(self, request: Request) -> Response:
text = await request.text()
limit = int(valid_int_f0(request.query.get("limit", 1024)))
async def __print_handler(self, req: Request) -> Response:
text = await req.text()
limit = int(valid_int_f0(req.query.get("limit", 1024)))
if limit > 0:
text = text[:limit]
symmap = self.__ensure_symmap(request.query.get("keymap", self.__default_keymap_name))
self.__hid.send_key_events(text_to_web_keys(text, symmap))
symmap = self.__ensure_symmap(req.query.get("keymap", self.__default_keymap_name))
self.__hid.send_key_events(text_to_web_keys(text, symmap), no_ignore_keys=True)
return make_json_response()
def __ensure_symmap(self, keymap_name: str) -> dict[int, dict[int, str]]:
@@ -162,8 +151,7 @@ class HidApi:
state = valid_bool(data[0])
except Exception:
return
if key not in self.__ignore_keys:
self.__hid.send_key_events([(key, state)])
self.__hid.send_key_event(key, state)
@exposed_ws(2)
async def __ws_bin_mouse_button_handler(self, _: WsSession, data: bytes) -> None:
@@ -182,17 +170,17 @@ class HidApi:
to_y = valid_hid_mouse_move(to_y)
except Exception:
return
self.__send_mouse_move_event(to_x, to_y)
self.__hid.send_mouse_move_event(to_x, to_y)
@exposed_ws(4)
async def __ws_bin_mouse_relative_handler(self, _: WsSession, data: bytes) -> None:
self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_relative_event)
self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_relative_events)
@exposed_ws(5)
async def __ws_bin_mouse_wheel_handler(self, _: WsSession, data: bytes) -> None:
self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_wheel_event)
self.__process_ws_bin_delta_request(data, self.__hid.send_mouse_wheel_events)
def __process_ws_bin_delta_request(self, data: bytes, handler: Callable[[int, int], None]) -> None:
def __process_ws_bin_delta_request(self, data: bytes, handler: Callable[[Iterable[tuple[int, int]], bool], None]) -> None:
try:
squash = valid_bool(data[0])
data = data[1:]
@@ -202,7 +190,7 @@ class HidApi:
deltas.append((valid_hid_mouse_delta(delta_x), valid_hid_mouse_delta(delta_y)))
except Exception:
return
self.__send_mouse_delta_event(deltas, squash, handler)
handler(deltas, squash)
# =====
@@ -213,8 +201,7 @@ class HidApi:
state = valid_bool(event["state"])
except Exception:
return
if key not in self.__ignore_keys:
self.__hid.send_key_events([(key, state)])
self.__hid.send_key_event(key, state)
@exposed_ws("mouse_button")
async def __ws_mouse_button_handler(self, _: WsSession, event: dict) -> None:
@@ -232,17 +219,17 @@ class HidApi:
to_y = valid_hid_mouse_move(event["to"]["y"])
except Exception:
return
self.__send_mouse_move_event(to_x, to_y)
self.__hid.send_mouse_move_event(to_x, to_y)
@exposed_ws("mouse_relative")
async def __ws_mouse_relative_handler(self, _: WsSession, event: dict) -> None:
self.__process_ws_delta_event(event, self.__hid.send_mouse_relative_event)
self.__process_ws_delta_event(event, self.__hid.send_mouse_relative_events)
@exposed_ws("mouse_wheel")
async def __ws_mouse_wheel_handler(self, _: WsSession, event: dict) -> None:
self.__process_ws_delta_event(event, self.__hid.send_mouse_wheel_event)
self.__process_ws_delta_event(event, self.__hid.send_mouse_wheel_events)
def __process_ws_delta_event(self, event: dict, handler: Callable[[int, int], None]) -> None:
def __process_ws_delta_event(self, event: dict, handler: Callable[[Iterable[tuple[int, int]], bool], None]) -> None:
try:
raw_delta = event["delta"]
deltas = [
@@ -252,26 +239,25 @@ class HidApi:
squash = valid_bool(event.get("squash", False))
except Exception:
return
self.__send_mouse_delta_event(deltas, squash, handler)
handler(deltas, squash)
# =====
@exposed_http("POST", "/hid/events/send_key")
async def __events_send_key_handler(self, request: Request) -> Response:
key = valid_hid_key(request.query.get("key"))
if key not in self.__ignore_keys:
if "state" in request.query:
state = valid_bool(request.query["state"])
self.__hid.send_key_events([(key, state)])
else:
self.__hid.send_key_events([(key, True), (key, False)])
async def __events_send_key_handler(self, req: Request) -> Response:
key = valid_hid_key(req.query.get("key"))
if "state" in req.query:
state = valid_bool(req.query["state"])
self.__hid.send_key_event(key, state)
else:
self.__hid.send_key_events([(key, True), (key, False)])
return make_json_response()
@exposed_http("POST", "/hid/events/send_mouse_button")
async def __events_send_mouse_button_handler(self, request: Request) -> Response:
button = valid_hid_mouse_button(request.query.get("button"))
if "state" in request.query:
state = valid_bool(request.query["state"])
async def __events_send_mouse_button_handler(self, req: Request) -> Response:
button = valid_hid_mouse_button(req.query.get("button"))
if "state" in req.query:
state = valid_bool(req.query["state"])
self.__hid.send_mouse_button_event(button, state)
else:
self.__hid.send_mouse_button_event(button, True)
@@ -279,52 +265,22 @@ class HidApi:
return make_json_response()
@exposed_http("POST", "/hid/events/send_mouse_move")
async def __events_send_mouse_move_handler(self, request: Request) -> Response:
to_x = valid_hid_mouse_move(request.query.get("to_x"))
to_y = valid_hid_mouse_move(request.query.get("to_y"))
self.__send_mouse_move_event(to_x, to_y)
async def __events_send_mouse_move_handler(self, req: Request) -> Response:
to_x = valid_hid_mouse_move(req.query.get("to_x"))
to_y = valid_hid_mouse_move(req.query.get("to_y"))
self.__hid.send_mouse_move_event(to_x, to_y)
return make_json_response()
@exposed_http("POST", "/hid/events/send_mouse_relative")
async def __events_send_mouse_relative_handler(self, request: Request) -> Response:
return self.__process_http_delta_event(request, self.__hid.send_mouse_relative_event)
async def __events_send_mouse_relative_handler(self, req: Request) -> Response:
return self.__process_http_delta_event(req, self.__hid.send_mouse_relative_event)
@exposed_http("POST", "/hid/events/send_mouse_wheel")
async def __events_send_mouse_wheel_handler(self, request: Request) -> Response:
return self.__process_http_delta_event(request, self.__hid.send_mouse_wheel_event)
async def __events_send_mouse_wheel_handler(self, req: Request) -> Response:
return self.__process_http_delta_event(req, self.__hid.send_mouse_wheel_event)
def __process_http_delta_event(self, request: Request, handler: Callable[[int, int], None]) -> Response:
delta_x = valid_hid_mouse_delta(request.query.get("delta_x"))
delta_y = valid_hid_mouse_delta(request.query.get("delta_y"))
def __process_http_delta_event(self, req: Request, handler: Callable[[int, int], None]) -> Response:
delta_x = valid_hid_mouse_delta(req.query.get("delta_x"))
delta_y = valid_hid_mouse_delta(req.query.get("delta_y"))
handler(delta_x, delta_y)
return make_json_response()
# =====
def __send_mouse_move_event(self, to_x: int, to_y: int) -> None:
if self.__mouse_x_range != MouseRange.RANGE:
to_x = MouseRange.remap(to_x, *self.__mouse_x_range)
if self.__mouse_y_range != MouseRange.RANGE:
to_y = MouseRange.remap(to_y, *self.__mouse_y_range)
self.__hid.send_mouse_move_event(to_x, to_y)
def __send_mouse_delta_event(
self,
deltas: list[tuple[int, int]],
squash: bool,
handler: Callable[[int, int], None],
) -> None:
if squash:
prev = (0, 0)
for cur in deltas:
if abs(prev[0] + cur[0]) > 127 or abs(prev[1] + cur[1]) > 127:
handler(*prev)
prev = cur
else:
prev = (prev[0] + cur[0], prev[1] + cur[1])
if prev[0] or prev[1]:
handler(*prev)
else:
for xy in deltas:
handler(*xy)

View File

@@ -20,8 +20,6 @@
# ========================================================================== #
import asyncio
from aiohttp.web import Request
from aiohttp.web import Response
@@ -41,17 +39,13 @@ class InfoApi:
# =====
@exposed_http("GET", "/info")
async def __common_state_handler(self, request: Request) -> Response:
fields = self.__valid_info_fields(request)
results = dict(zip(fields, await asyncio.gather(*[
self.__info_manager.get_submanager(field).get_state()
for field in fields
])))
return make_json_response(results)
async def __common_state_handler(self, req: Request) -> Response:
fields = self.__valid_info_fields(req)
return make_json_response(await self.__info_manager.get_state(fields))
def __valid_info_fields(self, request: Request) -> list[str]:
subs = self.__info_manager.get_subs()
def __valid_info_fields(self, req: Request) -> list[str]:
available = self.__info_manager.get_subs()
return sorted(valid_info_fields(
arg=request.query.get("fields", ",".join(subs)),
variants=subs,
) or subs)
arg=req.query.get("fields", ",".join(available)),
variants=available,
) or available)

View File

@@ -47,12 +47,12 @@ class LogApi:
# =====
@exposed_http("GET", "/log")
async def __log_handler(self, request: Request) -> StreamResponse:
async def __log_handler(self, req: Request) -> StreamResponse:
if self.__log_reader is None:
raise LogReaderDisabledError()
seek = valid_log_seek(request.query.get("seek", 0))
follow = valid_bool(request.query.get("follow", False))
response = await start_streaming(request, "text/plain")
seek = valid_log_seek(req.query.get("seek", 0))
follow = valid_bool(req.query.get("follow", False))
response = await start_streaming(req, "text/plain")
try:
async for record in self.__log_reader.poll_log(seek, follow):
await response.write(("[%s %s] --- %s" % (

View File

@@ -63,32 +63,36 @@ class MsdApi:
@exposed_http("GET", "/msd")
async def __state_handler(self, _: Request) -> Response:
return make_json_response(await self.__msd.get_state())
state = await self.__msd.get_state()
if state["storage"] and state["storage"]["parts"]:
state["storage"]["size"] = state["storage"]["parts"][""]["size"] # Legacy API
state["storage"]["free"] = state["storage"]["parts"][""]["free"] # Legacy API
return make_json_response(state)
@exposed_http("POST", "/msd/set_params")
async def __set_params_handler(self, request: Request) -> Response:
async def __set_params_handler(self, req: Request) -> Response:
params = {
key: validator(request.query.get(param))
key: validator(req.query.get(param))
for (param, key, validator) in [
("image", "name", (lambda arg: str(arg).strip() and valid_msd_image_name(arg))),
("cdrom", "cdrom", valid_bool),
("rw", "rw", valid_bool),
]
if request.query.get(param) is not None
if req.query.get(param) is not None
}
await self.__msd.set_params(**params) # type: ignore
return make_json_response()
@exposed_http("POST", "/msd/set_connected")
async def __set_connected_handler(self, request: Request) -> Response:
await self.__msd.set_connected(valid_bool(request.query.get("connected")))
async def __set_connected_handler(self, req: Request) -> Response:
await self.__msd.set_connected(valid_bool(req.query.get("connected")))
return make_json_response()
# =====
@exposed_http("GET", "/msd/read")
async def __read_handler(self, request: Request) -> StreamResponse:
name = valid_msd_image_name(request.query.get("image"))
async def __read_handler(self, req: Request) -> StreamResponse:
name = valid_msd_image_name(req.query.get("image"))
compressors = {
"": ("", None),
"none": ("", None),
@@ -96,7 +100,7 @@ class MsdApi:
"zstd": (".zst", (lambda: zstandard.ZstdCompressor().compressobj())), # pylint: disable=unnecessary-lambda
}
(suffix, make_compressor) = compressors[check_string_in_list(
arg=request.query.get("compress", ""),
arg=req.query.get("compress", ""),
name="Compression mode",
variants=set(compressors),
)]
@@ -127,7 +131,7 @@ class MsdApi:
src = compressed()
size = -1
response = await start_streaming(request, "application/octet-stream", size, name + suffix)
response = await start_streaming(req, "application/octet-stream", size, name + suffix)
async for chunk in src:
await response.write(chunk)
return response
@@ -135,28 +139,28 @@ class MsdApi:
# =====
@exposed_http("POST", "/msd/write")
async def __write_handler(self, request: Request) -> Response:
unsafe_prefix = request.query.get("prefix", "") + "/"
name = valid_msd_image_name(unsafe_prefix + request.query.get("image", ""))
size = valid_int_f0(request.content_length)
remove_incomplete = self.__get_remove_incomplete(request)
async def __write_handler(self, req: Request) -> Response:
unsafe_prefix = req.query.get("prefix", "") + "/"
name = valid_msd_image_name(unsafe_prefix + req.query.get("image", ""))
size = valid_int_f0(req.content_length)
remove_incomplete = self.__get_remove_incomplete(req)
written = 0
async with self.__msd.write_image(name, size, remove_incomplete) as writer:
chunk_size = writer.get_chunk_size()
while True:
chunk = await request.content.read(chunk_size)
chunk = await req.content.read(chunk_size)
if not chunk:
break
written = await writer.write_chunk(chunk)
return make_json_response(self.__make_write_info(name, size, written))
@exposed_http("POST", "/msd/write_remote")
async def __write_remote_handler(self, request: Request) -> (Response | StreamResponse): # pylint: disable=too-many-locals
unsafe_prefix = request.query.get("prefix", "") + "/"
url = valid_url(request.query.get("url"))
insecure = valid_bool(request.query.get("insecure", False))
timeout = valid_float_f01(request.query.get("timeout", 10.0))
remove_incomplete = self.__get_remove_incomplete(request)
async def __write_remote_handler(self, req: Request) -> (Response | StreamResponse): # pylint: disable=too-many-locals
unsafe_prefix = req.query.get("prefix", "") + "/"
url = valid_url(req.query.get("url"))
insecure = valid_bool(req.query.get("insecure", False))
timeout = valid_float_f01(req.query.get("timeout", 10.0))
remove_incomplete = self.__get_remove_incomplete(req)
name = ""
size = written = 0
@@ -174,7 +178,7 @@ class MsdApi:
read_timeout=(7 * 24 * 3600),
) as remote:
name = str(request.query.get("image", "")).strip()
name = str(req.query.get("image", "")).strip()
if len(name) == 0:
name = htclient.get_filename(remote)
name = valid_msd_image_name(unsafe_prefix + name)
@@ -184,7 +188,7 @@ class MsdApi:
get_logger(0).info("Downloading image %r as %r to MSD ...", url, name)
async with self.__msd.write_image(name, size, remove_incomplete) as writer:
chunk_size = writer.get_chunk_size()
response = await start_streaming(request, "application/x-ndjson")
response = await start_streaming(req, "application/x-ndjson")
await stream_write_info()
last_report_ts = 0
async for chunk in remote.content.iter_chunked(chunk_size):
@@ -197,16 +201,16 @@ class MsdApi:
await stream_write_info()
return response
except Exception as err:
except Exception as ex:
if response is not None:
await stream_write_info()
await stream_json_exception(response, err)
elif isinstance(err, aiohttp.ClientError):
return make_json_exception(err, 400)
await stream_json_exception(response, ex)
elif isinstance(ex, aiohttp.ClientError):
return make_json_exception(ex, 400)
raise
def __get_remove_incomplete(self, request: Request) -> (bool | None):
flag: (str | None) = request.query.get("remove_incomplete")
def __get_remove_incomplete(self, req: Request) -> (bool | None):
flag: (str | None) = req.query.get("remove_incomplete")
return (valid_bool(flag) if flag is not None else None)
def __make_write_info(self, name: str, size: int, written: int) -> dict:
@@ -215,8 +219,8 @@ class MsdApi:
# =====
@exposed_http("POST", "/msd/remove")
async def __remove_handler(self, request: Request) -> Response:
await self.__msd.remove(valid_msd_image_name(request.query.get("image")))
async def __remove_handler(self, req: Request) -> Response:
await self.__msd.remove(valid_msd_image_name(req.query.get("image")))
return make_json_response()
@exposed_http("POST", "/msd/reset")

View File

@@ -88,12 +88,12 @@ class RedfishApi:
@exposed_http("GET", "/redfish/v1/Systems/0")
async def __server_handler(self, _: Request) -> Response:
(atx_state, meta_state) = await asyncio.gather(*[
(atx_state, info_state) = await asyncio.gather(*[
self.__atx.get_state(),
self.__info_manager.get_submanager("meta").get_state(),
self.__info_manager.get_state(["meta"]),
])
try:
host = str(meta_state.get("server", {})["host"]) # type: ignore
host = str(info_state["meta"].get("server", {})["host"]) # type: ignore
except Exception:
host = ""
return make_json_response({
@@ -111,10 +111,10 @@ class RedfishApi:
}, wrap_result=False)
@exposed_http("POST", "/redfish/v1/Systems/0/Actions/ComputerSystem.Reset")
async def __power_handler(self, request: Request) -> Response:
async def __power_handler(self, req: Request) -> Response:
try:
action = check_string_in_list(
arg=(await request.json())["ResetType"],
arg=(await req.json()).get("ResetType"),
name="Redfish ResetType",
variants=set(self.__actions),
lower=False,

View File

@@ -52,36 +52,36 @@ class StreamerApi:
return make_json_response(await self.__streamer.get_state())
@exposed_http("GET", "/streamer/snapshot")
async def __take_snapshot_handler(self, request: Request) -> Response:
async def __take_snapshot_handler(self, req: Request) -> Response:
snapshot = await self.__streamer.take_snapshot(
save=valid_bool(request.query.get("save", False)),
load=valid_bool(request.query.get("load", False)),
allow_offline=valid_bool(request.query.get("allow_offline", False)),
save=valid_bool(req.query.get("save", False)),
load=valid_bool(req.query.get("load", False)),
allow_offline=valid_bool(req.query.get("allow_offline", False)),
)
if snapshot:
if valid_bool(request.query.get("ocr", False)):
if valid_bool(req.query.get("ocr", False)):
langs = self.__ocr.get_available_langs()
return Response(
body=(await self.__ocr.recognize(
data=snapshot.data,
langs=valid_string_list(
arg=str(request.query.get("ocr_langs", "")).strip(),
arg=str(req.query.get("ocr_langs", "")).strip(),
subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)),
name="OCR langs list",
),
left=int(valid_number(request.query.get("ocr_left", -1))),
top=int(valid_number(request.query.get("ocr_top", -1))),
right=int(valid_number(request.query.get("ocr_right", -1))),
bottom=int(valid_number(request.query.get("ocr_bottom", -1))),
left=int(valid_number(req.query.get("ocr_left", -1))),
top=int(valid_number(req.query.get("ocr_top", -1))),
right=int(valid_number(req.query.get("ocr_right", -1))),
bottom=int(valid_number(req.query.get("ocr_bottom", -1))),
)),
headers=dict(snapshot.headers),
content_type="text/plain",
)
elif valid_bool(request.query.get("preview", False)):
elif valid_bool(req.query.get("preview", False)):
data = await snapshot.make_preview(
max_width=valid_int_f0(request.query.get("preview_max_width", 0)),
max_height=valid_int_f0(request.query.get("preview_max_height", 0)),
quality=valid_stream_quality(request.query.get("preview_quality", 80)),
max_width=valid_int_f0(req.query.get("preview_max_width", 0)),
max_height=valid_int_f0(req.query.get("preview_max_height", 0)),
quality=valid_stream_quality(req.query.get("preview_quality", 80)),
)
else:
data = snapshot.data
@@ -97,25 +97,6 @@ class StreamerApi:
self.__streamer.remove_snapshot()
return make_json_response()
# =====
async def get_ocr(self) -> dict: # XXX: Ugly hack
enabled = self.__ocr.is_available()
default: list[str] = []
available: list[str] = []
if enabled:
default = self.__ocr.get_default_langs()
available = self.__ocr.get_available_langs()
return {
"ocr": {
"enabled": enabled,
"langs": {
"default": default,
"available": available,
},
},
}
@exposed_http("GET", "/streamer/ocr")
async def __ocr_handler(self, _: Request) -> Response:
return make_json_response(await self.get_ocr())
return make_json_response({"ocr": (await self.__ocr.get_state())})

View File

@@ -42,23 +42,20 @@ class UserGpioApi:
@exposed_http("GET", "/gpio")
async def __state_handler(self, _: Request) -> Response:
return make_json_response({
"model": (await self.__user_gpio.get_model()),
"state": (await self.__user_gpio.get_state()),
})
return make_json_response(await self.__user_gpio.get_state())
@exposed_http("POST", "/gpio/switch")
async def __switch_handler(self, request: Request) -> Response:
channel = valid_ugpio_channel(request.query.get("channel"))
state = valid_bool(request.query.get("state"))
wait = valid_bool(request.query.get("wait", False))
async def __switch_handler(self, req: Request) -> Response:
channel = valid_ugpio_channel(req.query.get("channel"))
state = valid_bool(req.query.get("state"))
wait = valid_bool(req.query.get("wait", False))
await self.__user_gpio.switch(channel, state, wait)
return make_json_response()
@exposed_http("POST", "/gpio/pulse")
async def __pulse_handler(self, request: Request) -> Response:
channel = valid_ugpio_channel(request.query.get("channel"))
delay = valid_float_f0(request.query.get("delay", 0.0))
wait = valid_bool(request.query.get("wait", False))
async def __pulse_handler(self, req: Request) -> Response:
channel = valid_ugpio_channel(req.query.get("channel"))
delay = valid_float_f0(req.query.get("delay", 0.0))
wait = valid_bool(req.query.get("wait", False))
await self.__user_gpio.pulse(channel, delay, wait)
return make_json_response()

View File

@@ -20,6 +20,10 @@
# ========================================================================== #
import asyncio
from typing import AsyncGenerator
from ....yamlconf import Section
from .base import BaseInfoSubmanager
@@ -34,17 +38,59 @@ from .fan import FanInfoSubmanager
# =====
class InfoManager:
def __init__(self, config: Section) -> None:
self.__subs = {
self.__subs: dict[str, BaseInfoSubmanager] = {
"system": SystemInfoSubmanager(config.kvmd.streamer.cmd),
"auth": AuthInfoSubmanager(config.kvmd.auth.enabled),
"meta": MetaInfoSubmanager(config.kvmd.info.meta),
"auth": AuthInfoSubmanager(config.kvmd.auth.enabled),
"meta": MetaInfoSubmanager(config.kvmd.info.meta),
"extras": ExtrasInfoSubmanager(config),
"hw": HwInfoSubmanager(**config.kvmd.info.hw._unpack()),
"fan": FanInfoSubmanager(**config.kvmd.info.fan._unpack()),
"hw": HwInfoSubmanager(**config.kvmd.info.hw._unpack()),
"fan": FanInfoSubmanager(**config.kvmd.info.fan._unpack()),
}
self.__queue: "asyncio.Queue[tuple[str, (dict | None)]]" = asyncio.Queue()
def get_subs(self) -> set[str]:
return set(self.__subs)
def get_submanager(self, name: str) -> BaseInfoSubmanager:
return self.__subs[name]
async def get_state(self, fields: (list[str] | None)=None) -> dict:
fields = (fields or list(self.__subs))
return dict(zip(fields, await asyncio.gather(*[
self.__subs[field].get_state()
for field in fields
])))
async def trigger_state(self) -> None:
await asyncio.gather(*[
sub.trigger_state()
for sub in self.__subs.values()
])
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - system -- Partial
# - auth -- Partial
# - meta -- Partial, nullable
# - extras -- Partial, nullable
# - hw -- Partial
# - fan -- Partial
# ===========================
while True:
(field, value) = await self.__queue.get()
yield {field: value}
async def systask(self) -> None:
tasks = [
asyncio.create_task(self.__poller(field))
for field in self.__subs
]
try:
await asyncio.gather(*tasks)
except Exception:
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
async def __poller(self, field: str) -> None:
async for state in self.__subs[field].poll_state():
self.__queue.put_nowait((field, state))

View File

@@ -20,6 +20,10 @@
# ========================================================================== #
from typing import AsyncGenerator
from .... import aiotools
from .base import BaseInfoSubmanager
@@ -27,6 +31,15 @@ from .base import BaseInfoSubmanager
class AuthInfoSubmanager(BaseInfoSubmanager):
def __init__(self, enabled: bool) -> None:
self.__enabled = enabled
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict:
return {"enabled": self.__enabled}
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())

View File

@@ -20,7 +20,17 @@
# ========================================================================== #
from typing import AsyncGenerator
# =====
class BaseInfoSubmanager:
async def get_state(self) -> (dict | None):
raise NotImplementedError
async def trigger_state(self) -> None:
raise NotImplementedError
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
yield None
raise NotImplementedError

View File

@@ -24,6 +24,8 @@ import os
import re
import asyncio
from typing import AsyncGenerator
from ....logging import get_logger
from ....yamlconf import Section
@@ -42,13 +44,15 @@ from .base import BaseInfoSubmanager
class ExtrasInfoSubmanager(BaseInfoSubmanager):
def __init__(self, global_config: Section) -> None:
self.__global_config = global_config
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> (dict | None):
try:
sui = sysunit.SystemdUnitInfo()
await sui.open()
except Exception as err:
get_logger(0).error("Can't open systemd bus to get extras state: %s", tools.efmt(err))
except Exception as ex:
if not os.path.exists("/etc/kvmd/.docker_flag"):
get_logger(0).error("Can't open systemd bus to get extras state: %s", tools.efmt(ex))
sui = None
try:
extras: dict[str, dict] = {}
@@ -66,6 +70,14 @@ class ExtrasInfoSubmanager(BaseInfoSubmanager):
if sui is not None:
await aiotools.shield_fg(sui.close())
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())
def __get_extras_path(self, *parts: str) -> str:
return os.path.join(self.__global_config.kvmd.info.extras, *parts)

View File

@@ -21,7 +21,6 @@
import copy
import asyncio
from typing import AsyncGenerator
@@ -53,6 +52,8 @@ class FanInfoSubmanager(BaseInfoSubmanager):
self.__timeout = timeout
self.__state_poll = state_poll
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict:
monitored = await self.__get_monitored()
return {
@@ -60,24 +61,28 @@ class FanInfoSubmanager(BaseInfoSubmanager):
"state": ((await self.__get_fan_state() if monitored else None)),
}
async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {}
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
prev: dict = {}
while True:
if self.__unix_path:
pure = state = await self.get_state()
if (await self.__notifier.wait(timeout=self.__state_poll)) > 0:
prev = {}
new = await self.get_state()
pure = copy.deepcopy(new)
if pure["state"] is not None:
try:
pure = copy.deepcopy(state)
pure["state"]["service"]["now_ts"] = 0
except Exception:
pass
if pure != prev_state:
yield state
prev_state = pure
await asyncio.sleep(self.__state_poll)
if pure != prev:
prev = pure
yield new
else:
await self.__notifier.wait()
yield (await self.get_state())
await aiotools.wait_infinite()
# =====
@@ -87,8 +92,8 @@ class FanInfoSubmanager(BaseInfoSubmanager):
async with sysunit.SystemdUnitInfo() as sui:
status = await sui.get_status(self.__daemon)
return (status[0] or status[1])
except Exception as err:
get_logger(0).error("Can't get info about the service %r: %s", self.__daemon, tools.efmt(err))
except Exception as ex:
get_logger(0).error("Can't get info about the service %r: %s", self.__daemon, tools.efmt(ex))
return False
async def __get_fan_state(self) -> (dict | None):
@@ -97,8 +102,8 @@ class FanInfoSubmanager(BaseInfoSubmanager):
async with session.get("http://localhost/state") as response:
htclient.raise_not_200(response)
return (await response.json())["result"]
except Exception as err:
get_logger(0).error("Can't read fan state: %s", err)
except Exception as ex:
get_logger(0).error("Can't read fan state: %s", ex)
return None
def __make_http_session(self) -> aiohttp.ClientSession:

View File

@@ -22,6 +22,7 @@
import os
import asyncio
import copy
from typing import Callable
from typing import AsyncGenerator
@@ -60,6 +61,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
self.__dt_cache: dict[str, str] = {}
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict:
(
base,
@@ -70,8 +73,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
cpu_temp,
mem,
) = await asyncio.gather(
self.__read_dt_file("model"),
self.__read_dt_file("serial-number"),
self.__read_dt_file("model", upper=False),
self.__read_dt_file("serial-number", upper=True),
self.__read_platform_file(),
self.__get_throttling(),
self.__get_cpu_percent(),
@@ -97,18 +100,22 @@ class HwInfoSubmanager(BaseInfoSubmanager):
},
}
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {}
prev: dict = {}
while True:
state = await self.get_state()
if state != prev_state:
yield state
prev_state = state
await asyncio.sleep(self.__state_poll)
if (await self.__notifier.wait(timeout=self.__state_poll)) > 0:
prev = {}
new = await self.get_state()
if new != prev:
prev = copy.deepcopy(new)
yield new
# =====
async def __read_dt_file(self, name: str) -> (str | None):
async def __read_dt_file(self, name: str, upper: bool) -> (str | None):
if name not in self.__dt_cache:
path = os.path.join(f"{env.PROCFS_PREFIX}/proc/device-tree", name)
if not os.path.exists(path):
@@ -161,8 +168,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
+ system_all / total * 100
+ (st.steal + st.guest) / total * 100
)
except Exception as err:
get_logger(0).error("Can't get CPU percent: %s", err)
except Exception as ex:
get_logger(0).error("Can't get CPU percent: %s", ex)
return None
async def __get_mem(self) -> dict:
@@ -173,8 +180,8 @@ class HwInfoSubmanager(BaseInfoSubmanager):
"total": st.total,
"available": st.available,
}
except Exception as err:
get_logger(0).error("Can't get memory info: %s", err)
except Exception as ex:
get_logger(0).error("Can't get memory info: %s", ex)
return {
"percent": None,
"total": None,
@@ -217,6 +224,6 @@ class HwInfoSubmanager(BaseInfoSubmanager):
return None
try:
return parser(text)
except Exception as err:
get_logger(0).error("Can't parse [ %s ] output: %r: %s", tools.cmdfmt(cmd), text, tools.efmt(err))
except Exception as ex:
get_logger(0).error("Can't parse [ %s ] output: %r: %s", tools.cmdfmt(cmd), text, tools.efmt(ex))
return None

View File

@@ -20,6 +20,8 @@
# ========================================================================== #
from typing import AsyncGenerator
from ....logging import get_logger
from ....yamlconf.loader import load_yaml_file
@@ -33,6 +35,7 @@ from .base import BaseInfoSubmanager
class MetaInfoSubmanager(BaseInfoSubmanager):
def __init__(self, meta_path: str) -> None:
self.__meta_path = meta_path
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> (dict | None):
try:
@@ -40,3 +43,11 @@ class MetaInfoSubmanager(BaseInfoSubmanager):
except Exception:
get_logger(0).exception("Can't parse meta")
return None
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())

View File

@@ -24,8 +24,11 @@ import os
import asyncio
import platform
from typing import AsyncGenerator
from ....logging import get_logger
from .... import aiotools
from .... import aioproc
from .... import __version__
@@ -37,6 +40,7 @@ from .base import BaseInfoSubmanager
class SystemInfoSubmanager(BaseInfoSubmanager):
def __init__(self, streamer_cmd: list[str]) -> None:
self.__streamer_cmd = streamer_cmd
self.__notifier = aiotools.AioNotifier()
async def get_state(self) -> dict:
streamer_info = await self.__get_streamer_info()
@@ -50,6 +54,14 @@ class SystemInfoSubmanager(BaseInfoSubmanager):
},
}
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[(dict | None), None]:
while True:
await self.__notifier.wait()
yield (await self.get_state())
# =====
async def __get_streamer_info(self) -> dict:

View File

@@ -37,6 +37,7 @@ from ctypes import c_void_p
from ctypes import c_char
from typing import Generator
from typing import AsyncGenerator
from PIL import ImageOps
from PIL import Image as PilImage
@@ -76,8 +77,8 @@ def _load_libtesseract() -> (ctypes.CDLL | None):
setattr(func, "restype", restype)
setattr(func, "argtypes", argtypes)
return lib
except Exception as err:
warnings.warn(f"Can't load libtesseract: {err}", RuntimeWarning)
except Exception as ex:
warnings.warn(f"Can't load libtesseract: {ex}", RuntimeWarning)
return None
@@ -107,9 +108,37 @@ class Ocr:
def __init__(self, data_dir_path: str, default_langs: list[str]) -> None:
self.__data_dir_path = data_dir_path
self.__default_langs = default_langs
self.__notifier = aiotools.AioNotifier()
def is_available(self) -> bool:
return bool(_libtess)
async def get_state(self) -> dict:
enabled = bool(_libtess)
default: list[str] = []
available: list[str] = []
if enabled:
default = self.get_default_langs()
available = self.get_available_langs()
return {
"enabled": enabled,
"langs": {
"default": default,
"available": available,
},
}
async def trigger_state(self) -> None:
self.__notifier.notify()
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ===== Granularity table =====
# - enabled -- Full
# - langs -- Partial
# =============================
while True:
await self.__notifier.wait()
yield (await self.get_state())
# =====
def get_default_langs(self) -> list[str]:
return list(self.__default_langs)

View File

@@ -20,8 +20,6 @@
# ========================================================================== #
import asyncio
import operator
import dataclasses
from typing import Callable
@@ -33,6 +31,8 @@ from aiohttp.web import Request
from aiohttp.web import Response
from aiohttp.web import WebSocketResponse
from ... import __version__
from ...logging import get_logger
from ...errors import OperationError
@@ -98,54 +98,46 @@ class StreamerH264NotSupported(OperationError):
# =====
@dataclasses.dataclass(frozen=True)
class _SubsystemEventSource:
get_state: (Callable[[], Coroutine[Any, Any, dict]] | None) = None
poll_state: (Callable[[], AsyncGenerator[dict, None]] | None) = None
@dataclasses.dataclass
class _Subsystem:
name: str
sysprep: (Callable[[], None] | None)
systask: (Callable[[], Coroutine[Any, Any, None]] | None)
cleanup: (Callable[[], Coroutine[Any, Any, dict]] | None)
sources: dict[str, _SubsystemEventSource]
name: str
event_type: str
sysprep: (Callable[[], None] | None)
systask: (Callable[[], Coroutine[Any, Any, None]] | None)
cleanup: (Callable[[], Coroutine[Any, Any, dict]] | None)
trigger_state: (Callable[[], Coroutine[Any, Any, None]] | None) = None
poll_state: (Callable[[], AsyncGenerator[dict, None]] | None) = None
def __post_init__(self) -> None:
if self.event_type:
assert self.trigger_state
assert self.poll_state
@classmethod
def make(cls, obj: object, name: str, event_type: str="") -> "_Subsystem":
if isinstance(obj, BasePlugin):
name = f"{name} ({obj.get_plugin_name()})"
sub = _Subsystem(
return _Subsystem(
name=name,
event_type=event_type,
sysprep=getattr(obj, "sysprep", None),
systask=getattr(obj, "systask", None),
cleanup=getattr(obj, "cleanup", None),
sources={},
trigger_state=getattr(obj, "trigger_state", None),
poll_state=getattr(obj, "poll_state", None),
)
if event_type:
sub.add_source(
event_type=event_type,
get_state=getattr(obj, "get_state", None),
poll_state=getattr(obj, "poll_state", None),
)
return sub
def add_source(
self,
event_type: str,
get_state: (Callable[[], Coroutine[Any, Any, dict]] | None),
poll_state: (Callable[[], AsyncGenerator[dict, None]] | None),
) -> "_Subsystem":
assert event_type
assert event_type not in self.sources, (self, event_type)
assert get_state or poll_state, (self, event_type)
self.sources[event_type] = _SubsystemEventSource(get_state, poll_state)
return self
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
__EV_GPIO_STATE = "gpio_state"
__EV_HID_STATE = "hid_state"
__EV_ATX_STATE = "atx_state"
__EV_MSD_STATE = "msd_state"
__EV_STREAMER_STATE = "streamer_state"
__EV_OCR_STATE = "ocr_state"
__EV_INFO_STATE = "info_state"
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
auth_manager: AuthManager,
@@ -161,9 +153,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
snapshoter: Snapshoter,
keymap_path: str,
ignore_keys: list[str],
mouse_x_range: tuple[int, int],
mouse_y_range: tuple[int, int],
stream_forever: bool,
) -> None:
@@ -177,8 +166,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
self.__stream_forever = stream_forever
self.__hid_api = HidApi(hid, keymap_path, ignore_keys, mouse_x_range, mouse_y_range) # Ugly hack to get keymaps state
self.__streamer_api = StreamerApi(streamer, ocr) # Same hack to get ocr langs state
self.__hid_api = HidApi(hid, keymap_path) # Ugly hack to get keymaps state
self.__apis: list[object] = [
self,
AuthApi(auth_manager),
@@ -188,22 +176,19 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
self.__hid_api,
AtxApi(atx),
MsdApi(msd),
self.__streamer_api,
StreamerApi(streamer, ocr),
ExportApi(info_manager, atx, user_gpio),
RedfishApi(info_manager, atx),
]
self.__subsystems = [
_Subsystem.make(auth_manager, "Auth manager"),
_Subsystem.make(user_gpio, "User-GPIO", "gpio_state").add_source("gpio_model_state", user_gpio.get_model, None),
_Subsystem.make(hid, "HID", "hid_state").add_source("hid_keymaps_state", self.__hid_api.get_keymaps, None),
_Subsystem.make(atx, "ATX", "atx_state"),
_Subsystem.make(msd, "MSD", "msd_state"),
_Subsystem.make(streamer, "Streamer", "streamer_state").add_source("streamer_ocr_state", self.__streamer_api.get_ocr, None),
*[
_Subsystem.make(info_manager.get_submanager(sub), f"Info manager ({sub})", f"info_{sub}_state",)
for sub in sorted(info_manager.get_subs())
],
_Subsystem.make(user_gpio, "User-GPIO", self.__EV_GPIO_STATE),
_Subsystem.make(hid, "HID", self.__EV_HID_STATE),
_Subsystem.make(atx, "ATX", self.__EV_ATX_STATE),
_Subsystem.make(msd, "MSD", self.__EV_MSD_STATE),
_Subsystem.make(streamer, "Streamer", self.__EV_STREAMER_STATE),
_Subsystem.make(ocr, "OCR", self.__EV_OCR_STATE),
_Subsystem.make(info_manager, "Info manager", self.__EV_INFO_STATE),
]
self.__streamer_notifier = aiotools.AioNotifier()
@@ -213,16 +198,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
# ===== STREAMER CONTROLLER
@exposed_http("POST", "/streamer/set_params")
async def __streamer_set_params_handler(self, request: Request) -> Response:
async def __streamer_set_params_handler(self, req: Request) -> Response:
current_params = self.__streamer.get_params()
for (name, validator, exc_cls) in [
("quality", valid_stream_quality, StreamerQualityNotSupported),
("desired_fps", valid_stream_fps, None),
("resolution", valid_stream_resolution, StreamerResolutionNotSupported),
("quality", valid_stream_quality, StreamerQualityNotSupported),
("desired_fps", valid_stream_fps, None),
("resolution", valid_stream_resolution, StreamerResolutionNotSupported),
("h264_bitrate", valid_stream_h264_bitrate, StreamerH264NotSupported),
("h264_gop", valid_stream_h264_gop, StreamerH264NotSupported),
("h264_gop", valid_stream_h264_gop, StreamerH264NotSupported),
]:
value = request.query.get(name)
value = req.query.get(name)
if value:
if name not in current_params:
assert exc_cls is not None, name
@@ -242,24 +227,22 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
# ===== WEBSOCKET
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse:
stream = valid_bool(request.query.get("stream", True))
async with self._ws_session(request, stream=stream) as ws:
states = [
(event_type, src.get_state())
for sub in self.__subsystems
for (event_type, src) in sub.sources.items()
if src.get_state
]
events = dict(zip(
map(operator.itemgetter(0), states),
await asyncio.gather(*map(operator.itemgetter(1), states)),
))
await asyncio.gather(*[
ws.send_event(event_type, events.pop(event_type))
for (event_type, _) in states
])
await ws.send_event("loop", {})
async def __ws_handler(self, req: Request) -> WebSocketResponse:
stream = valid_bool(req.query.get("stream", True))
legacy = valid_bool(req.query.get("legacy", True))
async with self._ws_session(req, stream=stream, legacy=legacy) as ws:
(major, minor) = __version__.split(".")
await ws.send_event("loop", {
"version": {
"major": int(major),
"minor": int(minor),
},
})
for sub in self.__subsystems:
if sub.event_type:
assert sub.trigger_state
await sub.trigger_state()
await self._broadcast_ws_event("hid_keymaps_state", await self.__hid_api.get_keymaps()) # FIXME
return (await self._ws_loop(ws))
@exposed_ws("ping")
@@ -275,17 +258,17 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
aioproc.rename_process("main")
super().run(**kwargs)
async def _check_request_auth(self, exposed: HttpExposed, request: Request) -> None:
await check_request_auth(self.__auth_manager, exposed, request)
async def _check_request_auth(self, exposed: HttpExposed, req: Request) -> None:
await check_request_auth(self.__auth_manager, exposed, req)
async def _init_app(self) -> None:
aiotools.create_deadly_task("Stream controller", self.__stream_controller())
for sub in self.__subsystems:
if sub.systask:
aiotools.create_deadly_task(sub.name, sub.systask())
for (event_type, src) in sub.sources.items():
if src.poll_state:
aiotools.create_deadly_task(f"{sub.name} [poller]", self.__poll_state(event_type, src.poll_state()))
if sub.event_type:
assert sub.poll_state
aiotools.create_deadly_task(f"{sub.name} [poller]", self.__poll_state(sub.event_type, sub.poll_state()))
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
self._add_exposed(*self.__apis)
@@ -347,12 +330,67 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
prev = cur
await self.__streamer_notifier.wait()
async def __poll_state(self, event_type: str, poller: AsyncGenerator[dict, None]) -> None:
async for state in poller:
await self._broadcast_ws_event(event_type, state)
async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run(
is_live=self.__has_stream_clients,
notifier=self.__streamer_notifier,
)
async def __poll_state(self, event_type: str, poller: AsyncGenerator[dict, None]) -> None:
match event_type:
case self.__EV_GPIO_STATE:
await self.__poll_gpio_state(poller)
case self.__EV_INFO_STATE:
await self.__poll_info_state(poller)
case self.__EV_MSD_STATE:
await self.__poll_msd_state(poller)
case self.__EV_STREAMER_STATE:
await self.__poll_streamer_state(poller)
case self.__EV_OCR_STATE:
await self.__poll_ocr_state(poller)
case _:
async for state in poller:
await self._broadcast_ws_event(event_type, state)
async def __poll_gpio_state(self, poller: AsyncGenerator[dict, None]) -> None:
prev: dict = {"state": {"inputs": {}, "outputs": {}}}
async for state in poller:
await self._broadcast_ws_event(self.__EV_GPIO_STATE, state, legacy=False)
if "model" in state: # We have only "model"+"state" or "model" event
prev = state
await self._broadcast_ws_event("gpio_model_state", prev["model"], legacy=True)
else:
prev["state"]["inputs"].update(state["state"].get("inputs", {}))
prev["state"]["outputs"].update(state["state"].get("outputs", {}))
await self._broadcast_ws_event(self.__EV_GPIO_STATE, prev["state"], legacy=True)
async def __poll_info_state(self, poller: AsyncGenerator[dict, None]) -> None:
async for state in poller:
await self._broadcast_ws_event(self.__EV_INFO_STATE, state, legacy=False)
for (key, value) in state.items():
await self._broadcast_ws_event(f"info_{key}_state", value, legacy=True)
async def __poll_msd_state(self, poller: AsyncGenerator[dict, None]) -> None:
prev: dict = {"storage": None}
async for state in poller:
await self._broadcast_ws_event(self.__EV_MSD_STATE, state, legacy=False)
prev_storage = prev["storage"]
prev.update(state)
if prev["storage"] is not None and prev_storage is not None:
prev_storage.update(prev["storage"])
prev["storage"] = prev_storage
if "online" in prev: # Complete/Full
await self._broadcast_ws_event(self.__EV_MSD_STATE, prev, legacy=True)
async def __poll_streamer_state(self, poller: AsyncGenerator[dict, None]) -> None:
prev: dict = {}
async for state in poller:
await self._broadcast_ws_event(self.__EV_STREAMER_STATE, state, legacy=False)
prev.update(state)
if "features" in prev: # Complete/Full
await self._broadcast_ws_event(self.__EV_STREAMER_STATE, prev, legacy=True)
async def __poll_ocr_state(self, poller: AsyncGenerator[dict, None]) -> None:
async for state in poller:
await self._broadcast_ws_event(self.__EV_OCR_STATE, state, legacy=False)
await self._broadcast_ws_event("streamer_ocr_state", {"ocr": state}, legacy=True)

View File

@@ -20,22 +20,23 @@
# ========================================================================== #
import io
import signal
import asyncio
import asyncio.subprocess
import dataclasses
import functools
import copy
from typing import AsyncGenerator
from typing import Any
import aiohttp
from PIL import Image as PilImage
from ...logging import get_logger
from ...clients.streamer import StreamerSnapshot
from ...clients.streamer import HttpStreamerClient
from ...clients.streamer import HttpStreamerClientSession
from ... import tools
from ... import aiotools
from ... import aioproc
@@ -43,40 +44,6 @@ from ... import htclient
# =====
@dataclasses.dataclass(frozen=True)
class StreamerSnapshot:
online: bool
width: int
height: int
headers: tuple[tuple[str, str], ...]
data: bytes
async def make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
assert max_width >= 0
assert max_height >= 0
assert quality > 0
if max_width == 0 and max_height == 0:
max_width = self.width // 5
max_height = self.height // 5
else:
max_width = min((max_width or self.width), self.width)
max_height = min((max_height or self.height), self.height)
if (max_width, max_height) == (self.width, self.height):
return self.data
return (await aiotools.run_async(self.__inner_make_preview, max_width, max_height, quality))
@functools.lru_cache(maxsize=1)
def __inner_make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
with io.BytesIO(self.data) as snapshot_bio:
with io.BytesIO() as preview_bio:
with PilImage.open(snapshot_bio) as image:
image.thumbnail((max_width, max_height), PilImage.Resampling.LANCZOS)
image.save(preview_bio, format="jpeg", quality=quality)
return preview_bio.getvalue()
class _StreamerParams:
__DESIRED_FPS = "desired_fps"
@@ -136,7 +103,7 @@ class _StreamerParams:
}
def get_limits(self) -> dict:
limits = dict(self.__limits)
limits = copy.deepcopy(self.__limits)
if self.__has_resolution:
limits[self.__AVAILABLE_RESOLUTIONS] = list(limits[self.__AVAILABLE_RESOLUTIONS])
return limits
@@ -170,6 +137,11 @@ class _StreamerParams:
class Streamer: # pylint: disable=too-many-instance-attributes
__ST_FULL = 0xFF
__ST_PARAMS = 0x01
__ST_STREAMER = 0x02
__ST_SNAPSHOT = 0x04
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
@@ -203,7 +175,6 @@ class Streamer: # pylint: disable=too-many-instance-attributes
self.__state_poll = state_poll
self.__unix_path = unix_path
self.__timeout = timeout
self.__snapshot_timeout = snapshot_timeout
self.__process_name_prefix = process_name_prefix
@@ -220,7 +191,13 @@ class Streamer: # pylint: disable=too-many-instance-attributes
self.__streamer_task: (asyncio.Task | None) = None
self.__streamer_proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member
self.__http_session: (aiohttp.ClientSession | None) = None
self.__client = HttpStreamerClient(
name="jpeg",
unix_path=self.__unix_path,
timeout=timeout,
user_agent=htclient.make_user_agent("KVMD"),
)
self.__client_session: (HttpStreamerClientSession | None) = None
self.__snapshot: (StreamerSnapshot | None) = None
@@ -289,6 +266,7 @@ class Streamer: # pylint: disable=too-many-instance-attributes
def set_params(self, params: dict) -> None:
assert not self.__streamer_task
self.__notifier.notify(self.__ST_PARAMS)
return self.__params.set_params(params)
def get_params(self) -> dict:
@@ -297,55 +275,80 @@ class Streamer: # pylint: disable=too-many-instance-attributes
# =====
async def get_state(self) -> dict:
streamer_state = None
if self.__streamer_task:
session = self.__ensure_http_session()
try:
async with session.get(self.__make_url("state")) as response:
htclient.raise_not_200(response)
streamer_state = (await response.json())["result"]
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
pass
except Exception:
get_logger().exception("Invalid streamer response from /state")
snapshot: (dict | None) = None
if self.__snapshot:
snapshot = dataclasses.asdict(self.__snapshot)
del snapshot["headers"]
del snapshot["data"]
return {
"features": self.__params.get_features(),
"limits": self.__params.get_limits(),
"params": self.__params.get_params(),
"snapshot": {"saved": snapshot},
"streamer": streamer_state,
"features": self.__params.get_features(),
"streamer": (await self.__get_streamer_state()),
"snapshot": self.__get_snapshot_state(),
}
async def trigger_state(self) -> None:
self.__notifier.notify(self.__ST_FULL)
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - features -- Full
# - limits -- Partial, paired with params
# - params -- Partial, paired with limits
# - streamer -- Partial, nullable
# - snapshot -- Partial
# ===========================
def signal_handler(*_: Any) -> None:
get_logger(0).info("Got SIGUSR2, checking the stream state ...")
self.__notifier.notify()
self.__notifier.notify(self.__ST_STREAMER)
get_logger(0).info("Installing SIGUSR2 streamer handler ...")
asyncio.get_event_loop().add_signal_handler(signal.SIGUSR2, signal_handler)
waiter_task: (asyncio.Task | None) = None
prev_state: dict = {}
prev: dict = {}
while True:
state = await self.get_state()
if state != prev_state:
yield state
prev_state = state
new: dict = {}
if waiter_task is None:
waiter_task = asyncio.create_task(self.__notifier.wait())
if waiter_task in (await aiotools.wait_first(
asyncio.ensure_future(asyncio.sleep(self.__state_poll)),
waiter_task,
))[0]:
waiter_task = None
mask = await self.__notifier.wait(timeout=self.__state_poll)
if mask == self.__ST_FULL:
new = await self.get_state()
prev = copy.deepcopy(new)
yield new
continue
if mask < 0:
mask = self.__ST_STREAMER
def check_update(key: str, value: (dict | None)) -> None:
if prev.get(key) != value:
new[key] = value
if mask & self.__ST_PARAMS:
check_update("params", self.__params.get_params())
if mask & self.__ST_STREAMER:
check_update("streamer", await self.__get_streamer_state())
if mask & self.__ST_SNAPSHOT:
check_update("snapshot", self.__get_snapshot_state())
if new and prev != new:
prev.update(copy.deepcopy(new))
yield new
async def __get_streamer_state(self) -> (dict | None):
if self.__streamer_task:
session = self.__ensure_client_session()
try:
return (await session.get_state())
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
pass
except Exception:
get_logger().exception("Invalid streamer response from /state")
return None
def __get_snapshot_state(self) -> dict:
if self.__snapshot:
snapshot = dataclasses.asdict(self.__snapshot)
del snapshot["headers"]
del snapshot["data"]
return {"saved": snapshot}
return {"saved": None}
# =====
@@ -353,41 +356,17 @@ class Streamer: # pylint: disable=too-many-instance-attributes
if load:
return self.__snapshot
logger = get_logger()
session = self.__ensure_http_session()
session = self.__ensure_client_session()
try:
async with session.get(
self.__make_url("snapshot"),
timeout=self.__snapshot_timeout,
) as response:
htclient.raise_not_200(response)
online = (response.headers["X-UStreamer-Online"] == "true")
if online or allow_offline:
snapshot = StreamerSnapshot(
online=online,
width=int(response.headers["X-UStreamer-Width"]),
height=int(response.headers["X-UStreamer-Height"]),
headers=tuple(
(key, value)
for (key, value) in tools.sorted_kvs(dict(response.headers))
if key.lower().startswith("x-ustreamer-") or key.lower() in [
"x-timestamp",
"access-control-allow-origin",
"cache-control",
"pragma",
"expires",
]
),
data=bytes(await response.read()),
)
if save:
self.__snapshot = snapshot
self.__notifier.notify()
return snapshot
logger.error("Stream is offline, no signal or so")
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as err:
logger.error("Can't connect to streamer: %s", tools.efmt(err))
snapshot = await session.take_snapshot(self.__snapshot_timeout)
if snapshot.online or allow_offline:
if save:
self.__snapshot = snapshot
self.__notifier.notify(self.__ST_SNAPSHOT)
return snapshot
logger.error("Stream is offline, no signal or so")
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as ex:
logger.error("Can't connect to streamer: %s", tools.efmt(ex))
except Exception:
logger.exception("Invalid streamer response from /snapshot")
return None
@@ -400,25 +379,14 @@ class Streamer: # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg
async def cleanup(self) -> None:
await self.ensure_stop(immediately=True)
if self.__http_session:
await self.__http_session.close()
self.__http_session = None
if self.__client_session:
await self.__client_session.close()
self.__client_session = None
# =====
def __ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
kwargs: dict = {
"headers": {"User-Agent": htclient.make_user_agent("KVMD")},
"connector": aiohttp.UnixConnector(path=self.__unix_path),
"timeout": aiohttp.ClientTimeout(total=self.__timeout),
}
self.__http_session = aiohttp.ClientSession(**kwargs)
return self.__http_session
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"
def __ensure_client_session(self) -> HttpStreamerClientSession:
if not self.__client_session:
self.__client_session = self.__client.make_session()
return self.__client_session
# =====
@@ -473,8 +441,8 @@ class Streamer: # pylint: disable=too-many-instance-attributes
logger.info("%s: %s", name, tools.cmdfmt(cmd))
try:
await aioproc.log_process(cmd, logger, prefix=name)
except Exception as err:
logger.exception("Can't execute command: %s", err)
except Exception as ex:
logger.exception("Can't execute command: %s", ex)
async def __start_streamer_proc(self) -> None:
assert self.__streamer_proc is None

View File

@@ -35,6 +35,7 @@ class SystemdUnitInfo:
self.__bus: (dbus_next.aio.MessageBus | None) = None
self.__intr: (dbus_next.introspection.Node | None) = None
self.__manager: (dbus_next.aio.proxy_object.ProxyInterface | None) = None
self.__requested = False
async def get_status(self, name: str) -> tuple[bool, bool]:
assert self.__bus is not None
@@ -49,8 +50,9 @@ class SystemdUnitInfo:
unit = self.__bus.get_proxy_object("org.freedesktop.systemd1", unit_p, self.__intr)
unit_props = unit.get_interface("org.freedesktop.DBus.Properties")
started = ((await unit_props.call_get("org.freedesktop.systemd1.Unit", "ActiveState")).value == "active") # type: ignore
except dbus_next.errors.DBusError as err:
if err.type != "org.freedesktop.systemd1.NoSuchUnit":
self.__requested = True
except dbus_next.errors.DBusError as ex:
if ex.type != "org.freedesktop.systemd1.NoSuchUnit":
raise
started = False
enabled = ((await self.__manager.call_get_unit_file_state(name)) in [ # type: ignore
@@ -75,8 +77,13 @@ class SystemdUnitInfo:
async def close(self) -> None:
try:
if self.__bus is not None:
self.__bus.disconnect()
await self.__bus.wait_for_disconnect()
try:
# XXX: Workaround for dbus_next bug: https://github.com/pikvm/kvmd/pull/182
if not self.__requested:
await self.__manager.call_get_default_target() # type: ignore
finally:
self.__bus.disconnect()
await self.__bus.wait_for_disconnect()
except Exception:
pass
self.__manager = None

View File

@@ -21,6 +21,7 @@
import asyncio
import copy
from typing import AsyncGenerator
from typing import Callable
@@ -68,12 +69,12 @@ class GpioChannelIsBusyError(IsBusyError, GpioError):
class _GpioInput:
def __init__(
self,
channel: str,
ch: str,
config: Section,
driver: BaseUserGpioDriver,
) -> None:
self.__channel = channel
self.__ch = ch
self.__pin: str = str(config.pin)
self.__inverted: bool = config.inverted
@@ -100,7 +101,7 @@ class _GpioInput:
}
def __str__(self) -> str:
return f"Input({self.__channel}, driver={self.__driver}, pin={self.__pin})"
return f"Input({self.__ch}, driver={self.__driver}, pin={self.__pin})"
__repr__ = __str__
@@ -108,13 +109,13 @@ class _GpioInput:
class _GpioOutput: # pylint: disable=too-many-instance-attributes
def __init__(
self,
channel: str,
ch: str,
config: Section,
driver: BaseUserGpioDriver,
notifier: aiotools.AioNotifier,
) -> None:
self.__channel = channel
self.__ch = ch
self.__pin: str = str(config.pin)
self.__inverted: bool = config.inverted
@@ -184,7 +185,7 @@ class _GpioOutput: # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg
async def __run_action(self, wait: bool, name: str, func: Callable, *args: Any) -> None:
if wait:
async with self.__region:
with self.__region:
await func(*args)
else:
await aiotools.run_region_task(
@@ -224,7 +225,7 @@ class _GpioOutput: # pylint: disable=too-many-instance-attributes
await self.__driver.write(self.__pin, (state ^ self.__inverted))
def __str__(self) -> str:
return f"Output({self.__channel}, driver={self.__driver}, pin={self.__pin})"
return f"Output({self.__ch}, driver={self.__driver}, pin={self.__pin})"
__repr__ = __str__
@@ -232,8 +233,6 @@ class _GpioOutput: # pylint: disable=too-many-instance-attributes
# =====
class UserGpio:
def __init__(self, config: Section, otg_config: Section) -> None:
self.__view = config.view
self.__notifier = aiotools.AioNotifier()
self.__drivers = {
@@ -249,45 +248,67 @@ class UserGpio:
self.__inputs: dict[str, _GpioInput] = {}
self.__outputs: dict[str, _GpioOutput] = {}
for (channel, ch_config) in tools.sorted_kvs(config.scheme):
for (ch, ch_config) in tools.sorted_kvs(config.scheme):
driver = self.__drivers[ch_config.driver]
if ch_config.mode == UserGpioModes.INPUT:
self.__inputs[channel] = _GpioInput(channel, ch_config, driver)
self.__inputs[ch] = _GpioInput(ch, ch_config, driver)
else: # output:
self.__outputs[channel] = _GpioOutput(channel, ch_config, driver, self.__notifier)
self.__outputs[ch] = _GpioOutput(ch, ch_config, driver, self.__notifier)
async def get_model(self) -> dict:
return {
"scheme": {
"inputs": {channel: gin.get_scheme() for (channel, gin) in self.__inputs.items()},
"outputs": {
channel: gout.get_scheme()
for (channel, gout) in self.__outputs.items()
if not gout.is_const()
},
},
"view": self.__make_view(),
}
self.__scheme = self.__make_scheme()
self.__view = self.__make_view(config.view)
async def get_state(self) -> dict:
return {
"inputs": {channel: await gin.get_state() for (channel, gin) in self.__inputs.items()},
"model": {
"scheme": copy.deepcopy(self.__scheme),
"view": copy.deepcopy(self.__view),
},
"state": (await self.__get_io_state()),
}
async def trigger_state(self) -> None:
self.__notifier.notify(1)
async def poll_state(self) -> AsyncGenerator[dict, None]:
# ==== Granularity table ====
# - model -- Full
# - state.inputs -- Partial
# - state.outputs -- Partial
# ===========================
prev: dict = {"inputs": {}, "outputs": {}}
while True: # pylint: disable=too-many-nested-blocks
if (await self.__notifier.wait()) > 0:
full = await self.get_state()
prev = copy.deepcopy(full["state"])
yield full
else:
new = await self.__get_io_state()
diff: dict = {}
for sub in ["inputs", "outputs"]:
for ch in new[sub]:
if new[sub][ch] != prev[sub].get(ch):
if sub not in diff:
diff[sub] = {}
diff[sub][ch] = new[sub][ch]
if diff:
prev = copy.deepcopy(new)
yield {"state": diff}
async def __get_io_state(self) -> dict:
return {
"inputs": {
ch: (await gin.get_state())
for (ch, gin) in self.__inputs.items()
},
"outputs": {
channel: await gout.get_state()
for (channel, gout) in self.__outputs.items()
ch: (await gout.get_state())
for (ch, gout) in self.__outputs.items()
if not gout.is_const()
},
}
async def poll_state(self) -> AsyncGenerator[dict, None]:
prev_state: dict = {}
while True:
state = await self.get_state()
if state != prev_state:
yield state
prev_state = state
await self.__notifier.wait()
def sysprep(self) -> None:
get_logger(0).info("Preparing User-GPIO drivers ...")
for (_, driver) in tools.sorted_kvs(self.__drivers):
@@ -307,28 +328,43 @@ class UserGpio:
except Exception:
get_logger().exception("Can't cleanup driver %s", driver)
async def switch(self, channel: str, state: bool, wait: bool) -> None:
gout = self.__outputs.get(channel)
async def switch(self, ch: str, state: bool, wait: bool) -> None:
gout = self.__outputs.get(ch)
if gout is None:
raise GpioChannelNotFoundError()
await gout.switch(state, wait)
async def pulse(self, channel: str, delay: float, wait: bool) -> None:
gout = self.__outputs.get(channel)
async def pulse(self, ch: str, delay: float, wait: bool) -> None:
gout = self.__outputs.get(ch)
if gout is None:
raise GpioChannelNotFoundError()
await gout.pulse(delay, wait)
# =====
def __make_view(self) -> dict:
def __make_scheme(self) -> dict:
return {
"header": {"title": self.__make_view_title()},
"table": self.__make_view_table(),
"inputs": {
ch: gin.get_scheme()
for (ch, gin) in self.__inputs.items()
},
"outputs": {
ch: gout.get_scheme()
for (ch, gout) in self.__outputs.items()
if not gout.is_const()
},
}
def __make_view_title(self) -> list[dict]:
raw_title = self.__view["header"]["title"]
# =====
def __make_view(self, view: dict) -> dict:
return {
"header": {"title": self.__make_view_title(view)},
"table": self.__make_view_table(view),
}
def __make_view_title(self, view: dict) -> list[dict]:
raw_title = view["header"]["title"]
title: list[dict] = []
if isinstance(raw_title, list):
for item in raw_title:
@@ -342,9 +378,9 @@ class UserGpio:
title.append(self.__make_item_label(f"#{raw_title}"))
return title
def __make_view_table(self) -> list[list[dict] | None]:
def __make_view_table(self, view: dict) -> list[list[dict] | None]:
table: list[list[dict] | None] = []
for row in self.__view["table"]:
for row in view["table"]:
if len(row) == 0:
table.append(None)
continue

184
kvmd/apps/oled/__init__.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
# ========================================================================== #
# #
# KVMD-OLED - A small OLED daemon for PiKVM. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import sys
import os
import signal
import itertools
import logging
import time
import usb.core
from luma.core import cmdline as luma_cmdline
from PIL import ImageFont
from .screen import Screen
from .sensors import Sensors
# =====
_logger = logging.getLogger("oled")
# =====
def _detect_geometry() -> dict:
with open("/proc/device-tree/model") as file:
is_cm4 = ("Compute Module 4" in file.read())
has_usb = bool(list(usb.core.find(find_all=True)))
if is_cm4 and has_usb:
return {"height": 64, "rotate": 2}
return {"height": 32, "rotate": 0}
def _get_data_path(subdir: str, name: str) -> str:
if not name.startswith("@"):
return name # Just a regular system path
name = name[1:]
module_path = sys.modules[__name__].__file__
assert module_path is not None
return os.path.join(os.path.dirname(module_path), subdir, name)
# =====
def main() -> None: # pylint: disable=too-many-locals,too-many-branches,too-many-statements
logging.basicConfig(level=logging.INFO, format="%(message)s")
logging.getLogger("PIL").setLevel(logging.ERROR)
parser = luma_cmdline.create_parser(description="Display FQDN and IP on the OLED")
parser.set_defaults(**_detect_geometry())
parser.add_argument("--font", default="@ProggySquare.ttf", type=(lambda arg: _get_data_path("fonts", arg)), help="Font path")
parser.add_argument("--font-size", default=16, type=int, help="Font size")
parser.add_argument("--font-spacing", default=2, type=int, help="Font line spacing")
parser.add_argument("--offset-x", default=0, type=int, help="Horizontal offset")
parser.add_argument("--offset-y", default=0, type=int, help="Vertical offset")
parser.add_argument("--interval", default=5, type=int, help="Screens interval")
parser.add_argument("--image", default="", type=(lambda arg: _get_data_path("pics", arg)), help="Display some image, wait a single interval and exit")
parser.add_argument("--text", default="", help="Display some text, wait a single interval and exit")
parser.add_argument("--pipe", action="store_true", help="Read and display lines from stdin until EOF, wait a single interval and exit")
parser.add_argument("--clear-on-exit", action="store_true", help="Clear display on exit")
parser.add_argument("--contrast", default=64, type=int, help="Set OLED contrast, values from 0 to 255")
parser.add_argument("--fahrenheit", action="store_true", help="Display temperature in Fahrenheit instead of Celsius")
options = parser.parse_args(sys.argv[1:])
if options.config:
config = luma_cmdline.load_config(options.config)
options = parser.parse_args(config + sys.argv[1:])
device = luma_cmdline.create_device(options)
device.cleanup = (lambda _: None)
screen = Screen(
device=device,
font=ImageFont.truetype(options.font, options.font_size),
font_spacing=options.font_spacing,
offset=(options.offset_x, options.offset_y),
)
if options.display not in luma_cmdline.get_display_types()["emulator"]:
_logger.info("Iface: %s", options.interface)
_logger.info("Display: %s", options.display)
_logger.info("Size: %dx%d", device.width, device.height)
options.contrast = min(max(options.contrast, 0), 255)
_logger.info("Contrast: %d", options.contrast)
device.contrast(options.contrast)
try:
if options.image:
screen.draw_image(options.image)
time.sleep(options.interval)
elif options.text:
screen.draw_text(options.text.replace("\\n", "\n"))
time.sleep(options.interval)
elif options.pipe:
text = ""
for line in sys.stdin:
text += line
if "\0" in text:
screen.draw_text(text.replace("\0", ""))
text = ""
time.sleep(options.interval)
else:
stop_reason: (str | None) = None
def sigusr_handler(signum: int, _) -> None: # type: ignore
nonlocal stop_reason
if signum in (signal.SIGINT, signal.SIGTERM):
stop_reason = ""
elif signum == signal.SIGUSR1:
stop_reason = "Rebooting...\nPlease wait"
elif signum == signal.SIGUSR2:
stop_reason = "Halted"
for signum in [signal.SIGTERM, signal.SIGINT, signal.SIGUSR1, signal.SIGUSR2]:
signal.signal(signum, sigusr_handler)
hb = itertools.cycle(r"/-\|") # Heartbeat
swim = 0
def draw(text: str) -> None:
nonlocal swim
count = 0
while (count < max(options.interval, 1) * 2) and stop_reason is None:
screen.draw_text(
text=text.replace("__hb__", next(hb)),
offset_x=(3 if swim < 0 else 0),
)
count += 1
if swim >= 1200:
swim = -1200
else:
swim += 1
time.sleep(0.5)
sensors = Sensors(options.fahrenheit)
if device.height >= 64:
while stop_reason is None:
text = "{fqdn}\n{ip}\niface: {iface}\ntemp: {temp}\ncpu: {cpu} mem: {mem}\n(__hb__) {uptime}"
draw(sensors.render(text))
else:
summary = True
while stop_reason is None:
if summary:
text = "{fqdn}\n(__hb__) {uptime}\ntemp: {temp}"
else:
text = "{ip}\n(__hb__) iface: {iface}\ncpu: {cpu} mem: {mem}"
draw(sensors.render(text))
summary = (not summary)
if stop_reason is not None:
if len(stop_reason) > 0:
options.clear_on_exit = False
screen.draw_text(stop_reason)
while len(stop_reason) > 0:
time.sleep(0.1)
except (SystemExit, KeyboardInterrupt):
pass
if options.clear_on_exit:
screen.draw_text("")

View File

@@ -0,0 +1,24 @@
# ========================================================================== #
# #
# KVMD - The main PiKVM daemon. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
from . import main
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

54
kvmd/apps/oled/screen.py Normal file
View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# ========================================================================== #
# #
# KVMD-OLED - A small OLED daemon for PiKVM. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
from luma.core.device import device as luma_device
from luma.core.render import canvas as luma_canvas
from PIL import Image
from PIL import ImageFont
# =====
class Screen:
def __init__(
self,
device: luma_device,
font: ImageFont.FreeTypeFont,
font_spacing: int,
offset: tuple[int, int],
) -> None:
self.__device = device
self.__font = font
self.__font_spacing = font_spacing
self.__offset = offset
def draw_text(self, text: str, offset_x: int=0) -> None:
with luma_canvas(self.__device) as draw:
offset = list(self.__offset)
offset[0] += offset_x
draw.multiline_text(offset, text, font=self.__font, spacing=self.__font_spacing, fill="white")
def draw_image(self, image_path: str) -> None:
with luma_canvas(self.__device) as draw:
draw.bitmap(self.__offset, Image.open(image_path).convert("1"), fill="white")

126
kvmd/apps/oled/sensors.py Normal file
View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# ========================================================================== #
# #
# KVMD-OLED - A small OLED daemon for PiKVM. #
# #
# Copyright (C) 2018-2024 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import socket
import functools
import datetime
import time
import netifaces
import psutil
# =====
class Sensors:
def __init__(self, fahrenheit: bool) -> None:
self.__fahrenheit = fahrenheit
self.__sensors = {
"fqdn": socket.getfqdn,
"iface": self.__get_iface,
"ip": self.__get_ip,
"uptime": self.__get_uptime,
"temp": self.__get_temp,
"cpu": self.__get_cpu,
"mem": self.__get_mem,
}
def render(self, text: str) -> str:
return text.format_map(self)
def __getitem__(self, key: str) -> str:
return self.__sensors[key]() # type: ignore
# =====
def __get_iface(self) -> str:
return self.__get_netconf(round(time.monotonic() / 0.3))[0]
def __get_ip(self) -> str:
return self.__get_netconf(round(time.monotonic() / 0.3))[1]
@functools.lru_cache(maxsize=1)
def __get_netconf(self, ts: int) -> tuple[str, str]:
_ = ts
try:
gws = netifaces.gateways()
if "default" in gws:
for proto in [socket.AF_INET, socket.AF_INET6]:
if proto in gws["default"]:
iface = gws["default"][proto][1]
addrs = netifaces.ifaddresses(iface)
return (iface, addrs[proto][0]["addr"])
for iface in netifaces.interfaces():
if not iface.startswith(("lo", "docker")):
addrs = netifaces.ifaddresses(iface)
for proto in [socket.AF_INET, socket.AF_INET6]:
if proto in addrs:
return (iface, addrs[proto][0]["addr"])
except Exception:
# _logger.exception("Can't get iface/IP")
pass
return ("<no-iface>", "<no-ip>")
# =====
def __get_uptime(self) -> str:
uptime = datetime.timedelta(seconds=int(time.time() - psutil.boot_time()))
pl = {"days": uptime.days}
(pl["hours"], rem) = divmod(uptime.seconds, 3600)
(pl["mins"], pl["secs"]) = divmod(rem, 60)
return "{days}d {hours}h {mins}m".format(**pl)
# =====
def __get_temp(self) -> str:
try:
with open("/sys/class/thermal/thermal_zone0/temp") as file:
temp = int(file.read().strip()) / 1000
if self.__fahrenheit:
temp = temp * 9 / 5 + 32
return f"{temp:.1f}\u00b0F"
return f"{temp:.1f}\u00b0C"
except Exception:
# _logger.exception("Can't read temp")
return "<no-temp>"
# =====
def __get_cpu(self) -> str:
st = psutil.cpu_times_percent()
user = st.user - st.guest
nice = st.nice - st.guest_nice
idle_all = st.idle + st.iowait
system_all = st.system + st.irq + st.softirq
virtual = st.guest + st.guest_nice
total = max(1, user + nice + system_all + idle_all + st.steal + virtual)
percent = int(
st.nice / total * 100
+ st.user / total * 100
+ system_all / total * 100
+ (st.steal + st.guest) / total * 100
)
return f"{percent}%"
def __get_mem(self) -> str:
return f"{int(psutil.virtual_memory().percent)}%"

View File

@@ -350,5 +350,5 @@ def main(argv: (list[str] | None)=None) -> None:
options = parser.parse_args(argv[1:])
try:
options.cmd(config)
except ValidatorError as err:
raise SystemExit(str(err))
except ValidatorError as ex:
raise SystemExit(str(ex))

View File

@@ -50,9 +50,9 @@ def _set_param(gadget: str, instance: int, param: str, value: str) -> None:
try:
with open(_get_param_path(gadget, instance, param), "w") as file:
file.write(value + "\n")
except OSError as err:
if err.errno == errno.EBUSY:
raise SystemExit(f"Can't change {param!r} value because device is locked: {err}")
except OSError as ex:
if ex.errno == errno.EBUSY:
raise SystemExit(f"Can't change {param!r} value because device is locked: {ex}")
raise

View File

@@ -133,8 +133,8 @@ class _Service: # pylint: disable=too-many-instance-attributes
logger.info("CMD: %s", tools.cmdfmt(cmd))
try:
return (not (await aioproc.log_process(cmd, logger)).returncode)
except Exception as err:
logger.exception("Can't execute command: %s", err)
except Exception as ex:
logger.exception("Can't execute command: %s", ex)
return False
# =====

View File

@@ -50,7 +50,7 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
super().__init__()
self.__data_path = os.path.join(fstab.find_pst().root_path, "data")
self.__data_path = fstab.find_pst().root_path
self.__ro_retries_delay = ro_retries_delay
self.__ro_cleanup_delay = ro_cleanup_delay
self.__remount_cmd = remount_cmd
@@ -60,8 +60,8 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
# ===== WEBSOCKET
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse:
async with self._ws_session(request) as ws:
async def __ws_handler(self, req: Request) -> WebSocketResponse:
async with self._ws_session(req) as ws:
await ws.send_event("loop", {})
return (await self._ws_loop(ws))
@@ -128,9 +128,9 @@ class PstServer(HttpServer): # pylint: disable=too-many-arguments,too-many-inst
def __is_write_available(self) -> bool:
try:
return (not (os.statvfs(self.__data_path).f_flag & os.ST_RDONLY))
except Exception as err:
except Exception as ex:
get_logger(0).info("Can't get filesystem state of PST (%s): %s",
self.__data_path, tools.efmt(err))
self.__data_path, tools.efmt(ex))
return False
async def __remount_storage(self, rw: bool) -> bool:

View File

@@ -46,8 +46,8 @@ def _preexec() -> None:
if os.isatty(0):
try:
os.tcsetpgrp(0, os.getpgid(0))
except Exception as err:
get_logger(0).info("Can't perform tcsetpgrp(0): %s", tools.efmt(err))
except Exception as ex:
get_logger(0).info("Can't perform tcsetpgrp(0): %s", tools.efmt(ex))
async def _run_process(cmd: list[str], data_path: str) -> asyncio.subprocess.Process: # pylint: disable=no-member

View File

@@ -21,7 +21,7 @@
from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamFormats
from ...clients.streamer import StreamerFormats
from ...clients.streamer import BaseStreamerClient
from ...clients.streamer import HttpStreamerClient
from ...clients.streamer import MemsinkStreamerClient
@@ -51,8 +51,8 @@ def main(argv: (list[str] | None)=None) -> None:
return None
streamers: list[BaseStreamerClient] = list(filter(None, [
make_memsink_streamer("h264", StreamFormats.H264),
make_memsink_streamer("jpeg", StreamFormats.JPEG),
make_memsink_streamer("h264", StreamerFormats.H264),
make_memsink_streamer("jpeg", StreamerFormats.JPEG),
HttpStreamerClient(name="JPEG", user_agent=user_agent, **config.streamer._unpack()),
]))
@@ -71,6 +71,7 @@ def main(argv: (list[str] | None)=None) -> None:
desired_fps=config.desired_fps,
mouse_output=config.mouse_output,
keymap_path=config.keymap,
allow_cut_after=config.allow_cut_after,
kvmd=KvmdClient(user_agent=user_agent, **config.kvmd._unpack()),
streamers=streamers,

View File

@@ -22,6 +22,7 @@
import asyncio
import ssl
import time
from typing import Callable
from typing import Coroutine
@@ -64,6 +65,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
width: int,
height: int,
name: str,
allow_cut_after: float,
vnc_passwds: list[str],
vencrypt: bool,
none_auth_only: bool,
@@ -79,6 +81,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
self._width = width
self._height = height
self.__name = name
self.__allow_cut_after = allow_cut_after
self.__vnc_passwds = vnc_passwds
self.__vencrypt = vencrypt
self.__none_auth_only = none_auth_only
@@ -90,6 +93,8 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
self.__fb_cont_updates = False
self.__fb_reset_h264 = False
self.__allow_cut_since_ts = 0.0
self.__lock = asyncio.Lock()
# =====
@@ -120,10 +125,10 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
except asyncio.CancelledError:
logger.info("%s [%s]: Cancelling subtask ...", self._remote, name)
raise
except RfbConnectionError as err:
logger.info("%s [%s]: Gone: %s", self._remote, name, err)
except (RfbError, ssl.SSLError) as err:
logger.error("%s [%s]: Error: %s", self._remote, name, err)
except RfbConnectionError as ex:
logger.info("%s [%s]: Gone: %s", self._remote, name, ex)
except (RfbError, ssl.SSLError) as ex:
logger.error("%s [%s]: Error: %s", self._remote, name, ex)
except Exception:
logger.exception("%s [%s]: Unhandled exception", self._remote, name)
@@ -414,6 +419,7 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
# =====
async def __main_loop(self) -> None:
self.__allow_cut_since_ts = time.monotonic() + self.__allow_cut_after
handlers = {
0: self.__handle_set_pixel_format,
2: self.__handle_set_encodings,
@@ -499,7 +505,12 @@ class RfbClient(RfbClientStream): # pylint: disable=too-many-instance-attribute
async def __handle_client_cut_text(self) -> None:
length = (await self._read_struct("cut text length", "xxx L"))[0]
text = await self._read_text("cut text data", length)
await self._on_cut_event(text)
if self.__allow_cut_since_ts > 0 and time.monotonic() >= self.__allow_cut_since_ts:
# We should ignore cut event a few seconds after handshake
# because bVNC, AVNC and maybe some other clients perform
# it right after the connection automatically.
# - https://github.com/pikvm/pikvm/issues/1420
await self._on_cut_event(text)
async def __handle_enable_cont_updates(self) -> None:
enabled = bool((await self._read_struct("enabled ContUpdates", "B HH HH"))[0])

View File

@@ -29,5 +29,5 @@ class RfbError(Exception):
class RfbConnectionError(RfbError):
def __init__(self, msg: str, err: Exception) -> None:
super().__init__(f"{msg}: {tools.efmt(err)}")
def __init__(self, msg: str, ex: Exception) -> None:
super().__init__(f"{msg}: {tools.efmt(ex)}")

View File

@@ -51,22 +51,22 @@ class RfbClientStream:
else:
fmt = f">{fmt}"
return struct.unpack(fmt, await self.__reader.readexactly(struct.calcsize(fmt)))[0]
except (ConnectionError, asyncio.IncompleteReadError) as err:
raise RfbConnectionError(f"Can't read {msg}", err)
except (ConnectionError, asyncio.IncompleteReadError) as ex:
raise RfbConnectionError(f"Can't read {msg}", ex)
async def _read_struct(self, msg: str, fmt: str) -> tuple[int, ...]:
assert len(fmt) > 1
try:
fmt = f">{fmt}"
return struct.unpack(fmt, (await self.__reader.readexactly(struct.calcsize(fmt))))
except (ConnectionError, asyncio.IncompleteReadError) as err:
raise RfbConnectionError(f"Can't read {msg}", err)
except (ConnectionError, asyncio.IncompleteReadError) as ex:
raise RfbConnectionError(f"Can't read {msg}", ex)
async def _read_text(self, msg: str, length: int) -> str:
try:
return (await self.__reader.readexactly(length)).decode("utf-8", errors="ignore")
except (ConnectionError, asyncio.IncompleteReadError) as err:
raise RfbConnectionError(f"Can't read {msg}", err)
except (ConnectionError, asyncio.IncompleteReadError) as ex:
raise RfbConnectionError(f"Can't read {msg}", ex)
# =====
@@ -84,8 +84,8 @@ class RfbClientStream:
self.__writer.write(struct.pack(f">{fmt}", *values))
if drain:
await self.__writer.drain()
except ConnectionError as err:
raise RfbConnectionError(f"Can't write {msg}", err)
except ConnectionError as ex:
raise RfbConnectionError(f"Can't write {msg}", ex)
async def _write_reason(self, msg: str, text: str, drain: bool=True) -> None:
encoded = text.encode("utf-8", errors="ignore")
@@ -94,8 +94,8 @@ class RfbClientStream:
self.__writer.write(encoded)
if drain:
await self.__writer.drain()
except ConnectionError as err:
raise RfbConnectionError(f"Can't write {msg}", err)
except ConnectionError as ex:
raise RfbConnectionError(f"Can't write {msg}", ex)
async def _write_fb_update(self, msg: str, width: int, height: int, encoding: int, drain: bool=True) -> None:
await self._write_struct(
@@ -123,8 +123,8 @@ class RfbClientStream:
server_side=True,
ssl_handshake_timeout=ssl_timeout,
)
except ConnectionError as err:
raise RfbConnectionError("Can't start TLS", err)
except ConnectionError as ex:
raise RfbConnectionError("Can't start TLS", ex)
ssl_reader.set_transport(transport) # type: ignore
ssl_writer = asyncio.StreamWriter(

View File

@@ -42,7 +42,7 @@ from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError
from ...clients.streamer import StreamerPermError
from ...clients.streamer import StreamFormats
from ...clients.streamer import StreamerFormats
from ...clients.streamer import BaseStreamerClient
from ... import tools
@@ -81,6 +81,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
mouse_output: str,
keymap_name: str,
symmap: dict[int, dict[int, str]],
allow_cut_after: float,
kvmd: KvmdClient,
streamers: list[BaseStreamerClient],
@@ -100,6 +101,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
tls_timeout=tls_timeout,
x509_cert_path=x509_cert_path,
x509_key_path=x509_key_path,
allow_cut_after=allow_cut_after,
vnc_passwds=list(vnc_credentials),
vencrypt=vencrypt,
none_auth_only=none_auth_only,
@@ -175,20 +177,25 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__kvmd_ws = None
async def __process_ws_event(self, event_type: str, event: dict) -> None:
if event_type == "info_meta_state":
try:
host = event["server"]["host"]
except Exception:
host = None
else:
if isinstance(host, str):
name = f"PiKVM: {host}"
if self._encodings.has_rename:
await self._send_rename(name)
self.__shared_params.name = name
if event_type == "info_state":
if "meta" in event:
try:
host = event["meta"]["server"]["host"]
except Exception:
host = None
else:
if isinstance(host, str):
name = f"PiKVM: {host}"
if self._encodings.has_rename:
await self._send_rename(name)
self.__shared_params.name = name
elif event_type == "hid_state":
if self._encodings.has_leds_state:
if (
self._encodings.has_leds_state
and ("keyboard" in event)
and ("leds" in event["keyboard"])
):
await self._send_leds_state(**event["keyboard"]["leds"])
# =====
@@ -210,19 +217,19 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
await self.__queue_frame(frame)
else:
await self.__queue_frame("No signal")
except StreamerError as err:
if isinstance(err, StreamerPermError):
except StreamerError as ex:
if isinstance(ex, StreamerPermError):
streamer = self.__get_default_streamer()
logger.info("%s [streamer]: Permanent error: %s; switching to %s ...", self._remote, err, streamer)
logger.info("%s [streamer]: Permanent error: %s; switching to %s ...", self._remote, ex, streamer)
else:
logger.info("%s [streamer]: Waiting for stream: %s", self._remote, err)
logger.info("%s [streamer]: Waiting for stream: %s", self._remote, ex)
await self.__queue_frame("Waiting for stream ...")
await asyncio.sleep(1)
def __get_preferred_streamer(self) -> BaseStreamerClient:
formats = {
StreamFormats.JPEG: "has_tight",
StreamFormats.H264: "has_h264",
StreamerFormats.JPEG: "has_tight",
StreamerFormats.H264: "has_h264",
}
streamer: (BaseStreamerClient | None) = None
for streamer in self.__streamers:
@@ -248,7 +255,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
"data": (await make_text_jpeg(self._width, self._height, self._encodings.tight_jpeg_quality, text)),
"width": self._width,
"height": self._height,
"format": StreamFormats.JPEG,
"format": StreamerFormats.JPEG,
}
async def __fb_sender_task_loop(self) -> None: # pylint: disable=too-many-branches
@@ -258,21 +265,21 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
frame = await self.__fb_queue.get()
if (
last is None # pylint: disable=too-many-boolean-expressions
or frame["format"] == StreamFormats.JPEG
or frame["format"] == StreamerFormats.JPEG
or last["format"] != frame["format"]
or (frame["format"] == StreamFormats.H264 and (
or (frame["format"] == StreamerFormats.H264 and (
frame["key"]
or last["width"] != frame["width"]
or last["height"] != frame["height"]
or len(last["data"]) + len(frame["data"]) > 4194304
))
):
self.__fb_has_key = (frame["format"] == StreamFormats.H264 and frame["key"])
self.__fb_has_key = (frame["format"] == StreamerFormats.H264 and frame["key"])
last = frame
if self.__fb_queue.qsize() == 0:
break
continue
assert frame["format"] == StreamFormats.H264
assert frame["format"] == StreamerFormats.H264
last["data"] += frame["data"]
if self.__fb_queue.qsize() == 0:
break
@@ -294,9 +301,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
await self._send_fb_allow_again()
continue
if last["format"] == StreamFormats.JPEG:
if last["format"] == StreamerFormats.JPEG:
await self._send_fb_jpeg(last["data"])
elif last["format"] == StreamFormats.H264:
elif last["format"] == StreamerFormats.H264:
if not self._encodings.has_h264:
raise RfbError("The client doesn't want to accept H264 anymore")
if self.__fb_has_key:
@@ -439,6 +446,7 @@ class VncServer: # pylint: disable=too-many-instance-attributes
desired_fps: int,
mouse_output: str,
keymap_path: str,
allow_cut_after: float,
kvmd: KvmdClient,
streamers: list[BaseStreamerClient],
@@ -481,8 +489,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
try:
async with kvmd.make_session("", "") as kvmd_session:
none_auth_only = await kvmd_session.auth.check()
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("%s [entry]: Can't check KVMD auth mode: %s", remote, tools.efmt(err))
except (aiohttp.ClientError, asyncio.TimeoutError) as ex:
logger.error("%s [entry]: Can't check KVMD auth mode: %s", remote, tools.efmt(ex))
return
await _Client(
@@ -496,6 +504,7 @@ class VncServer: # pylint: disable=too-many-instance-attributes
mouse_output=mouse_output,
keymap_name=keymap_name,
symmap=symmap,
allow_cut_after=allow_cut_after,
kvmd=kvmd,
streamers=streamers,
vnc_credentials=(await self.__vnc_auth_manager.read_credentials())[0],

View File

@@ -54,8 +54,8 @@ class VncAuthManager:
if self.__enabled:
try:
return (await self.__inner_read_credentials(), True)
except VncAuthError as err:
get_logger(0).error(str(err))
except VncAuthError as ex:
get_logger(0).error(str(ex))
except Exception:
get_logger(0).exception("Unhandled exception while reading VNCAuth passwd file")
return ({}, (not self.__enabled))

View File

@@ -56,8 +56,8 @@ def _write_int(rtc: int, key: str, value: int) -> None:
def _reset_alarm(rtc: int, timeout: int) -> None:
try:
now = _read_int(rtc, "since_epoch")
except OSError as err:
if err.errno != errno.EINVAL:
except OSError as ex:
if ex.errno != errno.EINVAL:
raise
raise RtcIsNotAvailableError("Can't read since_epoch right now")
if now == 0:
@@ -65,8 +65,8 @@ def _reset_alarm(rtc: int, timeout: int) -> None:
try:
for wake in [0, now + timeout]:
_write_int(rtc, "wakealarm", wake)
except OSError as err:
if err.errno != errno.EIO:
except OSError as ex:
if ex.errno != errno.EIO:
raise
raise RtcIsNotAvailableError("IO error, probably the supercapacitor is not charged")
@@ -80,9 +80,9 @@ def _cmd_run(config: Section) -> None:
while True:
try:
_reset_alarm(config.rtc, config.timeout)
except RtcIsNotAvailableError as err:
except RtcIsNotAvailableError as ex:
if not fail:
logger.error("RTC%d is not available now: %s; waiting ...", config.rtc, err)
logger.error("RTC%d is not available now: %s; waiting ...", config.rtc, ex)
fail = True
else:
if fail: