diff --git a/kvmd/clients/kvmd.py b/kvmd/clients/kvmd.py index 497cdd3e..ce07953c 100644 --- a/kvmd/clients/kvmd.py +++ b/kvmd/clients/kvmd.py @@ -20,7 +20,6 @@ # ========================================================================== # -import asyncio import contextlib import struct @@ -161,47 +160,23 @@ class _SwitchApiPart(_BaseApiPart): class KvmdClientWs: def __init__(self, ws: aiohttp.ClientWebSocketResponse) -> None: self.__ws = ws - self.__writer_queue: "asyncio.Queue[tuple[str, dict] | bytes]" = asyncio.Queue() self.__communicated = False async def communicate(self) -> AsyncGenerator[tuple[str, dict], None]: # pylint: disable=too-many-branches assert not self.__communicated self.__communicated = True - recv_task: (asyncio.Task | None) = None - writer_task: (asyncio.Task | None) = None try: - while True: - if recv_task is None: - recv_task = asyncio.create_task(self.__ws.receive()) - if writer_task is None: - writer_task = asyncio.create_task(self.__writer_queue.get()) - - done = (await aiotools.wait_first(recv_task, writer_task))[0] - - if recv_task in done: - msg = recv_task.result() - if msg.type == aiohttp.WSMsgType.TEXT: + async for msg in self.__ws: + match msg.type: + case aiohttp.WSMsgType.TEXT: yield htserver.parse_ws_event(msg.data) - elif msg.type == aiohttp.WSMsgType.CLOSE: + case aiohttp.WSMsgType.CLOSE: await self.__ws.close() - elif msg.type == aiohttp.WSMsgType.CLOSED: + case aiohttp.WSMsgType.CLOSED: break - else: + case _: raise RuntimeError(f"Unhandled WS message type: {msg!r}") - recv_task = None - - if writer_task in done: - payload = writer_task.result() - if isinstance(payload, bytes): - await self.__ws.send_bytes(payload) - else: - await htserver.send_ws_event(self.__ws, *payload) - writer_task = None finally: - if recv_task: - recv_task.cancel() - if writer_task: - writer_task.cancel() try: await aiotools.shield_fg(self.__ws.close()) except Exception: @@ -211,20 +186,31 @@ class KvmdClientWs: async def send_key_event(self, key: int, state: bool) -> None: mask = (0b10000000 | int(bool(state))) - await self.__writer_queue.put(struct.pack(">BBH", 1, mask, key)) + await self.__send_struct(">BBH", 1, mask, key) async def send_mouse_button_event(self, button: int, state: bool) -> None: mask = (0b10000000 | int(bool(state))) - await self.__writer_queue.put(struct.pack(">BBH", 2, mask, button)) + await self.__send_struct(">BBH", 2, mask, button) async def send_mouse_move_event(self, to_x: int, to_y: int) -> None: - await self.__writer_queue.put(struct.pack(">Bhh", 3, to_x, to_y)) + await self.__send_struct(">Bhh", 3, to_x, to_y) async def send_mouse_relative_event(self, delta_x: int, delta_y: int) -> None: - await self.__writer_queue.put(struct.pack(">BBbb", 4, 0, delta_x, delta_y)) + await self.__send_struct(">BBbb", 4, 0, delta_x, delta_y) async def send_mouse_wheel_event(self, delta_x: int, delta_y: int) -> None: - await self.__writer_queue.put(struct.pack(">BBbb", 5, 0, delta_x, delta_y)) + await self.__send_struct(">BBbb", 5, 0, delta_x, delta_y) + + async def __send_struct(self, fmt: str, *values: int) -> None: + if not self.__communicated: + return + data = struct.pack(fmt, *values) + try: + await self.__ws.send_bytes(data) + except Exception: + # XXX: We don't care about any connection errors + # since they will be handled with communication() + pass class KvmdClientSession(BaseHttpClientSession):