share ClientSession via KvmdClientSession

This commit is contained in:
Devaev Maxim 2020-05-24 03:00:29 +03:00
parent 564c67fdb7
commit d61471d3a3
4 changed files with 165 additions and 124 deletions

View File

@ -21,9 +21,9 @@
import asyncio
import functools
from typing import Dict
from typing import Callable
import aiohttp
@ -95,7 +95,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
session.send_ipmi_response(code=0xC1)
def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None:
result = self.__make_request(session, "atx.get_state()", self.__kvmd.atx.get_state)
result = self.__make_request(session, "atx.get_state()", "atx.get_state")
data = [int(result["leds"]["power"]), 0, 0]
session.send_ipmi_response(data=data)
@ -107,7 +107,7 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
5: "off",
}.get(request["data"][0], "")
if action:
if not self.__make_request(session, f"atx.switch_power({action})", self.__kvmd.atx.switch_power, action=action):
if not self.__make_request(session, f"atx.switch_power({action})", "atx.switch_power", action=action):
code = 0xC0 # Try again later
else:
code = 0
@ -117,19 +117,18 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
# =====
def __make_request(self, session: IpmiServerSession, name: str, method: Callable, **kwargs): # type: ignore
def __make_request(self, session: IpmiServerSession, name: str, method_path: str, **kwargs): # type: ignore
async def runner(): # type: ignore
logger = get_logger(0)
credentials = self.__auth_manager.get_credentials(session.username.decode())
logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)",
name, credentials.ipmi_user, credentials.kvmd_user)
try:
return (await method(credentials.kvmd_user, credentials.kvmd_passwd, **kwargs))
async with self.__kvmd.make_session(credentials.kvmd_user, credentials.kvmd_passwd) as kvmd_session:
method = functools.reduce(getattr, method_path.split("."), kvmd_session)
return (await method(**kwargs))
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Can't perform request %s: %s", name, str(err))
raise
except Exception:
logger.exception("Unexpected exception while performing request %s", name)
raise
return aiotools.run_sync(runner())

View File

