mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2026-01-29 00:51:53 +08:00
own auth
This commit is contained in:
@@ -5,8 +5,9 @@ from ...logging import get_logger
|
||||
|
||||
from ... import gpio
|
||||
|
||||
from .logreader import LogReader
|
||||
from .auth import AuthManager
|
||||
from .info import InfoManager
|
||||
from .logreader import LogReader
|
||||
from .hid import Hid
|
||||
from .atx import Atx
|
||||
from .msd import MassStorageDevice
|
||||
@@ -20,6 +21,10 @@ def main() -> None:
|
||||
with gpio.bcm():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
auth_manager = AuthManager(
|
||||
htpasswd_path=str(config["auth"]["htpasswd"]),
|
||||
)
|
||||
|
||||
info_manager = InfoManager(
|
||||
meta_path=str(config["info"]["meta"]),
|
||||
extras_path=str(config["info"]["extras"]),
|
||||
@@ -80,6 +85,7 @@ def main() -> None:
|
||||
)
|
||||
|
||||
Server(
|
||||
auth_manager=auth_manager,
|
||||
info_manager=info_manager,
|
||||
log_reader=log_reader,
|
||||
|
||||
|
||||
37
kvmd/apps/kvmd/auth.py
Normal file
37
kvmd/apps/kvmd/auth.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import secrets
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import passlib.apache
|
||||
|
||||
from ...logging import get_logger
|
||||
|
||||
|
||||
# =====
|
||||
class AuthManager:
|
||||
def __init__(self, htpasswd_path: str) -> None:
|
||||
self.__htpasswd_path = htpasswd_path
|
||||
self.__tokens: Dict[str, str] = {} # {token: user}
|
||||
|
||||
def login(self, user: str, passwd: str) -> Optional[str]:
|
||||
htpasswd = passlib.apache.HtpasswdFile(self.__htpasswd_path)
|
||||
if htpasswd.check_password(user, passwd):
|
||||
for (token, token_user) in self.__tokens.items():
|
||||
if user == token_user:
|
||||
return token
|
||||
token = secrets.token_hex(32)
|
||||
self.__tokens[token] = user
|
||||
get_logger().info("Logged in user %r", user)
|
||||
return token
|
||||
else:
|
||||
get_logger().error("Access denied for user %r", user)
|
||||
return None
|
||||
|
||||
def logout(self, token: str) -> None:
|
||||
user = self.__tokens.pop(token, "")
|
||||
if user:
|
||||
get_logger().info("Logged out user %r", user)
|
||||
|
||||
def check(self, token: str) -> bool:
|
||||
return (token in self.__tokens)
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import asyncio
|
||||
@@ -23,6 +24,7 @@ from ...aioregion import RegionIsBusyError
|
||||
|
||||
from ... import __version__
|
||||
|
||||
from .auth import AuthManager
|
||||
from .info import InfoManager
|
||||
from .logreader import LogReader
|
||||
from .hid import Hid
|
||||
@@ -33,8 +35,29 @@ from .streamer import Streamer
|
||||
|
||||
|
||||
# =====
|
||||
def _json(result: Optional[Dict]=None, status: int=200) -> aiohttp.web.Response:
|
||||
return aiohttp.web.Response(
|
||||
class HttpError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadRequestError(HttpError):
|
||||
pass
|
||||
|
||||
|
||||
class UnauthorizedError(HttpError):
|
||||
pass
|
||||
|
||||
|
||||
class ForbiddenError(HttpError):
|
||||
pass
|
||||
|
||||
|
||||
def _json(
|
||||
result: Optional[Dict]=None,
|
||||
status: int=200,
|
||||
set_cookies: Optional[Dict[str, str]]=None,
|
||||
) -> aiohttp.web.Response:
|
||||
|
||||
response = aiohttp.web.Response(
|
||||
text=json.dumps({
|
||||
"ok": (status == 200),
|
||||
"result": (result or {}),
|
||||
@@ -42,37 +65,53 @@ def _json(result: Optional[Dict]=None, status: int=200) -> aiohttp.web.Response:
|
||||
status=status,
|
||||
content_type="application/json",
|
||||
)
|
||||
if set_cookies:
|
||||
for (key, value) in set_cookies.items():
|
||||
response.set_cookie(key, value)
|
||||
return response
|
||||
|
||||
|
||||
def _json_exception(err: Exception, status: int) -> aiohttp.web.Response:
|
||||
name = type(err).__name__
|
||||
msg = str(err)
|
||||
get_logger().error("API error: %s: %s", name, msg)
|
||||
if not isinstance(err, (UnauthorizedError, ForbiddenError)):
|
||||
get_logger().error("API error: %s: %s", name, msg)
|
||||
return _json({
|
||||
"error": name,
|
||||
"error_msg": msg,
|
||||
}, status=status)
|
||||
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_ATTR_EXPOSED = "exposed"
|
||||
_ATTR_EXPOSED_METHOD = "exposed_method"
|
||||
_ATTR_EXPOSED_PATH = "exposed_path"
|
||||
_ATTR_SYSTEM_TASK = "system_task"
|
||||
|
||||
_COOKIE_AUTH_TOKEN = "auth_token"
|
||||
|
||||
def _exposed(http_method: str, path: str) -> Callable:
|
||||
|
||||
def _exposed(http_method: str, path: str, auth_required: bool=True) -> Callable:
|
||||
def make_wrapper(method: Callable) -> Callable:
|
||||
async def wrap(self: "Server", request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
try:
|
||||
if auth_required:
|
||||
token = request.cookies.get(_COOKIE_AUTH_TOKEN, "")
|
||||
if token:
|
||||
if not self._auth_manager.check(_valid_token(token)):
|
||||
raise ForbiddenError("Forbidden")
|
||||
else:
|
||||
raise UnauthorizedError("Unauthorized")
|
||||
|
||||
return (await method(self, request))
|
||||
|
||||
except RegionIsBusyError as err:
|
||||
return _json_exception(err, 409)
|
||||
except (BadRequestError, MsdOperationError) as err:
|
||||
return _json_exception(err, 400)
|
||||
except UnauthorizedError as err:
|
||||
return _json_exception(err, 401)
|
||||
except ForbiddenError as err:
|
||||
return _json_exception(err, 403)
|
||||
|
||||
setattr(wrap, _ATTR_EXPOSED, True)
|
||||
setattr(wrap, _ATTR_EXPOSED_METHOD, http_method)
|
||||
@@ -95,6 +134,29 @@ def _system_task(method: Callable) -> Callable:
|
||||
return wrap
|
||||
|
||||
|
||||
def _valid_user(user: Optional[str]) -> str:
|
||||
if isinstance(user, str):
|
||||
stripped = user.strip()
|
||||
if re.match(r"^[a-z_][a-z0-9_-]*$", stripped):
|
||||
return stripped
|
||||
raise BadRequestError("Invalid user characters %r" % (user))
|
||||
|
||||
|
||||
def _valid_passwd(passwd: Optional[str]) -> str:
|
||||
if isinstance(passwd, str):
|
||||
if re.match(r"[\x20-\x7e]*$", passwd):
|
||||
return passwd
|
||||
raise BadRequestError("Invalid password characters")
|
||||
|
||||
|
||||
def _valid_token(token: Optional[str]) -> str:
|
||||
if isinstance(token, str):
|
||||
token = token.strip().lower()
|
||||
if re.match(r"^[0-9a-f]{64}$", token):
|
||||
return token
|
||||
raise BadRequestError("Invalid auth token characters")
|
||||
|
||||
|
||||
def _valid_bool(name: str, flag: Optional[str]) -> bool:
|
||||
flag = str(flag).strip().lower()
|
||||
if flag in ["1", "true", "yes"]:
|
||||
@@ -127,6 +189,7 @@ class _Events(Enum):
|
||||
class Server: # pylint: disable=too-many-instance-attributes
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
auth_manager: AuthManager,
|
||||
info_manager: InfoManager,
|
||||
log_reader: LogReader,
|
||||
|
||||
@@ -142,6 +205,7 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> None:
|
||||
|
||||
self._auth_manager = auth_manager
|
||||
self.__info_manager = info_manager
|
||||
self.__log_reader = log_reader
|
||||
|
||||
@@ -210,6 +274,29 @@ class Server: # pylint: disable=too-many-instance-attributes
|
||||
"extras": await self.__info_manager.get_extras(),
|
||||
}
|
||||
|
||||
# ===== AUTH
|
||||
|
||||
@_exposed("POST", "/auth/login", auth_required=False)
|
||||
async def __auth_login_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
credentials = await request.post()
|
||||
token = self._auth_manager.login(
|
||||
user=_valid_user(credentials.get("user", "")),
|
||||
passwd=_valid_passwd(credentials.get("passwd", "")),
|
||||
)
|
||||
if token:
|
||||
return _json({}, set_cookies={_COOKIE_AUTH_TOKEN: token})
|
||||
raise ForbiddenError("Forbidden")
|
||||
|
||||
@_exposed("POST", "/auth/logout")
|
||||
async def __auth_logout_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
token = _valid_token(request.cookies.get(_COOKIE_AUTH_TOKEN, ""))
|
||||
self._auth_manager.logout(token)
|
||||
return _json({})
|
||||
|
||||
@_exposed("GET", "/auth/check")
|
||||
async def __auth_check_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
return _json({})
|
||||
|
||||
# ===== SYSTEM
|
||||
|
||||
@_exposed("GET", "/info")
|
||||
|
||||
Reference in New Issue
Block a user