This commit is contained in:
Maxim Devaev 2022-01-18 09:25:17 +03:00
parent 3ee1948f19
commit 3ab43edeb9
8 changed files with 254 additions and 11 deletions

View File

@ -448,6 +448,10 @@ def _get_config_scheme() -> Dict:
"cmd_append": Option([], type=valid_options),
},
"ocr": {
"langs": Option(["eng"], type=valid_string_list, unpack_as="default_langs"),
},
"snapshot": {
"idle_interval": Option(0.0, type=valid_float_f0),
"live_interval": Option(0.0, type=valid_float_f0),

View File

@ -37,6 +37,7 @@ from .logreader import LogReader
from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .tesseract import TesseractOcr
from .server import KvmdServer
@ -86,6 +87,7 @@ def main(argv: Optional[List[str]]=None) -> None:
info_manager=InfoManager(global_config),
log_reader=LogReader(),
user_gpio=UserGpio(config.gpio, global_config.otg.udc),
ocr=TesseractOcr(**config.ocr._unpack()),
hid=hid,
atx=get_atx_class(config.atx.type)(**config.atx._unpack(ignore=["type"])),

View File

@ -112,7 +112,7 @@ class HidApi:
# =====
def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py)
async def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py)
keymaps: Set[str] = set()
for keymap_name in os.listdir(self.__keymaps_dir_path):
path = os.path.join(self.__keymaps_dir_path, keymap_name)
@ -127,7 +127,7 @@ class HidApi:
@exposed_http("GET", "/hid/keymaps")
async def __keymaps_handler(self, _: Request) -> Response:
return make_json_response(self.get_keymaps())
return make_json_response(await self.get_keymaps())
@exposed_http("POST", "/hid/print")
async def __print_handler(self, request: Request) -> Response:

View File

