mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2025-12-12 01:00:29 +08:00
refactoring
This commit is contained in:
parent
334b9f7d7b
commit
59eff99dcc
@ -54,18 +54,18 @@ async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, re
|
|||||||
user = valid_user(user)
|
user = valid_user(user)
|
||||||
passwd = req.headers.get("X-KVMD-Passwd", "")
|
passwd = req.headers.get("X-KVMD-Passwd", "")
|
||||||
set_request_auth_info(req, f"{user} (xhdr)")
|
set_request_auth_info(req, f"{user} (xhdr)")
|
||||||
if not (await auth_manager.authorize(user, valid_passwd(passwd))):
|
if (await auth_manager.authorize(user, valid_passwd(passwd))):
|
||||||
raise ForbiddenError()
|
return
|
||||||
return
|
raise ForbiddenError()
|
||||||
|
|
||||||
token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
|
token = req.cookies.get(_COOKIE_AUTH_TOKEN, "")
|
||||||
if token:
|
if token:
|
||||||
user = auth_manager.check(valid_auth_token(token)) # type: ignore
|
user = auth_manager.check(valid_auth_token(token)) # type: ignore
|
||||||
if not user:
|
if user:
|
||||||
set_request_auth_info(req, "- (token)")
|
set_request_auth_info(req, f"{user} (token)")
|
||||||
raise ForbiddenError()
|
return
|
||||||
set_request_auth_info(req, f"{user} (token)")
|
set_request_auth_info(req, "- (token)")
|
||||||
return
|
raise ForbiddenError()
|
||||||
|
|
||||||
basic_auth = req.headers.get("Authorization", "")
|
basic_auth = req.headers.get("Authorization", "")
|
||||||
if basic_auth and basic_auth[:6].lower() == "basic ":
|
if basic_auth and basic_auth[:6].lower() == "basic ":
|
||||||
@ -75,9 +75,9 @@ async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, re
|
|||||||
raise UnauthorizedError()
|
raise UnauthorizedError()
|
||||||
user = valid_user(user)
|
user = valid_user(user)
|
||||||
set_request_auth_info(req, f"{user} (basic)")
|
set_request_auth_info(req, f"{user} (basic)")
|
||||||
if not (await auth_manager.authorize(user, valid_passwd(passwd))):
|
if (await auth_manager.authorize(user, valid_passwd(passwd))):
|
||||||
raise ForbiddenError()
|
return
|
||||||
return
|
raise ForbiddenError()
|
||||||
|
|
||||||
if exposed.allow_usc:
|
if exposed.allow_usc:
|
||||||
creds = get_request_unix_credentials(req)
|
creds = get_request_unix_credentials(req)
|
||||||
@ -86,6 +86,7 @@ async def check_request_auth(auth_manager: AuthManager, exposed: HttpExposed, re
|
|||||||
if user:
|
if user:
|
||||||
set_request_auth_info(req, f"{user}[{creds.uid}] (unix)")
|
set_request_auth_info(req, f"{user}[{creds.uid}] (unix)")
|
||||||
return
|
return
|
||||||
|
raise UnauthorizedError()
|
||||||
|
|
||||||
raise UnauthorizedError()
|
raise UnauthorizedError()
|
||||||
|
|
||||||
|
|||||||
@ -22,15 +22,24 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from aiohttp.test_utils import make_mocked_request
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from kvmd.validators import ValidatorError
|
||||||
|
|
||||||
from kvmd.yamlconf import make_config
|
from kvmd.yamlconf import make_config
|
||||||
|
|
||||||
from kvmd.apps.kvmd.auth import AuthManager
|
from kvmd.apps.kvmd.auth import AuthManager
|
||||||
|
from kvmd.apps.kvmd.api.auth import check_request_auth
|
||||||
|
|
||||||
|
from kvmd.htserver import UnauthorizedError
|
||||||
|
from kvmd.htserver import ForbiddenError
|
||||||
|
|
||||||
from kvmd.plugins.auth import get_auth_service_class
|
from kvmd.plugins.auth import get_auth_service_class
|
||||||
|
|
||||||
@ -83,6 +92,57 @@ async def _get_configured_manager(
|
|||||||
|
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ok__request(tmpdir) -> None: # type: ignore
|
||||||
|
path = os.path.abspath(str(tmpdir.join("htpasswd")))
|
||||||
|
|
||||||
|
htpasswd = KvmdHtpasswdFile(path, new=True)
|
||||||
|
htpasswd.set_password("admin", "pass")
|
||||||
|
htpasswd.save()
|
||||||
|
|
||||||
|
async with _get_configured_manager([], path) as manager:
|
||||||
|
async def check(exposed: HttpExposed, **kwargs) -> None: # type: ignore
|
||||||
|
await check_request_auth(manager, exposed, make_mocked_request(exposed.method, exposed.path, **kwargs))
|
||||||
|
|
||||||
|
await check(_E_FREE)
|
||||||
|
with pytest.raises(UnauthorizedError):
|
||||||
|
await check(_E_AUTH)
|
||||||
|
|
||||||
|
# ===
|
||||||
|
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await check(_E_AUTH, headers={"X-KVMD-User": "admin", "X-KVMD-Passwd": "foo"})
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await check(_E_AUTH, headers={"X-KVMD-User": "adminx", "X-KVMD-Passwd": "pass"})
|
||||||
|
|
||||||
|
await check(_E_AUTH, headers={"X-KVMD-User": "admin", "X-KVMD-Passwd": "pass"})
|
||||||
|
|
||||||
|
# ===
|
||||||
|
|
||||||
|
with pytest.raises(UnauthorizedError):
|
||||||
|
await check(_E_AUTH, headers={"Cookie": "auth_token="})
|
||||||
|
with pytest.raises(ValidatorError):
|
||||||
|
await check(_E_AUTH, headers={"Cookie": "auth_token=0"})
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await check(_E_AUTH, headers={"Cookie": f"auth_token={'0' * 64}"})
|
||||||
|
|
||||||
|
token = await manager.login("admin", "pass", 0)
|
||||||
|
assert token
|
||||||
|
await check(_E_AUTH, headers={"Cookie": f"auth_token={token}"})
|
||||||
|
manager.logout(token)
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await check(_E_AUTH, headers={"Cookie": f"auth_token={token}"})
|
||||||
|
|
||||||
|
# ===
|
||||||
|
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await check(_E_AUTH, headers={"Authorization": "basic " + base64.b64encode(b"admin:foo").decode()})
|
||||||
|
with pytest.raises(ForbiddenError):
|
||||||
|
await check(_E_AUTH, headers={"Authorization": "basic " + base64.b64encode(b"adminx:pass").decode()})
|
||||||
|
|
||||||
|
await check(_E_AUTH, headers={"Authorization": "basic " + base64.b64encode(b"admin:pass").decode()})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ok__expire(tmpdir) -> None: # type: ignore
|
async def test_ok__expire(tmpdir) -> None: # type: ignore
|
||||||
path = os.path.abspath(str(tmpdir.join("htpasswd")))
|
path = os.path.abspath(str(tmpdir.join("htpasswd")))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user