refactoring

This commit is contained in:
Maxim Devaev 2022-04-12 09:12:21 +03:00
parent 1e5c8535f6
commit a29f9692c8
2 changed files with 31 additions and 24 deletions

View File

@ -42,19 +42,16 @@ import aiohttp.web
from ...logging import get_logger from ...logging import get_logger
from ...errors import OperationError from ...errors import OperationError
from ...errors import IsBusyError
from ... import aiotools from ... import aiotools
from ... import aioproc from ... import aioproc
from ...htserver import HttpError
from ...htserver import HttpExposed from ...htserver import HttpExposed
from ...htserver import exposed_http from ...htserver import exposed_http
from ...htserver import exposed_ws from ...htserver import exposed_ws
from ...htserver import get_exposed_http from ...htserver import get_exposed_http
from ...htserver import get_exposed_ws from ...htserver import get_exposed_ws
from ...htserver import make_json_response from ...htserver import make_json_response
from ...htserver import make_json_exception
from ...htserver import send_ws_event from ...htserver import send_ws_event
from ...htserver import broadcast_ws_event from ...htserver import broadcast_ws_event
from ...htserver import process_ws_messages from ...htserver import process_ws_messages
@ -65,7 +62,6 @@ from ...plugins.hid import BaseHid
from ...plugins.atx import BaseAtx from ...plugins.atx import BaseAtx
from ...plugins.msd import BaseMsd from ...plugins.msd import BaseMsd
from ...validators import ValidatorError
from ...validators.basic import valid_bool from ...validators.basic import valid_bool
from ...validators.kvm import valid_stream_quality from ...validators.kvm import valid_stream_quality
from ...validators.kvm import valid_stream_fps from ...validators.kvm import valid_stream_fps
@ -296,7 +292,10 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
aioproc.rename_process("main") aioproc.rename_process("main")
super().run(**kwargs) super().run(**kwargs)
async def _init_app(self, app: aiohttp.web.Application) -> None: async def _check_request_auth(self, exposed: HttpExposed, request: aiohttp.web.Request) -> None:
await check_request_auth(self.__auth_manager, exposed, request)
async def _init_app(self, _: aiohttp.web.Application) -> None:
self.__run_system_task(self.__stream_controller) self.__run_system_task(self.__stream_controller)
for comp in self.__components: for comp in self.__components:
if comp.systask: if comp.systask:
@ -307,7 +306,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
for api in self.__apis: for api in self.__apis:
for http_exposed in get_exposed_http(api): for http_exposed in get_exposed_http(api):
self.__add_app_route(app, http_exposed) self._add_exposed(http_exposed)
for ws_exposed in get_exposed_ws(api): for ws_exposed in get_exposed_ws(api):
self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler self.__ws_handlers[ws_exposed.event_type] = ws_exposed.handler
@ -324,19 +323,6 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
os.kill(os.getpid(), signal.SIGTERM) os.kill(os.getpid(), signal.SIGTERM)
self.__system_tasks.append(asyncio.create_task(wrapper())) self.__system_tasks.append(asyncio.create_task(wrapper()))
def __add_app_route(self, app: aiohttp.web.Application, exposed: HttpExposed) -> None:
async def wrapper(request: aiohttp.web.Request) -> aiohttp.web.Response:
try:
await check_request_auth(self.__auth_manager, exposed, request)
return (await exposed.handler(request))
except IsBusyError as err:
return make_json_exception(err, 409)
except (ValidatorError, OperationError) as err:
return make_json_exception(err, 400)
except HttpError as err:
return make_json_exception(err)
app.router.add_route(exposed.method, exposed.path, wrapper)
async def _on_shutdown(self, _: aiohttp.web.Application) -> None: async def _on_shutdown(self, _: aiohttp.web.Application) -> None:
logger = get_logger(0) logger = get_logger(0)

View File

@ -50,6 +50,11 @@ except ImportError:
from .logging import get_logger from .logging import get_logger
from .errors import OperationError
from .errors import IsBusyError
from .validators import ValidatorError
# ===== # =====
class HttpError(Exception): class HttpError(Exception):
@ -297,6 +302,19 @@ class HttpServer:
# ===== # =====
def _add_exposed(self, exposed: HttpExposed) -> None:
async def wrapper(request: Request) -> Response:
try:
await self._check_request_auth(exposed, request)
return (await exposed.handler(request))
except IsBusyError as err:
return make_json_exception(err, 409)
except (ValidatorError, OperationError) as err:
return make_json_exception(err, 400)
except HttpError as err:
return make_json_exception(err)
self.__app.router.add_route(exposed.method, exposed.path, wrapper)
async def _make_ws_response(self, request: Request) -> WebSocketResponse: async def _make_ws_response(self, request: Request) -> WebSocketResponse:
ws = WebSocketResponse(heartbeat=self.__heartbeat) ws = WebSocketResponse(heartbeat=self.__heartbeat)
await ws.prepare(request) await ws.prepare(request)
@ -304,6 +322,9 @@ class HttpServer:
# ===== # =====
async def _check_request_auth(self, exposed: HttpExposed, request: Request) -> None:
pass
async def _init_app(self, app: Application) -> None: async def _init_app(self, app: Application) -> None:
raise NotImplementedError raise NotImplementedError
@ -316,15 +337,15 @@ class HttpServer:
# ===== # =====
async def __make_app(self) -> Application: async def __make_app(self) -> Application:
app = Application(middlewares=[normalize_path_middleware( self.__app = Application(middlewares=[normalize_path_middleware( # pylint: disable=attribute-defined-outside-init
append_slash=False, append_slash=False,
remove_slash=True, remove_slash=True,
merge_slashes=True, merge_slashes=True,
)]) )])
app.on_shutdown.append(self._on_shutdown) self.__app.on_shutdown.append(self._on_shutdown)
app.on_cleanup.append(self._on_cleanup) self.__app.on_cleanup.append(self._on_cleanup)
await self._init_app(app) await self._init_app(self.__app)
return app return self.__app
def __run_app_print(self, text: str) -> None: def __run_app_print(self, text: str) -> None:
logger = get_logger(0) logger = get_logger(0)