@ -23,13 +23,19 @@
import io
import functools
from typing import List
from typing import Dict
from aiohttp.web import Request
from aiohttp.web import Response
from PIL import Image as PilImage
from ....validators import check_string_in_list
from ....validators.basic import valid_bool
from ....validators.basic import valid_number
from ....validators.basic import valid_int_f0
from ....validators.basic import valid_string_list
from ....validators.kvm import valid_stream_quality
from .... import aiotools
@ -41,11 +47,14 @@ from ..http import make_json_response
from ..streamer import StreamerSnapshot
from ..streamer import Streamer
from ..tesseract import TesseractOcr
# =====
class StreamerApi:
def __init__(self, streamer: Streamer) -> None:
def __init__(self, streamer: Streamer, ocr: TesseractOcr) -> None:
self.__streamer = streamer
self.__ocr = ocr
# =====
@ -61,7 +70,25 @@ class StreamerApi:
allow_offline=valid_bool(request.query.get("allow_offline", "false")),
)
if snapshot:
if valid_bool(request.query.get("preview", "false")):
if valid_bool(request.query.get("ocr", "false")):
langs = await self.__ocr.get_available_langs()
return Response(
body=(await self.__ocr.recognize(
data=snapshot.data,
langs=valid_string_list(
arg=str(request.query.get("ocr_langs", "")).strip(),
subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)),
name="OCR langs list",
),
left=int(valid_number(request.query.get("ocr_left", "-1"))),
top=int(valid_number(request.query.get("ocr_top", "-1"))),
right=int(valid_number(request.query.get("ocr_right", "-1"))),
bottom=int(valid_number(request.query.get("ocr_bottom", "-1"))),
)),
headers=dict(snapshot.headers),
content_type="text/plain",
)
elif valid_bool(request.query.get("preview", "false")):
data = await self.__make_preview(
snapshot=snapshot,
max_width=valid_int_f0(request.query.get("preview_max_width", "0")),
@ -84,6 +111,29 @@ class StreamerApi:
# =====
async def get_ocr(self) -> Dict: # XXX: Ugly hack
enabled = self.__ocr.is_available()
default: List[str] = []
available: List[str] = []
if enabled:
default = await self.__ocr.get_default_langs()
available = await self.__ocr.get_available_langs()
return {
"ocr": {
"enabled": enabled,
"langs": {
"default": default,
"available": available,
},
},
}
@exposed_http("GET", "/streamer/ocr")
async def __ocr_handler(self, _: Request) -> Response:
return make_json_response(await self.get_ocr())
# =====
async def __make_preview(self, snapshot: StreamerSnapshot, max_width: int, max_height: int, quality: int) -> bytes:
if max_width == 0 and max_height == 0:
max_width = snapshot.width // 5

View File

@ -32,6 +32,7 @@ from typing import List
from typing import Dict
from typing import Set
from typing import Callable
from typing import Awaitable
from typing import Coroutine
from typing import AsyncGenerator
from typing import Optional
@ -68,6 +69,7 @@ from .logreader import LogReader
from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .tesseract import TesseractOcr
from .http import HttpError
from .http import HttpExposed
@ -147,6 +149,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
info_manager: InfoManager,
log_reader: LogReader,
user_gpio: UserGpio,
ocr: TesseractOcr,
hid: BaseHid,
atx: BaseAtx,
@ -192,6 +195,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
]
self.__hid_api = HidApi(hid, keymap_path, ignore_keys, mouse_x_range, mouse_y_range) # Ugly hack to get keymaps state
self.__streamer_api = StreamerApi(streamer, ocr) # Same hack to get ocr langs state
self.__apis: List[object] = [
self,
AuthApi(auth_manager),
@ -201,7 +205,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
self.__hid_api,
AtxApi(atx),
MsdApi(msd),
StreamerApi(streamer),
self.__streamer_api,
ExportApi(info_manager, atx, user_gpio),
RedfishApi(info_manager, atx),
]
@ -251,21 +255,27 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse:
logger = get_logger(0)
client = _WsClient(
ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat),
stream=valid_bool(request.query.get("stream", "true")),
)
await client.ws.prepare(request)
await self.__register_ws_client(client)
try:
await self.__send_event(client.ws, "gpio_model_state", await self.__user_gpio.get_model())
await self.__send_event(client.ws, "hid_keymaps_state", self.__hid_api.get_keymaps())
await asyncio.gather(*[
self.__send_event(client.ws, component.event_type, await component.get_state())
for component in self.__components
if component.get_state
await self.__send_events_aws(client.ws, [
("gpio_model_state", self.__user_gpio.get_model()),
("hid_keymaps_state", self.__hid_api.get_keymaps()),
("streamer_ocr_state", self.__streamer_api.get_ocr()),
])
await self.__send_events_aws(client.ws, [
(comp.event_type, comp.get_state())
for comp in self.__components
if comp.get_state
])
await self.__send_event(client.ws, "loop", {})
async for msg in client.ws:
if msg.type == aiohttp.web.WSMsgType.TEXT:
try:
@ -282,6 +292,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.error("Unknown websocket event: %r", data)
else:
break
return client.ws
finally:
await self.__remove_ws_client(client)
@ -380,6 +391,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete")
async def __send_events_aws(self, ws: aiohttp.web.WebSocketResponse, sources: List[Tuple[str, Awaitable]]) -> None:
await asyncio.gather(*[
self.__send_event(ws, event_type, state)
for (event_type, state) in zip(
map(operator.itemgetter(0), sources),
await asyncio.gather(*map(operator.itemgetter(1), sources)),
)
])
async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,

161
kvmd/apps/kvmd/tesseract.py Normal file
View File

@ -0,0 +1,161 @@
# ========================================================================== #
# #
# KVMD - The main PiKVM daemon. #
# #
# Copyright (C) 2018-2022 Maxim Devaev <mdevaev@gmail.com> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #
import io
import ctypes
import ctypes.util
import contextlib
import warnings
from ctypes import POINTER
from ctypes import Structure
from ctypes import c_int
from ctypes import c_bool
from ctypes import c_char_p
from ctypes import c_void_p
from ctypes import c_char
from typing import List
from typing import Set
from typing import Generator
from typing import Optional
from PIL import Image as PilImage
from ...errors import OperationError
from ... import libc
from ... import aiotools
# =====
class OcrError(OperationError):
pass
# =====
class _TessBaseAPI(Structure):
pass
def _load_libtesseract() -> Optional[ctypes.CDLL]:
try:
path = ctypes.util.find_library("tesseract")
if not path:
raise RuntimeError("Can't find libtesseract")
lib = ctypes.CDLL(path)
for (name, restype, argtypes) in [
("TessBaseAPICreate", POINTER(_TessBaseAPI), []),
("TessBaseAPIInit3", c_int, [POINTER(_TessBaseAPI), c_char_p, c_char_p]),
("TessBaseAPISetImage", None, [POINTER(_TessBaseAPI), c_void_p, c_int, c_int, c_int, c_int]),
("TessBaseAPIGetUTF8Text", POINTER(c_char), [POINTER(_TessBaseAPI)]),
("TessBaseAPISetVariable", c_bool, [POINTER(_TessBaseAPI), c_char_p, c_char_p]),
("TessBaseAPIGetAvailableLanguagesAsVector", POINTER(POINTER(c_char)), [POINTER(_TessBaseAPI)]),
]:
func = getattr(lib, name)
if not func:
raise RuntimeError(f"Can't find libtesseract.{name}")
setattr(func, "restype", restype)
setattr(func, "argtypes", argtypes)
return lib
except Exception as err:
warnings.warn(f"Can't load libtesseract: {err}", RuntimeWarning)
return None
_libtess = _load_libtesseract()
@contextlib.contextmanager
def _tess_api(langs: List[str]) -> Generator[_TessBaseAPI, None, None]:
if not _libtess:
raise OcrError("Tesseract is not available")
api = _libtess.TessBaseAPICreate()
try:
if _libtess.TessBaseAPIInit3(api, None, "+".join(langs).encode()) != 0:
raise OcrError("Can't initialize Tesseract")
if not _libtess.TessBaseAPISetVariable(api, b"debug_file", b"/dev/null"):
raise OcrError("Can't set debug_file=/dev/null")
yield api
finally:
_libtess.TessBaseAPIDelete(api)
# =====
class TesseractOcr:
def __init__(self, default_langs: List[str]) -> None:
self.__default_langs = default_langs
def is_available(self) -> bool:
return bool(_libtess)
async def get_default_langs(self) -> List[str]:
return list(self.__default_langs)
async def get_available_langs(self) -> List[str]:
return (await aiotools.run_async(self.__inner_get_available_langs))
def __inner_get_available_langs(self) -> List[str]:
with _tess_api(["osd"]) as api:
assert _libtess
langs: Set[str] = set()
langs_ptr = _libtess.TessBaseAPIGetAvailableLanguagesAsVector(api)
if langs_ptr is not None:
index = 0
while langs_ptr[index]:
lang = ctypes.cast(langs_ptr[index], c_char_p).value
if lang is not None:
langs.add(lang.decode())
libc.free(langs_ptr[index])
index += 1
libc.free(langs_ptr)
return sorted(langs)
async def recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str:
if not langs:
langs = self.__default_langs
return (await aiotools.run_async(self.__inner_recognize, data, langs, left, top, right, bottom))
def __inner_recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str:
with _tess_api(langs) as api:
assert _libtess
with io.BytesIO(data) as bio:
with PilImage.open(bio) as image:
if left >= 0 or top >= 0 or right >= 0 or bottom >= 0:
left = (0 if left < 0 else min(image.width, left))
top = (0 if top < 0 else min(image.height, top))
right = (image.width if right < 0 else min(image.width, right))
bottom = (image.height if bottom < 0 else min(image.height, bottom))
if left < right and top < bottom:
image.crop((left, top, right, bottom))
_libtess.TessBaseAPISetImage(api, image.tobytes("raw", "RGB"), image.width, image.height, 3, image.width * 3)
text_ptr = None
try:
text_ptr = _libtess.TessBaseAPIGetUTF8Text(api)
text = ctypes.cast(text_ptr, c_char_p).value
if text is None:
raise OcrError("Can't recognize image")
return text.decode("utf-8")
finally:
if text_ptr is not None:
libc.free(text_ptr)

View File

@ -28,6 +28,7 @@ import ctypes.util
from ctypes import c_int
from ctypes import c_uint32
from ctypes import c_char_p
from ctypes import c_void_p
# =====
@ -41,6 +42,7 @@ def _load_libc() -> ctypes.CDLL:
("inotify_init", c_int, []),
("inotify_add_watch", c_int, [c_int, c_char_p, c_uint32]),
("inotify_rm_watch", c_int, [c_int, c_uint32]),
("free", c_int, [c_void_p]),
]:
func = getattr(lib, name)
if not func:
@ -57,3 +59,4 @@ _libc = _load_libc()
inotify_init = _libc.inotify_init
inotify_add_watch = _libc.inotify_add_watch
inotify_rm_watch = _libc.inotify_rm_watch
free = _libc.free

View File

@ -50,6 +50,9 @@ RUN pacman --noconfirm --ask=4 -Syy \
python-hidapi \
freetype2 \
nginx-mainline \
tesseract \
tesseract-data-eng \
tesseract-data-rus \
ipmitool \
socat \
eslint \