common websocket code

This commit is contained in:
Maxim Devaev 2022-06-14 18:18:21 +03:00
parent 37e5118fff
commit 88c7796551
3 changed files with 145 additions and 129 deletions

View File

@ -32,7 +32,6 @@ from typing import Callable
from aiohttp.web import Request from aiohttp.web import Request
from aiohttp.web import Response from aiohttp.web import Response
from aiohttp.web import WebSocketResponse
from ....mouse import MouseRange from ....mouse import MouseRange
@ -42,6 +41,7 @@ from ....keyboard.printer import text_to_web_keys
from ....htserver import exposed_http from ....htserver import exposed_http
from ....htserver import exposed_ws from ....htserver import exposed_ws
from ....htserver import make_json_response from ....htserver import make_json_response
from ....htserver import WsSession
from ....plugins.hid import BaseHid from ....plugins.hid import BaseHid
@ -158,7 +158,7 @@ class HidApi:
# ===== # =====
@exposed_ws("key") @exposed_ws("key")
async def __ws_key_handler(self, _: WebSocketResponse, event: Dict) -> None: async def __ws_key_handler(self, _: WsSession, event: Dict) -> None:
try: try:
key = valid_hid_key(event["key"]) key = valid_hid_key(event["key"])
state = valid_bool(event["state"]) state = valid_bool(event["state"])
@ -168,7 +168,7 @@ class HidApi:
self.__hid.send_key_events([(key, state)]) self.__hid.send_key_events([(key, state)])
@exposed_ws("mouse_button") @exposed_ws("mouse_button")
async def __ws_mouse_button_handler(self, _: WebSocketResponse, event: Dict) -> None: async def __ws_mouse_button_handler(self, _: WsSession, event: Dict) -> None:
try: try:
button = valid_hid_mouse_button(event["button"]) button = valid_hid_mouse_button(event["button"])
state = valid_bool(event["state"]) state = valid_bool(event["state"])
@ -177,7 +177,7 @@ class HidApi:
self.__hid.send_mouse_button_event(button, state) self.__hid.send_mouse_button_event(button, state)
@exposed_ws("mouse_move") @exposed_ws("mouse_move")
async def __ws_mouse_move_handler(self, _: WebSocketResponse, event: Dict) -> None: async def __ws_mouse_move_handler(self, _: WsSession, event: Dict) -> None:
try: try:
to_x = valid_hid_mouse_move(event["to"]["x"]) to_x = valid_hid_mouse_move(event["to"]["x"])
to_y = valid_hid_mouse_move(event["to"]["y"]) to_y = valid_hid_mouse_move(event["to"]["y"])
@ -186,11 +186,11 @@ class HidApi:
self.__send_mouse_move_event_remapped(to_x, to_y) self.__send_mouse_move_event_remapped(to_x, to_y)
@exposed_ws("mouse_relative") @exposed_ws("mouse_relative")
async def __ws_mouse_relative_handler(self, _: WebSocketResponse, event: Dict) -> None: async def __ws_mouse_relative_handler(self, _: WsSession, event: Dict) -> None:
self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event) self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event)
@exposed_ws("mouse_wheel") @exposed_ws("mouse_wheel")
async def __ws_mouse_wheel_handler(self, _: WebSocketResponse, event: Dict) -> None: async def __ws_mouse_wheel_handler(self, _: WsSession, event: Dict) -> None:
self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event) self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event)
def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None: def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None:

View File

