refactoring

This commit is contained in:
Maxim Devaev 2024-10-22 00:51:03 +03:00
parent cda32a083f
commit 0e4a70e7b9
6 changed files with 234 additions and 232 deletions

View File

@ -20,12 +20,10 @@
# ========================================================================== # # ========================================================================== #
import io
import signal import signal
import asyncio import asyncio
import asyncio.subprocess import asyncio.subprocess
import dataclasses import dataclasses
import functools
import copy import copy
from typing import AsyncGenerator from typing import AsyncGenerator
@ -33,10 +31,12 @@ from typing import Any
import aiohttp import aiohttp
from PIL import Image as PilImage
from ...logging import get_logger from ...logging import get_logger
from ...clients.streamer import StreamerSnapshot
from ...clients.streamer import HttpStreamerClient
from ...clients.streamer import HttpStreamerClientSession
from ... import tools from ... import tools
from ... import aiotools from ... import aiotools
from ... import aioproc from ... import aioproc
@ -44,40 +44,6 @@ from ... import htclient
# ===== # =====
@dataclasses.dataclass(frozen=True)
class StreamerSnapshot:
online: bool
width: int
height: int
headers: tuple[tuple[str, str], ...]
data: bytes
async def make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
assert max_width >= 0
assert max_height >= 0
assert quality > 0
if max_width == 0 and max_height == 0:
max_width = self.width // 5
max_height = self.height // 5
else:
max_width = min((max_width or self.width), self.width)
max_height = min((max_height or self.height), self.height)
if (max_width, max_height) == (self.width, self.height):
return self.data
return (await aiotools.run_async(self.__inner_make_preview, max_width, max_height, quality))
@functools.lru_cache(maxsize=1)
def __inner_make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
with io.BytesIO(self.data) as snapshot_bio:
with io.BytesIO() as preview_bio:
with PilImage.open(snapshot_bio) as image:
image.thumbnail((max_width, max_height), PilImage.Resampling.LANCZOS)
image.save(preview_bio, format="jpeg", quality=quality)
return preview_bio.getvalue()
class _StreamerParams: class _StreamerParams:
__DESIRED_FPS = "desired_fps" __DESIRED_FPS = "desired_fps"
@ -204,7 +170,6 @@ class Streamer: # pylint: disable=too-many-instance-attributes
self.__state_poll = state_poll self.__state_poll = state_poll
self.__unix_path = unix_path self.__unix_path = unix_path
self.__timeout = timeout
self.__snapshot_timeout = snapshot_timeout self.__snapshot_timeout = snapshot_timeout
self.__process_name_prefix = process_name_prefix self.__process_name_prefix = process_name_prefix
@ -221,7 +186,13 @@ class Streamer: # pylint: disable=too-many-instance-attributes
self.__streamer_task: (asyncio.Task | None) = None self.__streamer_task: (asyncio.Task | None) = None
self.__streamer_proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member self.__streamer_proc: (asyncio.subprocess.Process | None) = None # pylint: disable=no-member
self.__http_session: (aiohttp.ClientSession | None) = None self.__client = HttpStreamerClient(
name="jpeg",
unix_path=self.__unix_path,
timeout=timeout,
user_agent=htclient.make_user_agent("KVMD"),
)
self.__client_session: (HttpStreamerClientSession | None) = None
self.__snapshot: (StreamerSnapshot | None) = None self.__snapshot: (StreamerSnapshot | None) = None
@ -300,11 +271,9 @@ class Streamer: # pylint: disable=too-many-instance-attributes
async def get_state(self) -> dict: async def get_state(self) -> dict:
streamer_state = None streamer_state = None
if self.__streamer_task: if self.__streamer_task:
session = self.__ensure_http_session() session = self.__ensure_client_session()
try: try:
async with session.get(self.__make_url("state")) as response: streamer_state = await session.get_state()
htclient.raise_not_200(response)
streamer_state = (await response.json())["result"]
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError): except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError):
pass pass
except Exception: except Exception:
@ -350,39 +319,15 @@ class Streamer: # pylint: disable=too-many-instance-attributes
if load: if load:
return self.__snapshot return self.__snapshot
logger = get_logger() logger = get_logger()
session = self.__ensure_http_session() session = self.__ensure_client_session()
try: try:
async with session.get( snapshot = await session.take_snapshot(self.__snapshot_timeout)
self.__make_url("snapshot"), if snapshot.online or allow_offline:
timeout=aiohttp.ClientTimeout(total=self.__snapshot_timeout), if save:
) as response: self.__snapshot = snapshot
self.__notifier.notify()
htclient.raise_not_200(response) return snapshot
online = (response.headers["X-UStreamer-Online"] == "true") logger.error("Stream is offline, no signal or so")
if online or allow_offline:
snapshot = StreamerSnapshot(
online=online,
width=int(response.headers["X-UStreamer-Width"]),
height=int(response.headers["X-UStreamer-Height"]),
headers=tuple(
(key, value)
for (key, value) in tools.sorted_kvs(dict(response.headers))
if key.lower().startswith("x-ustreamer-") or key.lower() in [
"x-timestamp",
"access-control-allow-origin",
"cache-control",
"pragma",
"expires",
]
),
data=bytes(await response.read()),
)
if save:
self.__snapshot = snapshot
self.__notifier.notify()
return snapshot
logger.error("Stream is offline, no signal or so")
except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as ex: except (aiohttp.ClientConnectionError, aiohttp.ServerConnectionError) as ex:
logger.error("Can't connect to streamer: %s", tools.efmt(ex)) logger.error("Can't connect to streamer: %s", tools.efmt(ex))
except Exception: except Exception:
@ -397,25 +342,14 @@ class Streamer: # pylint: disable=too-many-instance-attributes
@aiotools.atomic_fg @aiotools.atomic_fg
async def cleanup(self) -> None: async def cleanup(self) -> None:
await self.ensure_stop(immediately=True) await self.ensure_stop(immediately=True)
if self.__http_session: if self.__client_session:
await self.__http_session.close() await self.__client_session.close()
self.__http_session = None self.__client_session = None
# ===== def __ensure_client_session(self) -> HttpStreamerClientSession:
if not self.__client_session:
def __ensure_http_session(self) -> aiohttp.ClientSession: self.__client_session = self.__client.make_session()
if not self.__http_session: return self.__client_session
kwargs: dict = {
"headers": {"User-Agent": htclient.make_user_agent("KVMD")},
"connector": aiohttp.UnixConnector(path=self.__unix_path),
"timeout": aiohttp.ClientTimeout(total=self.__timeout),
}
self.__http_session = aiohttp.ClientSession(**kwargs)
return self.__http_session
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"
# ===== # =====

