diff --git a/kvmd/apps/kvmd/msd.py b/kvmd/apps/kvmd/msd.py index 2cdb9050..79295109 100644 --- a/kvmd/apps/kvmd/msd.py +++ b/kvmd/apps/kvmd/msd.py @@ -1,12 +1,14 @@ import os import struct import asyncio +import asyncio.queues import types from typing import Dict from typing import NamedTuple from typing import Callable from typing import Type +from typing import AsyncGenerator from typing import Optional from typing import Any @@ -192,6 +194,8 @@ class MassStorageDevice: # pylint: disable=too-many-instance-attributes self.__device_file: Optional[aiofiles.base.AiofilesContextManager] = None self.__written = 0 + self.__state_queue: asyncio.queues.Queue = asyncio.Queue() + logger = get_logger(0) if self._device_path: logger.info("Using %r as mass-storage device", self._device_path) @@ -208,33 +212,6 @@ class MassStorageDevice: # pylint: disable=too-many-instance-attributes else: logger.warning("Mass-storage device is not operational") - @_msd_operated - async def connect_to_kvm(self, no_delay: bool=False) -> None: - with self.__region: - if self.__device_info: - raise MsdAlreadyConnectedToKvmError() - gpio.write(self.__target, False) - if not no_delay: - await asyncio.sleep(self.__init_delay) - await self.__load_device_info() - get_logger().info("Mass-storage device switched to KVM: %s", self.__device_info) - - @_msd_operated - async def connect_to_pc(self) -> None: - with self.__region: - if not self.__device_info: - raise MsdAlreadyConnectedToPcError() - gpio.write(self.__target, True) - self.__device_info = None - get_logger().info("Mass-storage device switched to Server") - - @_msd_operated - async def reset(self) -> None: - with self.__region: - gpio.write(self.__reset, True) - await asyncio.sleep(self.__reset_delay) - gpio.write(self.__reset, False) - def get_state(self) -> Dict: info = (self.__saved_device_info._asdict() if self.__saved_device_info else None) if info: @@ -253,11 +230,50 @@ class MassStorageDevice: # pylint: disable=too-many-instance-attributes "info": info, } + async def poll_state(self) -> AsyncGenerator[Dict, None]: + while True: + yield (await self.__state_queue.get()) + async def cleanup(self) -> None: await self.__close_device_file() gpio.write(self.__target, False) gpio.write(self.__reset, False) + @_msd_operated + async def connect_to_kvm(self, no_delay: bool=False) -> Dict: + with self.__region: + if self.__device_info: + raise MsdAlreadyConnectedToKvmError() + gpio.write(self.__target, False) + if not no_delay: + await asyncio.sleep(self.__init_delay) + await self.__load_device_info() + state = self.get_state() + await self.__state_queue.put(state) + get_logger().info("Mass-storage device switched to KVM: %s", self.__device_info) + return state + + @_msd_operated + async def connect_to_pc(self) -> Dict: + with self.__region: + if not self.__device_info: + raise MsdAlreadyConnectedToPcError() + gpio.write(self.__target, True) + self.__device_info = None + state = self.get_state() + await self.__state_queue.put(state) + get_logger().info("Mass-storage device switched to Server") + return state + + @_msd_operated + async def reset(self) -> None: + with self.__region: + get_logger().info("Mass-storage device reset") + gpio.write(self.__reset, True) + await asyncio.sleep(self.__reset_delay) + gpio.write(self.__reset, False) + await self.__state_queue.put(self.get_state()) + @_msd_operated async def __aenter__(self) -> "MassStorageDevice": self.__region.enter() @@ -268,6 +284,7 @@ class MassStorageDevice: # pylint: disable=too-many-instance-attributes self.__written = 0 return self finally: + await self.__state_queue.put(self.get_state()) self.__region.exit() async def write_image_info(self, name: str, complete: bool) -> None: @@ -296,6 +313,7 @@ class MassStorageDevice: # pylint: disable=too-many-instance-attributes try: await self.__close_device_file() finally: + await self.__state_queue.put(self.get_state()) self.__region.exit() async def __write_to_device_file(self, data: bytes) -> None: diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index af9e6765..7b3192ae 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -183,6 +183,7 @@ class Server: # pylint: disable=too-many-instance-attributes self.__loop.create_task(self.__stream_controller()), self.__loop.create_task(self.__poll_dead_sockets()), self.__loop.create_task(self.__poll_atx_state()), + self.__loop.create_task(self.__poll_msd_state()), self.__loop.create_task(self.__poll_streamer_state()), ]) @@ -303,9 +304,7 @@ class Server: # pylint: disable=too-many-instance-attributes }.get(button) if not clicker: raise BadRequest("Invalid param 'button'") - await self.__broadcast_event(_Events.ATX_STATE, self.__atx.get_state()) await clicker() - await self.__broadcast_event(_Events.ATX_STATE, self.__atx.get_state()) return _json({"clicked": button}) # ===== MSD @@ -317,16 +316,11 @@ class Server: # pylint: disable=too-many-instance-attributes async def __msd_connect_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: to = request.query.get("to") if to == "kvm": - await self.__msd.connect_to_kvm() - state = self.__msd.get_state() - await self.__broadcast_event(_Events.MSD_STATE, state) + return _json(await self.__msd.connect_to_kvm()) elif to == "server": - await self.__msd.connect_to_pc() - state = self.__msd.get_state() - await self.__broadcast_event(_Events.MSD_STATE, state) + return _json(await self.__msd.connect_to_pc()) else: raise BadRequest("Invalid param 'to'") - return _json(state) @_wrap_exceptions_for_web("Can't write data to mass-storage device") async def __msd_write_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: @@ -334,17 +328,16 @@ class Server: # pylint: disable=too-many-instance-attributes reader = await request.multipart() written = 0 try: - field = await reader.next() - if not field or field.name != "image_name": - raise BadRequest("Missing 'image_name' field") - image_name = (await field.read()).decode("utf-8")[:256] - - field = await reader.next() - if not field or field.name != "image_data": - raise BadRequest("Missing 'image_data' field") - async with self.__msd: - await self.__broadcast_event(_Events.MSD_STATE, self.__msd.get_state()) + field = await reader.next() + if not field or field.name != "image_name": + raise BadRequest("Missing 'image_name' field") + image_name = (await field.read()).decode("utf-8")[:256] + + field = await reader.next() + if not field or field.name != "image_data": + raise BadRequest("Missing 'image_data' field") + logger.info("Writing image %r to mass-storage device ...", image_name) await self.__msd.write_image_info(image_name, False) while True: @@ -354,7 +347,6 @@ class Server: # pylint: disable=too-many-instance-attributes written = await self.__msd.write_image_chunk(chunk) await self.__msd.write_image_info(image_name, True) finally: - await self.__broadcast_event(_Events.MSD_STATE, self.__msd.get_state()) if written != 0: logger.info("Written %d bytes to mass-storage device", written) return _json({"written": written}) @@ -450,27 +442,31 @@ class Server: # pylint: disable=too-many-instance-attributes @_system_task async def __poll_atx_state(self) -> None: async for state in self.__atx.poll_state(): - if self.__sockets: - await self.__broadcast_event(_Events.ATX_STATE, state) + await self.__broadcast_event(_Events.ATX_STATE, state) + + @_system_task + async def __poll_msd_state(self) -> None: + async for state in self.__msd.poll_state(): + await self.__broadcast_event(_Events.MSD_STATE, state) @_system_task async def __poll_streamer_state(self) -> None: async for state in self.__streamer.poll_state(): - if self.__sockets: - await self.__broadcast_event(_Events.STREAMER_STATE, state) + await self.__broadcast_event(_Events.STREAMER_STATE, state) async def __broadcast_event(self, event_type: _Events, event_attrs: Dict) -> None: - await asyncio.gather(*[ - ws.send_str(json.dumps({ - "msg_type": "event", - "msg": { - "event": event_type.value, - "event_attrs": event_attrs, - }, - })) - for ws in list(self.__sockets) - if not ws.closed and ws._req.transport # pylint: disable=protected-access - ], return_exceptions=True) + if self.__sockets: + await asyncio.gather(*[ + ws.send_str(json.dumps({ + "msg_type": "event", + "msg": { + "event": event_type.value, + "event_attrs": event_attrs, + }, + })) + for ws in list(self.__sockets) + if not ws.closed and ws._req.transport # pylint: disable=protected-access + ], return_exceptions=True) async def __register_socket(self, ws: aiohttp.web.WebSocketResponse) -> None: async with self.__sockets_lock: