common websocket code

This commit is contained in:
Maxim Devaev
2022-06-14 18:18:21 +03:00
parent 37e5118fff
commit 88c7796551
3 changed files with 145 additions and 129 deletions

View File

@@ -23,6 +23,7 @@
import os
import socket
import asyncio
import contextlib
import dataclasses
import inspect
import json
@@ -31,7 +32,9 @@ from typing import Tuple
from typing import List
from typing import Dict
from typing import Callable
from typing import AsyncGenerator
from typing import Optional
from typing import Any
from aiohttp.web import BaseRequest
from aiohttp.web import Request
@@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla
return set_attrs
def get_exposed_http(obj: object) -> List[HttpExposed]:
def _get_exposed_http(obj: object) -> List[HttpExposed]:
return [
HttpExposed(
method=getattr(handler, _HTTP_METHOD),
@@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable:
return set_attrs
def get_exposed_ws(obj: object) -> List[WsExposed]:
def _get_exposed_ws(obj: object) -> List[WsExposed]:
return [
WsExposed(
event_type=getattr(handler, _WS_EVENT_TYPE),
@@ -205,57 +208,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
}, False)
# =====
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
if wss:
await asyncio.gather(*[
send_ws_event(ws, event_type, event)
for ws in wss
if (
not ws.closed
and ws._req is not None # pylint: disable=protected-access
and ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
def _parse_ws_event(msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None:
logger = get_logger(1)
async for msg in ws:
if msg.type != WSMsgType.TEXT:
break
try:
(event_type, event) = _parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = handlers.get(event_type)
if handler:
await handler(ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
# =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
@@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None:
# =====
@dataclasses.dataclass(frozen=True)
class WsSession:
wsr: WebSocketResponse
kwargs: Dict[str, Any]
def __str__(self) -> str:
return f"WsSession(id={id(self)}, {self.kwargs})"
async def send_event(self, event_type: str, event: Optional[Dict]) -> None:
await self.wsr.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
class HttpServer:
def __init__(self) -> None:
self.__ws_heartbeat: Optional[float] = None
self.__ws_handlers: Dict[str, Callable] = {}
self.__ws_sessions: List[WsSession] = []
self.__ws_sessions_lock = asyncio.Lock()
def run(
self,
unix_path: str,
@@ -282,7 +255,7 @@ class HttpServer:
access_log_format: str,
) -> None:
self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init
self.__ws_heartbeat = heartbeat
if unix_rm and os.path.exists(unix_path):
os.remove(unix_path)
@@ -302,7 +275,14 @@ class HttpServer:
# =====
def _add_exposed(self, exposed: HttpExposed) -> None:
def _add_exposed(self, *objs: object) -> None:
for obj in objs:
for http_exposed in _get_exposed_http(obj):
self.__add_exposed_http(http_exposed)
for ws_exposed in _get_exposed_ws(obj):
self.__add_exposed_ws(ws_exposed)
def __add_exposed_http(self, exposed: HttpExposed) -> None:
async def wrapper(request: Request) -> Response:
try:
await self._check_request_auth(exposed, request)
@@ -315,10 +295,85 @@ class HttpServer:
return make_json_exception(err)
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
async def _make_ws_response(self, request: Request) -> WebSocketResponse:
ws = WebSocketResponse(heartbeat=self.__heartbeat)
await ws.prepare(request)
return ws
def __add_exposed_ws(self, exposed: WsExposed) -> None:
self.__ws_handlers[exposed.event_type] = exposed.handler
# =====
@contextlib.asynccontextmanager
async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
assert self.__ws_heartbeat is not None
wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
await wsr.prepare(request)
ws = WsSession(wsr, kwargs)
async with self.__ws_sessions_lock:
self.__ws_sessions.append(ws)
get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
try:
await self._on_ws_opened()
yield ws
finally:
await self.__close_ws(ws)
async def _ws_loop(self, ws: WsSession) -> WebSocketResponse:
logger = get_logger()
async for msg in ws.wsr:
if msg.type != WSMsgType.TEXT:
break
try:
(event_type, event) = self.__parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = self.__ws_handlers.get(event_type)
if handler:
await handler(ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
return ws.wsr
async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None:
if self.__ws_sessions:
await asyncio.gather(*[
ws.send_event(event_type, event)
for ws in self.__ws_sessions
if (
not ws.wsr.closed
and ws.wsr._req is not None # pylint: disable=protected-access
and ws.wsr._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
async def _close_all_wss(self) -> None:
for ws in self._get_wss():
await self.__close_ws(ws)
def _get_wss(self) -> List[WsSession]:
return list(self.__ws_sessions)
async def __close_ws(self, ws: WsSession) -> None:
async with self.__ws_sessions_lock:
try:
self.__ws_sessions.remove(ws)
get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions))
await ws.wsr.close()
except Exception:
pass
await self._on_ws_closed()
def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
# =====
@@ -334,6 +389,12 @@ class HttpServer:
async def _on_cleanup(self) -> None:
pass
async def _on_ws_opened(self) -> None:
pass
async def _on_ws_closed(self) -> None:
pass
# =====
async def __make_app(self) -> Application: