refactoring

This commit is contained in:
Devaev Maxim 2018-12-13 20:39:01 +03:00
parent 972c4928cc
commit 1c64b7b0e8
2 changed files with 108 additions and 102 deletions

View File

@ -2,6 +2,7 @@ import os
import signal
import socket
import asyncio
import inspect
import json
import time
@ -32,18 +33,6 @@ from .streamer import Streamer
# =====
def _system_task(method: Callable) -> Callable:
async def wrap(self: "Server") -> None:
try:
await method(self)
except asyncio.CancelledError:
pass
except Exception:
get_logger().exception("Unhandled exception, killing myself ...")
os.kill(os.getpid(), signal.SIGTERM)
return wrap
def _json(result: Optional[Dict]=None, status: int=200) -> aiohttp.web.Response:
return aiohttp.web.Response(
text=json.dumps({
@ -55,26 +44,64 @@ def _json(result: Optional[Dict]=None, status: int=200) -> aiohttp.web.Response:
)
def _json_exception(msg: str, err: Exception, status: int) -> aiohttp.web.Response:
msg = "%s: %s" % (msg, err)
get_logger().error(msg)
def _json_exception(err: Exception, status: int) -> aiohttp.web.Response:
name = type(err).__name__
msg = str(err)
get_logger().error("API error: %s: %s", name, msg)
return _json({
"error": type(err).__name__,
"error": name,
"error_msg": msg,
}, status=status)
class BadRequest(Exception):
class BadRequestError(Exception):
pass
_ATTR_EXPOSED = "exposed"
_ATTR_EXPOSED_METHOD = "exposed_method"
_ATTR_EXPOSED_PATH = "exposed_path"
_ATTR_SYSTEM_TASK = "system_task"
def _exposed(http_method: str, path: str) -> Callable:
def make_wrapper(method: Callable) -> Callable:
async def wrap(self: "Server", request: aiohttp.web.Request) -> aiohttp.web.Response:
try:
return (await method(self, request))
except RegionIsBusyError as err:
return _json_exception(err, 409)
except (BadRequestError, MsdOperationError) as err:
return _json_exception(err, 400)
setattr(wrap, _ATTR_EXPOSED, True)
setattr(wrap, _ATTR_EXPOSED_METHOD, http_method)
setattr(wrap, _ATTR_EXPOSED_PATH, path)
return wrap
return make_wrapper
def _system_task(method: Callable) -> Callable:
async def wrap(self: "Server") -> None:
try:
await method(self)
except asyncio.CancelledError:
pass
except Exception:
get_logger().exception("Unhandled exception, killing myself ...")
os.kill(os.getpid(), signal.SIGTERM)
setattr(wrap, _ATTR_SYSTEM_TASK, True)
return wrap
def _valid_bool(name: str, flag: Optional[str]) -> bool:
flag = str(flag).strip().lower()
if flag in ["1", "true", "yes"]:
return True
elif flag in ["0", "false", "no"]:
return False
raise BadRequest("Invalid param '%s'" % (name))
raise BadRequestError("Invalid param '%s'" % (name))
def _valid_int(name: str, value: Optional[str], min_value: Optional[int]=None, max_value: Optional[int]=None) -> int:
@ -87,20 +114,7 @@ def _valid_int(name: str, value: Optional[str], min_value: Optional[int]=None, m
raise ValueError()
return value_int
except Exception:
raise BadRequest("Invalid param %r" % (name))
def _wrap_exceptions_for_web(msg: str) -> Callable:
def make_wrapper(method: Callable) -> Callable:
async def wrap(self: "Server", request: aiohttp.web.Request) -> aiohttp.web.Response:
try:
return (await method(self, request))
except RegionIsBusyError as err:
return _json_exception(msg, err, 409)
except (BadRequest, MsdOperationError) as err:
return _json_exception(msg, err, 400)
return wrap
return make_wrapper
raise BadRequestError("Invalid param %r" % (name))
class _Events(Enum):
@ -156,37 +170,20 @@ class Server: # pylint: disable=too-many-instance-attributes
setproctitle.setproctitle("[main] " + setproctitle.getproctitle())
app = aiohttp.web.Application(loop=self.__loop)
app.router.add_get("/info", self.__info_handler)
app.router.add_get("/log", self.__log_handler)
app.router.add_get("/ws", self.__ws_handler)
app.router.add_post("/hid/reset", self.__hid_reset_handler)
app.router.add_get("/atx", self.__atx_state_handler)
app.router.add_post("/atx/click", self.__atx_click_handler)
app.router.add_get("/msd", self.__msd_state_handler)
app.router.add_post("/msd/connect", self.__msd_connect_handler)
app.router.add_post("/msd/write", self.__msd_write_handler)
app.router.add_post("/msd/reset", self.__msd_reset_handler)
app.router.add_get("/streamer", self.__streamer_state_handler)
app.router.add_post("/streamer/set_params", self.__streamer_set_params_handler)
app.router.add_post("/streamer/reset", self.__streamer_reset_handler)
app.on_shutdown.append(self.__on_shutdown)
app.on_cleanup.append(self.__on_cleanup)
self.__system_tasks.extend([
self.__loop.create_task(self.__hid_watchdog()),
self.__loop.create_task(self.__stream_controller()),
self.__loop.create_task(self.__poll_dead_sockets()),
self.__loop.create_task(self.__poll_atx_state()),
self.__loop.create_task(self.__poll_msd_state()),
self.__loop.create_task(self.__poll_streamer_state()),
])
for name in dir(self):
method = getattr(self, name)
if inspect.ismethod(method):
if getattr(method, _ATTR_SYSTEM_TASK, False):
self.__system_tasks.append(self.__loop.create_task(method()))
elif getattr(method, _ATTR_EXPOSED, False):
app.router.add_route(
getattr(method, _ATTR_EXPOSED_METHOD),
getattr(method, _ATTR_EXPOSED_PATH),
method,
)
assert port or unix_path
if unix_path:
@ -215,10 +212,11 @@ class Server: # pylint: disable=too-many-instance-attributes
# ===== SYSTEM
@_exposed("GET", "/info")
async def __info_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(await self.__make_info())
@_wrap_exceptions_for_web("Log error")
@_exposed("GET", "/log")
async def __log_handler(self, request: aiohttp.web.Request) -> aiohttp.web.StreamResponse:
seek = _valid_int("seek", request.query.get("seek", "0"), 0)
follow = _valid_bool("follow", request.query.get("follow", "false"))
@ -234,6 +232,7 @@ class Server: # pylint: disable=too-many-instance-attributes
# ===== WEBSOCKET
@_exposed("GET", "/ws")
async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse:
logger = get_logger(0)
ws = aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat)
@ -298,16 +297,18 @@ class Server: # pylint: disable=too-many-instance-attributes
# ===== HID
@_exposed("POST", "/hid/reset")
async def __hid_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
await self.__hid.reset()
return _json()
# ===== ATX
@_exposed("GET", "/atx")
async def __atx_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(self.__atx.get_state())
@_wrap_exceptions_for_web("Click error")
@_exposed("POST", "/atx/click")
async def __atx_click_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
button = request.query.get("button")
clicker = {
@ -316,16 +317,17 @@ class Server: # pylint: disable=too-many-instance-attributes
"reset": self.__atx.click_reset,
}.get(button)
if not clicker:
raise BadRequest("Invalid param 'button'")
raise BadRequestError("Invalid param 'button'")
await clicker()
return _json({"clicked": button})
# ===== MSD
@_exposed("GET", "/msd")
async def __msd_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(self.__msd.get_state())
@_wrap_exceptions_for_web("Mass-storage error")
@_exposed("POST", "/msd/connect")
async def __msd_connect_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
to = request.query.get("to")
if to == "kvm":
@ -333,9 +335,9 @@ class Server: # pylint: disable=too-many-instance-attributes
elif to == "server":
return _json(await self.__msd.connect_to_pc())
else:
raise BadRequest("Invalid param 'to'")
raise BadRequestError("Invalid param 'to'")
@_wrap_exceptions_for_web("Can't write data to mass-storage device")
@_exposed("POST", "/msd/write")
async def __msd_write_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
logger = get_logger(0)
reader = await request.multipart()
@ -344,12 +346,12 @@ class Server: # pylint: disable=too-many-instance-attributes
async with self.__msd:
field = await reader.next()
if not field or field.name != "image_name":
raise BadRequest("Missing 'image_name' field")
raise BadRequestError("Missing 'image_name' field")
image_name = (await field.read()).decode("utf-8")[:256]
field = await reader.next()
if not field or field.name != "image_data":
raise BadRequest("Missing 'image_data' field")
raise BadRequestError("Missing 'image_data' field")
logger.info("Writing image %r to mass-storage device ...", image_name)
await self.__msd.write_image_info(image_name, False)
@ -364,17 +366,18 @@ class Server: # pylint: disable=too-many-instance-attributes
logger.info("Written %d bytes to mass-storage device", written)
return _json({"written": written})
@_wrap_exceptions_for_web("Mass-storage error")
@_exposed("POST", "/msd/reset")
async def __msd_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
await self.__msd.reset()
return _json()
# ===== STREAMER
@_exposed("GET", "/streamer")
async def __streamer_state_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
return _json(await self.__streamer.get_state())
@_wrap_exceptions_for_web("Can't set stream params")
@_exposed("POST", "/streamer/set_params")
async def __streamer_set_params_handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
for (name, validator) in [
("quality", lambda arg: _valid_int("quality", arg, 1, 100)),
@ -385,6 +388,7 @@ class Server: # pylint: disable=too-many-instance-attributes
self.__streamer_params[name] = validator(value)
return _json()
@_exposed("POST", "/streamer/reset")
async def __streamer_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
self.__reset_streamer = True
return _json()
@ -413,6 +417,39 @@ class Server: # pylint: disable=too-many-instance-attributes
await self.__msd.cleanup()
await self.__hid.cleanup()
async def __broadcast_event(self, event_type: _Events, event_attrs: Dict) -> None:
if self.__sockets:
await asyncio.gather(*[
ws.send_str(json.dumps({
"msg_type": "event",
"msg": {
"event": event_type.value,
"event_attrs": event_attrs,
},
}))
for ws in list(self.__sockets)
if not ws.closed and ws._req.transport # pylint: disable=protected-access
], return_exceptions=True)
async def __register_socket(self, ws: aiohttp.web.WebSocketResponse) -> None:
async with self.__sockets_lock:
self.__sockets.add(ws)
get_logger().info("Registered new client socket: remote=%s; id=%d; active=%d",
ws._req.remote, id(ws), len(self.__sockets)) # pylint: disable=protected-access
async def __remove_socket(self, ws: aiohttp.web.WebSocketResponse) -> None:
async with self.__sockets_lock:
await self.__hid.clear_events()
try:
self.__sockets.remove(ws)
get_logger().info("Removed client socket: remote=%s; id=%d; active=%d",
ws._req.remote, id(ws), len(self.__sockets)) # pylint: disable=protected-access
await ws.close()
except Exception:
pass
# ===== SYSTEM TASKS
@_system_task
async def __hid_watchdog(self) -> None:
while self.__hid.is_alive():
@ -466,34 +503,3 @@ class Server: # pylint: disable=too-many-instance-attributes
async def __poll_streamer_state(self) -> None:
async for state in self.__streamer.poll_state():
await self.__broadcast_event(_Events.STREAMER_STATE, state)
async def __broadcast_event(self, event_type: _Events, event_attrs: Dict) -> None:
if self.__sockets:
await asyncio.gather(*[
ws.send_str(json.dumps({
"msg_type": "event",
"msg": {
"event": event_type.value,
"event_attrs": event_attrs,
},
}))
for ws in list(self.__sockets)
if not ws.closed and ws._req.transport # pylint: disable=protected-access
], return_exceptions=True)
async def __register_socket(self, ws: aiohttp.web.WebSocketResponse) -> None:
async with self.__sockets_lock:
self.__sockets.add(ws)
get_logger().info("Registered new client socket: remote=%s; id=%d; active=%d",
ws._req.remote, id(ws), len(self.__sockets)) # pylint: disable=protected-access
async def __remove_socket(self, ws: aiohttp.web.WebSocketResponse) -> None:
async with self.__sockets_lock:
await self.__hid.clear_events()
try:
self.__sockets.remove(ws)
get_logger().info("Removed client socket: remote=%s; id=%d; active=%d",
ws._req.remote, id(ws), len(self.__sockets)) # pylint: disable=protected-access
await ws.close()
except Exception:
pass

View File

@ -26,7 +26,7 @@ deps =
-rrequirements.txt
[testenv:vulture]
commands = vulture kvmd genmap.py testenv/vulture-wl.py
commands = vulture --ignore-decorators=@_exposed,@_system_task kvmd genmap.py testenv/vulture-wl.py
deps =
vulture
-rrequirements.txt