share ClientSession via KvmdClientSession

This commit is contained in:
Devaev Maxim 2020-05-24 03:00:29 +03:00
parent 564c67fdb7
commit d61471d3a3
4 changed files with 165 additions and 124 deletions

View File

@ -21,9 +21,9 @@
import asyncio import asyncio
import functools
from typing import Dict from typing import Dict
from typing import Callable
import aiohttp import aiohttp
@ -95,7 +95,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
session.send_ipmi_response(code=0xC1) session.send_ipmi_response(code=0xC1)
def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None: def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None:
result = self.__make_request(session, "atx.get_state()", self.__kvmd.atx.get_state) result = self.__make_request(session, "atx.get_state()", "atx.get_state")
data = [int(result["leds"]["power"]), 0, 0] data = [int(result["leds"]["power"]), 0, 0]
session.send_ipmi_response(data=data) session.send_ipmi_response(data=data)
@ -107,7 +107,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
5: "off", 5: "off",
}.get(request["data"][0], "") }.get(request["data"][0], "")
if action: if action:
if not self.__make_request(session, f"atx.switch_power({action})", self.__kvmd.atx.switch_power, action=action): if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action):
code = 0xC0 # Try again later code = 0xC0 # Try again later
else: else:
code = 0 code = 0
@ -117,19 +117,18 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
# ===== # =====
def __make_request(self, session: IpmiServerSession, name: str, method: Callable, **kwargs): # type: ignore def __make_request(self, session: IpmiServerSession, name: str, method_path: str, **kwargs): # type: ignore
async def runner(): # type: ignore async def runner(): # type: ignore
logger = get_logger(0) logger = get_logger(0)
credentials = self.__auth_manager.get_credentials(session.username.decode()) credentials = self.__auth_manager.get_credentials(session.username.decode())
logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)", logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)",
name, credentials.ipmi_user, credentials.kvmd_user) name, credentials.ipmi_user, credentials.kvmd_user)
try: try:
return (await method(credentials.kvmd_user, credentials.kvmd_passwd, **kwargs)) async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session:
method = functools.reduce(getattr, method_path.split("."), kvmd_session)
return (await method(**kwargs))
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Can't perform request %s: %s", name, str(err)) logger.error("Can't perform request %s: %s", name, str(err))
raise raise
except Exception:
logger.exception("Unexpected exception while performing request %s", name)
raise
return aiotools.run_sync(runner()) return aiotools.run_sync(runner())

View File