@ -38,6 +38,7 @@ from ...logging import get_logger
from ...keyboard.keysym import SymmapWebKey
from ...keyboard.keysym import build_symmap
from ...clients.kvmd import KvmdClientSession
from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError
@ -105,6 +106,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
self.__shared_params = shared_params
self.__kvmd_session: Optional[KvmdClientSession] = None
self.__authorized = asyncio.Future() # type: ignore
self.__ws_connected = asyncio.Future() # type: ignore
self.__ws_writer_queue: asyncio.queues.Queue = asyncio.Queue()
@ -123,10 +125,14 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
# =====
async def run(self) -> None:
await self._run(
kvmd=self.__kvmd_task_loop(),
streamer=self.__streamer_task_loop(),
)
try:
await self._run(
kvmd=self.__kvmd_task_loop(),
streamer=self.__streamer_task_loop(),
)
finally:
if self.__kvmd_session:
await self.__kvmd_session.close()
# =====
@ -134,9 +140,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
logger = get_logger(0)
await self.__authorized
(user, passwd) = self.__authorized.result()
assert self.__kvmd_session
async with self.__kvmd.ws(user, passwd) as ws:
async with self.__kvmd_session.ws() as ws:
logger.info("[kvmd] Client %s: Connected to KVMD websocket", self._remote)
self.__ws_connected.set_result(None)
@ -238,8 +244,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
# =====
async def _authorize_userpass(self, user: str, passwd: str) -> bool:
if (await self.__kvmd.auth.check(user, passwd)):
self.__authorized.set_result((user, passwd))
self.__kvmd_session = self.__kvmd.make_session(user, passwd)
if (await self.__kvmd_session.auth.check()):
self.__authorized.set_result(None)
return True
return False
@ -285,14 +292,12 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _on_cut_event(self, text: str) -> None:
assert self.__authorized.done()
(user, passwd) = self.__authorized.result()
assert self.__kvmd_session
logger = get_logger(0)
logger.info("[main] Client %s: Printing %d characters ...", self._remote, len(text))
try:
(default, available) = await self.__kvmd.hid.get_keymaps(user, passwd)
await self.__kvmd.hid.print(
user=user,
passwd=passwd,
(default, available) = await self.__kvmd_session.hid.get_keymaps()
await self.__kvmd_session.hid.print(
text=text,
limit=0,
keymap_name=(self.__keymap_name if self.__keymap_name in available else default),
@ -302,10 +307,10 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
async def _on_set_encodings(self) -> None:
assert self.__authorized.done()
(user, passwd) = self.__authorized.result()
assert self.__kvmd_session
get_logger(0).info("[main] Client %s: Applying streamer params: quality=%d%%; desired_fps=%d ...",
self._remote, self._encodings.tight_jpeg_quality, self.__desired_fps)
await self.__kvmd.streamer.set_params(user, passwd, self._encodings.tight_jpeg_quality, self.__desired_fps)
await self.__kvmd_session.streamer.set_params(self._encodings.tight_jpeg_quality, self.__desired_fps)
async def _on_fb_update_request(self) -> None:
async with self.__lock:
@ -348,7 +353,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
logger.info("Preparing client %s ...", remote)
try:
try:
none_auth_only = await kvmd.auth.check("", "")
async with kvmd.make_session("", "") as kvmd_session:
none_auth_only = await kvmd_session.auth.check()
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err)
return

View File

@ -21,11 +21,15 @@
import contextlib
import types
from typing import Tuple
from typing import Dict
from typing import Set
from typing import Callable
from typing import Type
from typing import AsyncGenerator
from typing import Optional
import aiohttp
@ -33,7 +37,128 @@ from .. import aiotools
# =====
class _BaseClientPart:
class _BaseApiPart:
def __init__(
self,
ensure_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self._ensure_http_session = ensure_http_session
self._make_url = make_url
class _AuthApiPart(_BaseApiPart):
async def check(self) -> bool:
session = self._ensure_http_session()
try:
async with session.get(self._make_url("auth/check")) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status in [401, 403]:
return False
raise
class _StreamerApiPart(_BaseApiPart):
async def set_params(self, quality: int, desired_fps: int) -> None:
session = self._ensure_http_session()
async with session.post(
url=self._make_url("streamer/set_params"),
params={"quality": quality, "desired_fps": desired_fps},
) as response:
aiotools.raise_not_200(response)
class _HidApiPart(_BaseApiPart):
async def get_keymaps(self) -> Tuple[str, Set[str]]:
session = self._ensure_http_session()
async with session.get(self._make_url("hid/keymaps")) as response:
aiotools.raise_not_200(response)
result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
async def print(self, text: str, limit: int, keymap_name: str) -> None:
session = self._ensure_http_session()
async with session.post(
url=self._make_url("hid/print"),
params={"limit": limit, "keymap": keymap_name},
data=text,
) as response:
aiotools.raise_not_200(response)
class _AtxApiPart(_BaseApiPart):
async def get_state(self) -> Dict:
session = self._ensure_http_session()
async with session.get(self._make_url("atx")) as response:
aiotools.raise_not_200(response)
return (await response.json())["result"]
async def switch_power(self, action: str) -> bool:
session = self._ensure_http_session()
try:
async with session.post(
url=self._make_url("atx/power"),
params={"action": action},
) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status == 409:
return False
raise
class KvmdClientSession:
def __init__(
self,
make_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self.__make_http_session = make_http_session
self.__make_url = make_url
self.__http_session: Optional[aiohttp.ClientSession] = None
args = (self.__ensure_http_session, make_url)
self.auth = _AuthApiPart(*args)
self.streamer = _StreamerApiPart(*args)
self.hid = _HidApiPart(*args)
self.atx = _AtxApiPart(*args)
@contextlib.asynccontextmanager
async def ws(self) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
session = self.__ensure_http_session()
async with session.ws_connect(self.__make_url("ws")) as ws:
yield ws
def __ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self.__make_http_session()
return self.__http_session
async def close(self) -> None:
if self.__http_session:
await self.__http_session.close()
self.__http_session = None
async def __aenter__(self) -> "KvmdClientSession":
return self
async def __aexit__(
self,
_exc_type: Type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class KvmdClient:
def __init__(
self,
host: str,
@ -43,14 +168,19 @@ class _BaseClientPart:
user_agent: str,
) -> None:
assert port or unix_path
self.__host = host
self.__port = port
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
def make_session(self, user: str, passwd: str) -> KvmdClientSession:
return KvmdClientSession(
make_http_session=(lambda: self.__make_http_session(user, passwd)),
make_url=self.__make_url,
)
def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
kwargs: Dict = {
"headers": {
"X-KVMD-User": user,
@ -63,102 +193,6 @@ class _BaseClientPart:
kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path)
return aiohttp.ClientSession(**kwargs)
def _make_url(self, handle: str) -> str:
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://{self.__host}:{self.__port}/{handle}"
class _AuthClientPart(_BaseClientPart):
async def check(self, user: str, passwd: str) -> bool:
try:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("auth/check")) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status in [401, 403]:
return False
raise
class _StreamerClientPart(_BaseClientPart):
async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("streamer/set_params"),
params={"quality": quality, "desired_fps": desired_fps},
) as response:
aiotools.raise_not_200(response)
class _HidClientPart(_BaseClientPart):
async def get_keymaps(self, user: str, passwd: str) -> Tuple[str, Set[str]]:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("hid/keymaps")) as response:
aiotools.raise_not_200(response)
result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
async def print(self, user: str, passwd: str, text: str, limit: int, keymap_name: str) -> None:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("hid/print"),
params={"limit": limit, "keymap": keymap_name},
data=text,
) as response:
aiotools.raise_not_200(response)
class _AtxClientPart(_BaseClientPart):
async def get_state(self, user: str, passwd: str) -> Dict:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("atx")) as response:
aiotools.raise_not_200(response)
return (await response.json())["result"]
async def switch_power(self, user: str, passwd: str, action: str) -> bool:
try:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("atx/power"),
params={"action": action},
) as response:
aiotools.raise_not_200(response)
return True
except aiohttp.ClientResponseError as err:
if err.status == 409:
return False
raise
# =====
class KvmdClient(_BaseClientPart):
def __init__(
self,
host: str,
port: int,
unix_path: str,
timeout: float,
user_agent: str,
) -> None:
kwargs: Dict = {
"host": host,
"port": port,
"unix_path": unix_path,
"timeout": timeout,
"user_agent": user_agent,
}
super().__init__(**kwargs)
self.auth = _AuthClientPart(**kwargs)
self.streamer = _StreamerClientPart(**kwargs)
self.hid = _HidClientPart(**kwargs)
self.atx = _AtxClientPart(**kwargs)
@contextlib.asynccontextmanager
async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
async with self._make_session(user, passwd) as session:
async with session.ws_connect(self._make_url("ws")) as ws:
yield ws

View File

@ -18,6 +18,8 @@ InotifyMask.UNMOUNT
IpmiServer.handle_raw_request
_AtxApiPart.switch_power
fake_rpi.RPi.GPIO
_KeyMapping.web_name