share ClientSession via KvmdClientSession

This commit is contained in:
Devaev Maxim 2020-05-24 03:00:29 +03:00
parent 564c67fdb7
commit d61471d3a3
4 changed files with 165 additions and 124 deletions

View File

@ -21,9 +21,9 @@
import asyncio import asyncio
import functools
from typing import Dict from typing import Dict
from typing import Callable
import aiohttp import aiohttp
@ -95,7 +95,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
session.send_ipmi_response(code=0xC1) session.send_ipmi_response(code=0xC1)
def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None: def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None:
result = self.__make_request(session, "atx.get_state()", self.__kvmd.atx.get_state) result = self.__make_request(session, "atx.get_state()", "atx.get_state")
data = [int(result["leds"]["power"]), 0, 0] data = [int(result["leds"]["power"]), 0, 0]
session.send_ipmi_response(data=data) session.send_ipmi_response(data=data)
@ -107,7 +107,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
5: "off", 5: "off",
}.get(request["data"][0], "") }.get(request["data"][0], "")
if action: if action:
if not self.__make_request(session, f"atx.switch_power({action})", self.__kvmd.atx.switch_power, action=action): if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action):
code = 0xC0 # Try again later code = 0xC0 # Try again later
else: else:
code = 0 code = 0
@ -117,19 +117,18 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
# ===== # =====
def __make_request(self, session: IpmiServerSession, name: str, method: Callable, **kwargs): # type: ignore def __make_request(self, session: IpmiServerSession, name: str, method_path: str, **kwargs): # type: ignore
async def runner(): # type: ignore async def runner(): # type: ignore
logger = get_logger(0) logger = get_logger(0)
credentials = self.__auth_manager.get_credentials(session.username.decode()) credentials = self.__auth_manager.get_credentials(session.username.decode())
logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)", logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)",
name, credentials.ipmi_user, credentials.kvmd_user) name, credentials.ipmi_user, credentials.kvmd_user)
try: try:
return (await method(credentials.kvmd_user, credentials.kvmd_passwd, **kwargs)) async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session:
method = functools.reduce(getattr, method_path.split("."), kvmd_session)
return (await method(**kwargs))
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Can't perform request %s: %s", name, str(err)) logger.error("Can't perform request %s: %s", name, str(err))
raise raise
except Exception:
logger.exception("Unexpected exception while performing request %s", name)
raise
return aiotools.run_sync(runner()) return aiotools.run_sync(runner())

View File