View File

@ -21,7 +21,7 @@
from ...clients.kvmd import KvmdClient from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamFormats from ...clients.streamer import StreamerFormats
from ...clients.streamer import BaseStreamerClient from ...clients.streamer import BaseStreamerClient
from ...clients.streamer import HttpStreamerClient from ...clients.streamer import HttpStreamerClient
from ...clients.streamer import MemsinkStreamerClient from ...clients.streamer import MemsinkStreamerClient
@ -51,8 +51,8 @@ def main(argv: (list[str] | None)=None) -> None:
return None return None
streamers: list[BaseStreamerClient] = list(filter(None, [ streamers: list[BaseStreamerClient] = list(filter(None, [
make_memsink_streamer("h264", StreamFormats.H264), make_memsink_streamer("h264", StreamerFormats.H264),
make_memsink_streamer("jpeg", StreamFormats.JPEG), make_memsink_streamer("jpeg", StreamerFormats.JPEG),
HttpStreamerClient(name="JPEG", user_agent=user_agent, **config.streamer._unpack()), HttpStreamerClient(name="JPEG", user_agent=user_agent, **config.streamer._unpack()),
])) ]))

View File

@ -42,7 +42,7 @@ from ...clients.kvmd import KvmdClient
from ...clients.streamer import StreamerError from ...clients.streamer import StreamerError
from ...clients.streamer import StreamerPermError from ...clients.streamer import StreamerPermError
from ...clients.streamer import StreamFormats from ...clients.streamer import StreamerFormats
from ...clients.streamer import BaseStreamerClient from ...clients.streamer import BaseStreamerClient
from ... import tools from ... import tools
@ -222,8 +222,8 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
def __get_preferred_streamer(self) -> BaseStreamerClient: def __get_preferred_streamer(self) -> BaseStreamerClient:
formats = { formats = {
StreamFormats.JPEG: "has_tight", StreamerFormats.JPEG: "has_tight",
StreamFormats.H264: "has_h264", StreamerFormats.H264: "has_h264",
} }
streamer: (BaseStreamerClient | None) = None streamer: (BaseStreamerClient | None) = None
for streamer in self.__streamers: for streamer in self.__streamers:
@ -249,7 +249,7 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
"data": (await make_text_jpeg(self._width, self._height, self._encodings.tight_jpeg_quality, text)), "data": (await make_text_jpeg(self._width, self._height, self._encodings.tight_jpeg_quality, text)),
"width": self._width, "width": self._width,
"height": self._height, "height": self._height,
"format": StreamFormats.JPEG, "format": StreamerFormats.JPEG,
} }
async def __fb_sender_task_loop(self) -> None: # pylint: disable=too-many-branches async def __fb_sender_task_loop(self) -> None: # pylint: disable=too-many-branches
@ -259,21 +259,21 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
frame = await self.__fb_queue.get() frame = await self.__fb_queue.get()
if ( if (
last is None # pylint: disable=too-many-boolean-expressions last is None # pylint: disable=too-many-boolean-expressions
or frame["format"] == StreamFormats.JPEG or frame["format"] == StreamerFormats.JPEG
or last["format"] != frame["format"] or last["format"] != frame["format"]
or (frame["format"] == StreamFormats.H264 and ( or (frame["format"] == StreamerFormats.H264 and (
frame["key"] frame["key"]
or last["width"] != frame["width"] or last["width"] != frame["width"]
or last["height"] != frame["height"] or last["height"] != frame["height"]
or len(last["data"]) + len(frame["data"]) > 4194304 or len(last["data"]) + len(frame["data"]) > 4194304
)) ))
): ):
self.__fb_has_key = (frame["format"] == StreamFormats.H264 and frame["key"]) self.__fb_has_key = (frame["format"] == StreamerFormats.H264 and frame["key"])
last = frame last = frame
if self.__fb_queue.qsize() == 0: if self.__fb_queue.qsize() == 0:
break break
continue continue
assert frame["format"] == StreamFormats.H264 assert frame["format"] == StreamerFormats.H264
last["data"] += frame["data"] last["data"] += frame["data"]
if self.__fb_queue.qsize() == 0: if self.__fb_queue.qsize() == 0:
break break
@ -295,9 +295,9 @@ class _Client(RfbClient): # pylint: disable=too-many-instance-attributes
await self._send_fb_allow_again() await self._send_fb_allow_again()
continue continue
if last["format"] == StreamFormats.JPEG: if last["format"] == StreamerFormats.JPEG:
await self._send_fb_jpeg(last["data"]) await self._send_fb_jpeg(last["data"])
elif last["format"] == StreamFormats.H264: elif last["format"] == StreamerFormats.H264:
if not self._encodings.has_h264: if not self._encodings.has_h264:
raise RfbError("The client doesn't want to accept H264 anymore") raise RfbError("The client doesn't want to accept H264 anymore")
if self.__fb_has_key: if self.__fb_has_key:

