mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2025-12-12 09:10:30 +08:00
pikvm/kvmd#66: OCR API
This commit is contained in:
parent
3ee1948f19
commit
3ab43edeb9
@ -448,6 +448,10 @@ def _get_config_scheme() -> Dict:
|
||||
"cmd_append": Option([], type=valid_options),
|
||||
},
|
||||
|
||||
"ocr": {
|
||||
"langs": Option(["eng"], type=valid_string_list, unpack_as="default_langs"),
|
||||
},
|
||||
|
||||
"snapshot": {
|
||||
"idle_interval": Option(0.0, type=valid_float_f0),
|
||||
"live_interval": Option(0.0, type=valid_float_f0),
|
||||
|
||||
@ -37,6 +37,7 @@ from .logreader import LogReader
|
||||
from .ugpio import UserGpio
|
||||
from .streamer import Streamer
|
||||
from .snapshoter import Snapshoter
|
||||
from .tesseract import TesseractOcr
|
||||
from .server import KvmdServer
|
||||
|
||||
|
||||
@ -86,6 +87,7 @@ def main(argv: Optional[List[str]]=None) -> None:
|
||||
info_manager=InfoManager(global_config),
|
||||
log_reader=LogReader(),
|
||||
user_gpio=UserGpio(config.gpio, global_config.otg.udc),
|
||||
ocr=TesseractOcr(**config.ocr._unpack()),
|
||||
|
||||
hid=hid,
|
||||
atx=get_atx_class(config.atx.type)(**config.atx._unpack(ignore=["type"])),
|
||||
|
||||
@ -112,7 +112,7 @@ class HidApi:
|
||||
|
||||
# =====
|
||||
|
||||
def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py)
|
||||
async def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py)
|
||||
keymaps: Set[str] = set()
|
||||
for keymap_name in os.listdir(self.__keymaps_dir_path):
|
||||
path = os.path.join(self.__keymaps_dir_path, keymap_name)
|
||||
@ -127,7 +127,7 @@ class HidApi:
|
||||
|
||||
@exposed_http("GET", "/hid/keymaps")
|
||||
async def __keymaps_handler(self, _: Request) -> Response:
|
||||
return make_json_response(self.get_keymaps())
|
||||
return make_json_response(await self.get_keymaps())
|
||||
|
||||
@exposed_http("POST", "/hid/print")
|
||||
async def __print_handler(self, request: Request) -> Response:
|
||||
|
||||
@ -23,13 +23,19 @@
|
||||
import io
|
||||
import functools
|
||||
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
|
||||
from aiohttp.web import Request
|
||||
from aiohttp.web import Response
|
||||
|
||||
from PIL import Image as PilImage
|
||||
|
||||
from ....validators import check_string_in_list
|
||||
from ....validators.basic import valid_bool
|
||||
from ....validators.basic import valid_number
|
||||
from ....validators.basic import valid_int_f0
|
||||
from ....validators.basic import valid_string_list
|
||||
from ....validators.kvm import valid_stream_quality
|
||||
|
||||
from .... import aiotools
|
||||
@ -41,11 +47,14 @@ from ..http import make_json_response
|
||||
from ..streamer import StreamerSnapshot
|
||||
from ..streamer import Streamer
|
||||
|
||||
from ..tesseract import TesseractOcr
|
||||
|
||||
|
||||
# =====
|
||||
class StreamerApi:
|
||||
def __init__(self, streamer: Streamer) -> None:
|
||||
def __init__(self, streamer: Streamer, ocr: TesseractOcr) -> None:
|
||||
self.__streamer = streamer
|
||||
self.__ocr = ocr
|
||||
|
||||
# =====
|
||||
|
||||
@ -61,7 +70,25 @@ class StreamerApi:
|
||||
allow_offline=valid_bool(request.query.get("allow_offline", "false")),
|
||||
)
|
||||
if snapshot:
|
||||
if valid_bool(request.query.get("preview", "false")):
|
||||
if valid_bool(request.query.get("ocr", "false")):
|
||||
langs = await self.__ocr.get_available_langs()
|
||||
return Response(
|
||||
body=(await self.__ocr.recognize(
|
||||
data=snapshot.data,
|
||||
langs=valid_string_list(
|
||||
arg=str(request.query.get("ocr_langs", "")).strip(),
|
||||
subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)),
|
||||
name="OCR langs list",
|
||||
),
|
||||
left=int(valid_number(request.query.get("ocr_left", "-1"))),
|
||||
top=int(valid_number(request.query.get("ocr_top", "-1"))),
|
||||
right=int(valid_number(request.query.get("ocr_right", "-1"))),
|
||||
bottom=int(valid_number(request.query.get("ocr_bottom", "-1"))),
|
||||
)),
|
||||
headers=dict(snapshot.headers),
|
||||
content_type="text/plain",
|
||||
)
|
||||
elif valid_bool(request.query.get("preview", "false")):
|
||||
data = await self.__make_preview(
|
||||
snapshot=snapshot,
|
||||
max_width=valid_int_f0(request.query.get("preview_max_width", "0")),
|
||||
@ -84,6 +111,29 @@ class StreamerApi:
|
||||
|
||||
# =====
|
||||
|
||||
async def get_ocr(self) -> Dict: # XXX: Ugly hack
|
||||
enabled = self.__ocr.is_available()
|
||||
default: List[str] = []
|
||||
available: List[str] = []
|
||||
if enabled:
|
||||
default = await self.__ocr.get_default_langs()
|
||||
available = await self.__ocr.get_available_langs()
|
||||
return {
|
||||
"ocr": {
|
||||
"enabled": enabled,
|
||||
"langs": {
|
||||
"default": default,
|
||||
"available": available,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@exposed_http("GET", "/streamer/ocr")
|
||||
async def __ocr_handler(self, _: Request) -> Response:
|
||||
return make_json_response(await self.get_ocr())
|
||||
|
||||
# =====
|
||||
|
||||
async def __make_preview(self, snapshot: StreamerSnapshot, max_width: int, max_height: int, quality: int) -> bytes:
|
||||
if max_width == 0 and max_height == 0:
|
||||
max_width = snapshot.width // 5
|
||||
|
||||
@ -32,6 +32,7 @@ from typing import List
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
from typing import Callable
|
||||
from typing import Awaitable
|
||||
from typing import Coroutine
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
@ -68,6 +69,7 @@ from .logreader import LogReader
|
||||
from .ugpio import UserGpio
|
||||
from .streamer import Streamer
|
||||
from .snapshoter import Snapshoter
|
||||
from .tesseract import TesseractOcr
|
||||
|
||||
from .http import HttpError
|
||||
from .http import HttpExposed
|
||||
@ -147,6 +149,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
info_manager: InfoManager,
|
||||
log_reader: LogReader,
|
||||
user_gpio: UserGpio,
|
||||
ocr: TesseractOcr,
|
||||
|
||||
hid: BaseHid,
|
||||
atx: BaseAtx,
|
||||
@ -192,6 +195,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
]
|
||||
|
||||
self.__hid_api = HidApi(hid, keymap_path, ignore_keys, mouse_x_range, mouse_y_range) # Ugly hack to get keymaps state
|
||||
self.__streamer_api = StreamerApi(streamer, ocr) # Same hack to get ocr langs state
|
||||
self.__apis: List[object] = [
|
||||
self,
|
||||
AuthApi(auth_manager),
|
||||
@ -201,7 +205,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
self.__hid_api,
|
||||
AtxApi(atx),
|
||||
MsdApi(msd),
|
||||
StreamerApi(streamer),
|
||||
self.__streamer_api,
|
||||
ExportApi(info_manager, atx, user_gpio),
|
||||
RedfishApi(info_manager, atx),
|
||||
]
|
||||
@ -251,21 +255,27 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
@exposed_http("GET", "/ws")
|
||||
async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse:
|
||||
logger = get_logger(0)
|
||||
|
||||
client = _WsClient(
|
||||
ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat),
|
||||
stream=valid_bool(request.query.get("stream", "true")),
|
||||
)
|
||||
await client.ws.prepare(request)
|
||||
await self.__register_ws_client(client)
|
||||
|
||||
try:
|
||||
await self.__send_event(client.ws, "gpio_model_state", await self.__user_gpio.get_model())
|
||||
await self.__send_event(client.ws, "hid_keymaps_state", self.__hid_api.get_keymaps())
|
||||
await asyncio.gather(*[
|
||||
self.__send_event(client.ws, component.event_type, await component.get_state())
|
||||
for component in self.__components
|
||||
if component.get_state
|
||||
await self.__send_events_aws(client.ws, [
|
||||
("gpio_model_state", self.__user_gpio.get_model()),
|
||||
("hid_keymaps_state", self.__hid_api.get_keymaps()),
|
||||
("streamer_ocr_state", self.__streamer_api.get_ocr()),
|
||||
])
|
||||
await self.__send_events_aws(client.ws, [
|
||||
(comp.event_type, comp.get_state())
|
||||
for comp in self.__components
|
||||
if comp.get_state
|
||||
])
|
||||
await self.__send_event(client.ws, "loop", {})
|
||||
|
||||
async for msg in client.ws:
|
||||
if msg.type == aiohttp.web.WSMsgType.TEXT:
|
||||
try:
|
||||
@ -282,6 +292,7 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
logger.error("Unknown websocket event: %r", data)
|
||||
else:
|
||||
break
|
||||
|
||||
return client.ws
|
||||
finally:
|
||||
await self.__remove_ws_client(client)
|
||||
@ -380,6 +391,15 @@ class KvmdServer(HttpServer): # pylint: disable=too-many-arguments,too-many-ins
|
||||
logger.exception("Cleanup error on %s", comp.name)
|
||||
logger.info("On-Cleanup complete")
|
||||
|
||||
async def __send_events_aws(self, ws: aiohttp.web.WebSocketResponse, sources: List[Tuple[str, Awaitable]]) -> None:
|
||||
await asyncio.gather(*[
|
||||
self.__send_event(ws, event_type, state)
|
||||
for (event_type, state) in zip(
|
||||
map(operator.itemgetter(0), sources),
|
||||
await asyncio.gather(*map(operator.itemgetter(1), sources)),
|
||||
)
|
||||
])
|
||||
|
||||
async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
|
||||
await ws.send_str(json.dumps({
|
||||
"event_type": event_type,
|
||||
|
||||
161
kvmd/apps/kvmd/tesseract.py
Normal file
161
kvmd/apps/kvmd/tesseract.py
Normal file
@ -0,0 +1,161 @@
|
||||
# ========================================================================== #
|
||||
# #
|
||||
# KVMD - The main PiKVM daemon. #
|
||||
# #
|
||||
# Copyright (C) 2018-2022 Maxim Devaev <mdevaev@gmail.com> #
|
||||
# #
|
||||
# This program is free software: you can redistribute it and/or modify #
|
||||
# it under the terms of the GNU General Public License as published by #
|
||||
# the Free Software Foundation, either version 3 of the License, or #
|
||||
# (at your option) any later version. #
|
||||
# #
|
||||
# This program is distributed in the hope that it will be useful, #
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
|
||||
# GNU General Public License for more details. #
|
||||
# #
|
||||
# You should have received a copy of the GNU General Public License #
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
|
||||
# #
|
||||
# ========================================================================== #
|
||||
|
||||
|
||||
import io
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
from ctypes import POINTER
|
||||
from ctypes import Structure
|
||||
from ctypes import c_int
|
||||
from ctypes import c_bool
|
||||
from ctypes import c_char_p
|
||||
from ctypes import c_void_p
|
||||
from ctypes import c_char
|
||||
|
||||
from typing import List
|
||||
from typing import Set
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image as PilImage
|
||||
|
||||
from ...errors import OperationError
|
||||
|
||||
from ... import libc
|
||||
from ... import aiotools
|
||||
|
||||
|
||||
# =====
|
||||
class OcrError(OperationError):
|
||||
pass
|
||||
|
||||
|
||||
# =====
|
||||
class _TessBaseAPI(Structure):
|
||||
pass
|
||||
|
||||
|
||||
def _load_libtesseract() -> Optional[ctypes.CDLL]:
|
||||
try:
|
||||
path = ctypes.util.find_library("tesseract")
|
||||
if not path:
|
||||
raise RuntimeError("Can't find libtesseract")
|
||||
lib = ctypes.CDLL(path)
|
||||
for (name, restype, argtypes) in [
|
||||
("TessBaseAPICreate", POINTER(_TessBaseAPI), []),
|
||||
("TessBaseAPIInit3", c_int, [POINTER(_TessBaseAPI), c_char_p, c_char_p]),
|
||||
("TessBaseAPISetImage", None, [POINTER(_TessBaseAPI), c_void_p, c_int, c_int, c_int, c_int]),
|
||||
("TessBaseAPIGetUTF8Text", POINTER(c_char), [POINTER(_TessBaseAPI)]),
|
||||
("TessBaseAPISetVariable", c_bool, [POINTER(_TessBaseAPI), c_char_p, c_char_p]),
|
||||
("TessBaseAPIGetAvailableLanguagesAsVector", POINTER(POINTER(c_char)), [POINTER(_TessBaseAPI)]),
|
||||
]:
|
||||
func = getattr(lib, name)
|
||||
if not func:
|
||||
raise RuntimeError(f"Can't find libtesseract.{name}")
|
||||
setattr(func, "restype", restype)
|
||||
setattr(func, "argtypes", argtypes)
|
||||
return lib
|
||||
except Exception as err:
|
||||
warnings.warn(f"Can't load libtesseract: {err}", RuntimeWarning)
|
||||
return None
|
||||
|
||||
|
||||
_libtess = _load_libtesseract()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tess_api(langs: List[str]) -> Generator[_TessBaseAPI, None, None]:
|
||||
if not _libtess:
|
||||
raise OcrError("Tesseract is not available")
|
||||
api = _libtess.TessBaseAPICreate()
|
||||
try:
|
||||
if _libtess.TessBaseAPIInit3(api, None, "+".join(langs).encode()) != 0:
|
||||
raise OcrError("Can't initialize Tesseract")
|
||||
if not _libtess.TessBaseAPISetVariable(api, b"debug_file", b"/dev/null"):
|
||||
raise OcrError("Can't set debug_file=/dev/null")
|
||||
yield api
|
||||
finally:
|
||||
_libtess.TessBaseAPIDelete(api)
|
||||
|
||||
|
||||
# =====
|
||||
class TesseractOcr:
|
||||
def __init__(self, default_langs: List[str]) -> None:
|
||||
self.__default_langs = default_langs
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return bool(_libtess)
|
||||
|
||||
async def get_default_langs(self) -> List[str]:
|
||||
return list(self.__default_langs)
|
||||
|
||||
async def get_available_langs(self) -> List[str]:
|
||||
return (await aiotools.run_async(self.__inner_get_available_langs))
|
||||
|
||||
def __inner_get_available_langs(self) -> List[str]:
|
||||
with _tess_api(["osd"]) as api:
|
||||
assert _libtess
|
||||
langs: Set[str] = set()
|
||||
langs_ptr = _libtess.TessBaseAPIGetAvailableLanguagesAsVector(api)
|
||||
if langs_ptr is not None:
|
||||
index = 0
|
||||
while langs_ptr[index]:
|
||||
lang = ctypes.cast(langs_ptr[index], c_char_p).value
|
||||
if lang is not None:
|
||||
langs.add(lang.decode())
|
||||
libc.free(langs_ptr[index])
|
||||
index += 1
|
||||
libc.free(langs_ptr)
|
||||
return sorted(langs)
|
||||
|
||||
async def recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str:
|
||||
if not langs:
|
||||
langs = self.__default_langs
|
||||
return (await aiotools.run_async(self.__inner_recognize, data, langs, left, top, right, bottom))
|
||||
|
||||
def __inner_recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str:
|
||||
with _tess_api(langs) as api:
|
||||
assert _libtess
|
||||
with io.BytesIO(data) as bio:
|
||||
with PilImage.open(bio) as image:
|
||||
if left >= 0 or top >= 0 or right >= 0 or bottom >= 0:
|
||||
left = (0 if left < 0 else min(image.width, left))
|
||||
top = (0 if top < 0 else min(image.height, top))
|
||||
right = (image.width if right < 0 else min(image.width, right))
|
||||
bottom = (image.height if bottom < 0 else min(image.height, bottom))
|
||||
if left < right and top < bottom:
|
||||
image.crop((left, top, right, bottom))
|
||||
|
||||
_libtess.TessBaseAPISetImage(api, image.tobytes("raw", "RGB"), image.width, image.height, 3, image.width * 3)
|
||||
text_ptr = None
|
||||
try:
|
||||
text_ptr = _libtess.TessBaseAPIGetUTF8Text(api)
|
||||
text = ctypes.cast(text_ptr, c_char_p).value
|
||||
if text is None:
|
||||
raise OcrError("Can't recognize image")
|
||||
return text.decode("utf-8")
|
||||
finally:
|
||||
if text_ptr is not None:
|
||||
libc.free(text_ptr)
|
||||
@ -28,6 +28,7 @@ import ctypes.util
|
||||
from ctypes import c_int
|
||||
from ctypes import c_uint32
|
||||
from ctypes import c_char_p
|
||||
from ctypes import c_void_p
|
||||
|
||||
|
||||
# =====
|
||||
@ -41,6 +42,7 @@ def _load_libc() -> ctypes.CDLL:
|
||||
("inotify_init", c_int, []),
|
||||
("inotify_add_watch", c_int, [c_int, c_char_p, c_uint32]),
|
||||
("inotify_rm_watch", c_int, [c_int, c_uint32]),
|
||||
("free", c_int, [c_void_p]),
|
||||
]:
|
||||
func = getattr(lib, name)
|
||||
if not func:
|
||||
@ -57,3 +59,4 @@ _libc = _load_libc()
|
||||
inotify_init = _libc.inotify_init
|
||||
inotify_add_watch = _libc.inotify_add_watch
|
||||
inotify_rm_watch = _libc.inotify_rm_watch
|
||||
free = _libc.free
|
||||
|
||||
@ -50,6 +50,9 @@ RUN pacman --noconfirm --ask=4 -Syy \
|
||||
python-hidapi \
|
||||
freetype2 \
|
||||
nginx-mainline \
|
||||
tesseract \
|
||||
tesseract-data-eng \
|
||||
tesseract-data-rus \
|
||||
ipmitool \
|
||||
socat \
|
||||
eslint \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user