common websocket code

This commit is contained in:
Maxim Devaev 2022-06-14 18:18:21 +03:00
parent 37e5118fff
commit 88c7796551
3 changed files with 145 additions and 129 deletions

View File

@ -32,7 +32,6 @@ from typing import Callable
from aiohttp.web import Request
from aiohttp.web import Response
from aiohttp.web import WebSocketResponse
from ....mouse import MouseRange
@ -42,6 +41,7 @@ from ....keyboard.printer import text_to_web_keys
from ....htserver import exposed_http
from ....htserver import exposed_ws
from ....htserver import make_json_response
from ....htserver import WsSession
from ....plugins.hid import BaseHid
@ -158,7 +158,7 @@ class HidApi:
# =====
@exposed_ws("key")
async def __ws_key_handler(self, _: WebSocketResponse, event: Dict) -> None:
async def __ws_key_handler(self, _: WsSession, event: Dict) -> None:
try:
key = valid_hid_key(event["key"])
state = valid_bool(event["state"])
@ -168,7 +168,7 @@ class HidApi:
self.__hid.send_key_events([(key, state)])
@exposed_ws("mouse_button")
async def __ws_mouse_button_handler(self, _: WebSocketResponse, event: Dict) -> None:
async def __ws_mouse_button_handler(self, _: WsSession, event: Dict) -> None:
try:
button = valid_hid_mouse_button(event["button"])
state = valid_bool(event["state"])
@ -177,7 +177,7 @@ class HidApi:
self.__hid.send_mouse_button_event(button, state)
@exposed_ws("mouse_move")
async def __ws_mouse_move_handler(self, _: WebSocketResponse, event: Dict) -> None:
async def __ws_mouse_move_handler(self, _: WsSession, event: Dict) -> None:
try:
to_x = valid_hid_mouse_move(event["to"]["x"])
to_y = valid_hid_mouse_move(event["to"]["y"])
@ -186,11 +186,11 @@ class HidApi:
self.__send_mouse_move_event_remapped(to_x, to_y)
@exposed_ws("mouse_relative")
async def __ws_mouse_relative_handler(self, _: WebSocketResponse, event: Dict) -> None:
async def __ws_mouse_relative_handler(self, _: WsSession, event: Dict) -> None:
self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event)
@exposed_ws("mouse_wheel")
async def __ws_mouse_wheel_handler(self, _: WebSocketResponse, event: Dict) -> None:
async def __ws_mouse_wheel_handler(self, _: WsSession, event: Dict) -> None:
self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event)
def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None:

View File

