modular auth

This commit is contained in:
Devaev Maxim 2019-04-01 10:30:30 +03:00
parent 70e526b773
commit 73e04b71ed
4 changed files with 48 additions and 13 deletions

View File

@ -129,6 +129,15 @@ def _as_string_list(values: Union[str, Sequence]) -> List[str]:
return list(map(str, values)) return list(map(str, values))
def _as_auth_type(auth_type: str) -> str:
if not isinstance(auth_type, str):
raise ValueError("Invalid auth type")
auth_type = str(auth_type).strip()
if auth_type not in ["basic"]:
raise ValueError("Invalid auth type")
return auth_type
def _get_config_scheme() -> Dict: def _get_config_scheme() -> Dict:
return { return {
"kvmd": { "kvmd": {
@ -144,8 +153,11 @@ def _get_config_scheme() -> Dict:
}, },
"auth": { "auth": {
"type": Option("basic", type=_as_auth_type, rename="auth_type"),
"basic": {
"htpasswd": Option("/etc/kvmd/htpasswd", type=_as_path, rename="htpasswd_path"), "htpasswd": Option("/etc/kvmd/htpasswd", type=_as_path, rename="htpasswd_path"),
}, },
},
"info": { "info": {
"meta": Option("/etc/kvmd/meta.yaml", type=_as_path, rename="meta_path"), "meta": Option("/etc/kvmd/meta.yaml", type=_as_path, rename="meta_path"),

View File

@ -38,9 +38,15 @@ from .. import init
# ===== # =====
def _get_htpasswd_path(config: Section) -> str:
if config.kvmd.auth.auth_type != "basic":
print("Warning: KVMD does not use basic auth", file=sys.stderr)
return config.kvmd.auth.basic.htpasswd
@contextlib.contextmanager @contextlib.contextmanager
def _get_htpasswd_for_write(config: Section) -> Generator[passlib.apache.HtpasswdFile, None, None]: def _get_htpasswd_for_write(config: Section) -> Generator[passlib.apache.HtpasswdFile, None, None]:
path = config.kvmd.auth.htpasswd path = _get_htpasswd_path(config)
(tmp_fd, tmp_path) = tempfile.mkstemp( (tmp_fd, tmp_path) = tempfile.mkstemp(
prefix=".%s." % (os.path.basename(path)), prefix=".%s." % (os.path.basename(path)),
dir=os.path.dirname(path), dir=os.path.dirname(path),
@ -72,7 +78,7 @@ def _valid_user(user: str) -> str:
# ==== # ====
def _cmd_list(config: Section, _: argparse.Namespace) -> None: def _cmd_list(config: Section, _: argparse.Namespace) -> None:
for user in passlib.apache.HtpasswdFile(config.kvmd.auth.htpasswd).users(): for user in passlib.apache.HtpasswdFile(_get_htpasswd_path(config)).users():
print(user) print(user)
@ -97,7 +103,7 @@ def main() -> None:
(parent_parser, argv, config) = init(add_help=False) (parent_parser, argv, config) = init(add_help=False)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="kvmd-htpasswd", prog="kvmd-htpasswd",
description="Manage KVMD users", description="Manage KVMD users (basic auth only)",
parents=[parent_parser], parents=[parent_parser],
) )
parser.set_defaults(cmd=(lambda *_: parser.print_help())) parser.set_defaults(cmd=(lambda *_: parser.print_help()))

View File

@ -32,13 +32,14 @@ from ...logging import get_logger
# ===== # =====
class AuthManager: class AuthManager:
def __init__(self, htpasswd_path: str) -> None: def __init__(self, auth_type: str, basic: Dict) -> None:
self.__htpasswd_path = htpasswd_path self.__login = {
"basic": lambda: _BasicLogin(**basic),
}[auth_type]().login
self.__tokens: Dict[str, str] = {} # {token: user} self.__tokens: Dict[str, str] = {} # {token: user}
def login(self, user: str, passwd: str) -> Optional[str]: def login(self, user: str, passwd: str) -> Optional[str]:
htpasswd = passlib.apache.HtpasswdFile(self.__htpasswd_path) if self.__login(user, passwd):
if htpasswd.check_password(user, passwd):
for (token, token_user) in self.__tokens.items(): for (token, token_user) in self.__tokens.items():
if user == token_user: if user == token_user:
return token return token
@ -57,3 +58,13 @@ class AuthManager:
def check(self, token: str) -> Optional[str]: def check(self, token: str) -> Optional[str]:
return self.__tokens.get(token) return self.__tokens.get(token)
class _BasicLogin:
def __init__(self, htpasswd_path: str) -> None:
get_logger().info("Using basic auth %r", htpasswd_path)
self.__htpasswd_path = htpasswd_path
def login(self, user: str, passwd: str) -> bool:
htpasswd = passlib.apache.HtpasswdFile(self.__htpasswd_path)
return htpasswd.check_password(user, passwd)

View File

@ -66,12 +66,15 @@ class Section(dict):
dict.__init__(self) dict.__init__(self)
self.__meta: Dict[str, Dict[str, Any]] = {} self.__meta: Dict[str, Dict[str, Any]] = {}
def _unpack_renamed(self) -> Dict[str, Any]: def _unpack_renamed(self, _section: Optional["Section"]=None) -> Dict[str, Any]:
if _section is None:
_section = self
unpacked: Dict[str, Any] = {} unpacked: Dict[str, Any] = {}
for (key, value) in self.items(): for (key, value) in _section.items():
assert not isinstance(value, Section), (key, value) if isinstance(value, Section):
key = (self.__meta[key]["rename"] or key) unpacked[key] = value._unpack_renamed() # pylint: disable=protected-access
unpacked[key] = value else: # Option
unpacked[_section._get_rename(key)] = value # pylint: disable=protected-access
return unpacked return unpacked
def _set_meta(self, key: str, default: Any, help: str, rename: str) -> None: # pylint: disable=redefined-builtin def _set_meta(self, key: str, default: Any, help: str, rename: str) -> None: # pylint: disable=redefined-builtin
@ -87,6 +90,9 @@ class Section(dict):
def _get_help(self, key: str) -> str: def _get_help(self, key: str) -> str:
return self.__meta[key]["help"] return self.__meta[key]["help"]
def _get_rename(self, key: str) -> str:
return (self.__meta[key]["rename"] or key)
def __getattribute__(self, key: str) -> Any: def __getattribute__(self, key: str) -> Any:
if key in self: if key in self:
return self[key] return self[key]