refactoring

This commit is contained in:
Maxim Devaev 2021-07-30 06:53:09 +03:00
parent 1aaea37f54
commit 7d7f4965de

View File

@ -31,8 +31,12 @@ from typing import Dict
from typing import Callable
from typing import Optional
import aiohttp
import aiohttp.web
from aiohttp.web import BaseRequest
from aiohttp.web import Request
from aiohttp.web import Response
from aiohttp.web import StreamResponse
from aiohttp.web import Application
from aiohttp.web import run_app
try:
from aiohttp.web import AccessLogger # type: ignore
@ -138,9 +142,9 @@ def make_json_response(
status: int=200,
set_cookies: Optional[Dict[str, str]]=None,
wrap_result: bool=True,
) -> aiohttp.web.Response:
) -> Response:
response = aiohttp.web.Response(
response = Response(
text=json.dumps(({
"ok": (status == 200),
"result": (result or {}),
@ -154,7 +158,7 @@ def make_json_response(
return response
def make_json_exception(err: Exception, status: Optional[int]=None) -> aiohttp.web.Response:
def make_json_exception(err: Exception, status: Optional[int]=None) -> Response:
name = type(err).__name__
msg = str(err)
if isinstance(err, HttpError):
@ -168,13 +172,13 @@ def make_json_exception(err: Exception, status: Optional[int]=None) -> aiohttp.w
}, status=status)
async def start_streaming(request: aiohttp.web.Request, content_type: str) -> aiohttp.web.StreamResponse:
response = aiohttp.web.StreamResponse(status=200, reason="OK", headers={"Content-Type": content_type})
async def start_streaming(request: Request, content_type: str) -> StreamResponse:
response = StreamResponse(status=200, reason="OK", headers={"Content-Type": content_type})
await response.prepare(request)
return response
async def stream_json(response: aiohttp.web.StreamResponse, result: Dict) -> None:
async def stream_json(response: StreamResponse, result: Dict) -> None:
await response.write(json.dumps(result).encode("utf-8") + b"\r\n")
@ -182,14 +186,14 @@ async def stream_json(response: aiohttp.web.StreamResponse, result: Dict) -> Non
_REQUEST_AUTH_INFO = "_kvmd_auth_info"
def _format_P(request: aiohttp.web.BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name
def _format_P(request: BaseRequest, *_, **__) -> str: # type: ignore # pylint: disable=invalid-name
return (getattr(request, _REQUEST_AUTH_INFO, None) or "-")
AccessLogger._format_P = staticmethod(_format_P) # type: ignore # pylint: disable=protected-access
def set_request_auth_info(request: aiohttp.web.BaseRequest, info: str) -> None:
def set_request_auth_info(request: BaseRequest, info: str) -> None:
setattr(request, _REQUEST_AUTH_INFO, info)
@ -218,7 +222,7 @@ class HttpServer:
else:
socket_kwargs = {"host": host, "port": port}
aiohttp.web.run_app(
run_app(
app=self._make_app(),
shutdown_timeout=1,
access_log_format=access_log_format,
@ -226,7 +230,7 @@ class HttpServer:
**socket_kwargs,
)
async def _make_app(self) -> aiohttp.web.Application:
async def _make_app(self) -> Application:
raise NotImplementedError
def __run_app_print(self, text: str) -> None: