refactoring

This commit is contained in:
Devaev Maxim 2020-10-29 13:51:58 +03:00
parent aaef672ac2
commit a5dbc1adea
3 changed files with 23 additions and 21 deletions

View File

@ -153,7 +153,7 @@ class _MouseWheelEvent(_BaseEvent):
# ===== # =====
class BasePhyConnection: class BasePhyConnection:
def send(self, request: bytes, receive: int) -> bytes: def send(self, request: bytes) -> bytes:
raise NotImplementedError raise NotImplementedError
@ -396,7 +396,7 @@ class BaseMcuHid(BaseHid, multiprocessing.Process): # pylint: disable=too-many-
def __send_request(self, conn: BasePhyConnection, request: bytes) -> bytes: def __send_request(self, conn: BasePhyConnection, request: bytes) -> bytes:
if not self.__noop: if not self.__noop:
response = conn.send(request, 4) response = conn.send(request)
else: else:
response = b"\x33\x20" # Magic + OK response = b"\x33\x20" # Magic + OK
response += struct.pack(">H", self.__make_crc16(response)) response += struct.pack(">H", self.__make_crc16(response))

View File

@ -45,11 +45,12 @@ class _SerialPhyConnection(BasePhyConnection):
def __init__(self, tty: serial.Serial) -> None: def __init__(self, tty: serial.Serial) -> None:
self.__tty = tty self.__tty = tty
def send(self, request: bytes, receive: int) -> bytes: def send(self, request: bytes) -> bytes:
assert len(request) == 8
if self.__tty.in_waiting: if self.__tty.in_waiting:
self.__tty.read_all() self.__tty.read_all()
assert self.__tty.write(request) == len(request) assert self.__tty.write(request) == 8
return self.__tty.read(receive) return self.__tty.read(4)
class _SerialPhy(BasePhy): class _SerialPhy(BasePhy):

View File

@ -31,6 +31,8 @@ from typing import Any
import spidev import spidev
from ...logging import get_logger
from ...yamlconf import Option from ...yamlconf import Option
from ...validators.basic import valid_int_f0 from ...validators.basic import valid_int_f0
@ -44,10 +46,6 @@ from ._mcu import BaseMcuHid
# ===== # =====
class SpiPhyError(Exception):
pass
class _SpiPhyConnection(BasePhyConnection): class _SpiPhyConnection(BasePhyConnection):
def __init__( def __init__(
self, self,
@ -60,41 +58,44 @@ class _SpiPhyConnection(BasePhyConnection):
self.__read_timeout = read_timeout self.__read_timeout = read_timeout
self.__read_delay = read_delay self.__read_delay = read_delay
def send(self, request: bytes, receive: int) -> bytes: self.__empty8 = b"\x00" * 8
assert 0 < receive <= len(request) self.__empty4 = b"\x00" * 4
def send(self, request: bytes) -> bytes:
assert len(request) == 8
dummy = b"\x00" * len(request)
deadline_ts = time.time() + self.__read_timeout deadline_ts = time.time() + self.__read_timeout
while time.time() < deadline_ts: while time.time() < deadline_ts:
garbage = bytes(self.__spi.xfer(dummy)) garbage = bytes(self.__spi.xfer(self.__empty8))
if garbage == dummy: if garbage == self.__empty8:
break break
else: else:
raise SpiPhyError("Timeout reached while reading a garbage") get_logger(0).error("SPI timeout reached while reading the a garbage")
return b""
self.__spi.xfer(request) self.__spi.xfer(request)
response: List[int] = [] response: List[int] = []
dummy = b"\x00" * receive
deadline_ts = time.time() + self.__read_timeout deadline_ts = time.time() + self.__read_timeout
found = False found = False
while time.time() < deadline_ts: while time.time() < deadline_ts:
if not found: if not found:
time.sleep(self.__read_delay) time.sleep(self.__read_delay)
for byte in self.__spi.xfer(dummy): for byte in self.__spi.xfer(self.__empty4):
if not found: if not found:
if byte == 0: if byte == 0:
continue continue
found = True found = True
response.append(byte) response.append(byte)
if len(response) >= receive: if len(response) >= 4:
break break
if len(response) >= receive: if len(response) >= 4:
break break
else: else:
raise SpiPhyError("Timeout reached while responce waiting") get_logger(0).error("SPI timeout reached while responce waiting")
return b""
assert len(response) == receive assert len(response) == 4
return bytes(response) return bytes(response)