@ -27,7 +27,6 @@ import dataclasses
from typing import Tuple
from typing import List
from typing import Dict
from typing import Set
from typing import Callable
from typing import Coroutine
from typing import AsyncGenerator
@ -48,12 +47,8 @@ from ... import aioproc
from ...htserver import HttpExposed
from ...htserver import exposed_http
from ...htserver import exposed_ws
from ...htserver import get_exposed_http
from ...htserver import get_exposed_ws
from ...htserver import make_json_response
from ...htserver import send_ws_event
from ...htserver import broadcast_ws_event
from ...htserver import process_ws_messages
from ...htserver import WsSession
from ...htserver import HttpServer
from ...plugins import BasePlugin
@ -128,15 +123,6 @@ class _Component: # pylint: disable=too-many-instance-attributes
assert self.event_type, self
@dataclasses.dataclass(frozen=True)
class _WsClient:
ws: WebSocketResponse
stream: bool
def __str__(self) -> str:
return f"WsClient(id={id(self)}, stream={self.stream})"
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
@ -160,6 +146,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
stream_forever: bool,
) -> None:
super().__init__()
self.__auth_manager = auth_manager
self.__hid = hid
self.__streamer = streamer
@ -201,11 +189,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
RedfishApi(info_manager, atx),
]
self.__ws_handlers: Dict[str, Callable] = {}
self.__ws_clients: Set[_WsClient] = set()
self.__ws_clients_lock = asyncio.Lock()
self.__streamer_notifier = aiotools.AioNotifier()
self.__reset_streamer = False
self.__new_streamer_params: Dict = {}
@ -244,11 +227,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse:
stream = valid_bool(request.query.get("stream", "true"))
ws = await self._make_ws_response(request)
client = _WsClient(ws, stream)
await self.__register_ws_client(client)
try:
async with self._ws_session(request, stream=stream) as ws:
stage1 = [
("gpio_model_state", self.__user_gpio.get_model()),
("hid_keymaps_state", self.__hid_api.get_keymaps()),
@ -266,19 +245,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
))
for stage in [stage1, stage2]:
await asyncio.gather(*[
send_ws_event(ws, event_type, events.pop(event_type))
ws.send_event(event_type, events.pop(event_type))
for (event_type, _) in stage
])
await send_ws_event(ws, "loop", {})
await process_ws_messages(ws, self.__ws_handlers)
return ws
finally:
await self.__remove_ws_client(client)
await ws.send_event("loop", {})
return (await self._ws_loop(ws))
@exposed_ws("ping")
async def __ws_ping_handler(self, ws: WebSocketResponse, _: Dict) -> None:
await send_ws_event(ws, "pong", {})
async def __ws_ping_handler(self, ws: WsSession, _: Dict) -> None:
await ws.send_event("pong", {})
# ===== SYSTEM STUFF
@ -300,26 +275,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
if comp.poll_state:
aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state()))
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
for api in self.__apis:
for http_exposed in get_exposed_http(api):
self._add_exposed(http_exposed)
for ws_exposed in get_exposed_ws(api):
self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler
self._add_exposed(*self.__apis)
async def _on_shutdown(self) -> None:
logger = get_logger(0)
logger.info("Waiting short tasks ...")
await aiotools.wait_all_short_tasks()
logger.info("Stopping system tasks ...")
await aiotools.stop_all_deadly_tasks()
logger.info("Disconnecting clients ...")
for client in list(self.__ws_clients):
await self.__remove_ws_client(client)
await self._close_all_wss()
logger.info("On-Shutdown complete")
async def _on_cleanup(self) -> None:
@ -333,25 +298,18 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete")
async def __register_ws_client(self, client: _WsClient) -> None:
async with self.__ws_clients_lock:
self.__ws_clients.add(client)
get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients))
async def _on_ws_opened(self) -> None:
await self.__streamer_notifier.notify()
async def __remove_ws_client(self, client: _WsClient) -> None:
async with self.__ws_clients_lock:
self.__hid.clear_events()
try:
self.__ws_clients.remove(client)
get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients))
await client.ws.close()
except Exception:
pass
async def _on_ws_closed(self) -> None:
self.__hid.clear_events()
await self.__streamer_notifier.notify()
def __has_stream_clients(self) -> bool:
return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients)))
return bool(sum(map(
(lambda ws: ws.kwargs["stream"]),
self._get_wss(),
)))
# ===== SYSTEM TASKS
@ -379,10 +337,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
async for state in poller:
await broadcast_ws_event([
client.ws
for client in list(self.__ws_clients)
], event_type, state)
await self._broadcast_ws_event(event_type, state)
async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run(

View File

@ -23,6 +23,7 @@
import os
import socket
import asyncio
import contextlib
import dataclasses
import inspect
import json
@ -31,7 +32,9 @@ from typing import Tuple
from typing import List
from typing import Dict
from typing import Callable
from typing import AsyncGenerator
from typing import Optional
from typing import Any
from aiohttp.web import BaseRequest
from aiohttp.web import Request
@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla
return set_attrs
def get_exposed_http(obj: object) -> List[HttpExposed]:
def _get_exposed_http(obj: object) -> List[HttpExposed]:
return [
HttpExposed(
method=getattr(handler, _HTTP_METHOD),
@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable:
return set_attrs
def get_exposed_ws(obj: object) -> List[WsExposed]:
def _get_exposed_ws(obj: object) -> List[WsExposed]:
return [
WsExposed(
event_type=getattr(handler, _WS_EVENT_TYPE),
@ -205,57 +208,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
}, False)
# =====
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
if wss:
await asyncio.gather(*[
send_ws_event(ws, event_type, event)
for ws in wss
if (
not ws.closed
and ws._req is not None # pylint: disable=protected-access
and ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
def _parse_ws_event(msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None:
logger = get_logger(1)
async for msg in ws:
if msg.type != WSMsgType.TEXT:
break
try:
(event_type, event) = _parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = handlers.get(event_type)
if handler:
await handler(ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
# =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None:
# =====
@dataclasses.dataclass(frozen=True)
class WsSession:
wsr: WebSocketResponse
kwargs: Dict[str, Any]
def __str__(self) -> str:
return f"WsSession(id={id(self)}, {self.kwargs})"
async def send_event(self, event_type: str, event: Optional[Dict]) -> None:
await self.wsr.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
class HttpServer:
def __init__(self) -> None:
self.__ws_heartbeat: Optional[float] = None
self.__ws_handlers: Dict[str, Callable] = {}
self.__ws_sessions: List[WsSession] = []
self.__ws_sessions_lock = asyncio.Lock()
def run(
self,
unix_path: str,
@ -282,7 +255,7 @@ class HttpServer:
access_log_format: str,
) -> None:
self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init
self.__ws_heartbeat = heartbeat
if unix_rm and os.path.exists(unix_path):
os.remove(unix_path)
@ -302,7 +275,14 @@ class HttpServer:
# =====
def _add_exposed(self, exposed: HttpExposed) -> None:
def _add_exposed(self, *objs: object) -> None:
for obj in objs:
for http_exposed in _get_exposed_http(obj):
self.__add_exposed_http(http_exposed)
for ws_exposed in _get_exposed_ws(obj):
self.__add_exposed_ws(ws_exposed)
def __add_exposed_http(self, exposed: HttpExposed) -> None:
async def wrapper(request: Request) -> Response:
try:
await self._check_request_auth(exposed, request)
@ -315,10 +295,85 @@ class HttpServer:
return make_json_exception(err)
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
async def _make_ws_response(self, request: Request) -> WebSocketResponse:
ws = WebSocketResponse(heartbeat=self.__heartbeat)
await ws.prepare(request)
return ws
def __add_exposed_ws(self, exposed: WsExposed) -> None:
self.__ws_handlers[exposed.event_type] = exposed.handler
# =====
@contextlib.asynccontextmanager
async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
assert self.__ws_heartbeat is not None
wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
await wsr.prepare(request)
ws = WsSession(wsr, kwargs)
async with self.__ws_sessions_lock:
self.__ws_sessions.append(ws)
get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
try:
await self._on_ws_opened()
yield ws
finally:
await self.__close_ws(ws)
async def _ws_loop(self, ws: WsSession) -> WebSocketResponse:
logger = get_logger()
async for msg in ws.wsr:
if msg.type != WSMsgType.TEXT:
break
try:
(event_type, event) = self.__parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = self.__ws_handlers.get(event_type)
if handler:
await handler(ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
return ws.wsr
async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None:
if self.__ws_sessions:
await asyncio.gather(*[
ws.send_event(event_type, event)
for ws in self.__ws_sessions
if (
not ws.wsr.closed
and ws.wsr._req is not None # pylint: disable=protected-access
and ws.wsr._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
async def _close_all_wss(self) -> None:
for ws in self._get_wss():
await self.__close_ws(ws)
def _get_wss(self) -> List[WsSession]:
return list(self.__ws_sessions)
async def __close_ws(self, ws: WsSession) -> None:
async with self.__ws_sessions_lock:
try:
self.__ws_sessions.remove(ws)
get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions))
await ws.wsr.close()
except Exception:
pass
await self._on_ws_closed()
def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
# =====
@ -334,6 +389,12 @@ class HttpServer:
async def _on_cleanup(self) -> None:
pass
async def _on_ws_opened(self) -> None:
pass
async def _on_ws_closed(self) -> None:
pass
# =====
async def __make_app(self) -> Application: