mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2026-01-29 00:51:53 +08:00
common websocket code
This commit is contained in:
179
kvmd/htserver.py
179
kvmd/htserver.py
@@ -23,6 +23,7 @@
|
||||
import os
|
||||
import socket
|
||||
import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
@@ -31,7 +32,9 @@ from typing import Tuple
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Callable
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
from typing import Any
|
||||
|
||||
from aiohttp.web import BaseRequest
|
||||
from aiohttp.web import Request
|
||||
@@ -103,7 +106,7 @@ def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Calla
|
||||
return set_attrs
|
||||
|
||||
|
||||
def get_exposed_http(obj: object) -> List[HttpExposed]:
|
||||
def _get_exposed_http(obj: object) -> List[HttpExposed]:
|
||||
return [
|
||||
HttpExposed(
|
||||
method=getattr(handler, _HTTP_METHOD),
|
||||
@@ -135,7 +138,7 @@ def exposed_ws(event_type: str) -> Callable:
|
||||
return set_attrs
|
||||
|
||||
|
||||
def get_exposed_ws(obj: object) -> List[WsExposed]:
|
||||
def _get_exposed_ws(obj: object) -> List[WsExposed]:
|
||||
return [
|
||||
WsExposed(
|
||||
event_type=getattr(handler, _WS_EVENT_TYPE),
|
||||
@@ -205,57 +208,6 @@ async def stream_json_exception(response: StreamResponse, err: Exception) -> Non
|
||||
}, False)
|
||||
|
||||
|
||||
# =====
|
||||
async def send_ws_event(ws: WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
|
||||
await ws.send_str(json.dumps({
|
||||
"event_type": event_type,
|
||||
"event": event,
|
||||
}))
|
||||
|
||||
|
||||
async def broadcast_ws_event(wss: List[WebSocketResponse], event_type: str, event: Optional[Dict]) -> None:
|
||||
if wss:
|
||||
await asyncio.gather(*[
|
||||
send_ws_event(ws, event_type, event)
|
||||
for ws in wss
|
||||
if (
|
||||
not ws.closed
|
||||
and ws._req is not None # pylint: disable=protected-access
|
||||
and ws._req.transport is not None # pylint: disable=protected-access
|
||||
)
|
||||
], return_exceptions=True)
|
||||
|
||||
|
||||
def _parse_ws_event(msg: str) -> Tuple[str, Dict]:
|
||||
data = json.loads(msg)
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError("Top-level event structure is not a dict")
|
||||
event_type = data.get("event_type")
|
||||
if not isinstance(event_type, str):
|
||||
raise RuntimeError("event_type must be a string")
|
||||
event = data["event"]
|
||||
if not isinstance(event, dict):
|
||||
raise RuntimeError("event must be a dict")
|
||||
return (event_type, event)
|
||||
|
||||
|
||||
async def process_ws_messages(ws: WebSocketResponse, handlers: Dict[str, Callable]) -> None:
|
||||
logger = get_logger(1)
|
||||
async for msg in ws:
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
break
|
||||
try:
|
||||
(event_type, event) = _parse_ws_event(msg.data)
|
||||
except Exception as err:
|
||||
logger.error("Can't parse JSON event from websocket: %r", err)
|
||||
else:
|
||||
handler = handlers.get(event_type)
|
||||
if handler:
|
||||
await handler(ws, event)
|
||||
else:
|
||||
logger.error("Unknown websocket event: %r", msg.data)
|
||||
|
||||
|
||||
# =====
|
||||
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
|
||||
|
||||
@@ -272,7 +224,28 @@ def set_request_auth_info(request: BaseRequest, info: str) -> None:
|
||||
|
||||
|
||||
# =====
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WsSession:
|
||||
wsr: WebSocketResponse
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"WsSession(id={id(self)}, {self.kwargs})"
|
||||
|
||||
async def send_event(self, event_type: str, event: Optional[Dict]) -> None:
|
||||
await self.wsr.send_str(json.dumps({
|
||||
"event_type": event_type,
|
||||
"event": event,
|
||||
}))
|
||||
|
||||
|
||||
class HttpServer:
|
||||
def __init__(self) -> None:
|
||||
self.__ws_heartbeat: Optional[float] = None
|
||||
self.__ws_handlers: Dict[str, Callable] = {}
|
||||
self.__ws_sessions: List[WsSession] = []
|
||||
self.__ws_sessions_lock = asyncio.Lock()
|
||||
|
||||
def run(
|
||||
self,
|
||||
unix_path: str,
|
||||
@@ -282,7 +255,7 @@ class HttpServer:
|
||||
access_log_format: str,
|
||||
) -> None:
|
||||
|
||||
self.__heartbeat = heartbeat # pylint: disable=attribute-defined-outside-init
|
||||
self.__ws_heartbeat = heartbeat
|
||||
|
||||
if unix_rm and os.path.exists(unix_path):
|
||||
os.remove(unix_path)
|
||||
@@ -302,7 +275,14 @@ class HttpServer:
|
||||
|
||||
# =====
|
||||
|
||||
def _add_exposed(self, exposed: HttpExposed) -> None:
|
||||
def _add_exposed(self, *objs: object) -> None:
|
||||
for obj in objs:
|
||||
for http_exposed in _get_exposed_http(obj):
|
||||
self.__add_exposed_http(http_exposed)
|
||||
for ws_exposed in _get_exposed_ws(obj):
|
||||
self.__add_exposed_ws(ws_exposed)
|
||||
|
||||
def __add_exposed_http(self, exposed: HttpExposed) -> None:
|
||||
async def wrapper(request: Request) -> Response:
|
||||
try:
|
||||
await self._check_request_auth(exposed, request)
|
||||
@@ -315,10 +295,85 @@ class HttpServer:
|
||||
return make_json_exception(err)
|
||||
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
|
||||
|
||||
async def _make_ws_response(self, request: Request) -> WebSocketResponse:
|
||||
ws = WebSocketResponse(heartbeat=self.__heartbeat)
|
||||
await ws.prepare(request)
|
||||
return ws
|
||||
def __add_exposed_ws(self, exposed: WsExposed) -> None:
|
||||
self.__ws_handlers[exposed.event_type] = exposed.handler
|
||||
|
||||
# =====
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ws_session(self, request: Request, **kwargs: Any) -> AsyncGenerator[WsSession, None]:
|
||||
assert self.__ws_heartbeat is not None
|
||||
wsr = WebSocketResponse(heartbeat=self.__ws_heartbeat)
|
||||
await wsr.prepare(request)
|
||||
ws = WsSession(wsr, kwargs)
|
||||
|
||||
async with self.__ws_sessions_lock:
|
||||
self.__ws_sessions.append(ws)
|
||||
get_logger(2).info("Registered new client session: %s; clients now: %d", ws, len(self.__ws_sessions))
|
||||
|
||||
try:
|
||||
await self._on_ws_opened()
|
||||
yield ws
|
||||
finally:
|
||||
await self.__close_ws(ws)
|
||||
|
||||
async def _ws_loop(self, ws: WsSession) -> WebSocketResponse:
|
||||
logger = get_logger()
|
||||
async for msg in ws.wsr:
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
break
|
||||
try:
|
||||
(event_type, event) = self.__parse_ws_event(msg.data)
|
||||
except Exception as err:
|
||||
logger.error("Can't parse JSON event from websocket: %r", err)
|
||||
else:
|
||||
handler = self.__ws_handlers.get(event_type)
|
||||
if handler:
|
||||
await handler(ws, event)
|
||||
else:
|
||||
logger.error("Unknown websocket event: %r", msg.data)
|
||||
return ws.wsr
|
||||
|
||||
async def _broadcast_ws_event(self, event_type: str, event: Optional[Dict]) -> None:
|
||||
if self.__ws_sessions:
|
||||
await asyncio.gather(*[
|
||||
ws.send_event(event_type, event)
|
||||
for ws in self.__ws_sessions
|
||||
if (
|
||||
not ws.wsr.closed
|
||||
and ws.wsr._req is not None # pylint: disable=protected-access
|
||||
and ws.wsr._req.transport is not None # pylint: disable=protected-access
|
||||
)
|
||||
], return_exceptions=True)
|
||||
|
||||
async def _close_all_wss(self) -> None:
|
||||
for ws in self._get_wss():
|
||||
await self.__close_ws(ws)
|
||||
|
||||
def _get_wss(self) -> List[WsSession]:
|
||||
return list(self.__ws_sessions)
|
||||
|
||||
async def __close_ws(self, ws: WsSession) -> None:
|
||||
async with self.__ws_sessions_lock:
|
||||
try:
|
||||
self.__ws_sessions.remove(ws)
|
||||
get_logger(3).info("Removed client socket: %s; clients now: %d", ws, len(self.__ws_sessions))
|
||||
await ws.wsr.close()
|
||||
except Exception:
|
||||
pass
|
||||
await self._on_ws_closed()
|
||||
|
||||
def __parse_ws_event(self, msg: str) -> Tuple[str, Dict]:
|
||||
data = json.loads(msg)
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError("Top-level event structure is not a dict")
|
||||
event_type = data.get("event_type")
|
||||
if not isinstance(event_type, str):
|
||||
raise RuntimeError("event_type must be a string")
|
||||
event = data["event"]
|
||||
if not isinstance(event, dict):
|
||||
raise RuntimeError("event must be a dict")
|
||||
return (event_type, event)
|
||||
|
||||
# =====
|
||||
|
||||
@@ -334,6 +389,12 @@ class HttpServer:
|
||||
async def _on_cleanup(self) -> None:
|
||||
pass
|
||||
|
||||
async def _on_ws_opened(self) -> None:
|
||||
pass
|
||||
|
||||
async def _on_ws_closed(self) -> None:
|
||||
pass
|
||||
|
||||
# =====
|
||||
|
||||
async def __make_app(self) -> Application:
|
||||
|
||||
Reference in New Issue
Block a user