/msd/write_remote handle

This commit is contained in:
Maxim Devaev 2021-07-27 05:25:54 +03:00
parent 3c421fa94c
commit 6b07a80834
9 changed files with 173 additions and 36 deletions

View File

@ -59,6 +59,18 @@ location /api/ws {
auth_request off; auth_request off;
} }
location /api/msd/write_remote {
rewrite ^/api/msd/write_remote$ /msd/write_remote break;
rewrite ^/api/msd/write_remote\?(.*)$ /msd/write_remote?$1 break;
proxy_pass http://kvmd;
include /etc/kvmd/nginx/loc-proxy.conf;
proxy_read_timeout 7d;
postpone_output 0;
proxy_buffering off;
proxy_ignore_headers X-Accel-Buffering;
auth_request off;
}
location /api/msd/write { location /api/msd/write {
rewrite ^/api/msd/write$ /msd/write break; rewrite ^/api/msd/write$ /msd/write break;
rewrite ^/api/msd/write\?(.*)$ /msd/write?$1 break; rewrite ^/api/msd/write\?(.*)$ /msd/write?$1 break;

View File

@ -20,19 +20,34 @@
# ========================================================================== # # ========================================================================== #
import time
from typing import Dict
from typing import Optional
import aiohttp
from aiohttp.web import Request from aiohttp.web import Request
from aiohttp.web import Response from aiohttp.web import Response
from aiohttp.web import StreamResponse
from ....logging import get_logger from ....logging import get_logger
from .... import htclient
from ....plugins.msd import BaseMsd from ....plugins.msd import BaseMsd
from ....validators.basic import valid_bool from ....validators.basic import valid_bool
from ....validators.basic import valid_int_f0 from ....validators.basic import valid_int_f0
from ....validators.basic import valid_float_f01
from ....validators.net import valid_url
from ....validators.kvm import valid_msd_image_name from ....validators.kvm import valid_msd_image_name
from ..http import exposed_http from ..http import exposed_http
from ..http import make_json_response from ..http import make_json_response
from ..http import make_json_exception
from ..http import start_streaming
from ..http import stream_json
from ..http import get_multipart_reader from ..http import get_multipart_reader
from ..http import get_multipart_reader_str from ..http import get_multipart_reader_str
from ..http import get_multipart_reader_field from ..http import get_multipart_reader_field
@ -67,29 +82,79 @@ class MsdApi:
await self.__msd.set_connected(valid_bool(request.query.get("connected"))) await self.__msd.set_connected(valid_bool(request.query.get("connected")))
return make_json_response() return make_json_response()
# =====
@exposed_http("POST", "/msd/write") @exposed_http("POST", "/msd/write")
async def __write_handler(self, request: Request) -> Response: async def __write_handler(self, request: Request) -> Response:
logger = get_logger(0)
reader = await get_multipart_reader(request) reader = await get_multipart_reader(request)
name = "" name = valid_msd_image_name(await get_multipart_reader_str(reader, "image"))
size = valid_int_f0(await get_multipart_reader_str(reader, "size"))
data_field = await get_multipart_reader_field(reader, "data")
written = 0 written = 0
async with self.__msd.write_image(name, size) as chunk_size:
while True:
chunk = await data_field.read_chunk(chunk_size)
if not chunk:
break
written = await self.__msd.write_image_chunk(chunk)
return make_json_response(self.__make_write_info(name, size, written))
@exposed_http("POST", "/msd/write_remote")
async def __write_remote_handler(self, request: Request) -> StreamResponse: # pylint: disable=too-many-locals
url = valid_url(request.query.get("url"))
insecure = valid_bool(request.query.get("insecure", "0"))
timeout = valid_float_f01(request.query.get("timeout", 10.0))
name = ""
size = written = 0
response: Optional[StreamResponse] = None
async def stream_write_info() -> None:
assert response is not None
await stream_json(response, self.__make_write_info(name, size, written))
try: try:
name = valid_msd_image_name(await get_multipart_reader_str(reader, "image")) async with htclient.download(
size = valid_int_f0(await get_multipart_reader_str(reader, "size")) url=url,
verify=(not insecure),
timeout=timeout,
read_timeout=(7 * 24 * 3600),
) as remote:
data_field = await get_multipart_reader_field(reader, "data") name = str(request.query.get("image", "")).strip()
if len(name) == 0:
name = htclient.get_filename(remote)
name = valid_msd_image_name(name)
async with self.__msd.write_image(name, size): size = htclient.get_content_length(remote)
logger.info("Writing image %r to MSD ...", name)
while True: get_logger(0).info("Downloading image %r as %r to MSD ...", url, name)
chunk = await data_field.read_chunk(self.__msd.get_upload_chunk_size()) async with self.__msd.write_image(name, size) as chunk_size:
if not chunk: response = await start_streaming(request, "application/stream+json")
break last_report_ts = 0
written = await self.__msd.write_image_chunk(chunk) async for chunk in remote.content.iter_chunked(chunk_size):
finally: written = await self.__msd.write_image_chunk(chunk)
if written != 0: now = int(time.time())
logger.info("Written image %r with size=%d bytes to MSD", name, written) if last_report_ts + 1 < now:
return make_json_response({"image": {"name": name, "size": written}}) await stream_write_info()
last_report_ts = now
await stream_write_info()
return response
except Exception as err:
if response is not None:
await stream_write_info()
elif isinstance(err, aiohttp.ClientError):
return make_json_exception(err, 400)
raise
def __make_write_info(self, name: str, size: int, written: int) -> Dict:
return {"image": {"name": name, "size": size, "written": written}}
# =====
@exposed_http("POST", "/msd/remove") @exposed_http("POST", "/msd/remove")
async def __remove_handler(self, request: Request) -> Response: async def __remove_handler(self, request: Request) -> Response:

View File

@ -176,6 +176,10 @@ async def start_streaming(request: aiohttp.web.Request, content_type: str) -> ai
return response return response
async def stream_json(response: aiohttp.web.StreamResponse, result: Dict) -> None:
await response.write(json.dumps(result).encode("utf-8") + b"\r\n")
# ===== # =====
async def get_multipart_reader(request: aiohttp.web.Request) -> aiohttp.MultipartReader: async def get_multipart_reader(request: aiohttp.web.Request) -> aiohttp.MultipartReader:
try: try:

View File

@ -20,7 +20,15 @@
# ========================================================================== # # ========================================================================== #
import os
import contextlib
from typing import Dict
from typing import AsyncGenerator
from typing import Optional
import aiohttp import aiohttp
import aiohttp.multipart
from . import __version__ from . import __version__
@ -41,3 +49,48 @@ def raise_not_200(response: aiohttp.ClientResponse) -> None:
message=response.reason, message=response.reason,
headers=response.headers, headers=response.headers,
) )
def get_content_length(response: aiohttp.ClientResponse) -> int:
try:
value = int(response.headers["Content-Length"])
except Exception:
raise aiohttp.ClientError("Empty or invalid Content-Length")
if value < 0:
raise aiohttp.ClientError("Negative Content-Length")
return value
def get_filename(response: aiohttp.ClientResponse) -> str:
try:
disp = response.headers["Content-Disposition"]
parsed = aiohttp.multipart.parse_content_disposition(disp)
return str(parsed[1]["filename"])
except Exception:
try:
return os.path.basename(response.url.path)
except Exception:
raise aiohttp.ClientError("Can't determine filename")
@contextlib.asynccontextmanager
async def download(
url: str,
verify: bool=True,
timeout: float=10.0,
read_timeout: Optional[float]=None,
app: str="KVMD",
) -> AsyncGenerator[aiohttp.ClientResponse, None]:
kwargs: Dict = {
"headers": {"User-Agent": make_user_agent(app)},
"timeout": aiohttp.ClientTimeout(
connect=timeout,
sock_connect=timeout,
sock_read=(read_timeout if read_timeout is not None else timeout),
),
}
async with aiohttp.ClientSession(**kwargs) as session:
async with session.get(url, verify_ssl=verify) as response:
raise_not_200(response)
yield response

