pikvm/pikvm#1204: Expire user session

This commit is contained in:
Maxim Devaev
2025-02-08 23:30:52 +02:00
parent abbd65a9a0
commit a7c3cdc1ea
8 changed files with 225 additions and 57 deletions

View File

@@ -34,6 +34,7 @@ from ....htserver import set_request_auth_info
from ....validators.auth import valid_user
from ....validators.auth import valid_passwd
from ....validators.auth import valid_expire
from ....validators.auth import valid_auth_token
from ..auth import AuthManager
@@ -91,6 +92,7 @@ class AuthApi:
token = await self.__auth_manager.login(
user=valid_user(credentials.get("user", "")),
passwd=valid_passwd(credentials.get("passwd", "")),
expire=valid_expire(credentials.get("expire", "0")),
)
if token:
return make_json_response(set_cookies={_COOKIE_AUTH_TOKEN: token})

View File

@@ -20,6 +20,9 @@
# ========================================================================== #
import dataclasses
import time
import secrets
import pyotp
@@ -34,6 +37,17 @@ from ...htserver import HttpExposed
# =====
@dataclasses.dataclass(frozen=True)
class _Session:
user: str
expire_ts: int
def __post_init__(self) -> None:
assert self.user.strip()
assert self.user
assert self.expire_ts >= 0
class AuthManager:
def __init__(
self,
@@ -72,7 +86,7 @@ class AuthManager:
self.__totp_secret_path = totp_secret_path
self.__tokens: dict[str, str] = {} # {token: user}
self.__sessions: dict[str, _Session] = {} # {token: session}
def is_auth_enabled(self) -> bool:
return self.__enabled
@@ -106,20 +120,26 @@ class AuthManager:
service = self.__internal_service
ok = (await service.authorize(user, passwd))
pname = service.get_plugin_name()
if ok:
get_logger().info("Authorized user %r via auth service %r", user, service.get_plugin_name())
get_logger().info("Authorized user %r via auth service %r", user, pname)
else:
get_logger().error("Got access denied for user %r from auth service %r", user, service.get_plugin_name())
get_logger().error("Got access denied for user %r from auth service %r", user, pname)
return ok
async def login(self, user: str, passwd: str) -> (str | None):
async def login(self, user: str, passwd: str, expire: int) -> (str | None):
assert user == user.strip()
assert user
assert expire >= 0
assert self.__enabled
if (await self.authorize(user, passwd)):
token = self.__make_new_token()
self.__tokens[token] = user
get_logger().info("Logged in user %r", user)
session = _Session(
user=user,
expire_ts=(0 if expire <= 0 else (self.__get_now_ts() + expire)),
)
self.__sessions[token] = session
get_logger().info("Logged in user %r (expire_ts=%d)", session.user, session.expire_ts)
return token
else:
return None
@@ -127,24 +147,40 @@ class AuthManager:
def __make_new_token(self) -> str:
for _ in range(10):
token = secrets.token_hex(32)
if token not in self.__tokens:
if token not in self.__sessions:
return token
raise AssertionError("Can't generate new unique token")
def __get_now_ts(self) -> int:
return int(time.monotonic())
def logout(self, token: str) -> None:
assert self.__enabled
if token in self.__tokens:
user = self.__tokens[token]
if token in self.__sessions:
user = self.__sessions[token].user
count = 0
for (r_token, r_user) in list(self.__tokens.items()):
if r_user == user:
for (key_t, session) in list(self.__sessions.items()):
if session.user == user:
count += 1
del self.__tokens[r_token]
get_logger().info("Logged out user %r (%d)", user, count)
del self.__sessions[key_t]
get_logger().info("Logged out user %r (was=%d)", user, count)
def check(self, token: str) -> (str | None):
assert self.__enabled
return self.__tokens.get(token)
session = self.__sessions.get(token)
if session is not None:
if session.expire_ts <= 0:
# Infinite session
assert session.user
return session.user
else:
# Limited session
if self.__get_now_ts() < session.expire_ts:
assert session.user
return session.user
else:
del self.__sessions[token]
return None
@aiotools.atomic_fg
async def cleanup(self) -> None:

View File

@@ -23,6 +23,7 @@
from typing import Any
from .basic import valid_string_list
from .basic import valid_number
from . import check_re_match
@@ -40,5 +41,9 @@ def valid_passwd(arg: Any) -> str:
return check_re_match(arg, "passwd characters", r"^[\x20-\x7e]*\Z$", strip=False, hide=True)
def valid_expire(arg: Any) -> int:
return int(valid_number(arg, min=0, name="expiration time"))
def valid_auth_token(arg: Any) -> str:
return check_re_match(arg, "auth token", r"^[0-9a-f]{64}$", hide=True)