diff --git a/kvmd/apps/kvmd/api/auth.py b/kvmd/apps/kvmd/api/auth.py
index dee4a85d..da4b0be9 100644
--- a/kvmd/apps/kvmd/api/auth.py
+++ b/kvmd/apps/kvmd/api/auth.py
@@ -34,6 +34,7 @@ from ....htserver import set_request_auth_info
from ....validators.auth import valid_user
from ....validators.auth import valid_passwd
+from ....validators.auth import valid_expire
from ....validators.auth import valid_auth_token
from ..auth import AuthManager
@@ -91,6 +92,7 @@ class AuthApi:
token = await self.__auth_manager.login(
user=valid_user(credentials.get("user", "")),
passwd=valid_passwd(credentials.get("passwd", "")),
+ expire=valid_expire(credentials.get("expire", "0")),
)
if token:
return make_json_response(set_cookies={_COOKIE_AUTH_TOKEN: token})
diff --git a/kvmd/apps/kvmd/auth.py b/kvmd/apps/kvmd/auth.py
index bf979836..706e770b 100644
--- a/kvmd/apps/kvmd/auth.py
+++ b/kvmd/apps/kvmd/auth.py
@@ -20,6 +20,9 @@
# ========================================================================== #
+import dataclasses
+import time
+
import secrets
import pyotp
@@ -34,6 +37,17 @@ from ...htserver import HttpExposed
# =====
+@dataclasses.dataclass(frozen=True)
+class _Session:
+ user: str
+ expire_ts: int
+
+ def __post_init__(self) -> None:
+ assert self.user.strip()
+ assert self.user
+ assert self.expire_ts >= 0
+
+
class AuthManager:
def __init__(
self,
@@ -72,7 +86,7 @@ class AuthManager:
self.__totp_secret_path = totp_secret_path
- self.__tokens: dict[str, str] = {} # {token: user}
+ self.__sessions: dict[str, _Session] = {} # {token: session}
def is_auth_enabled(self) -> bool:
return self.__enabled
@@ -106,20 +120,26 @@ class AuthManager:
service = self.__internal_service
ok = (await service.authorize(user, passwd))
+ pname = service.get_plugin_name()
if ok:
- get_logger().info("Authorized user %r via auth service %r", user, service.get_plugin_name())
+ get_logger().info("Authorized user %r via auth service %r", user, pname)
else:
- get_logger().error("Got access denied for user %r from auth service %r", user, service.get_plugin_name())
+ get_logger().error("Got access denied for user %r from auth service %r", user, pname)
return ok
- async def login(self, user: str, passwd: str) -> (str | None):
+ async def login(self, user: str, passwd: str, expire: int) -> (str | None):
assert user == user.strip()
assert user
+ assert expire >= 0
assert self.__enabled
if (await self.authorize(user, passwd)):
token = self.__make_new_token()
- self.__tokens[token] = user
- get_logger().info("Logged in user %r", user)
+ session = _Session(
+ user=user,
+ expire_ts=(0 if expire <= 0 else (self.__get_now_ts() + expire)),
+ )
+ self.__sessions[token] = session
+ get_logger().info("Logged in user %r (expire_ts=%d)", session.user, session.expire_ts)
return token
else:
return None
@@ -127,24 +147,40 @@ class AuthManager:
def __make_new_token(self) -> str:
for _ in range(10):
token = secrets.token_hex(32)
- if token not in self.__tokens:
+ if token not in self.__sessions:
return token
raise AssertionError("Can't generate new unique token")
+ def __get_now_ts(self) -> int:
+ return int(time.monotonic())
+
def logout(self, token: str) -> None:
assert self.__enabled
- if token in self.__tokens:
- user = self.__tokens[token]
+ if token in self.__sessions:
+ user = self.__sessions[token].user
count = 0
- for (r_token, r_user) in list(self.__tokens.items()):
- if r_user == user:
+ for (key_t, session) in list(self.__sessions.items()):
+ if session.user == user:
count += 1
- del self.__tokens[r_token]
- get_logger().info("Logged out user %r (%d)", user, count)
+ del self.__sessions[key_t]
+ get_logger().info("Logged out user %r (was=%d)", user, count)
def check(self, token: str) -> (str | None):
assert self.__enabled
- return self.__tokens.get(token)
+ session = self.__sessions.get(token)
+ if session is not None:
+ if session.expire_ts <= 0:
+ # Infinite session
+ assert session.user
+ return session.user
+ else:
+ # Limited session
+ if self.__get_now_ts() < session.expire_ts:
+ assert session.user
+ return session.user
+ else:
+ del self.__sessions[token]
+ return None
@aiotools.atomic_fg
async def cleanup(self) -> None:
diff --git a/kvmd/validators/auth.py b/kvmd/validators/auth.py
index 33cad456..d07a3d63 100644
--- a/kvmd/validators/auth.py
+++ b/kvmd/validators/auth.py
@@ -23,6 +23,7 @@
from typing import Any
from .basic import valid_string_list
+from .basic import valid_number
from . import check_re_match
@@ -40,5 +41,9 @@ def valid_passwd(arg: Any) -> str:
return check_re_match(arg, "passwd characters", r"^[\x20-\x7e]*\Z$", strip=False, hide=True)
+def valid_expire(arg: Any) -> int:
+ return int(valid_number(arg, min=0, name="expiration time"))
+
+
def valid_auth_token(arg: Any) -> str:
return check_re_match(arg, "auth token", r"^[0-9a-f]{64}$", hide=True)
diff --git a/testenv/tests/apps/kvmd/test_auth.py b/testenv/tests/apps/kvmd/test_auth.py
index 4fa1c8ae..d6183a39 100644
--- a/testenv/tests/apps/kvmd/test_auth.py
+++ b/testenv/tests/apps/kvmd/test_auth.py
@@ -21,6 +21,7 @@
import os
+import asyncio
import contextlib
from typing import AsyncGenerator
@@ -79,6 +80,64 @@ async def _get_configured_manager(
# =====
+@pytest.mark.asyncio
+async def test_ok__expire(tmpdir) -> None: # type: ignore
+ path = os.path.abspath(str(tmpdir.join("htpasswd")))
+
+ htpasswd = passlib.apache.HtpasswdFile(path, new=True)
+ htpasswd.set_password("admin", "pass")
+ htpasswd.save()
+
+ async with _get_configured_manager([], path) as manager:
+ assert manager.is_auth_enabled()
+ assert manager.is_auth_required(_E_AUTH)
+ assert manager.is_auth_required(_E_UNAUTH)
+ assert not manager.is_auth_required(_E_FREE)
+
+ assert manager.check("xxx") is None
+ manager.logout("xxx")
+
+ assert (await manager.login("user", "foo", 3)) is None
+ assert (await manager.login("admin", "foo", 3)) is None
+ assert (await manager.login("user", "pass", 3)) is None
+
+ token1 = await manager.login("admin", "pass", 3)
+ assert isinstance(token1, str)
+ assert len(token1) == 64
+
+ token2 = await manager.login("admin", "pass", 3)
+ assert isinstance(token2, str)
+ assert len(token2) == 64
+ assert token1 != token2
+
+ assert manager.check(token1) == "admin"
+ assert manager.check(token2) == "admin"
+ assert manager.check("foobar") is None
+
+ manager.logout(token1)
+
+ assert manager.check(token1) is None
+ assert manager.check(token2) is None
+ assert manager.check("foobar") is None
+
+ token3 = await manager.login("admin", "pass", 3)
+ assert isinstance(token3, str)
+ assert len(token3) == 64
+ assert token1 != token3
+ assert token2 != token3
+
+ await asyncio.sleep(4)
+
+ assert manager.check(token1) is None
+ assert manager.check(token2) is None
+ assert manager.check(token3) is None
+
+ # Check for removed token
+ assert manager.check(token1) is None
+ assert manager.check(token2) is None
+ assert manager.check(token3) is None
+
+
@pytest.mark.asyncio
async def test_ok__internal(tmpdir) -> None: # type: ignore
path = os.path.abspath(str(tmpdir.join("htpasswd")))
@@ -96,15 +155,15 @@ async def test_ok__internal(tmpdir) -> None: # type: ignore
assert manager.check("xxx") is None
manager.logout("xxx")
- assert (await manager.login("user", "foo")) is None
- assert (await manager.login("admin", "foo")) is None
- assert (await manager.login("user", "pass")) is None
+ assert (await manager.login("user", "foo", 0)) is None
+ assert (await manager.login("admin", "foo", 0)) is None
+ assert (await manager.login("user", "pass", 0)) is None
- token1 = await manager.login("admin", "pass")
+ token1 = await manager.login("admin", "pass", 0)
assert isinstance(token1, str)
assert len(token1) == 64
- token2 = await manager.login("admin", "pass")
+ token2 = await manager.login("admin", "pass", 0)
assert isinstance(token2, str)
assert len(token2) == 64
assert token1 != token2
@@ -119,7 +178,7 @@ async def test_ok__internal(tmpdir) -> None: # type: ignore
assert manager.check(token2) is None
assert manager.check("foobar") is None
- token3 = await manager.login("admin", "pass")
+ token3 = await manager.login("admin", "pass", 0)
assert isinstance(token3, str)
assert len(token3) == 64
assert token1 != token3
@@ -147,17 +206,17 @@ async def test_ok__external(tmpdir) -> None: # type: ignore
assert manager.is_auth_required(_E_UNAUTH)
assert not manager.is_auth_required(_E_FREE)
- assert (await manager.login("local", "foobar")) is None
- assert (await manager.login("admin", "pass2")) is None
+ assert (await manager.login("local", "foobar", 0)) is None
+ assert (await manager.login("admin", "pass2", 0)) is None
- token = await manager.login("admin", "pass1")
+ token = await manager.login("admin", "pass1", 0)
assert token is not None
assert manager.check(token) == "admin"
manager.logout(token)
assert manager.check(token) is None
- token = await manager.login("user", "foobar")
+ token = await manager.login("user", "foobar", 0)
assert token is not None
assert manager.check(token) == "user"
@@ -212,7 +271,7 @@ async def test_ok__disabled() -> None:
await manager.authorize("admin", "admin")
with pytest.raises(AssertionError):
- await manager.login("admin", "admin")
+ await manager.login("admin", "admin", 0)
with pytest.raises(AssertionError):
manager.logout("xxx")
diff --git a/testenv/tests/validators/test_auth.py b/testenv/tests/validators/test_auth.py
index d84e029b..0f57889c 100644
--- a/testenv/tests/validators/test_auth.py
+++ b/testenv/tests/validators/test_auth.py
@@ -28,6 +28,7 @@ from kvmd.validators import ValidatorError
from kvmd.validators.auth import valid_user
from kvmd.validators.auth import valid_users_list
from kvmd.validators.auth import valid_passwd
+from kvmd.validators.auth import valid_expire
from kvmd.validators.auth import valid_auth_token
@@ -109,6 +110,20 @@ def test_fail__valid_passwd(arg: Any) -> None:
print(valid_passwd(arg))
+# =====
+@pytest.mark.parametrize("arg", ["0 ", 0, 1, 13])
+def test_ok__valid_expire(arg: Any) -> None:
+ value = valid_expire(arg)
+ assert type(value) is int # pylint: disable=unidiomatic-typecheck
+ assert value == int(str(arg).strip())
+
+
+@pytest.mark.parametrize("arg", ["test", "", None, -1, -13, 1.1])
+def test_fail__valid_expire(arg: Any) -> None:
+ with pytest.raises(ValidatorError):
+ print(valid_expire(arg))
+
+
# =====
@pytest.mark.parametrize("arg", [
("0" * 64) + " ",
diff --git a/web/login/index.html b/web/login/index.html
index a8cbedd9..ca9a3bf0 100644
--- a/web/login/index.html
+++ b/web/login/index.html
@@ -37,6 +37,7 @@
+