View File

@ -31,6 +31,8 @@ from typing import Optional
import aiofiles import aiofiles
import aiofiles.base import aiofiles.base
from ...logging import get_logger
from ... import aiofs from ... import aiofs
from ...errors import OperationError from ...errors import OperationError
@ -119,13 +121,10 @@ class BaseMsd(BasePlugin):
raise NotImplementedError() raise NotImplementedError()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: # pylint: disable=unused-argument async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]: # pylint: disable=unused-argument
if self is not None: # XXX: Vulture and pylint hack if self is not None: # XXX: Vulture and pylint hack
raise NotImplementedError() raise NotImplementedError()
yield yield 1
def get_upload_chunk_size(self) -> int:
raise NotImplementedError()
async def write_image_chunk(self, chunk: bytes) -> int: async def write_image_chunk(self, chunk: bytes) -> int:
raise NotImplementedError() raise NotImplementedError()
@ -158,6 +157,7 @@ class MsdImageWriter:
async def open(self) -> "MsdImageWriter": async def open(self) -> "MsdImageWriter":
assert self.__file is None assert self.__file is None
get_logger(1).info("Writing %r image (%d bytes) to MSD ...", self.__name, self.__size)
self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore self.__file = await aiofiles.open(self.__path, mode="w+b", buffering=0) # type: ignore
return self return self
@ -176,6 +176,13 @@ class MsdImageWriter:
async def close(self) -> None: async def close(self) -> None:
assert self.__file is not None assert self.__file is not None
if self.__written == self.__size:
(log, result) = (get_logger().info, "OK")
elif self.__written < self.__size:
(log, result) = (get_logger().error, "INCOMPLETE")
else: # written > size
(log, result) = (get_logger().warning, "OVERFLOW")
log("Written %d of %d bytes to MSD image %r: %s", self.__written, self.__size, self.__name, result)
await aiofs.afile_sync(self.__file) await aiofs.afile_sync(self.__file)
await self.__file.close() # type: ignore await self.__file.close() # type: ignore

