# ========================================================================== # # # # KVMD - The main Pi-KVM daemon. # # # # Copyright (C) 2018 Maxim Devaev # # # # This program is free software: you can redistribute it and/or modify # # it under the terms of the GNU General Public License as published by # # the Free Software Foundation, either version 3 of the License, or # # (at your option) any later version. # # # # This program is distributed in the hope that it will be useful, # # but WITHOUT ANY WARRANTY; without even the implied warranty of # # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # # GNU General Public License for more details. # # # # You should have received a copy of the GNU General Public License # # along with this program. If not, see . # # # # ========================================================================== # import os import signal import asyncio import operator import dataclasses import json from typing import List from typing import Dict from typing import Set from typing import Callable from typing import Coroutine from typing import AsyncGenerator from typing import Optional from typing import Any import aiohttp import aiohttp.web from ...logging import get_logger from ...errors import OperationError from ...errors import IsBusyError from ...plugins import BasePlugin from ...plugins.hid import BaseHid from ...plugins.atx import BaseAtx from ...plugins.msd import BaseMsd from ...validators import ValidatorError from ...validators.basic import valid_bool from ...validators.kvm import valid_stream_quality from ...validators.kvm import valid_stream_fps from ...validators.kvm import valid_stream_resolution from ... import aiotools from ... import aioproc from .auth import AuthManager from .info import InfoManager from .logreader import LogReader from .wol import WakeOnLan from .ugpio import UserGpio from .streamer import Streamer from .snapshoter import Snapshoter from .http import HttpError from .http import HttpExposed from .http import exposed_http from .http import exposed_ws from .http import get_exposed_http from .http import get_exposed_ws from .http import make_json_response from .http import make_json_exception from .http import HttpServer from .api.auth import AuthApi from .api.auth import check_request_auth from .api.info import InfoApi from .api.log import LogApi from .api.wol import WolApi from .api.ugpio import UserGpioApi from .api.hid import HidApi from .api.atx import AtxApi from .api.msd import MsdApi from .api.streamer import StreamerApi from .api.export import ExportApi from .api.redfish import RedfishApi # ===== class StreamerQualityNotSupported(OperationError): def __init__(self) -> None: super().__init__("This streamer does not support quality settings") class StreamerResolutionNotSupported(OperationError): def __init__(self) -> None: super().__init__("This streamer does not support resolution settings") # ===== @dataclasses.dataclass(frozen=True) class _Component: # pylint: disable=too-many-instance-attributes name: str event_type: str obj: object sysprep: Optional[Callable[[], None]] = None systask: Optional[Callable[[], Coroutine[Any, Any, None]]] = None get_state: Optional[Callable[[], Coroutine[Any, Any, Dict]]] = None poll_state: Optional[Callable[[], AsyncGenerator[Dict, None]]] = None cleanup: Optional[Callable[[], Coroutine[Any, Any, Dict]]] = None def __post_init__(self) -> None: if isinstance(self.obj, BasePlugin): object.__setattr__(self, "name", f"{self.name} ({self.obj.get_plugin_name()})") for field in ["sysprep", "systask", "get_state", "poll_state", "cleanup"]: object.__setattr__(self, field, getattr(self.obj, field, None)) if self.get_state or self.poll_state: assert self.event_type, self @dataclasses.dataclass(frozen=True) class _WsClient: ws: aiohttp.web.WebSocketResponse stream: bool def __str__(self) -> str: return f"WsClient(id={id(self)}, stream={self.stream})" class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments self, auth_manager: AuthManager, info_manager: InfoManager, log_reader: LogReader, wol: WakeOnLan, user_gpio: UserGpio, hid: BaseHid, atx: BaseAtx, msd: BaseMsd, streamer: Streamer, snapshoter: Snapshoter, heartbeat: float, sync_chunk_size: int, keymap_path: str, stream_forever: bool, ) -> None: self.__auth_manager = auth_manager self.__hid = hid self.__streamer = streamer self.__snapshoter = snapshoter # Not a component: No state or cleanup self.__user_gpio = user_gpio # Has extra state "gpio_scheme_state" self.__heartbeat = heartbeat self.__stream_forever = stream_forever self.__components = [ *[ _Component("Auth manager", "", auth_manager), ], *[ _Component(f"Info manager ({sub})", f"info_{sub}_state", info_manager.get_submanager(sub)) for sub in sorted(info_manager.get_subs()) ], *[ _Component("Wake-on-LAN", "wol_state", wol), _Component("User-GPIO", "gpio_state", user_gpio), _Component("HID", "hid_state", hid), _Component("ATX", "atx_state", atx), _Component("MSD", "msd_state", msd), _Component("Streamer", "streamer_state", streamer), ], ] self.__apis: List[object] = [ self, AuthApi(auth_manager), InfoApi(info_manager), LogApi(log_reader), WolApi(wol), UserGpioApi(user_gpio), HidApi(hid, keymap_path), AtxApi(atx), MsdApi(msd, sync_chunk_size), StreamerApi(streamer), ExportApi(info_manager, atx, user_gpio), RedfishApi(info_manager, atx), ] self.__ws_handlers: Dict[str, Callable] = {} self.__ws_clients: Set[_WsClient] = set() self.__ws_clients_lock = asyncio.Lock() self.__system_tasks: List[asyncio.Task] = [] self.__streamer_notifier = aiotools.AioNotifier() self.__reset_streamer = False self.__new_streamer_params: Dict = {} # ===== STREAMER CONTROLLER @exposed_http("POST", "/streamer/set_params") async def __streamer_set_params_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: current_params = self.__streamer.get_params() for (name, validator, exc_cls) in [ ("quality", valid_stream_quality, StreamerQualityNotSupported), ("desired_fps", valid_stream_fps, None), ("resolution", valid_stream_resolution, StreamerResolutionNotSupported), ]: value = request.query.get(name) if value: if name not in current_params: assert exc_cls is not None, name raise exc_cls() value = validator(value) if current_params[name] != value: self.__new_streamer_params[name] = value await self.__streamer_notifier.notify() return make_json_response() @exposed_http("POST", "/streamer/reset") async def __streamer_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response: self.__reset_streamer = True await self.__streamer_notifier.notify() return make_json_response() # ===== WEBSOCKET @exposed_http("GET", "/ws") async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse: logger = get_logger(0) client = _WsClient( ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat), stream=valid_bool(request.query.get("stream", "true")), ) await client.ws.prepare(request) await self.__register_ws_client(client) try: await self.__send_event(client.ws, "gpio_model_state", await self.__user_gpio.get_model()) await asyncio.gather(*[ self.__send_event(client.ws, component.event_type, await component.get_state()) for component in self.__components if component.get_state ]) await self.__send_event(client.ws, "loop", {}) async for msg in client.ws: if msg.type == aiohttp.web.WSMsgType.TEXT: try: data = json.loads(msg.data) event_type = data.get("event_type") event = data["event"] except Exception as err: logger.error("Can't parse JSON event from websocket: %r", err) else: handler = self.__ws_handlers.get(event_type) if handler: await handler(client.ws, event) else: logger.error("Unknown websocket event: %r", data) else: break return client.ws finally: await self.__remove_ws_client(client) @exposed_ws("ping") async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None: await self.__send_event(ws, "pong", {}) # ===== SYSTEM STUFF def run(self, **kwargs: Any) -> None: # type: ignore # pylint: disable=arguments-differ for component in self.__components: if component.sysprep: component.sysprep() aioproc.rename_process("main") super().run(**kwargs) async def _make_app(self) -> aiohttp.web.Application: app = aiohttp.web.Application(middlewares=[aiohttp.web.normalize_path_middleware( append_slash=False, remove_slash=True, merge_slashes=True, )]) app.on_shutdown.append(self.__on_shutdown) app.on_cleanup.append(self.__on_cleanup) self.__run_system_task(self.__stream_controller) for component in self.__components: if component.systask: self.__run_system_task(component.systask) if component.poll_state: self.__run_system_task(self.__poll_state, component.event_type, component.poll_state()) self.__run_system_task(self.__stream_snapshoter) for api in self.__apis: for http_exposed in get_exposed_http(api): self.__add_app_route(app, http_exposed) for ws_exposed in get_exposed_ws(api): self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler return app def __run_system_task(self, method: Callable, *args: Any) -> None: async def wrapper() -> None: try: await method(*args) raise RuntimeError(f"Dead system task: {method}" f"({', '.join(getattr(arg, '__name__', str(arg)) for arg in args)})") except asyncio.CancelledError: pass except Exception: get_logger().exception("Unhandled exception, killing myself ...") os.kill(os.getpid(), signal.SIGTERM) self.__system_tasks.append(asyncio.create_task(wrapper())) def __add_app_route(self, app: aiohttp.web.Application, exposed: HttpExposed) -> None: async def wrapper(request: aiohttp.web.Request) -> aiohttp.web.Response: try: await check_request_auth(self.__auth_manager, exposed, request) return (await exposed.handler(request)) except IsBusyError as err: return make_json_exception(err, 409) except (ValidatorError, OperationError) as err: return make_json_exception(err, 400) except HttpError as err: return make_json_exception(err) app.router.add_route(exposed.method, exposed.path, wrapper) async def __on_shutdown(self, _: aiohttp.web.Application) -> None: logger = get_logger(0) logger.info("Waiting short tasks ...") await asyncio.gather(*aiotools.get_short_tasks(), return_exceptions=True) logger.info("Cancelling system tasks ...") for task in self.__system_tasks: task.cancel() logger.info("Waiting system tasks ...") await asyncio.gather(*self.__system_tasks, return_exceptions=True) logger.info("Disconnecting clients ...") for client in list(self.__ws_clients): await self.__remove_ws_client(client) async def __on_cleanup(self, _: aiohttp.web.Application) -> None: logger = get_logger(0) for component in self.__components: if component.cleanup: logger.info("Cleaning up %s ...", component.name) try: await component.cleanup() # type: ignore except Exception: logger.exception("Cleanup error on %s", component.name) async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None: await ws.send_str(json.dumps({ "event_type": event_type, "event": event, })) async def __broadcast_event(self, event_type: str, event: Optional[Dict]) -> None: if self.__ws_clients: await asyncio.gather(*[ self.__send_event(client.ws, event_type, event) for client in list(self.__ws_clients) if ( not client.ws.closed and client.ws._req is not None # pylint: disable=protected-access and client.ws._req.transport is not None # pylint: disable=protected-access ) ], return_exceptions=True) async def __register_ws_client(self, client: _WsClient) -> None: async with self.__ws_clients_lock: self.__ws_clients.add(client) get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients)) await self.__streamer_notifier.notify() async def __remove_ws_client(self, client: _WsClient) -> None: async with self.__ws_clients_lock: self.__hid.clear_events() try: self.__ws_clients.remove(client) get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients)) await client.ws.close() except Exception: pass await self.__streamer_notifier.notify() def __has_stream_clients(self) -> bool: return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients))) # ===== SYSTEM TASKS async def __stream_controller(self) -> None: prev = False while True: cur = (self.__has_stream_clients() or self.__snapshoter.snapshoting() or self.__stream_forever) if not prev and cur: await self.__streamer.ensure_start(reset=False) elif prev and not cur: await self.__streamer.ensure_stop(immediately=False) if self.__reset_streamer or self.__new_streamer_params: start = self.__streamer.is_working() await self.__streamer.ensure_stop(immediately=True) if self.__new_streamer_params: self.__streamer.set_params(self.__new_streamer_params) self.__new_streamer_params = {} if start: await self.__streamer.ensure_start(reset=self.__reset_streamer) self.__reset_streamer = False prev = cur await self.__streamer_notifier.wait() async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None: async for state in poller: await self.__broadcast_event(event_type, state) async def __stream_snapshoter(self) -> None: await self.__snapshoter.run( is_live=self.__has_stream_clients, notifier=self.__streamer_notifier, )