refactoring

This commit is contained in:
Maxim Devaev 2022-04-11 15:56:19 +03:00
parent c7f2564364
commit 6bd2b9c680
2 changed files with 57 additions and 37 deletions

View File

@ -25,7 +25,6 @@ import signal
import asyncio import asyncio
import operator import operator
import dataclasses import dataclasses
import json
from typing import Tuple from typing import Tuple
from typing import List from typing import List
@ -56,6 +55,9 @@ from ...htserver import get_exposed_http
from ...htserver import get_exposed_ws from ...htserver import get_exposed_ws
from ...htserver import make_json_response from ...htserver import make_json_response
from ...htserver import make_json_exception from ...htserver import make_json_exception
from ...htserver import send_ws_event
from ...htserver import broadcast_ws_event
from ...htserver import parse_ws_event
from ...htserver import HttpServer from ...htserver import HttpServer
from ...plugins import BasePlugin from ...plugins import BasePlugin
@ -279,18 +281,17 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
)) ))
for stage in [stage1, stage2]: for stage in [stage1, stage2]:
await asyncio.gather(*[ await asyncio.gather(*[
self.__send_event(client.ws, event_type, events.pop(event_type)) send_ws_event(client.ws, event_type, events.pop(event_type))
for (event_type, _) in stage for (event_type, _) in stage
]) ])
await self.__send_event(client.ws, "loop", {}) await send_ws_event(client.ws, "loop", {})
async for msg in client.ws: async for msg in client.ws:
if msg.type == aiohttp.web.WSMsgType.TEXT: if msg.type != aiohttp.web.WSMsgType.TEXT:
break
try: try:
data = json.loads(msg.data) (event_type, event) = parse_ws_event(msg.data)
event_type = data.get("event_type")
event = data["event"]
except Exception as err: except Exception as err:
logger.error("Can't parse JSON event from websocket: %r", err) logger.error("Can't parse JSON event from websocket: %r", err)
else: else:
@ -298,9 +299,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
if handler: if handler:
await handler(client.ws, event) await handler(client.ws, event)
else: else:
logger.error("Unknown websocket event: %r", data) logger.error("Unknown websocket event: %r", msg.data)
else:
break
return client.ws return client.ws
finally: finally:
@ -308,7 +307,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
@exposed_ws("ping") @exposed_ws("ping")
async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None: async def __ws_ping_handler(self, ws: aiohttp.web.WebSocketResponse, _: Dict) -> None:
await self.__send_event(ws, "pong", {}) await send_ws_event(ws, "pong", {})
# ===== SYSTEM STUFF # ===== SYSTEM STUFF
@ -390,24 +389,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.exception("Cleanup error on %s", comp.name) logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete") logger.info("On-Cleanup complete")
async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def __broadcast_event(self, event_type: str, event: Optional[Dict]) -> None:
if self.__ws_clients:
await asyncio.gather(*[
self.__send_event(client.ws, event_type, event)
for client in list(self.__ws_clients)
if (
not client.ws.closed
and client.ws._req is not None # pylint: disable=protected-access
and client.ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
async def __register_ws_client(self, client: _WsClient) -> None: async def __register_ws_client(self, client: _WsClient) -> None:
async with self.__ws_clients_lock: async with self.__ws_clients_lock:
self.__ws_clients.add(client) self.__ws_clients.add(client)
@ -454,7 +435,10 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None: async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
async for state in poller: async for state in poller:
await self.__broadcast_event(event_type, state) await broadcast_ws_event([
client.ws
for client in list(self.__ws_clients)
], event_type, state)
async def __stream_snapshoter(self) -> None: async def __stream_snapshoter(self) -> None:
await self.__snapshoter.run( await self.__snapshoter.run(

View File

@ -27,6 +27,7 @@ import dataclasses
import inspect import inspect
import json import json
from typing import Tuple
from typing import List from typing import List
from typing import Dict from typing import Dict
from typing import Callable from typing import Callable
@ -36,6 +37,7 @@ from aiohttp.web import BaseRequest
from aiohttp.web import Request from aiohttp.web import Request
from aiohttp.web import Response from aiohttp.web import Response
from aiohttp.web import StreamResponse from aiohttp.web import StreamResponse
from aiohttp.web import WebSocketResponse
from aiohttp.web import Application from aiohttp.web import Application
from aiohttp.web import run_app from aiohttp.web import run_app
from aiohttp.web import normalize_path_middleware from aiohttp.web import normalize_path_middleware
@ -197,6 +199,40 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
}, False) }, False)
# =====
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
"event": event,
}))
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
if wss:
await asyncio.gather(*[
send_ws_event(ws, event_type, event)
for ws in wss
if (
not ws.closed
and ws._req is not None # pylint: disable=protected-access
and ws._req.transport is not None # pylint: disable=protected-access
)
], return_exceptions=True)
def parse_ws_event(msg: str) -> Tuple[str, Dict]:
data = json.loads(msg)
if not isinstance(data, dict):
raise RuntimeError("Top-level event structure is not a dict")
event_type = data.get("event_type")
if not isinstance(event_type, str):
raise RuntimeError("event_type must be a string")
event = data["event"]
if not isinstance(event, dict):
raise RuntimeError("event must be a dict")
return (event_type, event)
# ===== # =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info" _REQUEST_AUTH_INFO = "_kvmd_auth_info"