validators, tests

This commit is contained in:
Devaev Maxim
2019-04-06 05:32:02 +03:00
parent 73e04b71ed
commit 1d75b738a0
44 changed files with 1616 additions and 311 deletions

View File

@@ -29,14 +29,13 @@ import logging.config
from typing import Tuple
from typing import List
from typing import Dict
from typing import Sequence
from typing import Optional
from typing import Union
import pygments
import pygments.lexers.data
import pygments.formatters
from ..yamlconf import ConfigError
from ..yamlconf import make_config
from ..yamlconf import Section
from ..yamlconf import Option
@@ -44,31 +43,59 @@ from ..yamlconf import build_raw_from_options
from ..yamlconf.dumper import make_config_dump
from ..yamlconf.loader import load_yaml_file
from ..validators.basic import valid_bool
from ..validators.basic import valid_number
from ..validators.basic import valid_int_f1
from ..validators.basic import valid_float_f01
from ..validators.fs import valid_abs_path
from ..validators.fs import valid_abs_path_exists
from ..validators.fs import valid_unix_mode
from ..validators.net import valid_ip_or_host
from ..validators.net import valid_port
from ..validators.auth import valid_auth_type
from ..validators.kvm import valid_stream_quality
from ..validators.kvm import valid_stream_fps
from ..validators.hw import valid_tty_speed
from ..validators.hw import valid_gpio_pin
from ..validators.hw import valid_gpio_pin_optional
# =====
def init(
prog: str=sys.argv[0],
prog: Optional[str]=None,
description: Optional[str]=None,
add_help: bool=True,
argv: Optional[List[str]]=None,
) -> Tuple[argparse.ArgumentParser, List[str], Section]:
args_parser = argparse.ArgumentParser(prog=prog, description=description, add_help=add_help)
argv = (argv or sys.argv)
assert len(argv) > 0
args_parser = argparse.ArgumentParser(prog=(prog or argv[0]), description=description, add_help=add_help)
args_parser.add_argument("-c", "--config", dest="config_path", default="/etc/kvmd/main.yaml", metavar="<file>",
help="Set config file path")
type=valid_abs_path_exists, help="Set config file path")
args_parser.add_argument("-o", "--set-options", dest="set_options", default=[], nargs="+",
help="Override config options list (like sec/sub/opt=value)")
args_parser.add_argument("-m", "--dump-config", dest="dump_config", action="store_true",
help="View current configuration (include all overrides)")
(options, remaining) = args_parser.parse_known_args(sys.argv)
(options, remaining) = args_parser.parse_known_args(argv)
raw_config: Dict = {}
options.config_path = os.path.expanduser(options.config_path)
if os.path.exists(options.config_path):
if options.config_path:
options.config_path = os.path.expanduser(options.config_path)
raw_config = load_yaml_file(options.config_path)
else:
raw_config = {}
_merge_dicts(raw_config, build_raw_from_options(options.set_options))
scheme = _get_config_scheme()
config = make_config(raw_config, scheme)
try:
_merge_dicts(raw_config, build_raw_from_options(options.set_options))
config = make_config(raw_config, scheme)
except ConfigError as err:
raise SystemExit("Config error: " + str(err))
if options.dump_config:
dump = make_config_dump(config)
@@ -96,135 +123,93 @@ def _merge_dicts(dest: Dict, src: Dict) -> None:
dest[key] = src[key]
def _as_pin(pin: int) -> int:
if not isinstance(pin, int) or pin <= 0:
raise ValueError("Invalid pin number")
return pin
def _as_optional_pin(pin: int) -> int:
if not isinstance(pin, int) or pin < -1:
raise ValueError("Invalid optional pin number")
return pin
def _as_path(path: str) -> str:
if not isinstance(path, str):
raise ValueError("Invalid path")
path = str(path).strip()
if not path:
raise ValueError("Invalid path")
return path
def _as_optional_path(path: str) -> str:
if not isinstance(path, str):
raise ValueError("Invalid path")
return str(path).strip()
def _as_string_list(values: Union[str, Sequence]) -> List[str]:
if isinstance(values, str):
values = [values]
return list(map(str, values))
def _as_auth_type(auth_type: str) -> str:
if not isinstance(auth_type, str):
raise ValueError("Invalid auth type")
auth_type = str(auth_type).strip()
if auth_type not in ["basic"]:
raise ValueError("Invalid auth type")
return auth_type
def _get_config_scheme() -> Dict:
return {
"kvmd": {
"server": {
"host": Option("localhost"),
"port": Option(0),
"unix": Option("", type=_as_optional_path, rename="unix_path"),
"unix_rm": Option(False),
"unix_mode": Option(0),
"heartbeat": Option(3.0),
"host": Option("localhost", type=valid_ip_or_host),
"port": Option(0, type=valid_port),
"unix": Option("", type=valid_abs_path, only_if="!port", unpack_as="unix_path"),
"unix_rm": Option(False, type=valid_bool),
"unix_mode": Option(0, type=valid_unix_mode),
"heartbeat": Option(3.0, type=valid_float_f01),
"access_log_format": Option("[%P / %{X-Real-IP}i] '%r' => %s; size=%b ---"
" referer='%{Referer}i'; user_agent='%{User-Agent}i'"),
},
"auth": {
"type": Option("basic", type=_as_auth_type, rename="auth_type"),
"type": Option("basic", type=valid_auth_type, unpack_as="auth_type"),
"basic": {
"htpasswd": Option("/etc/kvmd/htpasswd", type=_as_path, rename="htpasswd_path"),
"htpasswd": Option("/etc/kvmd/htpasswd", type=valid_abs_path_exists, unpack_as="htpasswd_path"),
},
},
"info": {
"meta": Option("/etc/kvmd/meta.yaml", type=_as_path, rename="meta_path"),
"extras": Option("/usr/share/kvmd/extras", type=_as_path, rename="extras_path"),
"meta": Option("/etc/kvmd/meta.yaml", type=valid_abs_path_exists, unpack_as="meta_path"),
"extras": Option("/usr/share/kvmd/extras", type=valid_abs_path_exists, unpack_as="extras_path"),
},
"hid": {
"reset_pin": Option(0, type=_as_pin),
"reset_delay": Option(0.1),
"reset_pin": Option(-1, type=valid_gpio_pin),
"reset_delay": Option(0.1, type=valid_float_f01),
"device": Option("", type=_as_path, rename="device_path"),
"speed": Option(115200),
"read_timeout": Option(2.0),
"read_retries": Option(10),
"common_retries": Option(100),
"retries_delay": Option(0.1),
"noop": Option(False),
"device": Option("", type=valid_abs_path_exists, unpack_as="device_path"),
"speed": Option(115200, type=valid_tty_speed),
"read_timeout": Option(2.0, type=valid_float_f01),
"read_retries": Option(10, type=valid_int_f1),
"common_retries": Option(100, type=valid_int_f1),
"retries_delay": Option(0.1, type=valid_float_f01),
"noop": Option(False, type=valid_bool),
"state_poll": Option(0.1),
"state_poll": Option(0.1, type=valid_float_f01),
},
"atx": {
"enabled": Option(True),
"enabled": Option(True, type=valid_bool),
"power_led_pin": Option(-1, type=_as_optional_pin),
"hdd_led_pin": Option(-1, type=_as_optional_pin),
"power_switch_pin": Option(-1, type=_as_optional_pin),
"reset_switch_pin": Option(-1, type=_as_optional_pin),
"power_led_pin": Option(-1, type=valid_gpio_pin, only_if="enabled"),
"hdd_led_pin": Option(-1, type=valid_gpio_pin, only_if="enabled"),
"power_switch_pin": Option(-1, type=valid_gpio_pin, only_if="enabled"),
"reset_switch_pin": Option(-1, type=valid_gpio_pin, only_if="enabled"),
"click_delay": Option(0.1),
"long_click_delay": Option(5.5),
"click_delay": Option(0.1, type=valid_float_f01),
"long_click_delay": Option(5.5, type=valid_float_f01),
"state_poll": Option(0.1),
"state_poll": Option(0.1, type=valid_float_f01),
},
"msd": {
"enabled": Option(True),
"enabled": Option(True, type=valid_bool),
"target_pin": Option(-1, type=_as_optional_pin),
"reset_pin": Option(-1, type=_as_optional_pin),
"target_pin": Option(-1, type=valid_gpio_pin, only_if="enabled"),
"reset_pin": Option(-1, type=valid_gpio_pin, only_if="enabled"),
"device": Option("", type=_as_optional_path, rename="device_path"),
"init_delay": Option(2.0),
"reset_delay": Option(1.0),
"write_meta": Option(True),
"chunk_size": Option(65536),
"device": Option("", type=valid_abs_path, only_if="enabled", unpack_as="device_path"),
"init_delay": Option(2.0, type=valid_float_f01),
"reset_delay": Option(1.0, type=valid_float_f01),
"write_meta": Option(True, type=valid_bool),
"chunk_size": Option(65536, type=(lambda arg: valid_number(arg, min=1024))),
},
"streamer": {
"cap_pin": Option(0, type=_as_optional_pin),
"conv_pin": Option(0, type=_as_optional_pin),
"cap_pin": Option(-1, type=valid_gpio_pin_optional),
"conv_pin": Option(-1, type=valid_gpio_pin_optional),
"sync_delay": Option(1.0),
"init_delay": Option(1.0),
"init_restart_after": Option(0.0),
"shutdown_delay": Option(10.0),
"state_poll": Option(1.0),
"sync_delay": Option(1.0, type=valid_float_f01),
"init_delay": Option(1.0, type=valid_float_f01),
"init_restart_after": Option(0.0, type=(lambda arg: valid_number(arg, min=0.0, type=float))),
"shutdown_delay": Option(10.0, type=valid_float_f01),
"state_poll": Option(1.0, type=valid_float_f01),
"quality": Option(80),
"desired_fps": Option(0),
"quality": Option(80, type=valid_stream_quality),
"desired_fps": Option(0, type=valid_stream_fps),
"host": Option("localhost"),
"port": Option(0),
"unix": Option("", type=_as_optional_path, rename="unix_path"),
"timeout": Option(2.0),
"host": Option("localhost", type=valid_ip_or_host),
"port": Option(0, type=valid_port),
"unix": Option("", type=valid_abs_path, only_if="!port", unpack_as="unix_path"),
"timeout": Option(2.0, type=valid_float_f01),
"cmd": Option(["/bin/true"], type=_as_string_list),
"cmd": Option(["/bin/true"]), # TODO: Validator
},
},

View File

@@ -24,6 +24,9 @@ import os
import subprocess
import time
from typing import List
from typing import Optional
from ...logging import get_logger
from ... import gpio
@@ -32,8 +35,8 @@ from .. import init
# =====
def main() -> None:
config = init("kvmd-cleanup", description="Kill KVMD and clear resources")[2].kvmd
def main(argv: Optional[List[str]]=None) -> None:
config = init("kvmd-cleanup", description="Kill KVMD and clear resources", argv=argv)[2].kvmd
logger = get_logger(0)
logger.info("Cleaning up ...")
@@ -47,7 +50,7 @@ def main() -> None:
("streamer_cap_pin", config.streamer.cap_pin),
("streamer_conv_pin", config.streamer.conv_pin),
]:
if pin > 0:
if pin >= 0:
logger.info("Writing value=0 to pin=%d (%s)", pin, name)
gpio.set_output(pin, initial=False)

