refactoring

This commit is contained in:
Maxim Devaev 2022-04-11 15:56:19 +03:00
parent c7f2564364
commit 6bd2b9c680
2 changed files with 57 additions and 37 deletions

View File

@ -25,7 +25,6 @@ import signal
import asyncio
import operator
import dataclasses
import json
from typing import Tuple
from typing import List
@ -56,6 +55,9 @@ from ...htserver import get_exposed_http
from ...htserver import get_exposed_ws
from ...htserver import make_json_response
from ...htserver import make_json_exception
from ...htserver import send_ws_event
from ...htserver import broadcast_ws_event
from ...htserver import parse_ws_event
from ...htserver import HttpServer
from ...plugins import BasePlugin
@ -279,28 +281,25 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
))
for stage in [stage1, stage2]:
await asyncio.gather(*[
self.__send_event(client.ws, event_type, events.pop(event_type))
send_ws_event(client.ws, event_type, events.pop(event_type))
for (event_type, _) in stage
])
await self.__send_event(client.ws, "loop", {})
await send_ws_event(client.ws, "loop", {})
async for msg in client.ws:
if msg.type == aiohttp.web.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
event_type = data.get("event_type")
event = data["event"]
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = self.__ws_handlers.get(event_type)
if handler:
await handler(client.ws, event)
else:
logger.error("Unknown websocket event: %r", data)
else:
if msg.type != aiohttp.web.WSMsgType.TEXT:
break
try:
(event_type, event) = parse_ws_event(msg.data)
except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err)
else:
handler = self.__ws_handlers.get(event_type)
if handler:
await handler(client.ws, event)
else:
logger.error("Unknown websocket event: %r", msg.data)
return client.ws
finally:
@ -308,7 +307,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
@exposed_ws("ping")
async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None:
await self.__send_event(ws, "pong", {})
await send_ws_event(ws, "pong", {})
# ===== SYSTEM STUFF
@ -390,24 +389,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete")
async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def __broadcast_event(self, event_type: str, event: Optional[Dict]) -> None:
if self.__ws_clients:
await asyncio.gather(*[
self.__send_event(client.ws, event_type, event)
for client in list(self.__ws_clients)
if (
not client.ws.closed
and client.ws._req is not None # pylint: disable=protected-access
and client.ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
async def __register_ws_client(self, client: _WsClient) -> None:
async with self.__ws_clients_lock:
self.__ws_clients.add(client)
@ -454,7 +435,10 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
async for state in poller:
await self.__broadcast_event(event_type, state)
await broadcast_ws_event([
client.ws
for client in list(self.__ws_clients)
], event_type, state)
async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run(

View File

@ -27,6 +27,7 @@ import dataclasses
import inspect
import json
from typing import Tuple
from typing import List
from typing import Dict
from typing import Callable
@ -36,6 +37,7 @@ from aiohttp.web import BaseRequest
from aiohttp.web import Request
from aiohttp.web import Response
from aiohttp.web import StreamResponse
from aiohttp.web import WebSocketResponse
from aiohttp.web import Application
from aiohttp.web import run_app
from aiohttp.web import normalize_path_middleware
@ -197,6 +199,40 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
}, False)
# =====
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
if wss:
await asyncio.gather(*[
send_ws_event(ws, event_type, event)
for ws in wss
if (
not ws.closed
and ws._req is not None # pylint: disable=protected-access
and ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
def parse_ws_event(msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
# =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info"