From a7c3cdc1ea8613bba786d3b0656d38d0e2ca8b6d Mon Sep 17 00:00:00 2001 From: Maxim Devaev Date: Sat, 8 Feb 2025 23:30:52 +0200 Subject: [PATCH] pikvm/pikvm#1204: Expire user session --- kvmd/apps/kvmd/api/auth.py | 2 + kvmd/apps/kvmd/auth.py | 64 ++++++++++++++++----- kvmd/validators/auth.py | 5 ++ testenv/tests/apps/kvmd/test_auth.py | 81 +++++++++++++++++++++++---- testenv/tests/validators/test_auth.py | 15 +++++ web/login/index.html | 19 +++++++ web/login/index.pug | 23 +++++++- web/share/js/login/main.js | 73 ++++++++++++++---------- 8 files changed, 225 insertions(+), 57 deletions(-) diff --git a/kvmd/apps/kvmd/api/auth.py b/kvmd/apps/kvmd/api/auth.py index dee4a85d..da4b0be9 100644 --- a/kvmd/apps/kvmd/api/auth.py +++ b/kvmd/apps/kvmd/api/auth.py @@ -34,6 +34,7 @@ from ....htserver import set_request_auth_info from ....validators.auth import valid_user from ....validators.auth import valid_passwd +from ....validators.auth import valid_expire from ....validators.auth import valid_auth_token from ..auth import AuthManager @@ -91,6 +92,7 @@ class AuthApi: token = await self.__auth_manager.login( user=valid_user(credentials.get("user", "")), passwd=valid_passwd(credentials.get("passwd", "")), + expire=valid_expire(credentials.get("expire", "0")), ) if token: return make_json_response(set_cookies={_COOKIE_AUTH_TOKEN: token}) diff --git a/kvmd/apps/kvmd/auth.py b/kvmd/apps/kvmd/auth.py index bf979836..706e770b 100644 --- a/kvmd/apps/kvmd/auth.py +++ b/kvmd/apps/kvmd/auth.py @@ -20,6 +20,9 @@ # ========================================================================== # +import dataclasses +import time + import secrets import pyotp @@ -34,6 +37,17 @@ from ...htserver import HttpExposed # ===== +@dataclasses.dataclass(frozen=True) +class _Session: + user: str + expire_ts: int + + def __post_init__(self) -> None: + assert self.user.strip() + assert self.user + assert self.expire_ts >= 0 + + class AuthManager: def __init__( self, @@ -72,7 +86,7 @@ class AuthManager: self.__totp_secret_path = totp_secret_path - self.__tokens: dict[str, str] = {} # {token: user} + self.__sessions: dict[str, _Session] = {} # {token: session} def is_auth_enabled(self) -> bool: return self.__enabled @@ -106,20 +120,26 @@ class AuthManager: service = self.__internal_service ok = (await service.authorize(user, passwd)) + pname = service.get_plugin_name() if ok: - get_logger().info("Authorized user %r via auth service %r", user, service.get_plugin_name()) + get_logger().info("Authorized user %r via auth service %r", user, pname) else: - get_logger().error("Got access denied for user %r from auth service %r", user, service.get_plugin_name()) + get_logger().error("Got access denied for user %r from auth service %r", user, pname) return ok - async def login(self, user: str, passwd: str) -> (str | None): + async def login(self, user: str, passwd: str, expire: int) -> (str | None): assert user == user.strip() assert user + assert expire >= 0 assert self.__enabled if (await self.authorize(user, passwd)): token = self.__make_new_token() - self.__tokens[token] = user - get_logger().info("Logged in user %r", user) + session = _Session( + user=user, + expire_ts=(0 if expire <= 0 else (self.__get_now_ts() + expire)), + ) + self.__sessions[token] = session + get_logger().info("Logged in user %r (expire_ts=%d)", session.user, session.expire_ts) return token else: return None @@ -127,24 +147,40 @@ class AuthManager: def __make_new_token(self) -> str: for _ in range(10): token = secrets.token_hex(32) - if token not in self.__tokens: + if token not in self.__sessions: return token raise AssertionError("Can't generate new unique token") + def __get_now_ts(self) -> int: + return int(time.monotonic()) + def logout(self, token: str) -> None: assert self.__enabled - if token in self.__tokens: - user = self.__tokens[token] + if token in self.__sessions: + user = self.__sessions[token].user count = 0 - for (r_token, r_user) in list(self.__tokens.items()): - if r_user == user: + for (key_t, session) in list(self.__sessions.items()): + if session.user == user: count += 1 - del self.__tokens[r_token] - get_logger().info("Logged out user %r (%d)", user, count) + del self.__sessions[key_t] + get_logger().info("Logged out user %r (was=%d)", user, count) def check(self, token: str) -> (str | None): assert self.__enabled - return self.__tokens.get(token) + session = self.__sessions.get(token) + if session is not None: + if session.expire_ts <= 0: + # Infinite session + assert session.user + return session.user + else: + # Limited session + if self.__get_now_ts() < session.expire_ts: + assert session.user + return session.user + else: + del self.__sessions[token] + return None @aiotools.atomic_fg async def cleanup(self) -> None: diff --git a/kvmd/validators/auth.py b/kvmd/validators/auth.py index 33cad456..d07a3d63 100644 --- a/kvmd/validators/auth.py +++ b/kvmd/validators/auth.py @@ -23,6 +23,7 @@ from typing import Any from .basic import valid_string_list +from .basic import valid_number from . import check_re_match @@ -40,5 +41,9 @@ def valid_passwd(arg: Any) -> str: return check_re_match(arg, "passwd characters", r"^[\x20-\x7e]*\Z$", strip=False, hide=True) +def valid_expire(arg: Any) -> int: + return int(valid_number(arg, min=0, name="expiration time")) + + def valid_auth_token(arg: Any) -> str: return check_re_match(arg, "auth token", r"^[0-9a-f]{64}$", hide=True) diff --git a/testenv/tests/apps/kvmd/test_auth.py b/testenv/tests/apps/kvmd/test_auth.py index 4fa1c8ae..d6183a39 100644 --- a/testenv/tests/apps/kvmd/test_auth.py +++ b/testenv/tests/apps/kvmd/test_auth.py @@ -21,6 +21,7 @@ import os +import asyncio import contextlib from typing import AsyncGenerator @@ -79,6 +80,64 @@ async def _get_configured_manager( # ===== +@pytest.mark.asyncio +async def test_ok__expire(tmpdir) -> None: # type: ignore + path = os.path.abspath(str(tmpdir.join("htpasswd"))) + + htpasswd = passlib.apache.HtpasswdFile(path, new=True) + htpasswd.set_password("admin", "pass") + htpasswd.save() + + async with _get_configured_manager([], path) as manager: + assert manager.is_auth_enabled() + assert manager.is_auth_required(_E_AUTH) + assert manager.is_auth_required(_E_UNAUTH) + assert not manager.is_auth_required(_E_FREE) + + assert manager.check("xxx") is None + manager.logout("xxx") + + assert (await manager.login("user", "foo", 3)) is None + assert (await manager.login("admin", "foo", 3)) is None + assert (await manager.login("user", "pass", 3)) is None + + token1 = await manager.login("admin", "pass", 3) + assert isinstance(token1, str) + assert len(token1) == 64 + + token2 = await manager.login("admin", "pass", 3) + assert isinstance(token2, str) + assert len(token2) == 64 + assert token1 != token2 + + assert manager.check(token1) == "admin" + assert manager.check(token2) == "admin" + assert manager.check("foobar") is None + + manager.logout(token1) + + assert manager.check(token1) is None + assert manager.check(token2) is None + assert manager.check("foobar") is None + + token3 = await manager.login("admin", "pass", 3) + assert isinstance(token3, str) + assert len(token3) == 64 + assert token1 != token3 + assert token2 != token3 + + await asyncio.sleep(4) + + assert manager.check(token1) is None + assert manager.check(token2) is None + assert manager.check(token3) is None + + # Check for removed token + assert manager.check(token1) is None + assert manager.check(token2) is None + assert manager.check(token3) is None + + @pytest.mark.asyncio async def test_ok__internal(tmpdir) -> None: # type: ignore path = os.path.abspath(str(tmpdir.join("htpasswd"))) @@ -96,15 +155,15 @@ async def test_ok__internal(tmpdir) -> None: # type: ignore assert manager.check("xxx") is None manager.logout("xxx") - assert (await manager.login("user", "foo")) is None - assert (await manager.login("admin", "foo")) is None - assert (await manager.login("user", "pass")) is None + assert (await manager.login("user", "foo", 0)) is None + assert (await manager.login("admin", "foo", 0)) is None + assert (await manager.login("user", "pass", 0)) is None - token1 = await manager.login("admin", "pass") + token1 = await manager.login("admin", "pass", 0) assert isinstance(token1, str) assert len(token1) == 64 - token2 = await manager.login("admin", "pass") + token2 = await manager.login("admin", "pass", 0) assert isinstance(token2, str) assert len(token2) == 64 assert token1 != token2 @@ -119,7 +178,7 @@ async def test_ok__internal(tmpdir) -> None: # type: ignore assert manager.check(token2) is None assert manager.check("foobar") is None - token3 = await manager.login("admin", "pass") + token3 = await manager.login("admin", "pass", 0) assert isinstance(token3, str) assert len(token3) == 64 assert token1 != token3 @@ -147,17 +206,17 @@ async def test_ok__external(tmpdir) -> None: # type: ignore assert manager.is_auth_required(_E_UNAUTH) assert not manager.is_auth_required(_E_FREE) - assert (await manager.login("local", "foobar")) is None - assert (await manager.login("admin", "pass2")) is None + assert (await manager.login("local", "foobar", 0)) is None + assert (await manager.login("admin", "pass2", 0)) is None - token = await manager.login("admin", "pass1") + token = await manager.login("admin", "pass1", 0) assert token is not None assert manager.check(token) == "admin" manager.logout(token) assert manager.check(token) is None - token = await manager.login("user", "foobar") + token = await manager.login("user", "foobar", 0) assert token is not None assert manager.check(token) == "user" @@ -212,7 +271,7 @@ async def test_ok__disabled() -> None: await manager.authorize("admin", "admin") with pytest.raises(AssertionError): - await manager.login("admin", "admin") + await manager.login("admin", "admin", 0) with pytest.raises(AssertionError): manager.logout("xxx") diff --git a/testenv/tests/validators/test_auth.py b/testenv/tests/validators/test_auth.py index d84e029b..0f57889c 100644 --- a/testenv/tests/validators/test_auth.py +++ b/testenv/tests/validators/test_auth.py @@ -28,6 +28,7 @@ from kvmd.validators import ValidatorError from kvmd.validators.auth import valid_user from kvmd.validators.auth import valid_users_list from kvmd.validators.auth import valid_passwd +from kvmd.validators.auth import valid_expire from kvmd.validators.auth import valid_auth_token @@ -109,6 +110,20 @@ def test_fail__valid_passwd(arg: Any) -> None: print(valid_passwd(arg)) +# ===== +@pytest.mark.parametrize("arg", ["0 ", 0, 1, 13]) +def test_ok__valid_expire(arg: Any) -> None: + value = valid_expire(arg) + assert type(value) is int # pylint: disable=unidiomatic-typecheck + assert value == int(str(arg).strip()) + + +@pytest.mark.parametrize("arg", ["test", "", None, -1, -13, 1.1]) +def test_fail__valid_expire(arg: Any) -> None: + with pytest.raises(ValidatorError): + print(valid_expire(arg)) + + # ===== @pytest.mark.parametrize("arg", [ ("0" * 64) + " ", diff --git a/web/login/index.html b/web/login/index.html index a8cbedd9..ca9a3bf0 100644 --- a/web/login/index.html +++ b/web/login/index.html @@ -37,6 +37,7 @@ +