changed /msd/write_remote format

This commit is contained in:
Maxim Devaev 2021-08-01 09:26:54 +03:00
parent 4f1c2a97aa
commit d6fd2e3775
2 changed files with 22 additions and 9 deletions

View File

@ -24,6 +24,7 @@ import time
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from typing import Union
import aiohttp import aiohttp
@ -48,6 +49,7 @@ from ..http import make_json_response
from ..http import make_json_exception from ..http import make_json_exception
from ..http import start_streaming from ..http import start_streaming
from ..http import stream_json from ..http import stream_json
from ..http import stream_json_exception
# ====== # ======
@ -97,7 +99,7 @@ class MsdApi:
return make_json_response(self.__make_write_info(name, size, written)) return make_json_response(self.__make_write_info(name, size, written))
@exposed_http("POST", "/msd/write_remote") @exposed_http("POST", "/msd/write_remote")
async def __write_remote_handler(self, request: Request) -> StreamResponse: # pylint: disable=too-many-locals async def __write_remote_handler(self, request: Request) -> Union[Response, StreamResponse]: # pylint: disable=too-many-locals
url = valid_url(request.query.get("url")) url = valid_url(request.query.get("url"))
insecure = valid_bool(request.query.get("insecure", "0")) insecure = valid_bool(request.query.get("insecure", "0"))
timeout = valid_float_f01(request.query.get("timeout", 10.0)) timeout = valid_float_f01(request.query.get("timeout", 10.0))
@ -106,9 +108,9 @@ class MsdApi:
size = written = 0 size = written = 0
response: Optional[StreamResponse] = None response: Optional[StreamResponse] = None
async def stream_write_info(err: Optional[Exception]=None) -> None: async def stream_write_info() -> None:
assert response is not None assert response is not None
await stream_json(response, self.__make_write_info(name, size, written), err) await stream_json(response, self.__make_write_info(name, size, written))
try: try:
async with htclient.download( async with htclient.download(
@ -127,7 +129,8 @@ class MsdApi:
get_logger(0).info("Downloading image %r as %r to MSD ...", url, name) get_logger(0).info("Downloading image %r as %r to MSD ...", url, name)
async with self.__msd.write_image(name, size) as chunk_size: async with self.__msd.write_image(name, size) as chunk_size:
response = await start_streaming(request, "application/stream+json") response = await start_streaming(request)
await stream_write_info()
last_report_ts = 0 last_report_ts = 0
async for chunk in remote.content.iter_chunked(chunk_size): async for chunk in remote.content.iter_chunked(chunk_size):
written = await self.__msd.write_image_chunk(chunk) written = await self.__msd.write_image_chunk(chunk)
@ -141,7 +144,8 @@ class MsdApi:
except Exception as err: except Exception as err:
if response is not None: if response is not None:
await stream_write_info(err) await stream_write_info()
await stream_json_exception(response, err)
elif isinstance(err, aiohttp.ClientError): elif isinstance(err, aiohttp.ClientError):
return make_json_exception(err, 400) return make_json_exception(err, 400)
raise raise

View File

@ -172,20 +172,29 @@ def make_json_exception(err: Exception, status: Optional[int]=None) -> Response:
}, status=status) }, status=status)
async def start_streaming(request: Request, content_type: str) -> StreamResponse: async def start_streaming(request: Request, content_type: str="application/x-ndjson") -> StreamResponse:
response = StreamResponse(status=200, reason="OK", headers={"Content-Type": content_type}) response = StreamResponse(status=200, reason="OK", headers={"Content-Type": content_type})
await response.prepare(request) await response.prepare(request)
return response return response
async def stream_json(response: StreamResponse, result: Dict, err: Optional[Exception]=None) -> None: async def stream_json(response: StreamResponse, result: Dict, ok: bool=True) -> None:
await response.write(json.dumps({ await response.write(json.dumps({
"ok": ok,
"result": result, "result": result,
"error": ("" if err is None else type(err).__name__),
"error_msg": ("" if err is None else str(err)),
}).encode("utf-8") + b"\r\n") }).encode("utf-8") + b"\r\n")
async def stream_json_exception(response: StreamResponse, err: Exception) -> None:
name = type(err).__name__
msg = str(err)
get_logger().error("API error: %s: %s", name, msg)
await stream_json(response, {
"error": name,
"error_msg": msg,
}, False)
# ===== # =====
_REQUEST_AUTH_INFO = "_kvmd_auth_info" _REQUEST_AUTH_INFO = "_kvmd_auth_info"