@ -38,6 +38,7 @@ from ...logging import get_logger
from ...keyboard.keysym import SymmapWebKey from ...keyboard.keysym import SymmapWebKey
from ...keyboard.keysym import build_symmap from ...keyboard.keysym import build_symmap
from ...clients.kvmd import KvmdClientSession
from ...clients.kvmd import KvmdClient from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError from ...clients.streamer import StreamerError
@ -105,6 +106,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__shared_params = shared_params self.__shared_params = shared_params
self.__kvmd_session: Optional[KvmdClientSession] = None
self.__authorized = asyncio.Future() # type: ignore self.__authorized = asyncio.Future() # type: ignore
self.__ws_connected = asyncio.Future() # type: ignore self.__ws_connected = asyncio.Future() # type: ignore
self.__ws_writer_queue: asyncio.queues.Queue = asyncio.Queue() self.__ws_writer_queue: asyncio.queues.Queue = asyncio.Queue()
@ -123,10 +125,14 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
# ===== # =====
async def run(self) -> None: async def run(self) -> None:
await self._run( try:
kvmd=self.__kvmd_task_loop(), await self._run(
streamer=self.__streamer_task_loop(), kvmd=self.__kvmd_task_loop(),
) streamer=self.__streamer_task_loop(),
)
finally:
if self.__kvmd_session:
await self.__kvmd_session.close()
# ===== # =====
@ -134,9 +140,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
logger = get_logger(0) logger = get_logger(0)
await self.__authorized await self.__authorized
(user, passwd) = self.__authorized.result() assert self.__kvmd_session
async with self.__kvmd.ws(user, passwd) as ws: async with self.__kvmd_session.ws() as ws:
logger.info("[kvmd] Client %s: Connected to KVMD websocket", self._remote) logger.info("[kvmd] Client %s: Connected to KVMD websocket", self._remote)
self.__ws_connected.set_result(None) self.__ws_connected.set_result(None)
@ -238,8 +244,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
# ===== # =====
async def _authorize_userpass(self, user: str, passwd: str) -> bool: async def _authorize_userpass(self, user: str, passwd: str) -> bool:
if (await self.__kvmd.auth.check(user, passwd)): self.__kvmd_session = self.__kvmd.make_session(user, passwd)
self.__authorized.set_result((user, passwd)) if (await self.__kvmd_session.auth.check()):
self.__authorized.set_result(None)
return True return True
return False return False
@ -285,14 +292,12 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _on_cut_event(self, text: str) -> None: async def _on_cut_event(self, text: str) -> None:
assert self.__authorized.done() assert self.__authorized.done()
(user, passwd) = self.__authorized.result() assert self.__kvmd_session
logger = get_logger(0) logger = get_logger(0)
logger.info("[main] Client %s: Printing %d characters ...", self._remote, len(text)) logger.info("[main] Client %s: Printing %d characters ...", self._remote, len(text))
try: try:
(default, available) = await self.__kvmd.hid.get_keymaps(user, passwd) (default, available) = await self.__kvmd_session.hid.get_keymaps()
await self.__kvmd.hid.print( await self.__kvmd_session.hid.print(
user=user,
passwd=passwd,
text=text, text=text,
limit=0, limit=0,
keymap_name=(self.__keymap_name if self.__keymap_name in available else default), keymap_name=(self.__keymap_name if self.__keymap_name in available else default),
@ -302,10 +307,10 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _on_set_encodings(self) -> None: async def _on_set_encodings(self) -> None:
assert self.__authorized.done() assert self.__authorized.done()
(user, passwd) = self.__authorized.result() assert self.__kvmd_session
get_logger(0).info("[main] Client %s: Applying streamer params: quality=%d%%; desired_fps=%d ...", get_logger(0).info("[main] Client %s: Applying streamer params: quality=%d%%; desired_fps=%d ...",
self._remote, self._encodings.tight_jpeg_quality, self.__desired_fps) self._remote, self._encodings.tight_jpeg_quality, self.__desired_fps)
await self.__kvmd.streamer.set_params(user, passwd, self._encodings.tight_jpeg_quality, self.__desired_fps) await self.__kvmd_session.streamer.set_params(self._encodings.tight_jpeg_quality, self.__desired_fps)
async def _on_fb_update_request(self) -> None: async def _on_fb_update_request(self) -> None:
async with self.__lock: async with self.__lock:
@ -348,7 +353,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
logger.info("Preparing client %s ...", remote) logger.info("Preparing client %s ...", remote)
try: try:
try: try:
none_auth_only = await kvmd.auth.check("", "") async with kvmd.make_session("", "") as kvmd_session:
none_auth_only = await kvmd_session.auth.check()
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err) logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err)
return return

View File

