mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2025-12-13 17:50:29 +08:00
common websocket code
This commit is contained in:
parent
37e5118fff
commit
88c7796551
@ -32,7 +32,6 @@ from typing import Callable
|
|||||||
|
|
||||||
from aiohttp.web import Request
|
from aiohttp.web import Request
|
||||||
from aiohttp.web import Response
|
from aiohttp.web import Response
|
||||||
from aiohttp.web import WebSocketResponse
|
|
||||||
|
|
||||||
from ....mouse import MouseRange
|
from ....mouse import MouseRange
|
||||||
|
|
||||||
@ -42,6 +41,7 @@ from ....keyboard.printer import text_to_web_keys
|
|||||||
from ....htserver import exposed_http
|
from ....htserver import exposed_http
|
||||||
from ....htserver import exposed_ws
|
from ....htserver import exposed_ws
|
||||||
from ....htserver import make_json_response
|
from ....htserver import make_json_response
|
||||||
|
from ....htserver import WsSession
|
||||||
|
|
||||||
from ....plugins.hid import BaseHid
|
from ....plugins.hid import BaseHid
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ class HidApi:
|
|||||||
# =====
|
# =====
|
||||||
|
|
||||||
@exposed_ws("key")
|
@exposed_ws("key")
|
||||||
async def __ws_key_handler(self, _: WebSocketResponse, event: Dict) -> None:
|
async def __ws_key_handler(self, _: WsSession, event: Dict) -> None:
|
||||||
try:
|
try:
|
||||||
key = valid_hid_key(event["key"])
|
key = valid_hid_key(event["key"])
|
||||||
state = valid_bool(event["state"])
|
state = valid_bool(event["state"])
|
||||||
@ -168,7 +168,7 @@ class HidApi:
|
|||||||
self.__hid.send_key_events([(key, state)])
|
self.__hid.send_key_events([(key, state)])
|
||||||
|
|
||||||
@exposed_ws("mouse_button")
|
@exposed_ws("mouse_button")
|
||||||
async def __ws_mouse_button_handler(self, _: WebSocketResponse, event: Dict) -> None:
|
async def __ws_mouse_button_handler(self, _: WsSession, event: Dict) -> None:
|
||||||
try:
|
try:
|
||||||
button = valid_hid_mouse_button(event["button"])
|
button = valid_hid_mouse_button(event["button"])
|
||||||
state = valid_bool(event["state"])
|
state = valid_bool(event["state"])
|
||||||
@ -177,7 +177,7 @@ class HidApi:
|
|||||||
self.__hid.send_mouse_button_event(button, state)
|
self.__hid.send_mouse_button_event(button, state)
|
||||||
|
|
||||||
@exposed_ws("mouse_move")
|
@exposed_ws("mouse_move")
|
||||||
async def __ws_mouse_move_handler(self, _: WebSocketResponse, event: Dict) -> None:
|
async def __ws_mouse_move_handler(self, _: WsSession, event: Dict) -> None:
|
||||||
try:
|
try:
|
||||||
to_x = valid_hid_mouse_move(event["to"]["x"])
|
to_x = valid_hid_mouse_move(event["to"]["x"])
|
||||||
to_y = valid_hid_mouse_move(event["to"]["y"])
|
to_y = valid_hid_mouse_move(event["to"]["y"])
|
||||||
@ -186,11 +186,11 @@ class HidApi:
|
|||||||
self.__send_mouse_move_event_remapped(to_x, to_y)
|
self.__send_mouse_move_event_remapped(to_x, to_y)
|
||||||
|
|
||||||
@exposed_ws("mouse_relative")
|
@exposed_ws("mouse_relative")
|
||||||
async def __ws_mouse_relative_handler(self, _: WebSocketResponse, event: Dict) -> None:
|
async def __ws_mouse_relative_handler(self, _: WsSession, event: Dict) -> None:
|
||||||
self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event)
|
self.__process_delta_ws_request(event, self.__hid.send_mouse_relative_event)
|
||||||
|
|
||||||
@exposed_ws("mouse_wheel")
|
@exposed_ws("mouse_wheel")
|
||||||
async def __ws_mouse_wheel_handler(self, _: WebSocketResponse, event: Dict) -> None:
|
async def __ws_mouse_wheel_handler(self, _: WsSession, event: Dict) -> None:
|
||||||
self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event)
|
self.__process_delta_ws_request(event, self.__hid.send_mouse_wheel_event)
|
||||||
|
|
||||||
def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None:
|
def __process_delta_ws_request(self, event: Dict, handler: Callable[[int, int], None]) -> None:
|
||||||
|
|||||||
@ -27,7 +27,6 @@ import dataclasses
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Set
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing import Coroutine
|
from typing import Coroutine
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
@ -48,12 +47,8 @@ from ... import aioproc
|
|||||||
from ...htserver import HttpExposed
|
from ...htserver import HttpExposed
|
||||||
from ...htserver import exposed_http
|
from ...htserver import exposed_http
|
||||||
from ...htserver import exposed_ws
|
from ...htserver import exposed_ws
|
||||||
from ...htserver import get_exposed_http
|
|
||||||
from ...htserver import get_exposed_ws
|
|
||||||
from ...htserver import make_json_response
|
from ...htserver import make_json_response
|
||||||
from ...htserver import send_ws_event
|
from ...htserver import WsSession
|
||||||
from ...htserver import broadcast_ws_event
|
|
||||||
from ...htserver import process_ws_messages
|
|
||||||
from ...htserver import HttpServer
|
from ...htserver import HttpServer
|
||||||
|
|
||||||
from ...plugins import BasePlugin
|
from ...plugins import BasePlugin
|
||||||
@ -128,15 +123,6 @@ class _Component: # pylint: disable=too-many-instance-attributes
|
|||||||
assert self.event_type, self
|
assert self.event_type, self
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class _WsClient:
|
|
||||||
ws: WebSocketResponse
|
|
||||||
stream: bool
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"WsClient(id={id(self)}, stream={self.stream})"
|
|
||||||
|
|
||||||
|
|
||||||
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
|
class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-instance-attributes
|
||||||
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
||||||
self,
|
self,
|
||||||
@ -160,6 +146,8 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
stream_forever: bool,
|
stream_forever: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.__auth_manager = auth_manager
|
self.__auth_manager = auth_manager
|
||||||
self.__hid = hid
|
self.__hid = hid
|
||||||
self.__streamer = streamer
|
self.__streamer = streamer
|
||||||
@ -201,11 +189,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
RedfishApi(info_manager, atx),
|
RedfishApi(info_manager, atx),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.__ws_handlers: Dict[str, Callable] = {}
|
|
||||||
|
|
||||||
self.__ws_clients: Set[_WsClient] = set()
|
|
||||||
self.__ws_clients_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
self.__streamer_notifier = aiotools.AioNotifier()
|
self.__streamer_notifier = aiotools.AioNotifier()
|
||||||
self.__reset_streamer = False
|
self.__reset_streamer = False
|
||||||
self.__new_streamer_params: Dict = {}
|
self.__new_streamer_params: Dict = {}
|
||||||
@ -244,11 +227,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
@exposed_http("GET", "/ws")
|
@exposed_http("GET", "/ws")
|
||||||
async def __ws_handler(self, request: Request) -> WebSocketResponse:
|
async def __ws_handler(self, request: Request) -> WebSocketResponse:
|
||||||
stream = valid_bool(request.query.get("stream", "true"))
|
stream = valid_bool(request.query.get("stream", "true"))
|
||||||
ws = await self._make_ws_response(request)
|
async with self._ws_session(request, stream=stream) as ws:
|
||||||
client = _WsClient(ws, stream)
|
|
||||||
await self.__register_ws_client(client)
|
|
||||||
|
|
||||||
try:
|
|
||||||
stage1 = [
|
stage1 = [
|
||||||
("gpio_model_state", self.__user_gpio.get_model()),
|
("gpio_model_state", self.__user_gpio.get_model()),
|
||||||
("hid_keymaps_state", self.__hid_api.get_keymaps()),
|
("hid_keymaps_state", self.__hid_api.get_keymaps()),
|
||||||
@ -266,19 +245,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
))
|
))
|
||||||
for stage in [stage1, stage2]:
|
for stage in [stage1, stage2]:
|
||||||
await asyncio.gather(*[
|
await asyncio.gather(*[
|
||||||
send_ws_event(ws, event_type, events.pop(event_type))
|
ws.send_event(event_type, events.pop(event_type))
|
||||||
for (event_type, _) in stage
|
for (event_type, _) in stage
|
||||||
])
|
])
|
||||||
|
await ws.send_event("loop", {})
|
||||||
await send_ws_event(ws, "loop", {})
|
return (await self._ws_loop(ws))
|
||||||
await process_ws_messages(ws, self.__ws_handlers)
|
|
||||||
return ws
|
|
||||||
finally:
|
|
||||||
await self.__remove_ws_client(client)
|
|
||||||
|
|
||||||
@exposed_ws("ping")
|
@exposed_ws("ping")
|
||||||
async def __ws_ping_handler(self, ws: WebSocketResponse, _: Dict) -> None:
|
async def __ws_ping_handler(self, ws: WsSession, _: Dict) -> None:
|
||||||
await send_ws_event(ws, "pong", {})
|
await ws.send_event("pong", {})
|
||||||
|
|
||||||
# ===== SYSTEM STUFF
|
# ===== SYSTEM STUFF
|
||||||
|
|
||||||
@ -300,26 +275,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
if comp.poll_state:
|
if comp.poll_state:
|
||||||
aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state()))
|
aiotools.create_deadly_task(f"{comp.name} [poller]", self.__poll_state(comp.event_type, comp.poll_state()))
|
||||||
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
|
aiotools.create_deadly_task("Stream snapshoter", self.__stream_snapshoter())
|
||||||
|
self._add_exposed(*self.__apis)
|
||||||
for api in self.__apis:
|
|
||||||
for http_exposed in get_exposed_http(api):
|
|
||||||
self._add_exposed(http_exposed)
|
|
||||||
for ws_exposed in get_exposed_ws(api):
|
|
||||||
self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler
|
|
||||||
|
|
||||||
async def _on_shutdown(self) -> None:
|
async def _on_shutdown(self) -> None:
|
||||||
logger = get_logger(0)
|
logger = get_logger(0)
|
||||||
|
|
||||||
logger.info("Waiting short tasks ...")
|
logger.info("Waiting short tasks ...")
|
||||||
await aiotools.wait_all_short_tasks()
|
await aiotools.wait_all_short_tasks()
|
||||||
|
|
||||||
logger.info("Stopping system tasks ...")
|
logger.info("Stopping system tasks ...")
|
||||||
await aiotools.stop_all_deadly_tasks()
|
await aiotools.stop_all_deadly_tasks()
|
||||||
|
|
||||||
logger.info("Disconnecting clients ...")
|
logger.info("Disconnecting clients ...")
|
||||||
for client in list(self.__ws_clients):
|
await self._close_all_wss()
|
||||||
await self.__remove_ws_client(client)
|
|
||||||
|
|
||||||
logger.info("On-Shutdown complete")
|
logger.info("On-Shutdown complete")
|
||||||
|
|
||||||
async def _on_cleanup(self) -> None:
|
async def _on_cleanup(self) -> None:
|
||||||
@ -333,25 +298,18 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
logger.exception("Cleanup error on %s", comp.name)
|
logger.exception("Cleanup error on %s", comp.name)
|
||||||
logger.info("On-Cleanup complete")
|
logger.info("On-Cleanup complete")
|
||||||
|
|
||||||
async def __register_ws_client(self, client: _WsClient) -> None:
|
async def _on_ws_opened(self) -> None:
|
||||||
async with self.__ws_clients_lock:
|
|
||||||
self.__ws_clients.add(client)
|
|
||||||
get_logger().info("Registered new client socket: %s; clients now: %d", client, len(self.__ws_clients))
|
|
||||||
await self.__streamer_notifier.notify()
|
await self.__streamer_notifier.notify()
|
||||||
|
|
||||||
async def __remove_ws_client(self, client: _WsClient) -> None:
|
async def _on_ws_closed(self) -> None:
|
||||||
async with self.__ws_clients_lock:
|
self.__hid.clear_events()
|
||||||
self.__hid.clear_events()
|
|
||||||
try:
|
|
||||||
self.__ws_clients.remove(client)
|
|
||||||
get_logger().info("Removed client socket: %s; clients now: %d", client, len(self.__ws_clients))
|
|
||||||
await client.ws.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await self.__streamer_notifier.notify()
|
await self.__streamer_notifier.notify()
|
||||||
|
|
||||||
def __has_stream_clients(self) -> bool:
|
def __has_stream_clients(self) -> bool:
|
||||||
return bool(sum(map(operator.attrgetter("stream"), self.__ws_clients)))
|
return bool(sum(map(
|
||||||
|
(lambda ws: ws.kwargs["stream"]),
|
||||||
|
self._get_wss(),
|
||||||
|
)))
|
||||||
|
|
||||||
# ===== SYSTEM TASKS
|
# ===== SYSTEM TASKS
|
||||||
|
|
||||||
@ -379,10 +337,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
|||||||
|
|
||||||
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
|
async def __poll_state(self, event_type: str, poller: AsyncGenerator[Dict, None]) -> None:
|
||||||
async for state in poller:
|
async for state in poller:
|
||||||
await broadcast_ws_event([
|
await self._broadcast_ws_event(event_type, state)
|
||||||
client.ws
|
|
||||||
for client in list(self.__ws_clients)
|
|
||||||
], event_type, state)
|
|
||||||
|
|
||||||
async def __stream_snapshoter(self) -> None:
|
async def __stream_snapshoter(self) -> None:
|
||||||
await self.__snapshoter.run(
|
await self.__snapshoter.run(
|
||||||
|
|||||||
179
kvmd/htserver.py
179
kvmd/htserver.py
@ -23,6 +23,7 @@
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@ -31,7 +32,9 @@ from typing import Tuple
|
|||||||
from typing import List
|
from typing import List
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from typing import AsyncGenerator
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from aiohttp.web import BaseRequest
|
from aiohttp.web import BaseRequest
|
||||||
from aiohttp.web import Request
|
from aiohttp.web import Request
|
||||||
@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla
|
|||||||
return set_attrs
|
return set_attrs
|
||||||
|
|
||||||
|
|
||||||
def get_exposed_http(obj: object) -> List[HttpExposed]:
|
def _get_exposed_http(obj: object) -> List[HttpExposed]:
|
||||||
return [
|
return [
|
||||||
HttpExposed(
|
HttpExposed(
|
||||||
method=getattr(handler, _HTTP_METHOD),
|
method=getattr(handler, _HTTP_METHOD),
|
||||||
@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable:
|
|||||||
return set_attrs
|
return set_attrs
|
||||||
|
|
||||||
|
|
||||||
def get_exposed_ws(obj: object) -> List[WsExposed]:
|
def _get_exposed_ws(obj: object) -> List[WsExposed]:
|
||||||
return [
|
return [
|
||||||
WsExposed(
|
WsExposed(
|
||||||
event_type=getattr(handler, _WS_EVENT_TYPE),
|
event_type=getattr(handler, _WS_EVENT_TYPE),
|
||||||
@ -205,57 +208,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
|
|||||||
}, False)
|
}, False)
|
||||||
|
|
||||||
|
|
||||||
# =====
|
|
||||||
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
|
|
||||||
await ws.send_str(json.dumps({
|
|
||||||
"event_type": event_type,
|
|
||||||
"event": event,
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
|
|
||||||
if wss:
|
|
||||||
await asyncio.gather(*[
|
|
||||||
send_ws_event(ws, event_type, event)
|
|
||||||
for ws in wss
|
|
||||||
if (
|
|
||||||
not ws.closed
|
|
||||||
and ws._req is not None # pylint: disable=protected-access
|
|
||||||
and ws._req.transport is not None # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
], return_exceptions=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_ws_event(msg: str) -> Tuple[str, Dict]:
|
|
||||||
data = json.loads(msg)
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
raise RuntimeError("Top-level event structure is not a dict")
|
|
||||||
event_type = data.get("event_type")
|
|
||||||
if not isinstance(event_type, str):
|
|
||||||
raise RuntimeError("event_type must be a string")
|
|
||||||
event = data["event"]
|
|
||||||
if not isinstance(event, dict):
|
|
||||||
raise RuntimeError("event must be a dict")
|
|
||||||
return (event_type, event)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None:
|
|
||||||
logger = get_logger(1)
|
|
||||||
async for msg in ws:
|
|
||||||
if msg.type != WSMsgType.TEXT:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
(event_type, event) = _parse_ws_event(msg.data)
|
|
||||||
except Exception as err:
|
|
||||||
logger.error("Can't parse JSON event from websocket: %r", err)
|
|
||||||
else:
|
|
||||||
handler = handlers.get(event_type)
|
|
||||||
if handler:
|
|
||||||
await handler(ws, event)
|
|
||||||
else:
|
|
||||||
logger.error("Unknown websocket event: %r", msg.data)
|
|
||||||
|
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
|
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
|
||||||
|
|
||||||
@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class WsSession:
|
||||||
|
wsr: WebSocketResponse
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"WsSession(id={id(self)}, {self.kwargs})"
|
||||||
|
|
||||||
|
async def send_event(self, event_type: str, event: Optional[Dict]) -> None:
|
||||||
|
await self.wsr.send_str(json.dumps({
|
||||||
|
"event_type": event_type,
|
||||||
|
"event": event,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
class HttpServer:
|
class HttpServer:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.__ws_heartbeat: Optional[float] = None
|
||||||
|
self.__ws_handlers: Dict[str, Callable] = {}
|
||||||
|
self.__ws_sessions: List[WsSession] = []
|
||||||
|
self.__ws_sessions_lock = asyncio.Lock()
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
unix_path: str,
|
unix_path: str,
|
||||||
@ -282,7 +255,7 @@ class HttpServer:
|
|||||||
access_log_format: str,
|
access_log_format: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init
|
self.__ws_heartbeat = heartbeat
|
||||||
|
|
||||||
if unix_rm and os.path.exists(unix_path):
|
if unix_rm and os.path.exists(unix_path):
|
||||||
os.remove(unix_path)
|
os.remove(unix_path)
|
||||||
@ -302,7 +275,14 @@ class HttpServer:
|
|||||||
|
|
||||||
# =====
|
# =====
|
||||||
|
|
||||||
def _add_exposed(self, exposed: HttpExposed) -> None:
|
def _add_exposed(self, *objs: object) -> None:
|
||||||
|
for obj in objs:
|
||||||
|
for http_exposed in _get_exposed_http(obj):
|
||||||
|
self.__add_exposed_http(http_exposed)
|
||||||
|
for ws_exposed in _get_exposed_ws(obj):
|
||||||
|
self.__add_exposed_ws(ws_exposed)
|
||||||
|
|
||||||
|
def __add_exposed_http(self, exposed: HttpExposed) -> None:
|
||||||
async def wrapper(request: Request) -> Response:
|
async def wrapper(request: Request) -> Response:
|
||||||
try:
|
try:
|
||||||
await self._check_request_auth(exposed, request)
|
await self._check_request_auth(exposed, request)
|
||||||
@ -315,10 +295,85 @@ class HttpServer:
|
|||||||
return make_json_exception(err)
|
return make_json_exception(err)
|
||||||
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
|
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
|
||||||
|
|
||||||
async def _make_ws_response(self, request: Request) -> WebSocketResponse:
|
def __add_exposed_ws(self, exposed: WsExposed) -> None:
|
||||||
ws = WebSocketResponse(heartbeat=self.__heartbeat)
|
self.__ws_handlers[exposed.event_type] = exposed.handler
|
||||||
await ws.prepare(request)
|
|
||||||
return ws
|
# =====
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
|
||||||
|
assert self.__ws_heartbeat is not None
|
||||||
|
wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
|
||||||
|
await wsr.prepare(request)
|
||||||
|
ws = WsSession(wsr, kwargs)
|
||||||
|
|
||||||
|
async with self.__ws_sessions_lock:
|
||||||
|
self.__ws_sessions.append(ws)
|
||||||
|
get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._on_ws_opened()
|
||||||
|
yield ws
|
||||||
|
finally:
|
||||||
|
await self.__close_ws(ws)
|
||||||
|
|
||||||
|
async def _ws_loop(self, ws: WsSession) -> WebSocketResponse:
|
||||||
|
logger = get_logger()
|
||||||
|
async for msg in ws.wsr:
|
||||||
|
if msg.type != WSMsgType.TEXT:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
(event_type, event) = self.__parse_ws_event(msg.data)
|
||||||
|
except Exception as err:
|
||||||
|
logger.error("Can't parse JSON event from websocket: %r", err)
|
||||||
|
else:
|
||||||
|
handler = self.__ws_handlers.get(event_type)
|
||||||
|
if handler:
|
||||||
|
await handler(ws, event)
|
||||||
|
else:
|
||||||
|
logger.error("Unknown websocket event: %r", msg.data)
|
||||||
|
return ws.wsr
|
||||||
|
|
||||||
|
async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None:
|
||||||
|
if self.__ws_sessions:
|
||||||
|
await asyncio.gather(*[
|
||||||
|
ws.send_event(event_type, event)
|
||||||
|
for ws in self.__ws_sessions
|
||||||
|
if (
|
||||||
|
not ws.wsr.closed
|
||||||
|
and ws.wsr._req is not None # pylint: disable=protected-access
|
||||||
|
and ws.wsr._req.transport is not None # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
], return_exceptions=True)
|
||||||
|
|
||||||
|
async def _close_all_wss(self) -> None:
|
||||||
|
for ws in self._get_wss():
|
||||||
|
await self.__close_ws(ws)
|
||||||
|
|
||||||
|
def _get_wss(self) -> List[WsSession]:
|
||||||
|
return list(self.__ws_sessions)
|
||||||
|
|
||||||
|
async def __close_ws(self, ws: WsSession) -> None:
|
||||||
|
async with self.__ws_sessions_lock:
|
||||||
|
try:
|
||||||
|
self.__ws_sessions.remove(ws)
|
||||||
|
get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions))
|
||||||
|
await ws.wsr.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await self._on_ws_closed()
|
||||||
|
|
||||||
|
def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]:
|
||||||
|
data = json.loads(msg)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise RuntimeError("Top-level event structure is not a dict")
|
||||||
|
event_type = data.get("event_type")
|
||||||
|
if not isinstance(event_type, str):
|
||||||
|
raise RuntimeError("event_type must be a string")
|
||||||
|
event = data["event"]
|
||||||
|
if not isinstance(event, dict):
|
||||||
|
raise RuntimeError("event must be a dict")
|
||||||
|
return (event_type, event)
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
|
|
||||||
@ -334,6 +389,12 @@ class HttpServer:
|
|||||||
async def _on_cleanup(self) -> None:
|
async def _on_cleanup(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _on_ws_opened(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _on_ws_closed(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
|
|
||||||
async def __make_app(self) -> Application:
|
async def __make_app(self) -> Application:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user