refactoring

This commit is contained in:
Devaev Maxim
2020-05-18 13:34:23 +03:00
parent 3947640771
commit 028e0b06ff
10 changed files with 110 additions and 132 deletions

View File

@@ -39,6 +39,8 @@ from typing import TypeVar
from typing import Optional from typing import Optional
from typing import Any from typing import Any
import aiohttp
import aiofiles import aiofiles
import aiofiles.base import aiofiles.base
@@ -93,6 +95,20 @@ async def wait_first(*aws: Awaitable) -> Tuple[Set[asyncio.Future], Set[asyncio.
return (await asyncio.wait(list(aws), return_when=asyncio.FIRST_COMPLETED)) return (await asyncio.wait(list(aws), return_when=asyncio.FIRST_COMPLETED))
# =====
def raise_not_200(response: aiohttp.ClientResponse) -> None:
if response.status != 200:
assert response.reason is not None
response.release()
raise aiohttp.ClientResponseError(
response.request_info,
response.history,
status=response.status,
message=response.reason,
headers=response.headers,
)
# ===== # =====
async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: bytes) -> None: async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: bytes) -> None:
await afile.write(data) await afile.write(data)

View File

@@ -311,10 +311,10 @@ def _get_config_scheme() -> Dict:
}, },
"kvmd": { "kvmd": {
"host": Option("localhost", type=valid_ip_or_host, unpack_as="kvmd_host"), "host": Option("localhost", type=valid_ip_or_host),
"port": Option(0, type=valid_port, unpack_as="kvmd_port"), "port": Option(0, type=valid_port),
"unix": Option("", type=valid_abs_path, only_if="!port", unpack_as="kvmd_unix_path"), "unix": Option("", type=valid_abs_path, only_if="!port", unpack_as="unix_path"),
"timeout": Option(5.0, type=valid_float_f01, unpack_as="kvmd_timeout"), "timeout": Option(5.0, type=valid_float_f01),
}, },
"auth": { "auth": {

View File

@@ -23,6 +23,10 @@
from typing import List from typing import List
from typing import Optional from typing import Optional
from ...clients.kvmd import KvmdClient
from ... import make_user_agent
from .. import init from .. import init
from .auth import IpmiAuthManager from .auth import IpmiAuthManager
@@ -40,8 +44,9 @@ def main(argv: Optional[List[str]]=None) -> None:
# pylint: disable=protected-access # pylint: disable=protected-access
IpmiServer( IpmiServer(
auth_manager=IpmiAuthManager(**config.auth._unpack()), auth_manager=IpmiAuthManager(**config.auth._unpack()),
**{ # Dirty mypy hack kvmd=KvmdClient(
**config.server._unpack(), user_agent=make_user_agent("KVMD-IPMI"),
**config.kvmd._unpack(), **config.kvmd._unpack(),
}, ),
).run() # type: ignore **config.server._unpack(),
).run()

View File

@@ -20,13 +20,10 @@
# ========================================================================== # # ========================================================================== #
import sys
import asyncio import asyncio
import threading
from typing import Tuple
from typing import Dict from typing import Dict
from typing import Optional from typing import Callable
import aiohttp import aiohttp
@@ -36,7 +33,9 @@ from pyghmi.ipmi.private.serversession import ServerSession as IpmiServerSession
from ...logging import get_logger from ...logging import get_logger
from ... import make_user_agent from ...clients.kvmd import KvmdClient
from ... import aiotools
from .auth import IpmiAuthManager from .auth import IpmiAuthManager
@@ -49,30 +48,22 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
def __init__( def __init__(
self, self,
auth_manager: IpmiAuthManager, auth_manager: IpmiAuthManager,
kvmd: KvmdClient,
host: str, host: str,
port: str, port: str,
timeout: float, timeout: float,
kvmd_host: str,
kvmd_port: int,
kvmd_unix_path: str,
kvmd_timeout: float,
) -> None: ) -> None:
super().__init__(authdata=auth_manager, address=host, port=port) super().__init__(authdata=auth_manager, address=host, port=port)
self.__auth_manager = auth_manager self.__auth_manager = auth_manager
self.__kvmd = kvmd
self.__host = host self.__host = host
self.__port = port self.__port = port
self.__timeout = timeout self.__timeout = timeout
self.__kvmd_host = kvmd_host
self.__kvmd_port = kvmd_port
self.__kvmd_unix_path = kvmd_unix_path
self.__kvmd_timeout = kvmd_timeout
def run(self) -> None: def run(self) -> None:
logger = get_logger(0) logger = get_logger(0)
logger.info("Listening IPMI on UPD [%s]:%d ...", self.__host, self.__port) logger.info("Listening IPMI on UPD [%s]:%d ...", self.__host, self.__port)
@@ -104,19 +95,19 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
session.send_ipmi_response(code=0xC1) session.send_ipmi_response(code=0xC1)
def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None: def __get_chassis_status_handler(self, _: Dict, session: IpmiServerSession) -> None:
result = self.__make_request("GET", "/atx", session)[1] result = self.__make_request(session, "atx.get_state()", self.__kvmd.atx.get_state)
data = [int(result["leds"]["power"]), 0, 0] data = [int(result["leds"]["power"]), 0, 0]
session.send_ipmi_response(data=data) session.send_ipmi_response(data=data)
def __chassis_control_handler(self, request: Dict, session: IpmiServerSession) -> None: def __chassis_control_handler(self, request: Dict, session: IpmiServerSession) -> None:
handle = { action = {
0: "/atx/power?action=off_hard", 0: "off_hard",
1: "/atx/power?action=on", 1: "on",
3: "/atx/power?action=reset_hard", 3: "reset_hard",
5: "/atx/power?action=off", 5: "off",
}.get(request["data"][0], "") }.get(request["data"][0], "")
if handle: if action:
if self.__make_request("POST", handle, session)[0] == 409: if not self.__make_request(session, f"atx.switch_power({action})", self.__kvmd.atx.switch_power, action=action):
code = 0xC0 # Try again later code = 0xC0 # Try again later
else: else:
code = 0 code = 0
@@ -126,65 +117,19 @@ class IpmiServer(BaseIpmiServer): # pylint: disable=too-many-instance-attribute
# ===== # =====
def __make_request(self, method: str, handle: str, ipmi_session: IpmiServerSession) -> Tuple[int, Dict]: def __make_request(self, session: IpmiServerSession, name: str, method: Callable, **kwargs): # type: ignore
result: Optional[Tuple[int, Dict]] = None async def runner(): # type: ignore
exc_info = None logger = get_logger(0)
credentials = self.__auth_manager.get_credentials(session.username.decode())
def make_request() -> None: logger.info("Performing request %s from user %r (IPMI) as %r (KVMD)",
nonlocal result name, credentials.ipmi_user, credentials.kvmd_user)
nonlocal exc_info
loop = asyncio.new_event_loop()
try: try:
result = loop.run_until_complete(self.__make_request_async(method, handle, ipmi_session)) return (await method(credentials.kvmd_user, credentials.kvmd_passwd, **kwargs))
except: # noqa: E722 # pylint: disable=bare-except
exc_info = sys.exc_info()
finally:
loop.close()
thread = threading.Thread(target=make_request, daemon=True)
thread.start()
thread.join()
if exc_info is not None:
raise exc_info[1].with_traceback(exc_info[2]) # type: ignore # pylint: disable=unsubscriptable-object
assert result is not None
# Dirty pylint hack
return (result[0], result[1]) # pylint: disable=unsubscriptable-object
async def __make_request_async(self, method: str, handle: str, ipmi_session: IpmiServerSession) -> Tuple[int, Dict]:
logger = get_logger(0)
assert handle.startswith("/")
url = f"http://{self.__kvmd_host}:{self.__kvmd_port}{handle}"
credentials = self.__auth_manager.get_credentials(ipmi_session.username.decode())
logger.info("Performing %r request to %r from user %r (IPMI) as %r (KVMD)",
method, url, credentials.ipmi_user, credentials.kvmd_user)
async with self.__make_http_session_async() as http_session:
try:
async with http_session.request(
method=method,
url=url,
headers={
"X-KVMD-User": credentials.kvmd_user,
"X-KVMD-Passwd": credentials.kvmd_passwd,
"User-Agent": make_user_agent("KVMD-IPMI"),
},
timeout=self.__kvmd_timeout,
) as response:
if response.status != 409:
response.raise_for_status()
return (response.status, (await response.json())["result"])
except (aiohttp.ClientError, asyncio.TimeoutError) as err: except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Can't perform %r request to %r: %s: %s", method, url, type(err).__name__, str(err)) logger.error("Can't perform request %s: %s", name, str(err))
raise raise
except Exception: except Exception:
logger.exception("Unexpected exception while performing %r request to %r", method, url) logger.exception("Unexpected exception while performing request %s", name)
raise raise
def __make_http_session_async(self) -> aiohttp.ClientSession: return aiotools.run_sync(runner())
if self.__kvmd_unix_path:
return aiohttp.ClientSession(connector=aiohttp.UnixConnector(path=self.__kvmd_unix_path))
else:
return aiohttp.ClientSession()

View File

@@ -183,7 +183,7 @@ class Streamer: # pylint: disable=too-many-instance-attributes
headers={"User-Agent": make_user_agent("KVMD")}, headers={"User-Agent": make_user_agent("KVMD")},
timeout=self.__timeout, timeout=self.__timeout,
) as response: ) as response:
response.raise_for_status() aiotools.raise_not_200(response)
state = (await response.json())["result"] state = (await response.json())["result"]
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError): except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
pass pass