@ -21,11 +21,15 @@
import contextlib import contextlib
import types
from typing import Tuple from typing import Tuple
from typing import Dict from typing import Dict
from typing import Set from typing import Set
from typing import Callable
from typing import Type
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Optional
import aiohttp import aiohttp
@ -33,7 +37,128 @@ from .. import aiotools
# ===== # =====
class _BaseClientPart: class _BaseApiPart:
def __init__(
self,
ensure_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self._ensure_http_session = ensure_http_session
self._make_url = make_url
class _AuthApiPart(_BaseApiPart):
async def check(self) -> bool:
session = self._ensure_http_session()
try:
async with session.get(self._make_url("auth/check")) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status in [401, 403]:
return False
raise
class _StreamerApiPart(_BaseApiPart):
async def set_params(self, quality: int, desired_fps: int) -> None:
session = self._ensure_http_session()
async with session.post(
url=self._make_url("streamer/set_params"),
params={"quality": quality, "desired_fps": desired_fps},
) as response:
aiotools.raise_not_200(response)
class _HidApiPart(_BaseApiPart):
async def get_keymaps(self) -> Tuple[str, Set[str]]:
session = self._ensure_http_session()
async with session.get(self._make_url("hid/keymaps")) as response:
aiotools.raise_not_200(response)
result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
async def print(self, text: str, limit: int, keymap_name: str) -> None:
session = self._ensure_http_session()
async with session.post(
url=self._make_url("hid/print"),
params={"limit": limit, "keymap": keymap_name},
data=text,
) as response:
aiotools.raise_not_200(response)
class _AtxApiPart(_BaseApiPart):
async def get_state(self) -> Dict:
session = self._ensure_http_session()
async with session.get(self._make_url("atx")) as response:
aiotools.raise_not_200(response)
return (await response.json())["result"]
async def switch_power(self, action: str) -> bool:
session = self._ensure_http_session()
try:
async with session.post(
url=self._make_url("atx/power"),
params={"action": action},
) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status == 409:
return False
raise
class KvmdClientSession:
def __init__(
self,
make_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self.__make_http_session = make_http_session
self.__make_url = make_url
self.__http_session: Optional[aiohttp.ClientSession] = None
args = (self.__ensure_http_session, make_url)
self.auth = _AuthApiPart(*args)
self.streamer = _StreamerApiPart(*args)
self.hid = _HidApiPart(*args)
self.atx = _AtxApiPart(*args)
@contextlib.asynccontextmanager
async def ws(self) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
session = self.__ensure_http_session()
async with session.ws_connect(self.__make_url("ws")) as ws:
yield ws
def __ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self.__make_http_session()
return self.__http_session
async def close(self) -> None:
if self.__http_session:
await self.__http_session.close()
self.__http_session = None
async def __aenter__(self) -> "KvmdClientSession":
return self
async def __aexit__(
self,
_exc_type: Type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class KvmdClient:
def __init__( def __init__(
self, self,
host: str, host: str,
@ -43,14 +168,19 @@ class _BaseClientPart:
user_agent: str, user_agent: str,
) -> None: ) -> None:
assert port or unix_path
self.__host = host self.__host = host
self.__port = port self.__port = port
self.__unix_path = unix_path self.__unix_path = unix_path
self.__timeout = timeout self.__timeout = timeout
self.__user_agent = user_agent self.__user_agent = user_agent
def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession: def make_session(self, user: str, passwd: str) -> KvmdClientSession:
return KvmdClientSession(
make_http_session=(lambda: self.__make_http_session(user, passwd)),
make_url=self.__make_url,
)
def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
kwargs: Dict = { kwargs: Dict = {
"headers": { "headers": {
"X-KVMD-User": user, "X-KVMD-User": user,
@ -63,102 +193,6 @@ class _BaseClientPart:
kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path) kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path)
return aiohttp.ClientSession(**kwargs) return aiohttp.ClientSession(**kwargs)
def _make_url(self, handle: str) -> str: def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle assert not handle.startswith("/"), handle
return f"http://{self.__host}:{self.__port}/{handle}" return f"http://{self.__host}:{self.__port}/{handle}"
class _AuthClientPart(_BaseClientPart):
async def check(self, user: str, passwd: str) -> bool:
try:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("auth/check")) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status in [401, 403]:
return False
raise
class _StreamerClientPart(_BaseClientPart):
async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("streamer/set_params"),
params={"quality": quality, "desired_fps": desired_fps},
) as response:
aiotools.raise_not_200(response)
class _HidClientPart(_BaseClientPart):
async def get_keymaps(self, user: str, passwd: str) -> Tuple[str, Set[str]]:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("hid/keymaps")) as response:
aiotools.raise_not_200(response)
result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
async def print(self, user: str, passwd: str, text: str, limit: int, keymap_name: str) -> None:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("hid/print"),
params={"limit": limit, "keymap": keymap_name},
data=text,
) as response:
aiotools.raise_not_200(response)
class _AtxClientPart(_BaseClientPart):
async def get_state(self, user: str, passwd: str) -> Dict:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("atx")) as response:
aiotools.raise_not_200(response)
return (await response.json())["result"]
async def switch_power(self, user: str, passwd: str, action: str) -> bool:
try:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("atx/power"),
params={"action": action},
) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status == 409:
return False
raise
# =====
class KvmdClient(_BaseClientPart):
def __init__(
self,
host: str,
port: int,
unix_path: str,
timeout: float,
user_agent: str,
) -> None:
kwargs: Dict = {
"host": host,
"port": port,
"unix_path": unix_path,
"timeout": timeout,
"user_agent": user_agent,
}
super().__init__(**kwargs)
self.auth = _AuthClientPart(**kwargs)
self.streamer = _StreamerClientPart(**kwargs)
self.hid = _HidClientPart(**kwargs)
self.atx = _AtxClientPart(**kwargs)
@contextlib.asynccontextmanager
async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
async with self._make_session(user, passwd) as session:
async with session.ws_connect(self._make_url("ws")) as ws:
yield ws

View File

@ -18,6 +18,8 @@ InotifyMask.UNMOUNT
IpmiServer.handle_raw_request IpmiServer.handle_raw_request
_AtxApiPart.switch_power
fake_rpi.RPi.GPIO fake_rpi.RPi.GPIO
_KeyMapping.web_name _KeyMapping.web_name