refactoring

This commit is contained in:
Maxim Devaev 2025-05-03 21:37:09 +03:00
parent 334b9f7d7b
commit 59eff99dcc
2 changed files with 72 additions and 11 deletions

View File

@ -54,18 +54,18 @@ async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, re
user = valid_user(user) user = valid_user(user)
passwd = req.headers.get("X-KVMD-Passwd", "") passwd = req.headers.get("X-KVMD-Passwd", "")
set_request_auth_info(req, f"{user} (xhdr)") set_request_auth_info(req, f"{user} (xhdr)")
if not (await auth_manager.authorize(user, valid_passwd(passwd))): if (await auth_manager.authorize(user, valid_passwd(passwd))):
raise ForbiddenError() return
return raise ForbiddenError()
token = req.cookies.get(_COOKIE_AUTH_TOKEN, "") token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
if token: if token:
user = auth_manager.check(valid_auth_token(token)) # type: ignore user = auth_manager.check(valid_auth_token(token)) # type: ignore
if not user: if user:
set_request_auth_info(req, "- (token)") set_request_auth_info(req, f"{user} (token)")
raise ForbiddenError() return
set_request_auth_info(req, f"{user} (token)") set_request_auth_info(req, "- (token)")
return raise ForbiddenError()
basic_auth = req.headers.get("Authorization", "") basic_auth = req.headers.get("Authorization", "")
if basic_auth and basic_auth[:6].lower() == "basic ": if basic_auth and basic_auth[:6].lower() == "basic ":
@ -75,9 +75,9 @@ async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, re
raise UnauthorizedError() raise UnauthorizedError()
user = valid_user(user) user = valid_user(user)
set_request_auth_info(req, f"{user} (basic)") set_request_auth_info(req, f"{user} (basic)")
if not (await auth_manager.authorize(user, valid_passwd(passwd))): if (await auth_manager.authorize(user, valid_passwd(passwd))):
raise ForbiddenError() return
return raise ForbiddenError()
if exposed.allow_usc: if exposed.allow_usc:
creds = get_request_unix_credentials(req) creds = get_request_unix_credentials(req)
@ -86,6 +86,7 @@ async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, re
if user: if user:
set_request_auth_info(req, f"{user}[{creds.uid}] (unix)") set_request_auth_info(req, f"{user}[{creds.uid}] (unix)")
return return
raise UnauthorizedError()
raise UnauthorizedError() raise UnauthorizedError()

View File

@ -22,15 +22,24 @@
import os import os
import asyncio import asyncio
import base64
import contextlib import contextlib
from typing import AsyncGenerator from typing import AsyncGenerator
from aiohttp.test_utils import make_mocked_request
import pytest import pytest
from kvmd.validators import ValidatorError
from kvmd.yamlconf import make_config from kvmd.yamlconf import make_config
from kvmd.apps.kvmd.auth import AuthManager from kvmd.apps.kvmd.auth import AuthManager
from kvmd.apps.kvmd.api.auth import check_request_auth
from kvmd.htserver import UnauthorizedError
from kvmd.htserver import ForbiddenError
from kvmd.plugins.auth import get_auth_service_class from kvmd.plugins.auth import get_auth_service_class
@ -83,6 +92,57 @@ async def _get_configured_manager(
# ===== # =====
@pytest.mark.asyncio
async def test_ok__request(tmpdir) -> None: # type: ignore
path = os.path.abspath(str(tmpdir.join("htpasswd")))
htpasswd = KvmdHtpasswdFile(path, new=True)
htpasswd.set_password("admin", "pass")
htpasswd.save()
async with _get_configured_manager([], path) as manager:
async def check(exposed: HttpExposed, **kwargs) -> None: # type: ignore
await check_request_auth(manager, exposed, make_mocked_request(exposed.method, exposed.path, **kwargs))
await check(_E_FREE)
with pytest.raises(UnauthorizedError):
await check(_E_AUTH)
# ===
with pytest.raises(ForbiddenError):
await check(_E_AUTH, headers={"X-KVMD-User": "admin", "X-KVMD-Passwd": "foo"})
with pytest.raises(ForbiddenError):
await check(_E_AUTH, headers={"X-KVMD-User": "adminx", "X-KVMD-Passwd": "pass"})
await check(_E_AUTH, headers={"X-KVMD-User": "admin", "X-KVMD-Passwd": "pass"})
# ===
with pytest.raises(UnauthorizedError):
await check(_E_AUTH, headers={"Cookie": "auth_token="})
with pytest.raises(ValidatorError):
await check(_E_AUTH, headers={"Cookie": "auth_token=0"})
with pytest.raises(ForbiddenError):
await check(_E_AUTH, headers={"Cookie": f"auth_token={'0' * 64}"})
token = await manager.login("admin", "pass", 0)
assert token
await check(_E_AUTH, headers={"Cookie": f"auth_token={token}"})
manager.logout(token)
with pytest.raises(ForbiddenError):
await check(_E_AUTH, headers={"Cookie": f"auth_token={token}"})
# ===
with pytest.raises(ForbiddenError):
await check(_E_AUTH, headers={"Authorization": "basic " + base64.b64encode(b"admin:foo").decode()})
with pytest.raises(ForbiddenError):
await check(_E_AUTH, headers={"Authorization": "basic " + base64.b64encode(b"adminx:pass").decode()})
await check(_E_AUTH, headers={"Authorization": "basic " + base64.b64encode(b"admin:pass").decode()})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ok__expire(tmpdir) -> None: # type: ignore async def test_ok__expire(tmpdir) -> None: # type: ignore
path = os.path.abspath(str(tmpdir.join("htpasswd"))) path = os.path.abspath(str(tmpdir.join("htpasswd")))