@ -38,6 +38,7 @@ from ...logging import get_logger
from ...keyboard.keysym import SymmapWebKey from ...keyboard.keysym import SymmapWebKey
from ...keyboard.keysym import build_symmap from ...keyboard.keysym import build_symmap
from ...clients.kvmd import KvmdClientSession
from ...clients.kvmd import KvmdClient from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError from ...clients.streamer import StreamerError
@ -105,6 +106,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__shared_params = shared_params self.__shared_params = shared_params
self.__kvmd_session: Optional[KvmdClientSession] = None
self.__authorized = asyncio.Future() # type: ignore self.__authorized = asyncio.Future() # type: ignore
self.__ws_connected = asyncio.Future() # type: ignore self.__ws_connected = asyncio.Future() # type: ignore
self.__ws_writer_queue: asyncio.queues.Queue = asyncio.Queue() self.__ws_writer_queue: asyncio.queues.Queue = asyncio.Queue()
@ -123,10 +125,14 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
# ===== # =====
async def run(self) -> None: async def run(self) -> None:
try:
await self._run( await self._run(
kvmd=self.__kvmd_task_loop(), kvmd=self.__kvmd_task_loop(),
streamer=self.__streamer_task_loop(), streamer=self.__streamer_task_loop(),
) )
finally:
if self.__kvmd_session:
await self.__kvmd_session.close()
# ===== # =====
@ -134,9 +140,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
logger = get_logger(0) logger = get_logger(0)
await self.__authorized await self.__authorized
(user, passwd) = self.__authorized.result() assert self.__kvmd_session
async with self.__kvmd.ws(user, passwd) as ws: async with self.__kvmd_session.ws() as ws:
logger.info("[kvmd] Client %s: Connected to KVMD websocket", self._remote) logger.info("[kvmd] Client %s: Connected to KVMD websocket", self._remote)
self.__ws_connected.set_result(None) self.__ws_connected.set_result(None)
@ -238,8 +244,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
# ===== # =====
async def _authorize_userpass(self, user: str, passwd: str) -> bool: async def _authorize_userpass(self, user: str, passwd: str) -> bool:
if (await self.__kvmd.auth.check(user, passwd)): self.__kvmd_session = self.__kvmd.make_session(user, passwd)
self.__authorized.set_result((user, passwd)) if (await self.__kvmd_session.auth.check()):
self.__authorized.set_result(None)
return True return True
return False return False
@ -285,14 +292,12 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _on_cut_event(self, text: str) -> None: async def _on_cut_event(self, text: str) -> None:
assert self.__authorized.done() assert self.__authorized.done()
(user, passwd) = self.__authorized.result() assert self.__kvmd_session
logger = get_logger(0) logger = get_logger(0)
logger.info("[main] Client %s: Printing %d characters ...", self._remote, len(text)) logger.info("[main] Client %s: Printing %d characters ...", self._remote, len(text))
try: try:
(default, available) = await self.__kvmd.hid.get_keymaps(user, passwd) (default, available) = await self.__kvmd_session.hid.get_keymaps()
await self.__kvmd.hid.print( await self.__kvmd_session.hid.print(
user=user,
passwd=passwd,
text=text, text=text,
limit=0, limit=0,
keymap_name=(self.__keymap_name if self.__keymap_name in available else default), keymap_name=(self.__keymap_name if self.__keymap_name in available else default),
@ -302,10 +307,10 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _on_set_encodings(self) -> None: async def _on_set_encodings(self) -> None:
assert self.__authorized.done() assert self.__authorized.done()
(user, passwd) = self.__authorized.result() assert self.__kvmd_session
get_logger(0).info("[main] Client %s: Applying streamer params: quality=%d%%; desired_fps=%d ...", get_logger(0).info("[main] Client %s: Applying streamer params: quality=%d%%; desired_fps=%d ...",
self._remote, self._encodings.tight_jpeg_quality, self.__desired_fps) self._remote, self._encodings.tight_jpeg_quality, self.__desired_fps)
await self.__kvmd.streamer.set_params(user, passwd, self._encodings.tight_jpeg_quality, self.__desired_fps) await self.__kvmd_session.streamer.set_params(self._encodings.tight_jpeg_quality, self.__desired_fps)
async def _on_fb_update_request(self) -> None: async def _on_fb_update_request(self) -> None:
async with self.__lock: async with self.__lock:
@ -348,7 +353,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
logger.info("Preparing client %s ...", remote) logger.info("Preparing client %s ...", remote)
try: try:
try: try:
none_auth_only = await kvmd.auth.check("", "") async with kvmd.make_session("", "") as kvmd_session:
none_auth_only = await kvmd_session.auth.check()
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err) logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err)
return return

View File

