One-KVM/kvmd/apps/kvmd/server.py
2020-09-28 02:41:46 +03:00

444 lines
17 KiB
Python

# ========================================================================== #
# #
# KVMD - The main Pi-KVM daemon. #
# #
# Copyright (C) 2018 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import os
import signal
import asyncio
import operator
import dataclasses
import json
from typing import List
from typing import Dict
from typing import Set
from typing import Callable
from typing import Coroutine
from typing import AsyncGenerator
from typing import Optional
from typing import Any
import aiohttp
import aiohttp.web
from ...logging import get_logger
from ...errors import OperationError
from ...errors import IsBusyError
from ...plugins import BasePlugin
from ...plugins.hid import BaseHid
from ...plugins.atx import BaseAtx
from ...plugins.msd import BaseMsd
from ...validators import ValidatorError
from ...validators.basic import valid_bool
from ...validators.kvm import valid_stream_quality
from ...validators.kvm import valid_stream_fps
from ...validators.kvm import valid_stream_resolution
from ... import aiotools
from ... import aioproc
from .auth import AuthManager
from .info import InfoManager
from .logreader import LogReader
from .wol import WakeOnLan
from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .http import HttpError
from .http import HttpExposed
from .http import exposed_http
from .http import exposed_ws
from .http import get_exposed_http
from .http import get_exposed_ws
from .http import make_json_response
from .http import make_json_exception
from .http import HttpServer
from .api.auth import AuthApi
from .api.auth import check_request_auth
from .api.info import InfoApi
from .api.log import LogApi
from .api.wol import WolApi
from .api.ugpio import UserGpioApi
from .api.hid import HidApi
from .api.atx import AtxApi
from .api.msd import MsdApi
from .api.streamer import StreamerApi
from .api.export import ExportApi
from .api.redfish import RedfishApi
# =====
class StreamerQualityNotSupported(OperationError):
def __init__(self) -> None:
super().__init__("This streamer does not support quality settings")
class StreamerResolutionNotSupported(OperationError):
def __init__(self) -> None:
super().__init__("This streamer does not support resolution settings")
# =====
@dataclasses.dataclass(frozen=True)
class _Component: # pylint: disable=too-many-instance-attributes
name: str
event_type: str
obj: object
sysprep: Optional[Callable[[], None]] = None
systask: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
get_state: Optional[Callable[[], Coroutine[Any, Any, Dict]]] = None
poll_state: Optional[Callable[[], AsyncGenerator[Dict, None]]] = None
cleanup: Optional[Callable[[], Coroutine[Any, Any, Dict]]] = None
def __post_init__(self) -> None:
if isinstance(self.obj, BasePlugin):
object.__setattr__(self, "name", f"{self.name} ({self.obj.get_plugin_name()})")
for field in ["sysprep", "systask", "get_state", "poll_state", "cleanup"]:
object.__setattr__(self, field, getattr(self.obj, field, None))
if self.get_state or self.poll_state:
assert self.event_type, self
@dataclasses.dataclass(frozen=True)
class _WsClient:
ws: aiohttp.web.WebSocketResponse
stream: bool
def __str__(self) -> str:
return f"WsClient(id={id(self)}, stream={self.stream})"
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments
self,
auth_manager: AuthManager,
info_manager: InfoManager,
log_reader: LogReader,
wol: WakeOnLan,
user_gpio: UserGpio,
hid: BaseHid,
atx: BaseAtx,
msd: BaseMsd,
streamer: Streamer,
snapshoter: Snapshoter,
heartbeat: float,
sync_chunk_size: int,
keymap_path: str,
stream_forever: bool,
) -> None:
self.__auth_manager = auth_manager
self.__hid = hid
self.__streamer = streamer
self.__snapshoter = snapshoter # Not a component: No state or cleanup
self.__user_gpio = user_gpio # Has extra state "gpio_scheme_state"
self.__heartbeat = heartbeat
self.__stream_forever = stream_forever
self.__components = [
*[
_Component("Auth manager", "", auth_manager),
],
*[
_Component(f"Info manager ({sub})", f"info_{sub}_state", info_manager.get_submanager(sub))
for sub in sorted(info_manager.get_subs())
],
*[
_Component("Wake-on-LAN", "wol_state", wol),
_Component("User-GPIO", "gpio_state", user_gpio),
_Component("HID", "hid_state", hid),
_Component("ATX", "atx_state", atx),
_Component("MSD", "msd_state", msd),
_Component("Streamer", "streamer_state", streamer),
],
]
self.__apis: List[object] = [
self,
AuthApi(auth_manager),
InfoApi(info_manager),
LogApi(log_reader),
WolApi(wol),
UserGpioApi(user_gpio),
HidApi(hid, keymap_path),
AtxApi(atx),
MsdApi(msd, sync_chunk_size),
StreamerApi(streamer),
ExportApi(info_manager, atx, user_gpio),
RedfishApi(info_manager, atx),
]
self.__ws_handlers: Dict[str, Callable] = {}
self.__ws_clients: Set[_WsClient] = set()
self.__ws_clients_lock = asyncio.Lock()
self.__system_tasks: List[asyncio.Task] = []
self.__streamer_notifier = aiotools.AioNotifier()
self.__reset_streamer = False
self.__new_streamer_params: Dict = {}
# ===== STREAMER CONTROLLER
@exposed_http("POST", "/streamer/set_params")
async def __streamer_set_params_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
current_params = self.__streamer.get_params()
for (name, validator, exc_cls) in [
("quality", valid_stream_quality, StreamerQualityNotSupported),
("desired_fps", valid_stream_fps, None),
("resolution", valid_stream_resolution, StreamerResolutionNotSupported),
]:
value = request.query.get(name)
if value:
if name not in current_params:
assert exc_cls is not None, name
raise exc_cls()
value = validator(value)
if current_params[name] != value:
self.__new_streamer_params[name] = value
await self.__streamer_notifier.notify()
return make_json_response()
@exposed_http("POST", "/streamer/reset")
async def __streamer_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
self.__reset_streamer = True
await self.__streamer_notifier.notify()
return make_json_response()
# ===== WEBSOCKET
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse:
logger = get_logger(0)
client = _WsClient(
ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat),
stream=valid_bool(request.query.get("stream", "true")),
)
await client.ws.prepare(request)
await self.__register_ws_client(client)
try:
await self.__send_event(client.ws, "gpio_model_state", await self.__user_gpio.get_model())
await asyncio.gather(*[
self.__send_event(client.ws, component.event_type, await component.get_state())
for component in self.__components
if component.get_state
])
await self.__send_event(client.ws, "loop", {})
async for msg in client.ws:
if msg.type == aiohttp.web.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
event_type = data.get("event_type")
event = data["event"]
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = self.__ws_handlers.get(event_type)
if handler:
await handler(client.ws, event)
else:
logger.error("Unknown websocket event: %r", data)
else:
break
return client.ws
finally:
await self.__remove_ws_client(client)
@exposed_ws("ping")
async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None:
await self.__send_event(ws, "pong", {})
# ===== SYSTEM STUFF
def run(self, **kwargs: Any) -> None: # type: ignore # pylint: disable=arguments-differ
for component in self.__components:
if component.sysprep:
component.sysprep()
aioproc.rename_process("main")
super().run(**kwargs)
async def _make_app(self) -> aiohttp.web.Application:
app = aiohttp.web.Application(middlewares=[aiohttp.web.normalize_path_middleware(
append_slash=False,
remove_slash=True,
merge_slashes=True,
)])
app.on_shutdown.append(self.__on_shutdown)
app.on_cleanup.append(self.__on_cleanup)
self.__run_system_task(self.__stream_controller)
for component in self.__components:
if component.systask:
self.__run_system_task(component.systask)
if component.poll_state:
self.__run_system_task(self.__poll_state, component.event_type, component.poll_state())
self.__run_system_task(self.__stream_snapshoter)
for api in self.__apis:
for http_exposed in get_exposed_http(api):
self.__add_app_route(app, http_exposed)
for ws_exposed in get_exposed_ws(api):
self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler
return app
def __run_system_task(self, method: Callable, *args: Any) -> None:
async def wrapper() -> None:
try:
await method(*args)
raise RuntimeError(f"Dead system task: {method}"
f"({', '.join(getattr(arg, '__name__', str(arg)) for arg in args)})")
except asyncio.CancelledError:
pass
except Exception:
get_logger().exception("Unhandled exception, killing myself ...")
os.kill(os.getpid(), signal.SIGTERM)
self.__system_tasks.append(asyncio.create_task(wrapper()))
def __add_app_route(self, app: aiohttp.web.Application, exposed: HttpExposed) -> None:
async def wrapper(request: aiohttp.web.Request) -> aiohttp.web.Response:
try:
await check_request_auth(self.__auth_manager, exposed, request)
return (await exposed.handler(request))
except IsBusyError as err:
return make_json_exception(err, 409)
except (ValidatorError, OperationError) as err:
return make_json_exception(err, 400)
except HttpError as err:
return make_json_exception(err)
app.router.add_route(exposed.method, exposed.path, wrapper)
async def __on_shutdown(self, _: aiohttp.web.Application) -> None:
logger = get_logger(0)
logger.info("Waiting short tasks ...")
await asyncio.gather(*aiotools.get_short_tasks(), return_exceptions=True)
logger.info("Cancelling system tasks ...")
for task in self.__system_tasks:
task.cancel()
logger.info("Waiting system tasks ...")
await asyncio.gather(*self.__system_tasks, return_exceptions=True)
logger.info("Disconnecting clients ...")
for client in list(self.__ws_clients):
await self.__remove_ws_client(client)
async def __on_cleanup(self, _: aiohttp.web.Application) -> None:
logger = get_logger(0)
for component in self.__components:
if component.cleanup:
logger.info("Cleaning up %s ...", component.name)
try:
await component.cleanup() # type: ignore
except Exception:
logger.exception("Cleanup error on %s", component.name)
async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def __broadcast_event(self, event_type: str, event: Optional[Dict]) -> None:
if self.__ws_clients:
await asyncio.gather(*[
self.__send_event(client.ws, event_type, event)
for client in list(self.__ws_clients)
if (
not client.ws.closed
and client.ws._req is not None # pylint: disable=protected-access
and client.ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
async def __register_ws_client(self, client: _WsClient) -> None:
async with self.__ws_clients_lock:
self.__ws_clients.add(client)
get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients))
await self.__streamer_notifier.notify()
async def __remove_ws_client(self, client: _WsClient) -> None:
async with self.__ws_clients_lock:
self.__hid.clear_events()
try:
self.__ws_clients.remove(client)
get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients))
await client.ws.close()
except Exception:
pass
await self.__streamer_notifier.notify()
def __has_stream_clients(self) -> bool:
return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients)))
# ===== SYSTEM TASKS
async def __stream_controller(self) -> None:
prev = False
while True:
cur = (self.__has_stream_clients() or self.__snapshoter.snapshoting() or self.__stream_forever)
if not prev and cur:
await self.__streamer.ensure_start(reset=False)
elif prev and not cur:
await self.__streamer.ensure_stop(immediately=False)
if self.__reset_streamer or self.__new_streamer_params:
start = self.__streamer.is_working()
await self.__streamer.ensure_stop(immediately=True)
if self.__new_streamer_params:
self.__streamer.set_params(self.__new_streamer_params)
self.__new_streamer_params = {}
if start:
await self.__streamer.ensure_start(reset=self.__reset_streamer)
self.__reset_streamer = False
prev = cur
await self.__streamer_notifier.wait()
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
async for state in poller:
await self.__broadcast_event(event_type, state)
async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run(
is_live=self.__has_stream_clients,
notifier=self.__streamer_notifier,
)