@ -27,7 +27,6 @@ import dataclasses
from typing import Tuple from typing import Tuple
from typing import List from typing import List
from typing import Dict from typing import Dict
from typing import Set
from typing import Callable from typing import Callable
from typing import Coroutine from typing import Coroutine
from typing import AsyncGenerator from typing import AsyncGenerator
@ -48,12 +47,8 @@ from ... import aioproc
from ...htserver import HttpExposed from ...htserver import HttpExposed
from ...htserver import exposed_http from ...htserver import exposed_http
from ...htserver import exposed_ws from ...htserver import exposed_ws
from ...htserver import get_exposed_http
from ...htserver import get_exposed_ws
from ...htserver import make_json_response from ...htserver import make_json_response
from ...htserver import send_ws_event from ...htserver import WsSession
from ...htserver import broadcast_ws_event
from ...htserver import process_ws_messages
from ...htserver import HttpServer from ...htserver import HttpServer
from ...plugins import BasePlugin from ...plugins import BasePlugin
@ -128,15 +123,6 @@ class _Component: # pylint: disable=too-many-instance-attributes
assert self.event_type, self assert self.event_type, self
@dataclasses.dataclass(frozen=True)
class _WsClient:
ws: WebSocketResponse
stream: bool
def __str__(self) -> str:
return f"WsClient(id={id(self)}, stream={self.stream})"
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals def __init__( # pylint: disable=too-many-arguments,too-many-locals
self, self,
@ -160,6 +146,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
stream_forever: bool, stream_forever: bool,
) -> None: ) -> None:
super().__init__()
self.__auth_manager = auth_manager self.__auth_manager = auth_manager
self.__hid = hid self.__hid = hid
self.__streamer = streamer self.__streamer = streamer
@ -201,11 +189,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
RedfishApi(info_manager, atx), RedfishApi(info_manager, atx),
] ]
self.__ws_handlers: Dict[str, Callable] = {}
self.__ws_clients: Set[_WsClient] = set()
self.__ws_clients_lock = asyncio.Lock()
self.__streamer_notifier = aiotools.AioNotifier() self.__streamer_notifier = aiotools.AioNotifier()
self.__reset_streamer = False self.__reset_streamer = False
self.__new_streamer_params: Dict = {} self.__new_streamer_params: Dict = {}
@ -244,11 +227,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
@exposed_http("GET", "/ws") @exposed_http("GET", "/ws")
async def __ws_handler(self, request: Request) -> WebSocketResponse: async def __ws_handler(self, request: Request) -> WebSocketResponse:
stream = valid_bool(request.query.get("stream", "true")) stream = valid_bool(request.query.get("stream", "true"))
ws = await self._make_ws_response(request) async with self._ws_session(request, stream=stream) as ws:
client = _WsClient(ws, stream)
await self.__register_ws_client(client)
try:
stage1 = [ stage1 = [
("gpio_model_state", self.__user_gpio.get_model()), ("gpio_model_state", self.__user_gpio.get_model()),
("hid_keymaps_state", self.__hid_api.get_keymaps()), ("hid_keymaps_state", self.__hid_api.get_keymaps()),
@ -266,19 +245,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
)) ))
for stage in [stage1, stage2]: for stage in [stage1, stage2]:
await asyncio.gather(*[ await asyncio.gather(*[
send_ws_event(ws, event_type, events.pop(event_type)) ws.send_event(event_type, events.pop(event_type))
for (event_type, _) in stage for (event_type, _) in stage
]) ])
await ws.send_event("loop", {})
await send_ws_event(ws, "loop", {}) return (await self._ws_loop(ws))
await process_ws_messages(ws, self.__ws_handlers)
return ws
finally:
await self.__remove_ws_client(client)
@exposed_ws("ping") @exposed_ws("ping")
async def __ws_ping_handler(self, ws: WebSocketResponse, _: Dict) -> None: async def __ws_ping_handler(self, ws: WsSession, _: Dict) -> None:
await send_ws_event(ws, "pong", {}) await ws.send_event("pong", {})
# ===== SYSTEM STUFF # ===== SYSTEM STUFF
@ -300,26 +275,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
if comp.poll_state: if comp.poll_state:
aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state())) aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state()))
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter()) aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
self._add_exposed(*self.__apis)
for api in self.__apis:
for http_exposed in get_exposed_http(api):
self._add_exposed(http_exposed)
for ws_exposed in get_exposed_ws(api):
self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler
async def _on_shutdown(self) -> None: async def _on_shutdown(self) -> None:
logger = get_logger(0) logger = get_logger(0)
logger.info("Waiting short tasks ...") logger.info("Waiting short tasks ...")
await aiotools.wait_all_short_tasks() await aiotools.wait_all_short_tasks()
logger.info("Stopping system tasks ...") logger.info("Stopping system tasks ...")
await aiotools.stop_all_deadly_tasks() await aiotools.stop_all_deadly_tasks()
logger.info("Disconnecting clients ...") logger.info("Disconnecting clients ...")
for client in list(self.__ws_clients): await self._close_all_wss()
await self.__remove_ws_client(client)
logger.info("On-Shutdown complete") logger.info("On-Shutdown complete")
async def _on_cleanup(self) -> None: async def _on_cleanup(self) -> None:
@ -333,25 +298,18 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.exception("Cleanup error on %s", comp.name) logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete") logger.info("On-Cleanup complete")
async def __register_ws_client(self, client: _WsClient) -> None: async def _on_ws_opened(self) -> None:
async with self.__ws_clients_lock:
self.__ws_clients.add(client)
get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients))
await self.__streamer_notifier.notify() await self.__streamer_notifier.notify()
async def __remove_ws_client(self, client: _WsClient) -> None: async def _on_ws_closed(self) -> None:
async with self.__ws_clients_lock: self.__hid.clear_events()
self.__hid.clear_events()
try:
self.__ws_clients.remove(client)
get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients))
await client.ws.close()
except Exception:
pass
await self.__streamer_notifier.notify() await self.__streamer_notifier.notify()
def __has_stream_clients(self) -> bool: def __has_stream_clients(self) -> bool:
return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients))) return bool(sum(map(
(lambda ws: ws.kwargs["stream"]),
self._get_wss(),
)))
# ===== SYSTEM TASKS # ===== SYSTEM TASKS
@ -379,10 +337,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None: async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
async for state in poller: async for state in poller:
await broadcast_ws_event([ await self._broadcast_ws_event(event_type, state)
client.ws
for client in list(self.__ws_clients)
], event_type, state)
async def __stream_snapshoter(self) -> None: async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run( await self.__snapshoter.run(

View File

@ -23,6 +23,7 @@
import os import os
import socket import socket
import asyncio import asyncio
import contextlib
import dataclasses import dataclasses
import inspect import inspect
import json import json
@ -31,7 +32,9 @@ from typing import Tuple
from typing import List from typing import List
from typing import Dict from typing import Dict
from typing import Callable from typing import Callable
from typing import AsyncGenerator
from typing import Optional from typing import Optional
from typing import Any
from aiohttp.web import BaseRequest from aiohttp.web import BaseRequest
from aiohttp.web import Request from aiohttp.web import Request
@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla
return set_attrs return set_attrs
def get_exposed_http(obj: object) -> List[HttpExposed]: def _get_exposed_http(obj: object) -> List[HttpExposed]:
return [ return [
HttpExposed( HttpExposed(
method=getattr(handler, _HTTP_METHOD), method=getattr(handler, _HTTP_METHOD),
@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable:
return set_attrs return set_attrs
def get_exposed_ws(obj: object) -> List[WsExposed]: def _get_exposed_ws(obj: object) -> List[WsExposed]:
return [ return [
WsExposed( WsExposed(
event_type=getattr(handler, _WS_EVENT_TYPE), event_type=getattr(handler, _WS_EVENT_TYPE),
@ -205,57 +208,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
}, False) }, False)
# =====
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
if wss:
await asyncio.gather(*[
send_ws_event(ws, event_type, event)
for ws in wss
if (
not ws.closed
and ws._req is not None # pylint: disable=protected-access
and ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
def _parse_ws_event(msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None:
logger = get_logger(1)
async for msg in ws:
if msg.type != WSMsgType.TEXT:
break
try:
(event_type, event) = _parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = handlers.get(event_type)
if handler:
await handler(ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
# ===== # =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info" _REQUEST_AUTH_INFO = "_kvmd_auth_info"
@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None:
# ===== # =====
@dataclasses.dataclass(frozen=True)
class WsSession:
wsr: WebSocketResponse
kwargs: Dict[str, Any]
def __str__(self) -> str:
return f"WsSession(id={id(self)}, {self.kwargs})"
async def send_event(self, event_type: str, event: Optional[Dict]) -> None:
await self.wsr.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
class HttpServer: class HttpServer:
def __init__(self) -> None:
self.__ws_heartbeat: Optional[float] = None
self.__ws_handlers: Dict[str, Callable] = {}
self.__ws_sessions: List[WsSession] = []
self.__ws_sessions_lock = asyncio.Lock()
def run( def run(
self, self,
unix_path: str, unix_path: str,
@ -282,7 +255,7 @@ class HttpServer:
access_log_format: str, access_log_format: str,
) -> None: ) -> None:
self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init self.__ws_heartbeat = heartbeat
if unix_rm and os.path.exists(unix_path): if unix_rm and os.path.exists(unix_path):
os.remove(unix_path) os.remove(unix_path)
@ -302,7 +275,14 @@ class HttpServer:
# ===== # =====
def _add_exposed(self, exposed: HttpExposed) -> None: def _add_exposed(self, *objs: object) -> None:
for obj in objs:
for http_exposed in _get_exposed_http(obj):
self.__add_exposed_http(http_exposed)
for ws_exposed in _get_exposed_ws(obj):
self.__add_exposed_ws(ws_exposed)
def __add_exposed_http(self, exposed: HttpExposed) -> None:
async def wrapper(request: Request) -> Response: async def wrapper(request: Request) -> Response:
try: try:
await self._check_request_auth(exposed, request) await self._check_request_auth(exposed, request)
@ -315,10 +295,85 @@ class HttpServer:
return make_json_exception(err) return make_json_exception(err)
self.__app.router.add_route(exposed.method, exposed.path, wrapper) self.__app.router.add_route(exposed.method, exposed.path, wrapper)
async def _make_ws_response(self, request: Request) -> WebSocketResponse: def __add_exposed_ws(self, exposed: WsExposed) -> None:
ws = WebSocketResponse(heartbeat=self.__heartbeat) self.__ws_handlers[exposed.event_type] = exposed.handler
await ws.prepare(request)
return ws # =====
@contextlib.asynccontextmanager
async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
assert self.__ws_heartbeat is not None
wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
await wsr.prepare(request)
ws = WsSession(wsr, kwargs)
async with self.__ws_sessions_lock:
self.__ws_sessions.append(ws)
get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
try:
await self._on_ws_opened()
yield ws
finally:
await self.__close_ws(ws)
async def _ws_loop(self, ws: WsSession) -> WebSocketResponse:
logger = get_logger()
async for msg in ws.wsr:
if msg.type != WSMsgType.TEXT:
break
try:
(event_type, event) = self.__parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = self.__ws_handlers.get(event_type)
if handler:
await handler(ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
return ws.wsr
async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None:
if self.__ws_sessions:
await asyncio.gather(*[
ws.send_event(event_type, event)
for ws in self.__ws_sessions
if (
not ws.wsr.closed
and ws.wsr._req is not None # pylint: disable=protected-access
and ws.wsr._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
async def _close_all_wss(self) -> None:
for ws in self._get_wss():
await self.__close_ws(ws)
def _get_wss(self) -> List[WsSession]:
return list(self.__ws_sessions)
async def __close_ws(self, ws: WsSession) -> None:
async with self.__ws_sessions_lock:
try:
self.__ws_sessions.remove(ws)
get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions))
await ws.wsr.close()
except Exception:
pass
await self._on_ws_closed()
def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
# ===== # =====
@ -334,6 +389,12 @@ class HttpServer:
async def _on_cleanup(self) -> None: async def _on_cleanup(self) -> None:
pass pass
async def _on_ws_opened(self) -> None:
pass
async def _on_ws_closed(self) -> None:
pass
# ===== # =====
async def __make_app(self) -> Application: async def __make_app(self) -> Application: