refactoring

This commit is contained in:
Maxim Devaev
2022-04-06 00:39:16 +03:00
parent 8ce08fb456
commit 6f6772a6b6
16 changed files with 68 additions and 69 deletions

View File

@@ -23,15 +23,15 @@
from aiohttp.web import Request
from aiohttp.web import Response
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....plugins.atx import BaseAtx
from ....validators.basic import valid_bool
from ....validators.kvm import valid_atx_power_action
from ....validators.kvm import valid_atx_button
from ..http import exposed_http
from ..http import make_json_response
# =====
class AtxApi:

View File

@@ -25,17 +25,17 @@ import base64
from aiohttp.web import Request
from aiohttp.web import Response
from ....htserver import UnauthorizedError
from ....htserver import ForbiddenError
from ....htserver import HttpExposed
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....htserver import set_request_auth_info
from ....validators.auth import valid_user
from ....validators.auth import valid_passwd
from ....validators.auth import valid_auth_token
from ..http import UnauthorizedError
from ..http import ForbiddenError
from ..http import HttpExposed
from ..http import exposed_http
from ..http import make_json_response
from ..http import set_request_auth_info
from ..auth import AuthManager

View File

@@ -30,14 +30,14 @@ from aiohttp.web import Response
from .... import tools
from ....htserver import exposed_http
from ....plugins.atx import BaseAtx
from ....plugins.ugpio import UserGpioModes
from ..info import InfoManager
from ..ugpio import UserGpio
from ..http import exposed_http
# =====
class ExportApi:

View File

@@ -36,6 +36,13 @@ from aiohttp.web import WebSocketResponse
from ....mouse import MouseRange
from ....keyboard.keysym import build_symmap
from ....keyboard.printer import text_to_web_keys
from ....htserver import exposed_http
from ....htserver import exposed_ws
from ....htserver import make_json_response
from ....plugins.hid import BaseHid
from ....validators import raise_error
@@ -49,13 +56,6 @@ from ....validators.hid import valid_hid_mouse_move
from ....validators.hid import valid_hid_mouse_button
from ....validators.hid import valid_hid_mouse_delta
from ....keyboard.keysym import build_symmap
from ....keyboard.printer import text_to_web_keys
from ..http import exposed_http
from ..http import exposed_ws
from ..http import make_json_response
# =====
class HidApi:

View File

@@ -27,13 +27,13 @@ from typing import List
from aiohttp.web import Request
from aiohttp.web import Response
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....validators.kvm import valid_info_fields
from ..info import InfoManager
from ..http import exposed_http
from ..http import make_json_response
# =====
class InfoApi:

View File

@@ -23,14 +23,14 @@
from aiohttp.web import Request
from aiohttp.web import StreamResponse
from ....htserver import exposed_http
from ....htserver import start_streaming
from ....validators.basic import valid_bool
from ....validators.kvm import valid_log_seek
from ..logreader import LogReader
from ..http import exposed_http
from ..http import start_streaming
# =====
class LogApi:

View File

@@ -36,6 +36,13 @@ from ....logging import get_logger
from .... import htclient
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....htserver import make_json_exception
from ....htserver import start_streaming
from ....htserver import stream_json
from ....htserver import stream_json_exception
from ....plugins.msd import BaseMsd
from ....validators.basic import valid_bool
@@ -44,13 +51,6 @@ from ....validators.basic import valid_float_f01
from ....validators.net import valid_url
from ....validators.kvm import valid_msd_image_name
from ..http import exposed_http
from ..http import make_json_response
from ..http import make_json_exception
from ..http import start_streaming
from ..http import stream_json
from ..http import stream_json_exception
# ======
class MsdApi:

View File

@@ -25,6 +25,10 @@ import asyncio
from aiohttp.web import Request
from aiohttp.web import Response
from ....htserver import HttpError
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....plugins.atx import BaseAtx
from ....validators import ValidatorError
@@ -32,10 +36,6 @@ from ....validators import check_string_in_list
from ..info import InfoManager
from ..http import HttpError
from ..http import exposed_http
from ..http import make_json_response
# =====
class RedfishApi:

View File

@@ -26,6 +26,10 @@ from typing import Dict
from aiohttp.web import Request
from aiohttp.web import Response
from ....htserver import UnavailableError
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....validators import check_string_in_list
from ....validators.basic import valid_bool
from ....validators.basic import valid_number
@@ -33,10 +37,6 @@ from ....validators.basic import valid_int_f0
from ....validators.basic import valid_string_list
from ....validators.kvm import valid_stream_quality
from ..http import UnavailableError
from ..http import exposed_http
from ..http import make_json_response
from ..streamer import Streamer
from ..tesseract import TesseractOcr

