From fb2a9986d8af9c17dd96b27d4ce68e0ce6435f3c Mon Sep 17 00:00:00 2001 From: Devaev Maxim Date: Tue, 10 Dec 2019 06:27:27 +0300 Subject: [PATCH] refactoring --- kvmd/apps/kvmd/http.py | 20 ++++++++++++++++++++ kvmd/apps/kvmd/server.py | 24 ++++-------------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/kvmd/apps/kvmd/http.py b/kvmd/apps/kvmd/http.py index a33eab75..a43d7d57 100644 --- a/kvmd/apps/kvmd/http.py +++ b/kvmd/apps/kvmd/http.py @@ -12,6 +12,11 @@ from typing import Optional import aiohttp import aiohttp.web +try: + from aiohttp.web import AccessLogger # type: ignore +except ImportError: + from aiohttp.helpers import AccessLogger # type: ignore + from ...logging import get_logger from ...validators import ValidatorError @@ -138,6 +143,21 @@ async def get_multipart_field(reader: aiohttp.MultipartReader, name: str) -> aio return field +# ===== +_REQUEST_AUTH_INFO = "_kvd_auth_info" + + +def _format_P(request: aiohttp.web.BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name + return (getattr(request, _REQUEST_AUTH_INFO, None) or "-") + + +AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access + + +def set_request_auth_info(request: aiohttp.web.BaseRequest, info: str) -> None: + setattr(request, _REQUEST_AUTH_INFO, info) + + # ===== class HttpServer: def run( diff --git a/kvmd/apps/kvmd/server.py b/kvmd/apps/kvmd/server.py index d4d35601..2ad491b9 100644 --- a/kvmd/apps/kvmd/server.py +++ b/kvmd/apps/kvmd/server.py @@ -79,6 +79,7 @@ from .http import get_exposed_http from .http import get_exposed_ws from .http import make_json_response from .http import make_json_exception +from .http import set_request_auth_info from .http import HttpServer from .api.log import LogApi @@ -88,23 +89,6 @@ from .api.atx import AtxApi from .api.msd import MsdApi -# ===== -try: - from aiohttp.web import AccessLogger # type: ignore # pylint: disable=ungrouped-imports -except ImportError: - from aiohttp.helpers import AccessLogger # type: ignore # pylint: disable=ungrouped-imports - - -_ATTR_KVMD_AUTH_INFO = "kvmd_auth_info" - - -def _format_P(request: aiohttp.web.BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name - return (getattr(request, _ATTR_KVMD_AUTH_INFO, None) or "-") - - -AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access - - # ===== _HEADER_AUTH_USER = "X-KVMD-User" _HEADER_AUTH_PASSWD = "X-KVMD-Passwd" @@ -318,16 +302,16 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins if user: user = valid_user(user) - setattr(request, _ATTR_KVMD_AUTH_INFO, f"{user} (xhdr)") + set_request_auth_info(request, f"{user} (xhdr)") if not (await self.__auth_manager.authorize(user, valid_passwd(passwd))): raise ForbiddenError("Forbidden") elif token: user = self.__auth_manager.check(valid_auth_token(token)) if not user: - setattr(request, _ATTR_KVMD_AUTH_INFO, "- (token)") + set_request_auth_info(request, "- (token)") raise ForbiddenError("Forbidden") - setattr(request, _ATTR_KVMD_AUTH_INFO, f"{user} (token)") + set_request_auth_info(request, f"{user} (token)") else: raise UnauthorizedError("Unauthorized")