refactoring

This commit is contained in:
Devaev Maxim
2021-01-28 20:36:46 +03:00
parent 1442515e5c
commit 0538a6828f
2 changed files with 41 additions and 20 deletions

View File

@@ -105,6 +105,30 @@ class AioNotifier:
await self.__queue.get() await self.__queue.get()
# =====
class AioStage:
def __init__(self) -> None:
self.__fut = asyncio.Future() # type: ignore
def set_passed(self, multi: bool=False) -> None:
if multi and self.__fut.done():
return
self.__fut.set_result(None)
def is_passed(self) -> bool:
return self.__fut.done()
async def wait_passed(self, timeout: float=-1) -> bool:
if timeout >= 0:
try:
await asyncio.wait_for(self.__fut, timeout=timeout)
except asyncio.TimeoutError:
return False
else:
await self.__fut
return True
# ===== # =====
class AioExclusiveRegion: class AioExclusiveRegion:
def __init__( def __init__(

View File

@@ -52,6 +52,7 @@ from ...clients.streamer import StreamFormats
from ...clients.streamer import BaseStreamerClient from ...clients.streamer import BaseStreamerClient
from ... import tools from ... import tools
from ... import aiotools
from .rfb import RfbClient from .rfb import RfbClient
from .rfb.stream import rfb_format_remote from .rfb.stream import rfb_format_remote
@@ -113,9 +114,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__shared_params = shared_params self.__shared_params = shared_params
self.__stage1_authorized = asyncio.Future() # type: ignore self.__stage1_authorized = aiotools.AioStage()
self.__stage2_encodings_accepted = asyncio.Future() # type: ignore self.__stage2_encodings_accepted = aiotools.AioStage()
self.__stage3_ws_connected = asyncio.Future() # type: ignore self.__stage3_ws_connected = aiotools.AioStage()
self.__kvmd_session: Optional[KvmdClientSession] = None self.__kvmd_session: Optional[KvmdClientSession] = None
self.__kvmd_ws: Optional[KvmdClientWs] = None self.__kvmd_ws: Optional[KvmdClientWs] = None
@@ -149,19 +150,17 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def __kvmd_task_loop(self) -> None: async def __kvmd_task_loop(self) -> None:
logger = get_logger(0) logger = get_logger(0)
await self.__stage1_authorized await self.__stage1_authorized.wait_passed()
logger.info("[kvmd] %s: Waiting for the SetEncodings message ...", self._remote) logger.info("[kvmd] %s: Waiting for the SetEncodings message ...", self._remote)
try: if not (await self.__stage2_encodings_accepted.wait_passed(timeout=5)):
await asyncio.wait_for(self.__stage2_encodings_accepted, timeout=5)
except asyncio.TimeoutError:
raise RfbError("No SetEncodings message recieved from the client in 5 secs") raise RfbError("No SetEncodings message recieved from the client in 5 secs")
assert self.__kvmd_session assert self.__kvmd_session
try: try:
async with self.__kvmd_session.ws() as self.__kvmd_ws: async with self.__kvmd_session.ws() as self.__kvmd_ws:
logger.info("[kvmd] %s: Connected to KVMD websocket", self._remote) logger.info("[kvmd] %s: Connected to KVMD websocket", self._remote)
self.__stage3_ws_connected.set_result(None) self.__stage3_ws_connected.set_passed()
async for event in self.__kvmd_ws.communicate(): async for event in self.__kvmd_ws.communicate():
await self.__process_ws_event(event) await self.__process_ws_event(event)
raise RfbError("KVMD closes the websocket (the server may have been stopped)") raise RfbError("KVMD closes the websocket (the server may have been stopped)")
@@ -191,7 +190,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def __streamer_task_loop(self) -> None: async def __streamer_task_loop(self) -> None:
logger = get_logger(0) logger = get_logger(0)
await self.__stage3_ws_connected await self.__stage3_ws_connected.wait_passed()
streamer = self.__get_preferred_streamer() streamer = self.__get_preferred_streamer()
while True: while True:
try: try:
@@ -272,7 +271,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _authorize_userpass(self, user: str, passwd: str) -> bool: async def _authorize_userpass(self, user: str, passwd: str) -> bool:
self.__kvmd_session = self.__kvmd.make_session(user, passwd) self.__kvmd_session = self.__kvmd.make_session(user, passwd)
if (await self.__kvmd_session.auth.check()): if (await self.__kvmd_session.auth.check()):
self.__stage1_authorized.set_result(None) self.__stage1_authorized.set_passed()
return True return True
return False return False
@@ -339,25 +338,23 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__mouse_move = move self.__mouse_move = move
async def _on_cut_event(self, text: str) -> None: async def _on_cut_event(self, text: str) -> None:
assert self.__stage1_authorized.done() assert self.__stage1_authorized.is_passed()
assert self.__kvmd_session assert self.__kvmd_session
logger = get_logger(0) logger = get_logger(0)
logger.info("[main] %s: Printing %d characters ...", self._remote, len(text)) logger.info("[main] %s: Printing %d characters ...", self._remote, len(text))
try: try:
(default, available) = await self.__kvmd_session.hid.get_keymaps() (keymap_name, available) = await self.__kvmd_session.hid.get_keymaps()
await self.__kvmd_session.hid.print( if self.__keymap_name in available:
text=text, keymap_name = self.__keymap_name
limit=0, await self.__kvmd_session.hid.print(text, 0, keymap_name)
keymap_name=(self.__keymap_name if self.__keymap_name in available else default),
)
except Exception: except Exception:
logger.exception("[main] %s: Can't print characters", self._remote) logger.exception("[main] %s: Can't print characters", self._remote)
async def _on_set_encodings(self) -> None: async def _on_set_encodings(self) -> None:
assert self.__stage1_authorized.done() assert self.__stage1_authorized.is_passed()
assert self.__kvmd_session assert self.__kvmd_session
if not self.__stage2_encodings_accepted.done(): self.__stage2_encodings_accepted.set_passed(multi=True)
self.__stage2_encodings_accepted.set_result(None)
has_quality = (await self.__kvmd_session.streamer.get_state())["features"]["quality"] has_quality = (await self.__kvmd_session.streamer.get_state())["features"]["quality"]
quality = (self._encodings.tight_jpeg_quality if has_quality else None) quality = (self._encodings.tight_jpeg_quality if has_quality else None)
get_logger(0).info("[main] %s: Applying streamer params: jpeg_quality=%s; desired_fps=%d ...", get_logger(0).info("[main] %s: Applying streamer params: jpeg_quality=%s; desired_fps=%d ...",