@ -21,11 +21,15 @@
import contextlib import contextlib
import types
from typing import Tuple from typing import Tuple
from typing import Dict from typing import Dict
from typing import Set from typing import Set
from typing import Callable
from typing import Type
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Optional
import aiohttp import aiohttp
@ -33,45 +37,21 @@ from .. import aiotools
# ===== # =====
class _BaseClientPart: class _BaseApiPart:
def __init__( def __init__(
self, self,
host: str, ensure_http_session: Callable[[], aiohttp.ClientSession],
port: int, make_url: Callable[[str], str],
unix_path: str,
timeout: float,
user_agent: str,
) -> None: ) -> None:
assert port or unix_path self._ensure_http_session = ensure_http_session
self.__host = host self._make_url = make_url
self.__port = port
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
kwargs: Dict = {
"headers": {
"X-KVMD-User": user,
"X-KVMD-Passwd": passwd,
"User-Agent": self.__user_agent,
},
"timeout": aiohttp.ClientTimeout(total=self.__timeout),
}
if self.__unix_path:
kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path)
return aiohttp.ClientSession(**kwargs)
def _make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://{self.__host}:{self.__port}/{handle}"
class _AuthClientPart(_BaseClientPart): class _AuthApiPart(_BaseApiPart):
async def check(self, user: str, passwd: str) -> bool: async def check(self) -> bool:
session = self._ensure_http_session()
try: try:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("auth/check")) as response: async with session.get(self._make_url("auth/check")) as response:
aiotools.raise_not_200(response) aiotools.raise_not_200(response)
return True return True
@ -81,9 +61,9 @@ class _AuthClientPart(_BaseClientPart):
raise raise
class _StreamerClientPart(_BaseClientPart): class _StreamerApiPart(_BaseApiPart):
async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None: async def set_params(self, quality: int, desired_fps: int) -> None:
async with self._make_session(user, passwd) as session: session = self._ensure_http_session()
async with session.post( async with session.post(
url=self._make_url("streamer/set_params"), url=self._make_url("streamer/set_params"),
params={"quality": quality, "desired_fps": desired_fps}, params={"quality": quality, "desired_fps": desired_fps},
@ -91,16 +71,16 @@ class _StreamerClientPart(_BaseClientPart):
aiotools.raise_not_200(response) aiotools.raise_not_200(response)
class _HidClientPart(_BaseClientPart): class _HidApiPart(_BaseApiPart):
async def get_keymaps(self, user: str, passwd: str) -> Tuple[str, Set[str]]: async def get_keymaps(self) -> Tuple[str, Set[str]]:
async with self._make_session(user, passwd) as session: session = self._ensure_http_session()
async with session.get(self._make_url("hid/keymaps")) as response: async with session.get(self._make_url("hid/keymaps")) as response:
aiotools.raise_not_200(response) aiotools.raise_not_200(response)
result = (await response.json())["result"] result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"])) return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
async def print(self, user: str, passwd: str, text: str, limit: int, keymap_name: str) -> None: async def print(self, text: str, limit: int, keymap_name: str) -> None:
async with self._make_session(user, passwd) as session: session = self._ensure_http_session()
async with session.post( async with session.post(
url=self._make_url("hid/print"), url=self._make_url("hid/print"),
params={"limit": limit, "keymap": keymap_name}, params={"limit": limit, "keymap": keymap_name},
@ -109,16 +89,16 @@ class _HidClientPart(_BaseClientPart):
aiotools.raise_not_200(response) aiotools.raise_not_200(response)
class _AtxClientPart(_BaseClientPart): class _AtxApiPart(_BaseApiPart):
async def get_state(self, user: str, passwd: str) -> Dict: async def get_state(self) -> Dict:
async with self._make_session(user, passwd) as session: session = self._ensure_http_session()
async with session.get(self._make_url("atx")) as response: async with session.get(self._make_url("atx")) as response:
aiotools.raise_not_200(response) aiotools.raise_not_200(response)
return (await response.json())["result"] return (await response.json())["result"]
async def switch_power(self, user: str, passwd: str, action: str) -> bool: async def switch_power(self, action: str) -> bool:
session = self._ensure_http_session()
try: try:
async with self._make_session(user, passwd) as session:
async with session.post( async with session.post(
url=self._make_url("atx/power"), url=self._make_url("atx/power"),
params={"action": action}, params={"action": action},
@ -131,8 +111,54 @@ class _AtxClientPart(_BaseClientPart):
raise raise
# ===== class KvmdClientSession:
class KvmdClient(_BaseClientPart): def __init__(
self,
make_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self.__make_http_session = make_http_session
self.__make_url = make_url
self.__http_session: Optional[aiohttp.ClientSession] = None
args = (self.__ensure_http_session, make_url)
self.auth = _AuthApiPart(*args)
self.streamer = _StreamerApiPart(*args)
self.hid = _HidApiPart(*args)
self.atx = _AtxApiPart(*args)
@contextlib.asynccontextmanager
async def ws(self) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
session = self.__ensure_http_session()
async with session.ws_connect(self.__make_url("ws")) as ws:
yield ws
def __ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self.__make_http_session()
return self.__http_session
async def close(self) -> None:
if self.__http_session:
await self.__http_session.close()
self.__http_session = None
async def __aenter__(self) -> "KvmdClientSession":
return self
async def __aexit__(
self,
_exc_type: Type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class KvmdClient:
def __init__( def __init__(
self, self,
host: str, host: str,
@ -142,23 +168,31 @@ class KvmdClient(_BaseClientPart):
user_agent: str, user_agent: str,
) -> None: ) -> None:
self.__host = host
self.__port = port
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def make_session(self, user: str, passwd: str) -> KvmdClientSession:
return KvmdClientSession(
make_http_session=(lambda: self.__make_http_session(user, passwd)),
make_url=self.__make_url,
)
def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
kwargs: Dict = { kwargs: Dict = {
"host": host, "headers": {
"port": port, "X-KVMD-User": user,
"unix_path": unix_path, "X-KVMD-Passwd": passwd,
"timeout": timeout, "User-Agent": self.__user_agent,
"user_agent": user_agent, },
"timeout": aiohttp.ClientTimeout(total=self.__timeout),
} }
if self.__unix_path:
kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path)
return aiohttp.ClientSession(**kwargs)
super().__init__(**kwargs) def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
self.auth = _AuthClientPart(**kwargs) return f"http://{self.__host}:{self.__port}/{handle}"
self.streamer = _StreamerClientPart(**kwargs)
self.hid = _HidClientPart(**kwargs)
self.atx = _AtxClientPart(**kwargs)
@contextlib.asynccontextmanager
async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
async with self._make_session(user, passwd) as session:
async with session.ws_connect(self._make_url("ws")) as ws:
yield ws

View File

@ -18,6 +18,8 @@ InotifyMask.UNMOUNT
IpmiServer.handle_raw_request IpmiServer.handle_raw_request
_AtxApiPart.switch_power
fake_rpi.RPi.GPIO fake_rpi.RPi.GPIO
_KeyMapping.web_name _KeyMapping.web_name