View File

@@ -23,15 +23,15 @@
from aiohttp.web import Request
from aiohttp.web import Response
from ....htserver import exposed_http
from ....htserver import make_json_response
from ....validators.basic import valid_bool
from ....validators.basic import valid_float_f0
from ....validators.ugpio import valid_ugpio_channel
from ..ugpio import UserGpio
from ..http import exposed_http
from ..http import make_json_response
# =====
class UserGpioApi:

View File

@@ -28,11 +28,11 @@ from typing import Optional
from ...logging import get_logger
from ... import aiotools
from ...plugins.auth import BaseAuthService
from ...plugins.auth import get_auth_service_class
from ... import aiotools
# =====
class AuthManager:

View File

@@ -1,264 +0,0 @@
# ========================================================================== #
# #
# KVMD - The main PiKVM daemon. #
# #
# Copyright (C) 2018-2022 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import os
import socket
import asyncio
import dataclasses
import inspect
import json
from typing import List
from typing import Dict
from typing import Callable
from typing import Optional
from aiohttp.web import BaseRequest
from aiohttp.web import Request
from aiohttp.web import Response
from aiohttp.web import StreamResponse
from aiohttp.web import Application
from aiohttp.web import run_app
from aiohttp.web import normalize_path_middleware
try:
from aiohttp.web import AccessLogger # type: ignore
except ImportError:
from aiohttp.helpers import AccessLogger # type: ignore
from ...logging import get_logger
# =====
class HttpError(Exception):
def __init__(self, msg: str, status: int) -> None:
super().__init__(msg)
self.status = status
class UnauthorizedError(HttpError):
def __init__(self) -> None:
super().__init__("Unauthorized", 401)
class ForbiddenError(HttpError):
def __init__(self) -> None:
super().__init__("Forbidden", 403)
class UnavailableError(HttpError):
def __init__(self) -> None:
super().__init__("Service Unavailable", 503)
# =====
@dataclasses.dataclass(frozen=True)
class HttpExposed:
method: str
path: str
auth_required: bool
handler: Callable
_HTTP_EXPOSED = "_http_exposed"
_HTTP_METHOD = "_http_method"
_HTTP_PATH = "_http_path"
_HTTP_AUTH_REQUIRED = "_http_auth_required"
def exposed_http(http_method: str, path: str, auth_required: bool=True) -> Callable:
def set_attrs(handler: Callable) -> Callable:
setattr(handler, _HTTP_EXPOSED, True)
setattr(handler, _HTTP_METHOD, http_method)
setattr(handler, _HTTP_PATH, path)
setattr(handler, _HTTP_AUTH_REQUIRED, auth_required)
return handler
return set_attrs
def get_exposed_http(obj: object) -> List[HttpExposed]:
return [
HttpExposed(
method=getattr(handler, _HTTP_METHOD),
path=getattr(handler, _HTTP_PATH),
auth_required=getattr(handler, _HTTP_AUTH_REQUIRED),
handler=handler,
)
for handler in [getattr(obj, name) for name in dir(obj)]
if inspect.ismethod(handler) and getattr(handler, _HTTP_EXPOSED, False)
]
# =====
@dataclasses.dataclass(frozen=True)
class WsExposed:
event_type: str
handler: Callable
_WS_EXPOSED = "_ws_exposed"
_WS_EVENT_TYPE = "_ws_event_type"
def exposed_ws(event_type: str) -> Callable:
def set_attrs(handler: Callable) -> Callable:
setattr(handler, _WS_EXPOSED, True)
setattr(handler, _WS_EVENT_TYPE, event_type)
return handler
return set_attrs
def get_exposed_ws(obj: object) -> List[WsExposed]:
return [
WsExposed(
event_type=getattr(handler, _WS_EVENT_TYPE),
handler=handler,
)
for handler in [getattr(obj, name) for name in dir(obj)]
if inspect.ismethod(handler) and getattr(handler, _WS_EXPOSED, False)
]
# =====
def make_json_response(
result: Optional[Dict]=None,
status: int=200,
set_cookies: Optional[Dict[str, str]]=None,
wrap_result: bool=True,
) -> Response:
response = Response(
text=json.dumps(({
"ok": (status == 200),
"result": (result or {}),
} if wrap_result else result), sort_keys=True, indent=4),
status=status,
content_type="application/json",
)
if set_cookies:
for (key, value) in set_cookies.items():
response.set_cookie(key, value)
return response
def make_json_exception(err: Exception, status: Optional[int]=None) -> Response:
name = type(err).__name__
msg = str(err)
if isinstance(err, HttpError):
status = err.status
else:
get_logger().error("API error: %s: %s", name, msg)
assert status is not None, err
return make_json_response({
"error": name,
"error_msg": msg,
}, status=status)
async def start_streaming(request: Request, content_type: str="application/x-ndjson") -> StreamResponse:
response = StreamResponse(status=200, reason="OK", headers={"Content-Type": content_type})
await response.prepare(request)
return response
async def stream_json(response: StreamResponse, result: Dict, ok: bool=True) -> None:
await response.write(json.dumps({
"ok": ok,
"result": result,
}).encode("utf-8") + b"\r\n")
async def stream_json_exception(response: StreamResponse, err: Exception) -> None:
name = type(err).__name__
msg = str(err)
get_logger().error("API error: %s: %s", name, msg)
await stream_json(response, {
"error": name,
"error_msg": msg,
}, False)
# =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
def _format_P(request: BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name
return (getattr(request, _REQUEST_AUTH_INFO, None) or "-")
AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access
def set_request_auth_info(request: BaseRequest, info: str) -> None:
setattr(request, _REQUEST_AUTH_INFO, info)
# =====
class HttpServer:
def run(
self,
unix_path: str,
unix_rm: bool,
unix_mode: int,
access_log_format: str,
) -> None:
if unix_rm and os.path.exists(unix_path):
os.remove(unix_path)
server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_socket.bind(unix_path)
if unix_mode:
os.chmod(unix_path, unix_mode)
run_app(
sock=server_socket,
app=self.__make_app(),
shutdown_timeout=1,
access_log_format=access_log_format,
print=self.__run_app_print,
loop=asyncio.get_event_loop(),
)
async def _init_app(self, app: Application) -> None:
raise NotImplementedError
async def _on_shutdown(self, app: Application) -> None:
_ = app
async def _on_cleanup(self, app: Application) -> None:
_ = app
async def __make_app(self) -> Application:
app = Application(middlewares=[normalize_path_middleware(
append_slash=False,
remove_slash=True,
merge_slashes=True,
)])
app.on_shutdown.append(self._on_shutdown)
app.on_cleanup.append(self._on_cleanup)
await self._init_app(app)
return app
def __run_app_print(self, text: str) -> None:
logger = get_logger(0)
for line in text.strip().splitlines():
logger.info(line.strip())

View File

@@ -45,8 +45,20 @@ from ...logging import get_logger
from ...errors import OperationError
from ...errors import IsBusyError
from ...plugins import BasePlugin
from ... import aiotools
from ... import aioproc
from ...htserver import HttpError
from ...htserver import HttpExposed
from ...htserver import exposed_http
from ...htserver import exposed_ws
from ...htserver import get_exposed_http
from ...htserver import get_exposed_ws
from ...htserver import make_json_response
from ...htserver import make_json_exception
from ...htserver import HttpServer
from ...plugins import BasePlugin
from ...plugins.hid import BaseHid
from ...plugins.atx import BaseAtx
from ...plugins.msd import BaseMsd
@@ -59,9 +71,6 @@ from ...validators.kvm import valid_stream_resolution
from ...validators.kvm import valid_stream_h264_bitrate
from ...validators.kvm import valid_stream_h264_gop
from ... import aiotools
from ... import aioproc
from .auth import AuthManager
from .info import InfoManager
from .logreader import LogReader
@@ -70,16 +79,6 @@ from .streamer import Streamer
from .snapshoter import Snapshoter
from .tesseract import TesseractOcr
from .http import HttpError
from .http import HttpExposed
from .http import exposed_http
from .http import exposed_ws
from .http import get_exposed_http
from .http import get_exposed_ws
from .http import make_json_response
from .http import make_json_exception
from .http import HttpServer
from .api.auth import AuthApi
from .api.auth import check_request_auth

View File

@@ -27,10 +27,10 @@ from typing import Callable
from ...logging import get_logger
from ...plugins.hid import BaseHid
from ... import aiotools
from ...plugins.hid import BaseHid
from .streamer import Streamer

View File

@@ -31,6 +31,11 @@ from typing import Any
from ...logging import get_logger
from ...errors import IsBusyError
from ... import tools
from ... import aiotools
from ...plugins.ugpio import GpioError
from ...plugins.ugpio import GpioOperationError
from ...plugins.ugpio import GpioDriverOfflineError
@@ -38,13 +43,8 @@ from ...plugins.ugpio import UserGpioModes
from ...plugins.ugpio import BaseUserGpioDriver
from ...plugins.ugpio import get_ugpio_driver_class
from ... import tools
from ... import aiotools
from ...yamlconf import Section
from ...errors import IsBusyError
# =====
class GpioChannelNotFoundError(GpioOperationError):