View File

@@ -43,6 +43,8 @@ def main(argv: Optional[List[str]]=None) -> None:
argv=argv, argv=argv,
)[2].vnc )[2].vnc
user_agent = make_user_agent("KVMD-VNC")
# pylint: disable=protected-access # pylint: disable=protected-access
VncServer( VncServer(
host=config.server.host, host=config.server.host,
@@ -55,9 +57,12 @@ def main(argv: Optional[List[str]]=None) -> None:
desired_fps=config.desired_fps, desired_fps=config.desired_fps,
symmap=build_symmap(config.keymap), symmap=build_symmap(config.keymap),
kvmd=KvmdClient(**config.kvmd._unpack()), kvmd=KvmdClient(
user_agent=user_agent,
**config.kvmd._unpack(),
),
streamer=StreamerClient( streamer=StreamerClient(
user_agent=make_user_agent("KVMD-VNC"), user_agent=user_agent,
**config.streamer._unpack(), **config.streamer._unpack(),
), ),
vnc_auth_manager=VncAuthManager(**config.auth.vncauth._unpack()), vnc_auth_manager=VncAuthManager(**config.auth.vncauth._unpack()),

View File

@@ -34,7 +34,6 @@ import aiohttp
from ...logging import get_logger from ...logging import get_logger
from ...clients.kvmd import KvmdError
from ...clients.kvmd import KvmdClient from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError from ...clients.streamer import StreamerError
@@ -327,8 +326,8 @@ class VncServer: # pylint: disable=too-many-instance-attributes
try: try:
try: try:
none_auth_only = await kvmd.auth.check("", "") none_auth_only = await kvmd.auth.check("", "")
except KvmdError as err: except (aiohttp.ClientError, asyncio.TimeoutError) as err:
logger.error("Client %s: Can't check KVMD auth mode: %s", remote, err) logger.error("Client %s: Can't check KVMD auth mode: %s: %s", remote, type(err).__name__, err)
return return
await _Client( await _Client(

View File

@@ -24,20 +24,10 @@ import contextlib
from typing import Dict from typing import Dict
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Union
import aiohttp import aiohttp
from .. import make_user_agent from .. import aiotools
# =====
class KvmdError(Exception):
def __init__(self, err: Union[Exception, str]) -> None:
if isinstance(err, Exception):
super().__init__(f"{type(err).__name__}: {err}")
else:
super().__init__(err)
# ===== # =====
@@ -48,6 +38,7 @@ class _BaseClientPart:
port: int, port: int,
unix_path: str, unix_path: str,
timeout: float, timeout: float,
user_agent: str,
) -> None: ) -> None:
assert port or unix_path assert port or unix_path
@@ -55,17 +46,14 @@ class _BaseClientPart:
self.__port = port self.__port = port
self.__unix_path = unix_path self.__unix_path = unix_path
self.__timeout = timeout self.__timeout = timeout
self.__user_agent = user_agent
def _make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://{self.__host}:{self.__port}/{handle}"
def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession: def _make_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
kwargs: Dict = { kwargs: Dict = {
"headers": { "headers": {
"X-KVMD-User": user, "X-KVMD-User": user,
"X-KVMD-Passwd": passwd, "X-KVMD-Passwd": passwd,
"User-Agent": make_user_agent("KVMD-VNC"), "User-Agent": self.__user_agent,
}, },
"timeout": aiohttp.ClientTimeout(total=self.__timeout), "timeout": aiohttp.ClientTimeout(total=self.__timeout),
} }
@@ -73,35 +61,54 @@ class _BaseClientPart:
kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path) kwargs["connector"] = aiohttp.UnixConnector(path=self.__unix_path)
return aiohttp.ClientSession(**kwargs) return aiohttp.ClientSession(**kwargs)
def _make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://{self.__host}:{self.__port}/{handle}"
class _AuthClientPart(_BaseClientPart): class _AuthClientPart(_BaseClientPart):
async def check(self, user: str, passwd: str) -> bool: async def check(self, user: str, passwd: str) -> bool:
try: try:
async with self._make_session(user, passwd) as session: async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("auth/check")) as response: async with session.get(self._make_url("auth/check")) as response:
response.raise_for_status() aiotools.raise_not_200(response)
if response.status == 200: return True
return True
raise KvmdError(f"Invalid OK response: {response.status} {await response.text()}")
except aiohttp.ClientResponseError as err: except aiohttp.ClientResponseError as err:
if err.status in [401, 403]: if err.status in [401, 403]:
return False return False
raise KvmdError(err) raise
except aiohttp.ClientError as err:
raise KvmdError(err)
class _StreamerClientPart(_BaseClientPart): class _StreamerClientPart(_BaseClientPart):
async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None: async def set_params(self, user: str, passwd: str, quality: int, desired_fps: int) -> None:
async with self._make_session(user, passwd) as session:
async with session.post(
url=self._make_url("streamer/set_params"),
params={"quality": quality, "desired_fps": desired_fps},
) as response:
aiotools.raise_not_200(response)
class _AtxClientPart(_BaseClientPart):
async def get_state(self, user: str, passwd: str) -> Dict:
async with self._make_session(user, passwd) as session:
async with session.get(self._make_url("atx")) as response:
aiotools.raise_not_200(response)
return (await response.json())["result"]
async def switch_power(self, user: str, passwd: str, action: str) -> bool:
try: try:
async with self._make_session(user, passwd) as session: async with self._make_session(user, passwd) as session:
async with session.post( async with session.post(
url=self._make_url("streamer/set_params"), url=self._make_url("atx/power"),
params={"quality": quality, "desired_fps": desired_fps}, params={"action": action},
) as response: ) as response:
response.raise_for_status() aiotools.raise_not_200(response)
except aiohttp.ClientError as err: return True
raise KvmdError(err) except aiohttp.ClientResponseError as err:
if err.status == 409:
return False
raise
# ===== # =====
@@ -112,6 +119,7 @@ class KvmdClient(_BaseClientPart):
port: int, port: int,
unix_path: str, unix_path: str,
timeout: float, timeout: float,
user_agent: str,
) -> None: ) -> None:
kwargs: Dict = { kwargs: Dict = {
@@ -119,18 +127,17 @@ class KvmdClient(_BaseClientPart):
"port": port, "port": port,
"unix_path": unix_path, "unix_path": unix_path,
"timeout": timeout, "timeout": timeout,
"user_agent": user_agent,
} }
super().__init__(**kwargs) super().__init__(**kwargs)
self.auth = _AuthClientPart(**kwargs) self.auth = _AuthClientPart(**kwargs)
self.streamer = _StreamerClientPart(**kwargs) self.streamer = _StreamerClientPart(**kwargs)
self.atx = _AtxClientPart(**kwargs)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]: async def ws(self, user: str, passwd: str) -> AsyncGenerator[aiohttp.ClientWebSocketResponse, None]:
try: async with self._make_session(user, passwd) as session:
async with self._make_session(user, passwd) as session: async with session.ws_connect(self._make_url("ws")) as ws:
async with session.ws_connect(self._make_url("ws")) as ws: yield ws
yield ws
except aiohttp.ClientError as err:
raise KvmdError(err)