View File

@@ -22,7 +22,6 @@
import sys
import os
import re
import getpass
import tempfile
import contextlib
@@ -34,12 +33,16 @@ import passlib.apache
from ...yamlconf import Section
from ...validators import ValidatorError
from ...validators.auth import valid_user
from ...validators.auth import valid_passwd
from .. import init
# =====
def _get_htpasswd_path(config: Section) -> str:
if config.kvmd.auth.auth_type != "basic":
if config.kvmd.auth.type != "basic":
print("Warning: KVMD does not use basic auth", file=sys.stderr)
return config.kvmd.auth.basic.htpasswd
@@ -69,13 +72,6 @@ def _get_htpasswd_for_write(config: Section) -> Generator[passlib.apache.Htpassw
os.remove(tmp_path)
def _valid_user(user: str) -> str:
stripped = user.strip()
if re.match(r"^[a-z_][a-z0-9_-]*$", stripped):
return stripped
raise SystemExit("Invalid user %r" % (user))
# ====
def _cmd_list(config: Section, _: argparse.Namespace) -> None:
for user in passlib.apache.HtpasswdFile(_get_htpasswd_path(config)).users():
@@ -85,10 +81,10 @@ def _cmd_list(config: Section, _: argparse.Namespace) -> None:
def _cmd_set(config: Section, options: argparse.Namespace) -> None:
with _get_htpasswd_for_write(config) as htpasswd:
if options.read_stdin:
passwd = input()
passwd = valid_passwd(input())
else:
passwd = getpass.getpass("Password: ", stream=sys.stderr)
if getpass.getpass("Repeat: ", stream=sys.stderr) != passwd:
passwd = valid_passwd(getpass.getpass("Password: ", stream=sys.stderr))
if valid_passwd(getpass.getpass("Repeat: ", stream=sys.stderr)) != passwd:
raise SystemExit("Sorry, passwords do not match")
htpasswd.set_password(options.user, passwd)
@@ -113,13 +109,16 @@ def main() -> None:
cmd_list_parser.set_defaults(cmd=_cmd_list)
cmd_set_parser = subparsers.add_parser("set", help="Create user or change password")
cmd_set_parser.add_argument("user", type=_valid_user)
cmd_set_parser.add_argument("user", type=valid_user)
cmd_set_parser.add_argument("-i", "--read-stdin", action="store_true", help="Read password from stdin")
cmd_set_parser.set_defaults(cmd=_cmd_set)
cmd_delete_parser = subparsers.add_parser("del", help="Delete user")
cmd_delete_parser.add_argument("user", type=_valid_user)
cmd_delete_parser.add_argument("user", type=valid_user)
cmd_delete_parser.set_defaults(cmd=_cmd_delete)
options = parser.parse_args(argv[1:])
options.cmd(config, options)
try:
options.cmd(config, options)
except ValidatorError as err:
raise SystemExit(str(err))

View File

@@ -45,15 +45,15 @@ def main() -> None:
# pylint: disable=protected-access
loop = asyncio.get_event_loop()
Server(
auth_manager=AuthManager(**config.auth._unpack_renamed()),
info_manager=InfoManager(loop=loop, **config.info._unpack_renamed()),
auth_manager=AuthManager(**config.auth._unpack()),
info_manager=InfoManager(loop=loop, **config.info._unpack()),
log_reader=LogReader(loop=loop),
hid=Hid(**config.hid._unpack_renamed()),
atx=Atx(**config.atx._unpack_renamed()),
msd=MassStorageDevice(loop=loop, **config.msd._unpack_renamed()),
streamer=Streamer(loop=loop, **config.streamer._unpack_renamed()),
hid=Hid(**config.hid._unpack()),
atx=Atx(**config.atx._unpack()),
msd=MassStorageDevice(loop=loop, **config.msd._unpack()),
streamer=Streamer(loop=loop, **config.streamer._unpack()),
loop=loop,
).run(**config.server._unpack_renamed())
).run(**config.server._unpack())
get_logger().info("Bye-bye")