View File

@ -70,13 +70,10 @@ class Plugin(BaseMsd):
raise MsdDisabledError() raise MsdDisabledError()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]:
if self is not None: # XXX: Vulture and pylint hack if self is not None: # XXX: Vulture and pylint hack
raise MsdDisabledError() raise MsdDisabledError()
yield yield 1
def get_upload_chunk_size(self) -> int:
raise MsdDisabledError()
async def write_image_chunk(self, chunk: bytes) -> int: async def write_image_chunk(self, chunk: bytes) -> int:
raise MsdDisabledError() raise MsdDisabledError()

View File

@ -306,7 +306,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__state.vd.connected = connected self.__state.vd.connected = connected
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]:
try: try:
async with self.__state._region: # pylint: disable=protected-access async with self.__state._region: # pylint: disable=protected-access
try: try:
@ -328,7 +328,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__new_writer = await MsdImageWriter(path, size, self.__sync_chunk_size).open() self.__new_writer = await MsdImageWriter(path, size, self.__sync_chunk_size).open()
await self.__notifier.notify() await self.__notifier.notify()
yield yield self.__upload_chunk_size
self.__set_image_complete(name, True) self.__set_image_complete(name, True)
finally: finally:
@ -343,9 +343,6 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
await self.__reload_state() await self.__reload_state()
await self.__notifier.notify() await self.__notifier.notify()
def get_upload_chunk_size(self) -> int:
return self.__upload_chunk_size
async def write_image_chunk(self, chunk: bytes) -> int: async def write_image_chunk(self, chunk: bytes) -> int:
assert self.__new_writer assert self.__new_writer
written = await self.__new_writer.write(chunk) written = await self.__new_writer.write(chunk)

View File

@ -208,7 +208,7 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
self.__connected = connected self.__connected = connected
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def write_image(self, name: str, size: int) -> AsyncGenerator[None, None]: async def write_image(self, name: str, size: int) -> AsyncGenerator[int, None]:
async with self.__working(): async with self.__working():
async with self.__region: async with self.__region:
try: try:
@ -220,15 +220,12 @@ class Plugin(BaseMsd): # pylint: disable=too-many-instance-attributes
await self.__write_image_info(False) await self.__write_image_info(False)
await self.__notifier.notify() await self.__notifier.notify()
yield yield self.__upload_chunk_size
await self.__write_image_info(True) await self.__write_image_info(True)
finally: finally:
await self.__close_device_writer() await self.__close_device_writer()
await self.__load_device_info() await self.__load_device_info()
def get_upload_chunk_size(self) -> int:
return self.__upload_chunk_size
async def write_image_chunk(self, chunk: bytes) -> int: async def write_image_chunk(self, chunk: bytes) -> int:
assert self.__device_writer assert self.__device_writer
return (await self.__device_writer.write(chunk)) return (await self.__device_writer.write(chunk))

View File

@ -116,3 +116,8 @@ def valid_ssl_ciphers(arg: Any) -> str:
except Exception as err: except Exception as err:
raise ValidatorError(f"The argument {arg!r} is not a valid {name}: {err}") raise ValidatorError(f"The argument {arg!r} is not a valid {name}: {err}")
return arg return arg
def valid_url(arg: Any) -> str:
# XXX: VERY primitive
return check_re_match(arg, "HTTP(S) URL", r"^https?://[\[\w]+\S*")