View File

@@ -26,6 +26,8 @@ from typing import AsyncGenerator
import aiohttp import aiohttp
from .. import aiotools
# ===== # =====
class StreamerError(Exception): class StreamerError(Exception):
@@ -59,7 +61,7 @@ class StreamerClient:
params={"extra_headers": "1"}, params={"extra_headers": "1"},
headers={"User-Agent": self.__user_agent}, headers={"User-Agent": self.__user_agent},
) as response: ) as response:
response.raise_for_status() aiotools.raise_not_200(response)
reader = aiohttp.MultipartReader.from_response(response) reader = aiohttp.MultipartReader.from_response(response)
while True: while True:
frame = await reader.next() # pylint: disable=not-callable frame = await reader.next() # pylint: disable=not-callable

View File

@@ -34,6 +34,7 @@ from ...validators.basic import valid_float_f01
from ...logging import get_logger from ...logging import get_logger
from ... import make_user_agent from ... import make_user_agent
from ... import aiotools
from . import BaseAuthService from . import BaseAuthService
@@ -89,10 +90,8 @@ class Plugin(BaseAuthService):
"X-KVMD-User": user, "X-KVMD-User": user,
}, },
) as response: ) as response:
response.raise_for_status() aiotools.raise_not_200(response)
if response.status == 200: return True
return True
raise RuntimeError(f"Invalid OK response: {response.status} {await response.text()}; expected 200")
except Exception: except Exception:
get_logger().exception("Failed HTTP auth request for user %r", user) get_logger().exception("Failed HTTP auth request for user %r", user)
return False return False