View File

@@ -46,7 +46,7 @@ from ... import gpio
# =====
def _get_keymap() -> Dict[str, int]:
return yaml.load(pkgutil.get_data("kvmd", "data/keymap.yaml").decode()) # type: ignore
return yaml.safe_load(pkgutil.get_data("kvmd", "data/keymap.yaml").decode()) # type: ignore
_KEYMAP = _get_keymap()

View File

@@ -21,7 +21,6 @@
import os
import re
import signal
import socket
import asyncio
@@ -36,7 +35,6 @@ from typing import Dict
from typing import Set
from typing import Callable
from typing import Optional
from typing import Any
import aiohttp.web
import setproctitle
@@ -45,6 +43,18 @@ from ...logging import get_logger
from ...aioregion import RegionIsBusyError
from ...validators import ValidatorError
from ...validators.basic import valid_bool
from ...validators.auth import valid_user
from ...validators.auth import valid_passwd
from ...validators.auth import valid_auth_token
from ...validators.kvm import valid_atx_button
from ...validators.kvm import valid_kvm_target
from ...validators.kvm import valid_log_seek
from ...validators.kvm import valid_stream_quality
from ...validators.kvm import valid_stream_fps
from ... import __version__
from .auth import AuthManager
@@ -80,10 +90,6 @@ class HttpError(Exception):
pass
class BadRequestError(HttpError):
pass
class UnauthorizedError(HttpError):
pass
@@ -138,7 +144,7 @@ def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
if auth_required:
token = request.cookies.get(_COOKIE_AUTH_TOKEN, "")
if token:
user = self._auth_manager.check(_valid_token(token))
user = self._auth_manager.check(valid_auth_token(token))
if not user:
raise ForbiddenError("Forbidden")
setattr(request, _ATTR_KVMD_USER, user)
@@ -149,7 +155,7 @@ def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
except RegionIsBusyError as err:
return _json_exception(err, 409)
except (BadRequestError, AtxOperationError, MsdOperationError) as err:
except (ValidatorError, AtxOperationError, MsdOperationError) as err:
return _json_exception(err, 400)
except UnauthorizedError as err:
return _json_exception(err, 401)
@@ -178,51 +184,6 @@ def _system_task(method: Callable) -> Callable:
return wrap
def _valid_user(user: Any) -> str:
if isinstance(user, str):
stripped = user.strip()
if re.match(r"^[a-z_][a-z0-9_-]*$", stripped):
return stripped
raise BadRequestError("Invalid user characters %r" % (user))
def _valid_passwd(passwd: Any) -> str:
if isinstance(passwd, str):
if re.match(r"[\x20-\x7e]*$", passwd):
return passwd
raise BadRequestError("Invalid password characters")
def _valid_token(token: Optional[str]) -> str:
if isinstance(token, str):
token = token.strip().lower()
if re.match(r"^[0-9a-f]{64}$", token):
return token
raise BadRequestError("Invalid auth token characters")
def _valid_bool(name: str, flag: Optional[str]) -> bool:
flag = str(flag).strip().lower()
if flag in ["1", "true", "yes"]:
return True
elif flag in ["0", "false", "no"]:
return False
raise BadRequestError("Invalid param '%s'" % (name))
def _valid_int(name: str, value: Optional[str], min_value: Optional[int]=None, max_value: Optional[int]=None) -> int:
try:
value_int = int(value) # type: ignore
if (
(min_value is not None and value_int < min_value)
or (max_value is not None and value_int > max_value)
):
raise ValueError()
return value_int
except Exception:
raise BadRequestError("Invalid param %r" % (name))
class _Events(Enum):
INFO_STATE = "info_state"
HID_STATE = "hid_state"
@@ -337,8 +298,8 @@ class Server: # pylint: disable=too-many-instance-attributes
async def __auth_login_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
credentials = await request.post()
token = self._auth_manager.login(
user=_valid_user(credentials.get("user", "")),
passwd=_valid_passwd(credentials.get("passwd", "")),
user=valid_user(credentials.get("user", "")),
passwd=valid_passwd(credentials.get("passwd", "")),
)
if token:
return _json({}, set_cookies={_COOKIE_AUTH_TOKEN: token})
@@ -346,7 +307,7 @@ class Server: # pylint: disable=too-many-instance-attributes
@_exposed("POST", "/auth/logout")
async def __auth_logout_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
token = _valid_token(request.cookies.get(_COOKIE_AUTH_TOKEN, ""))
token = valid_auth_token(request.cookies.get(_COOKIE_AUTH_TOKEN, ""))
self._auth_manager.logout(token)
return _json({})
@@ -362,8 +323,8 @@ class Server: # pylint: disable=too-many-instance-attributes
@_exposed("GET", "/log")
async def __log_handler(self, request: aiohttp.web.Request) -> aiohttp.web.StreamResponse:
seek = _valid_int("seek", request.query.get("seek", "0"), 0)
follow = _valid_bool("follow", request.query.get("follow", "false"))
seek = valid_log_seek(request.query.get("seek", "0"))
follow = valid_bool(request.query.get("follow", "false"))
response = aiohttp.web.StreamResponse(status=200, reason="OK", headers={"Content-Type": "text/plain"})
await response.prepare(request)
async for record in self.__log_reader.poll_log(seek, follow):
@@ -460,15 +421,12 @@ class Server: # pylint: disable=too-many-instance-attributes
@_exposed("POST", "/atx/click")
async def __atx_click_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
button = request.query.get("button")
clicker = {
button = valid_atx_button(request.query.get("button"))
await ({
"power": self.__atx.click_power,
"power_long": self.__atx.click_power_long,
"reset": self.__atx.click_reset,
}.get(button)
if not clicker:
raise BadRequestError("Invalid param 'button'")
await clicker()
}[button])()
return _json({"clicked": button})
# ===== MSD
@@ -479,13 +437,11 @@ class Server: # pylint: disable=too-many-instance-attributes
@_exposed("POST", "/msd/connect")
async def __msd_connect_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
to = request.query.get("to")
if to == "kvm":
return _json(await self.__msd.connect_to_kvm())
elif to == "server":
return _json(await self.__msd.connect_to_pc())
else:
raise BadRequestError("Invalid param 'to'")
to = valid_kvm_target(request.query.get("to"))
return _json(await ({
"kvm": self.__msd.connect_to_kvm,
"server": self.__msd.connect_to_pc,
}[to])())
@_exposed("POST", "/msd/write")
async def __msd_write_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
@@ -496,12 +452,12 @@ class Server: # pylint: disable=too-many-instance-attributes
async with self.__msd:
field = await reader.next()
if not field or field.name != "image_name":
raise BadRequestError("Missing 'image_name' field")
raise ValidatorError("Missing 'image_name' field")
image_name = (await field.read()).decode("utf-8")[:256]
field = await reader.next()
if not field or field.name != "image_data":
raise BadRequestError("Missing 'image_data' field")
raise ValidatorError("Missing 'image_data' field")
logger.info("Writing image %r to mass-storage device ...", image_name)
await self.__msd.write_image_info(image_name, False)
@@ -530,8 +486,8 @@ class Server: # pylint: disable=too-many-instance-attributes
@_exposed("POST", "/streamer/set_params")
async def __streamer_set_params_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
for (name, validator) in [
("quality", lambda arg: _valid_int("quality", arg, 1, 100)),
("desired_fps", lambda arg: _valid_int("desired_fps", arg, 0, 30)),
("quality", valid_stream_quality),
("desired_fps", valid_stream_fps),
]:
value = request.query.get(name)
if value:

View File

@@ -63,8 +63,8 @@ class Streamer: # pylint: disable=too-many-instance-attributes
loop: asyncio.AbstractEventLoop,
) -> None:
self.__cap_pin = (gpio.set_output(cap_pin) if cap_pin > 0 else 0)
self.__conv_pin = (gpio.set_output(conv_pin) if conv_pin > 0 else 0)
self.__cap_pin = (gpio.set_output(cap_pin) if cap_pin >= 0 else -1)
self.__conv_pin = (gpio.set_output(conv_pin) if conv_pin >= 0 else -1)
self.__sync_delay = sync_delay
self.__init_delay = init_delay
@@ -179,9 +179,9 @@ class Streamer: # pylint: disable=too-many-instance-attributes
async def __set_hw_enabled(self, enabled: bool) -> None:
# XXX: This sequence is very important to enable converter and cap board
if self.__cap_pin > 0:
if self.__cap_pin >= 0:
gpio.write(self.__cap_pin, enabled)
if self.__conv_pin > 0:
if self.__conv_pin >= 0:
if enabled:
await asyncio.sleep(self.__sync_delay)
gpio.write(self.__conv_pin, enabled)