View File

@ -18,3 +18,67 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # # along with this program. If not, see <https://www.gnu.org/licenses/>. #
# # # #
# ========================================================================== # # ========================================================================== #
import types
from typing import Callable
from typing import Self
import aiohttp
# =====
class BaseHttpClientSession:
def __init__(self, make_http_session: Callable[[], aiohttp.ClientSession]) -> None:
self._make_http_session = make_http_session
self.__http_session: (aiohttp.ClientSession | None) = None
def _ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self._make_http_session()
return self.__http_session
async def close(self) -> None:
if self.__http_session:
await self.__http_session.close()
self.__http_session = None
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
_exc_type: type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class BaseHttpClient:
def __init__(
self,
unix_path: str,
timeout: float,
user_agent: str,
) -> None:
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def make_session(self) -> BaseHttpClientSession:
raise NotImplementedError
def _make_http_session(self, headers: (dict[str, str] | None)=None) -> aiohttp.ClientSession:
return aiohttp.ClientSession(
base_url="http://localhost:0",
headers={
"User-Agent": self.__user_agent,
**(headers or {}),
},
connector=aiohttp.UnixConnector(path=self.__unix_path),
timeout=aiohttp.ClientTimeout(total=self.__timeout),
)

View File

@ -23,7 +23,6 @@
import asyncio import asyncio
import contextlib import contextlib
import struct import struct
import types
from typing import Callable from typing import Callable
from typing import AsyncGenerator from typing import AsyncGenerator
@ -34,22 +33,19 @@ from .. import aiotools
from .. import htclient from .. import htclient
from .. import htserver from .. import htserver
from . import BaseHttpClient
from . import BaseHttpClientSession
# ===== # =====
class _BaseApiPart: class _BaseApiPart:
def __init__( def __init__(self, ensure_http_session: Callable[[], aiohttp.ClientSession]) -> None:
self,
ensure_http_session: Callable[[], aiohttp.ClientSession],
make_url: Callable[[str], str],
) -> None:
self._ensure_http_session = ensure_http_session self._ensure_http_session = ensure_http_session
self._make_url = make_url
async def _set_params(self, handle: str, **params: (int | str | None)) -> None: async def _set_params(self, handle: str, **params: (int | str | None)) -> None:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.post( async with session.post(
url=self._make_url(handle), url=handle,
params={ params={
key: value key: value
for (key, value) in params.items() for (key, value) in params.items()
@ -63,7 +59,7 @@ class _AuthApiPart(_BaseApiPart):
async def check(self) -> bool: async def check(self) -> bool:
session = self._ensure_http_session() session = self._ensure_http_session()
try: try:
async with session.get(self._make_url("auth/check")) as response: async with session.get("/auth/check") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return True return True
except aiohttp.ClientResponseError as ex: except aiohttp.ClientResponseError as ex:
@ -75,13 +71,13 @@ class _AuthApiPart(_BaseApiPart):
class _StreamerApiPart(_BaseApiPart): class _StreamerApiPart(_BaseApiPart):
async def get_state(self) -> dict: async def get_state(self) -> dict:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.get(self._make_url("streamer")) as response: async with session.get("/streamer") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return (await response.json())["result"] return (await response.json())["result"]
async def set_params(self, quality: (int | None)=None, desired_fps: (int | None)=None) -> None: async def set_params(self, quality: (int | None)=None, desired_fps: (int | None)=None) -> None:
await self._set_params( await self._set_params(
"streamer/set_params", "/streamer/set_params",
quality=quality, quality=quality,
desired_fps=desired_fps, desired_fps=desired_fps,
) )
@ -90,7 +86,7 @@ class _StreamerApiPart(_BaseApiPart):
class _HidApiPart(_BaseApiPart): class _HidApiPart(_BaseApiPart):
async def get_keymaps(self) -> tuple[str, set[str]]: async def get_keymaps(self) -> tuple[str, set[str]]:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.get(self._make_url("hid/keymaps")) as response: async with session.get("/hid/keymaps") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
result = (await response.json())["result"] result = (await response.json())["result"]
return (result["keymaps"]["default"], set(result["keymaps"]["available"])) return (result["keymaps"]["default"], set(result["keymaps"]["available"]))
@ -98,7 +94,7 @@ class _HidApiPart(_BaseApiPart):
async def print(self, text: str, limit: int, keymap_name: str) -> None: async def print(self, text: str, limit: int, keymap_name: str) -> None:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.post( async with session.post(
url=self._make_url("hid/print"), url="/hid/print",
params={"limit": limit, "keymap": keymap_name}, params={"limit": limit, "keymap": keymap_name},
data=text, data=text,
) as response: ) as response:
@ -106,7 +102,7 @@ class _HidApiPart(_BaseApiPart):
async def set_params(self, keyboard_output: (str | None)=None, mouse_output: (str | None)=None) -> None: async def set_params(self, keyboard_output: (str | None)=None, mouse_output: (str | None)=None) -> None:
await self._set_params( await self._set_params(
"hid/set_params", "/hid/set_params",
keyboard_output=keyboard_output, keyboard_output=keyboard_output,
mouse_output=mouse_output, mouse_output=mouse_output,
) )
@ -115,7 +111,7 @@ class _HidApiPart(_BaseApiPart):
class _AtxApiPart(_BaseApiPart): class _AtxApiPart(_BaseApiPart):
async def get_state(self) -> dict: async def get_state(self) -> dict:
session = self._ensure_http_session() session = self._ensure_http_session()
async with session.get(self._make_url("atx")) as response: async with session.get("/atx") as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
return (await response.json())["result"] return (await response.json())["result"]
@ -123,7 +119,7 @@ class _AtxApiPart(_BaseApiPart):
session = self._ensure_http_session() session = self._ensure_http_session()
try: try:
async with session.post( async with session.post(
url=self._make_url("atx/power"), url="/atx/power",
params={"action": action}, params={"action": action},
) as response: ) as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
@ -138,7 +134,6 @@ class _AtxApiPart(_BaseApiPart):
class KvmdClientWs: class KvmdClientWs:
def __init__(self, ws: aiohttp.ClientWebSocketResponse) -> None: def __init__(self, ws: aiohttp.ClientWebSocketResponse) -> None:
self.__ws = ws self.__ws = ws
self.__writer_queue: "asyncio.Queue[tuple[str, dict] | bytes]" = asyncio.Queue() self.__writer_queue: "asyncio.Queue[tuple[str, dict] | bytes]" = asyncio.Queue()
self.__communicated = False self.__communicated = False
@ -200,83 +195,25 @@ class KvmdClientWs:
await self.__writer_queue.put(struct.pack(">bbbb", 5, 0, delta_x, delta_y)) await self.__writer_queue.put(struct.pack(">bbbb", 5, 0, delta_x, delta_y))
class KvmdClientSession: class KvmdClientSession(BaseHttpClientSession):
def __init__( def __init__(self, make_http_session: Callable[[], aiohttp.ClientSession]) -> None:
self, super().__init__(make_http_session)
make_http_session: Callable[[], aiohttp.ClientSession], self.auth = _AuthApiPart(self._ensure_http_session)
make_url: Callable[[str], str], self.streamer = _StreamerApiPart(self._ensure_http_session)
) -> None: self.hid = _HidApiPart(self._ensure_http_session)
self.atx = _AtxApiPart(self._ensure_http_session)
self.__make_http_session = make_http_session
self.__make_url = make_url
self.__http_session: (aiohttp.ClientSession | None) = None
args = (self.__ensure_http_session, make_url)
self.auth = _AuthApiPart(*args)
self.streamer = _StreamerApiPart(*args)
self.hid = _HidApiPart(*args)
self.atx = _AtxApiPart(*args)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def ws(self) -> AsyncGenerator[KvmdClientWs, None]: async def ws(self) -> AsyncGenerator[KvmdClientWs, None]:
session = self.__ensure_http_session() session = self._ensure_http_session()
async with session.ws_connect(self.__make_url("ws"), params={"legacy": 0}) as ws: async with session.ws_connect("/ws", params={"legacy": "0"}) as ws:
yield KvmdClientWs(ws) yield KvmdClientWs(ws)
def __ensure_http_session(self) -> aiohttp.ClientSession:
if not self.__http_session:
self.__http_session = self.__make_http_session()
return self.__http_session
async def close(self) -> None: class KvmdClient(BaseHttpClient):
if self.__http_session: def make_session(self, user: str="", passwd: str="") -> KvmdClientSession:
await self.__http_session.close() headers = {
self.__http_session = None "X-KVMD-User": user,
"X-KVMD-Passwd": passwd,
async def __aenter__(self) -> "KvmdClientSession": }
return self return KvmdClientSession(lambda: self._make_http_session(headers))
async def __aexit__(
self,
_exc_type: type[BaseException],
_exc: BaseException,
_tb: types.TracebackType,
) -> None:
await self.close()
class KvmdClient:
def __init__(
self,
unix_path: str,
timeout: float,
user_agent: str,
) -> None:
self.__unix_path = unix_path
self.__timeout = timeout
self.__user_agent = user_agent
def make_session(self, user: str, passwd: str) -> KvmdClientSession:
return KvmdClientSession(
make_http_session=(lambda: self.__make_http_session(user, passwd)),
make_url=self.__make_url,
)
def __make_http_session(self, user: str, passwd: str) -> aiohttp.ClientSession:
return aiohttp.ClientSession(
headers={
"X-KVMD-User": user,
"X-KVMD-Passwd": passwd,
"User-Agent": self.__user_agent,
},
connector=aiohttp.UnixConnector(path=self.__unix_path),
timeout=aiohttp.ClientTimeout(total=self.__timeout),
)
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"

View File

@ -20,7 +20,10 @@
# ========================================================================== # # ========================================================================== #
import io
import contextlib import contextlib
import dataclasses
import functools
import types import types
from typing import Callable from typing import Callable
@ -31,10 +34,15 @@ from typing import AsyncGenerator
import aiohttp import aiohttp
import ustreamer import ustreamer
from PIL import Image as PilImage
from .. import tools from .. import tools
from .. import aiotools from .. import aiotools
from .. import htclient from .. import htclient
from . import BaseHttpClient
from . import BaseHttpClientSession
# ===== # =====
class StreamerError(Exception): class StreamerError(Exception):
@ -50,7 +58,7 @@ class StreamerPermError(StreamerError):
# ===== # =====
class StreamFormats: class StreamerFormats:
JPEG = 1195724874 # V4L2_PIX_FMT_JPEG JPEG = 1195724874 # V4L2_PIX_FMT_JPEG
H264 = 875967048 # V4L2_PIX_FMT_H264 H264 = 875967048 # V4L2_PIX_FMT_H264
_MJPEG = 1196444237 # V4L2_PIX_FMT_MJPEG _MJPEG = 1196444237 # V4L2_PIX_FMT_MJPEG
@ -68,8 +76,76 @@ class BaseStreamerClient:
# ===== # =====
@dataclasses.dataclass(frozen=True)
class StreamerSnapshot:
online: bool
width: int
height: int
headers: tuple[tuple[str, str], ...]
data: bytes
async def make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
assert max_width >= 0
assert max_height >= 0
assert quality > 0
if max_width == 0 and max_height == 0:
max_width = self.width // 5
max_height = self.height // 5
else:
max_width = min((max_width or self.width), self.width)
max_height = min((max_height or self.height), self.height)
if (max_width, max_height) == (self.width, self.height):
return self.data
return (await aiotools.run_async(self.__inner_make_preview, max_width, max_height, quality))
@functools.lru_cache(maxsize=1)
def __inner_make_preview(self, max_width: int, max_height: int, quality: int) -> bytes:
with io.BytesIO(self.data) as snapshot_bio:
with io.BytesIO() as preview_bio:
with PilImage.open(snapshot_bio) as image:
image.thumbnail((max_width, max_height), PilImage.Resampling.LANCZOS)
image.save(preview_bio, format="jpeg", quality=quality)
return preview_bio.getvalue()
class HttpStreamerClientSession(BaseHttpClientSession):
async def get_state(self) -> dict:
session = self._ensure_http_session()
async with session.get("/state") as response:
htclient.raise_not_200(response)
return (await response.json())["result"]
async def take_snapshot(self, timeout: float) -> StreamerSnapshot:
session = self._ensure_http_session()
async with session.get(
url="/snapshot",
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
htclient.raise_not_200(response)
return StreamerSnapshot(
online=(response.headers["X-UStreamer-Online"] == "true"),
width=int(response.headers["X-UStreamer-Width"]),
height=int(response.headers["X-UStreamer-Height"]),
headers=tuple(
(key, value)
for (key, value) in tools.sorted_kvs(dict(response.headers))
if key.lower().startswith("x-ustreamer-") or key.lower() in [
"x-timestamp",
"access-control-allow-origin",
"cache-control",
"pragma",
"expires",
]
),
data=bytes(await response.read()),
)
@contextlib.contextmanager @contextlib.contextmanager
def _http_handle_errors() -> Generator[None, None, None]: def _http_reading_handle_errors() -> Generator[None, None, None]:
try: try:
yield yield
except Exception as ex: # Тут бывают и ассерты, и KeyError, и прочая херня except Exception as ex: # Тут бывают и ассерты, и KeyError, и прочая херня
@ -78,7 +154,7 @@ def _http_handle_errors() -> Generator[None, None, None]:
raise StreamerTempError(tools.efmt(ex)) raise StreamerTempError(tools.efmt(ex))
class HttpStreamerClient(BaseStreamerClient): class HttpStreamerClient(BaseHttpClient, BaseStreamerClient):
def __init__( def __init__(
self, self,
name: str, name: str,
@ -87,29 +163,35 @@ class HttpStreamerClient(BaseStreamerClient):
user_agent: str, user_agent: str,
) -> None: ) -> None:
super().__init__(unix_path, timeout, user_agent)
self.__name = name self.__name = name
self.__unix_path = unix_path
self.__timeout = timeout def make_session(self) -> HttpStreamerClientSession:
self.__user_agent = user_agent return HttpStreamerClientSession(self._make_http_session)
def get_format(self) -> int: def get_format(self) -> int:
return StreamFormats.JPEG return StreamerFormats.JPEG
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]: async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]:
with _http_handle_errors(): with _http_reading_handle_errors():
async with self.__make_http_session() as session: async with self._make_http_session() as session:
async with session.get( async with session.get(
url=self.__make_url("stream"), url="/stream",
params={"extra_headers": "1"}, params={"extra_headers": "1"},
timeout=aiohttp.ClientTimeout(
connect=session.timeout.total,
sock_read=session.timeout.total,
),
) as response: ) as response:
htclient.raise_not_200(response) htclient.raise_not_200(response)
reader = aiohttp.MultipartReader.from_response(response) reader = aiohttp.MultipartReader.from_response(response)
self.__patch_stream_reader(reader.resp.content) self.__patch_stream_reader(reader.resp.content)
async def read_frame(key_required: bool) -> dict: async def read_frame(key_required: bool) -> dict:
_ = key_required _ = key_required
with _http_handle_errors(): with _http_reading_handle_errors():
frame = await reader.next() # pylint: disable=not-callable frame = await reader.next() # pylint: disable=not-callable
if not isinstance(frame, aiohttp.BodyPartReader): if not isinstance(frame, aiohttp.BodyPartReader):
raise StreamerTempError("Expected body part") raise StreamerTempError("Expected body part")
@ -123,26 +205,11 @@ class HttpStreamerClient(BaseStreamerClient):
"width": int(frame.headers["X-UStreamer-Width"]), "width": int(frame.headers["X-UStreamer-Width"]),
"height": int(frame.headers["X-UStreamer-Height"]), "height": int(frame.headers["X-UStreamer-Height"]),
"data": data, "data": data,
"format": StreamFormats.JPEG, "format": StreamerFormats.JPEG,
} }
yield read_frame yield read_frame
def __make_http_session(self) -> aiohttp.ClientSession:
kwargs: dict = {
"headers": {"User-Agent": self.__user_agent},
"connector": aiohttp.UnixConnector(path=self.__unix_path),
"timeout": aiohttp.ClientTimeout(
connect=self.__timeout,
sock_read=self.__timeout,
),
}
return aiohttp.ClientSession(**kwargs)
def __make_url(self, handle: str) -> str:
assert not handle.startswith("/"), handle
return f"http://localhost:0/{handle}"
def __patch_stream_reader(self, reader: aiohttp.StreamReader) -> None: def __patch_stream_reader(self, reader: aiohttp.StreamReader) -> None:
# https://github.com/pikvm/pikvm/issues/92 # https://github.com/pikvm/pikvm/issues/92
# Infinite looping in BodyPartReader.read() because _at_eof flag. # Infinite looping in BodyPartReader.read() because _at_eof flag.
@ -162,7 +229,7 @@ class HttpStreamerClient(BaseStreamerClient):
# ===== # =====
@contextlib.contextmanager @contextlib.contextmanager
def _memsink_handle_errors() -> Generator[None, None, None]: def _memsink_reading_handle_errors() -> Generator[None, None, None]:
try: try:
yield yield
except StreamerPermError: except StreamerPermError:
@ -198,11 +265,11 @@ class MemsinkStreamerClient(BaseStreamerClient):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]: async def reading(self) -> AsyncGenerator[Callable[[bool], Awaitable[dict]], None]:
with _memsink_handle_errors(): with _memsink_reading_handle_errors():
with ustreamer.Memsink(**self.__kwargs) as sink: with ustreamer.Memsink(**self.__kwargs) as sink:
async def read_frame(key_required: bool) -> dict: async def read_frame(key_required: bool) -> dict:
key_required = (key_required and self.__fmt == StreamFormats.H264) key_required = (key_required and self.__fmt == StreamerFormats.H264)
with _memsink_handle_errors(): with _memsink_reading_handle_errors():
while True: while True:
frame = await aiotools.run_async(sink.wait_frame, key_required) frame = await aiotools.run_async(sink.wait_frame, key_required)
if frame is not None: if frame is not None:
@ -211,8 +278,8 @@ class MemsinkStreamerClient(BaseStreamerClient):
yield read_frame yield read_frame
def __check_format(self, fmt: int) -> None: def __check_format(self, fmt: int) -> None:
if fmt == StreamFormats._MJPEG: # pylint: disable=protected-access if fmt == StreamerFormats._MJPEG: # pylint: disable=protected-access
fmt = StreamFormats.JPEG fmt = StreamerFormats.JPEG
if fmt != self.__fmt: if fmt != self.__fmt:
raise StreamerPermError("Invalid sink format") raise StreamerPermError("Invalid sink format")