mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2026-02-02 11:01:53 +08:00
validators, tests
This commit is contained in:
@@ -45,15 +45,15 @@ def main() -> None:
|
||||
# pylint: disable=protected-access
|
||||
loop = asyncio.get_event_loop()
|
||||
Server(
|
||||
auth_manager=AuthManager(**config.auth._unpack_renamed()),
|
||||
info_manager=InfoManager(loop=loop, **config.info._unpack_renamed()),
|
||||
auth_manager=AuthManager(**config.auth._unpack()),
|
||||
info_manager=InfoManager(loop=loop, **config.info._unpack()),
|
||||
log_reader=LogReader(loop=loop),
|
||||
|
||||
hid=Hid(**config.hid._unpack_renamed()),
|
||||
atx=Atx(**config.atx._unpack_renamed()),
|
||||
msd=MassStorageDevice(loop=loop, **config.msd._unpack_renamed()),
|
||||
streamer=Streamer(loop=loop, **config.streamer._unpack_renamed()),
|
||||
hid=Hid(**config.hid._unpack()),
|
||||
atx=Atx(**config.atx._unpack()),
|
||||
msd=MassStorageDevice(loop=loop, **config.msd._unpack()),
|
||||
streamer=Streamer(loop=loop, **config.streamer._unpack()),
|
||||
|
||||
loop=loop,
|
||||
).run(**config.server._unpack_renamed())
|
||||
).run(**config.server._unpack())
|
||||
get_logger().info("Bye-bye")
|
||||
|
||||
@@ -46,7 +46,7 @@ from ... import gpio
|
||||
|
||||
# =====
|
||||
def _get_keymap() -> Dict[str, int]:
|
||||
return yaml.load(pkgutil.get_data("kvmd", "data/keymap.yaml").decode()) # type: ignore
|
||||
return yaml.safe_load(pkgutil.get_data("kvmd", "data/keymap.yaml").decode()) # type: ignore
|
||||
|
||||
|
||||
_KEYMAP = _get_keymap()
|
||||
|
||||
@@ -21,7 +21,6 @@
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import asyncio
|
||||
@@ -36,7 +35,6 @@ from typing import Dict
|
||||
from typing import Set
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Any
|
||||
|
||||
import aiohttp.web
|
||||
import setproctitle
|
||||
@@ -45,6 +43,18 @@ from ...logging import get_logger
|
||||
|
||||
from ...aioregion import RegionIsBusyError
|
||||
|
||||
from ...validators import ValidatorError
|
||||
|
||||
from ...validators.basic import valid_bool
|
||||
from ...validators.auth import valid_user
|
||||
from ...validators.auth import valid_passwd
|
||||
from ...validators.auth import valid_auth_token
|
||||
from ...validators.kvm import valid_atx_button
|
||||
from ...validators.kvm import valid_kvm_target
|
||||
from ...validators.kvm import valid_log_seek
|
||||
from ...validators.kvm import valid_stream_quality
|
||||
from ...validators.kvm import valid_stream_fps
|
||||
|
||||
from ... import __version__
|
||||
|
||||
from .auth import AuthManager
|
||||
@@ -80,10 +90,6 @@ class HttpError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadRequestError(HttpError):
|
||||
pass
|
||||
|
||||
|
||||
class UnauthorizedError(HttpError):
|
||||
pass
|
||||
|
||||
@@ -138,7 +144,7 @@ def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
|
||||
if auth_required:
|
||||
token = request.cookies.get(_COOKIE_AUTH_TOKEN, "")
|
||||
if token:
|
||||
user = self._auth_manager.check(_valid_token(token))
|
||||
user = self._auth_manager.check(valid_auth_token(token))
|
||||
if not user:
|
||||
raise ForbiddenError("Forbidden")
|
||||
setattr(request, _ATTR_KVMD_USER, user)
|
||||
@@ -149,7 +155,7 @@ def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
|
||||
|
||||
except RegionIsBusyError as err:
|
||||
return _json_exception(err, 409)
|
||||
except (BadRequestError, AtxOperationError, MsdOperationError) as err:
|
||||
except (ValidatorError, AtxOperationError, MsdOperationError) as err:
|
||||
return _json_exception(err, 400)
|
||||
except UnauthorizedError as err:
|
||||
return _json_exception(err, 401)
|
||||
@@ -178,51 +184,6 @@ def _system_task(method: Callable) -> Callable:
|
||||
return wrap
|
||||
|
||||
|
||||
def _valid_user(user: Any) -> str:
|
||||
if isinstance(user, str):
|
||||
stripped = user.strip()
|
||||
if re.match(r"^[a-z_][a-z0-9_-]*$", stripped):
|
||||
return stripped
|
||||
raise BadRequestError("Invalid user characters %r" % (user))
|
||||
|
||||
|
||||
def _valid_passwd(passwd: Any) -> str:
|
||||
if isinstance(passwd, str):
|
||||
if re.match(r"[\x20-\x7e]*$", passwd):
|
||||
return passwd
|
||||
raise BadRequestError("Invalid password characters")
|
||||
|
||||
|
||||
def _valid_token(token: Optional[str]) -> str:
|
||||
if isinstance(token, str):
|
||||
token = token.strip().lower()
|
||||
if re.match(r"^[0-9a-f]{64}$", token):
|
||||
return token
|
||||
raise BadRequestError("Invalid auth token characters")
|
||||
|
||||
|
||||
def _valid_bool(name: str, flag: Optional[str]) -> bool:
|
||||
flag = str(flag).strip().lower()
|
||||
if flag in ["1", "true", "yes"]:
|
||||
return True
|
||||
elif flag in ["0", "false", "no"]:
|
||||
return False
|
||||
raise BadRequestError("Invalid param '%s'" % (name))
|
||||
|
||||
|
||||
def _valid_int(name: str, value: Optional[str], min_value: Optional[int]=None, max_value: Optional[int]=None) -> int:
|
||||
try:
|
||||
value_int = int(value) # type: ignore
|
||||
if (
|
||||
(min_value is not None and value_int < min_value)
|
||||
or (max_value is not None and value_int > max_value)
|
||||
):
|
||||
raise ValueError()
|
||||
return value_int
|
||||
except Exception:
|
||||
raise BadRequestError("Invalid param %r" % (name))
|
||||
|
||||
|
||||
class _Events(Enum):
|
||||
INFO_STATE = "info_state"
|
||||
HID_STATE = "hid_state"
|
||||
@@ -337,8 +298,8 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
async def __auth_login_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
credentials = await request.post()
|
||||
token = self._auth_manager.login(
|
||||
user=_valid_user(credentials.get("user", "")),
|
||||
passwd=_valid_passwd(credentials.get("passwd", "")),
|
||||
user=valid_user(credentials.get("user", "")),
|
||||
passwd=valid_passwd(credentials.get("passwd", "")),
|
||||
)
|
||||
if token:
|
||||
return _json({}, set_cookies={_COOKIE_AUTH_TOKEN: token})
|
||||
@@ -346,7 +307,7 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@_exposed("POST", "/auth/logout")
|
||||
async def __auth_logout_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
token = _valid_token(request.cookies.get(_COOKIE_AUTH_TOKEN, ""))
|
||||
token = valid_auth_token(request.cookies.get(_COOKIE_AUTH_TOKEN, ""))
|
||||
self._auth_manager.logout(token)
|
||||
return _json({})
|
||||
|
||||
@@ -362,8 +323,8 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@_exposed("GET", "/log")
|
||||
async def __log_handler(self, request: aiohttp.web.Request) -> aiohttp.web.StreamResponse:
|
||||
seek = _valid_int("seek", request.query.get("seek", "0"), 0)
|
||||
follow = _valid_bool("follow", request.query.get("follow", "false"))
|
||||
seek = valid_log_seek(request.query.get("seek", "0"))
|
||||
follow = valid_bool(request.query.get("follow", "false"))
|
||||
response = aiohttp.web.StreamResponse(status=200, reason="OK", headers={"Content-Type": "text/plain"})
|
||||
await response.prepare(request)
|
||||
async for record in self.__log_reader.poll_log(seek, follow):
|
||||
@@ -460,15 +421,12 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@_exposed("POST", "/atx/click")
|
||||
async def __atx_click_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
button = request.query.get("button")
|
||||
clicker = {
|
||||
button = valid_atx_button(request.query.get("button"))
|
||||
await ({
|
||||
"power": self.__atx.click_power,
|
||||
"power_long": self.__atx.click_power_long,
|
||||
"reset": self.__atx.click_reset,
|
||||
}.get(button)
|
||||
if not clicker:
|
||||
raise BadRequestError("Invalid param 'button'")
|
||||
await clicker()
|
||||
}[button])()
|
||||
return _json({"clicked": button})
|
||||
|
||||
# ===== MSD
|
||||
@@ -479,13 +437,11 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@_exposed("POST", "/msd/connect")
|
||||
async def __msd_connect_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
to = request.query.get("to")
|
||||
if to == "kvm":
|
||||
return _json(await self.__msd.connect_to_kvm())
|
||||
elif to == "server":
|
||||
return _json(await self.__msd.connect_to_pc())
|
||||
else:
|
||||
raise BadRequestError("Invalid param 'to'")
|
||||
to = valid_kvm_target(request.query.get("to"))
|
||||
return _json(await ({
|
||||
"kvm": self.__msd.connect_to_kvm,
|
||||
"server": self.__msd.connect_to_pc,
|
||||
}[to])())
|
||||
|
||||
@_exposed("POST", "/msd/write")
|
||||
async def __msd_write_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
@@ -496,12 +452,12 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
async with self.__msd:
|
||||
field = await reader.next()
|
||||
if not field or field.name != "image_name":
|
||||
raise BadRequestError("Missing 'image_name' field")
|
||||
raise ValidatorError("Missing 'image_name' field")
|
||||
image_name = (await field.read()).decode("utf-8")[:256]
|
||||
|
||||
field = await reader.next()
|
||||
if not field or field.name != "image_data":
|
||||
raise BadRequestError("Missing 'image_data' field")
|
||||
raise ValidatorError("Missing 'image_data' field")
|
||||
|
||||
logger.info("Writing image %r to mass-storage device ...", image_name)
|
||||
await self.__msd.write_image_info(image_name, False)
|
||||
@@ -530,8 +486,8 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
@_exposed("POST", "/streamer/set_params")
|
||||
async def __streamer_set_params_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
for (name, validator) in [
|
||||
("quality", lambda arg: _valid_int("quality", arg, 1, 100)),
|
||||
("desired_fps", lambda arg: _valid_int("desired_fps", arg, 0, 30)),
|
||||
("quality", valid_stream_quality),
|
||||
("desired_fps", valid_stream_fps),
|
||||
]:
|
||||
value = request.query.get(name)
|
||||
if value:
|
||||
|
||||
@@ -63,8 +63,8 @@ class Streamer: # pylint: disable=too-many-instance-attributes
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> None:
|
||||
|
||||
self.__cap_pin = (gpio.set_output(cap_pin) if cap_pin > 0 else 0)
|
||||
self.__conv_pin = (gpio.set_output(conv_pin) if conv_pin > 0 else 0)
|
||||
self.__cap_pin = (gpio.set_output(cap_pin) if cap_pin >= 0 else -1)
|
||||
self.__conv_pin = (gpio.set_output(conv_pin) if conv_pin >= 0 else -1)
|
||||
|
||||
self.__sync_delay = sync_delay
|
||||
self.__init_delay = init_delay
|
||||
@@ -179,9 +179,9 @@ class Streamer: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
async def __set_hw_enabled(self, enabled: bool) -> None:
|
||||
# XXX: This sequence is very important to enable converter and cap board
|
||||
if self.__cap_pin > 0:
|
||||
if self.__cap_pin >= 0:
|
||||
gpio.write(self.__cap_pin, enabled)
|
||||
if self.__conv_pin > 0:
|
||||
if self.__conv_pin >= 0:
|
||||
if enabled:
|
||||
await asyncio.sleep(self.__sync_delay)
|
||||
gpio.write(self.__conv_pin, enabled)
|
||||
|
||||
Reference in New Issue
Block a user