refactoring

This commit is contained in:
Maxim Devaev 2024-09-16 23:07:38 +03:00
parent b779c18530
commit c57334f214
3 changed files with 68 additions and 53 deletions

View File

@ -11,15 +11,16 @@ from ... import aioproc
from ...logging import get_logger from ...logging import get_logger
from .stun import StunNatType
from .stun import Stun from .stun import Stun
# ===== # =====
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class _Netcfg: class _Netcfg:
nat_type: str = dataclasses.field(default="") nat_type: StunNatType = dataclasses.field(default=StunNatType.ERROR)
src_ip: str = dataclasses.field(default="") src_ip: str = dataclasses.field(default="")
ext_ip: str = dataclasses.field(default="") ext_ip: str = dataclasses.field(default="")
stun_host: str = dataclasses.field(default="") stun_host: str = dataclasses.field(default="")
stun_port: int = dataclasses.field(default=0) stun_port: int = dataclasses.field(default=0)
@ -92,8 +93,9 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
async def __get_netcfg(self) -> _Netcfg: async def __get_netcfg(self) -> _Netcfg:
src_ip = (self.__get_default_ip() or "0.0.0.0") src_ip = (self.__get_default_ip() or "0.0.0.0")
(stun, (nat_type, ext_ip)) = await self.__get_stun_info(src_ip) info = await self.__stun.get_info(src_ip, 0)
return _Netcfg(nat_type, src_ip, ext_ip, stun.host, stun.port) # В текущей реализации _Netcfg() это копия StunInfo()
return _Netcfg(**dataclasses.asdict(info))
def __get_default_ip(self) -> str: def __get_default_ip(self) -> str:
try: try:
@ -115,13 +117,6 @@ class JanusRunner: # pylint: disable=too-many-instance-attributes
get_logger().error("Can't get default IP: %s", tools.efmt(err)) get_logger().error("Can't get default IP: %s", tools.efmt(err))
return "" return ""
async def __get_stun_info(self, src_ip: str) -> tuple[Stun, tuple[str, str]]:
try:
return (self.__stun, (await self.__stun.get_info(src_ip, 0)))
except Exception as err:
get_logger().error("Can't get STUN info: %s", tools.efmt(err))
return (self.__stun, ("", ""))
# ===== # =====
@aiotools.atomic_fg @aiotools.atomic_fg

View File

@ -4,6 +4,7 @@ import ipaddress
import struct import struct
import secrets import secrets
import dataclasses import dataclasses
import enum
from ... import tools from ... import tools
from ... import aiotools from ... import aiotools
@ -12,29 +13,39 @@ from ...logging import get_logger
# ===== # =====
class StunNatType(enum.Enum):
ERROR = ""
BLOCKED = "Blocked"
OPEN_INTERNET = "Open Internet"
SYMMETRIC_UDP_FW = "Symmetric UDP Firewall"
FULL_CONE_NAT = "Full Cone NAT"
RESTRICTED_NAT = "Restricted NAT"
RESTRICTED_PORT_NAT = "Restricted Port NAT"
SYMMETRIC_NAT = "Symmetric NAT"
CHANGED_ADDR_ERROR = "Error when testing on Changed-IP and Port"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class StunAddress: class StunInfo:
ip: str nat_type: StunNatType
src_ip: str
ext_ip: str
stun_host: str
stun_port: int
@dataclasses.dataclass(frozen=True)
class _StunAddress:
ip: str
port: int port: int
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class StunResponse: class _StunResponse:
ok: bool ok: bool
ext: (StunAddress | None) = dataclasses.field(default=None) ext: (_StunAddress | None) = dataclasses.field(default=None)
src: (StunAddress | None) = dataclasses.field(default=None) src: (_StunAddress | None) = dataclasses.field(default=None)
changed: (StunAddress | None) = dataclasses.field(default=None) changed: (_StunAddress | None) = dataclasses.field(default=None)
class StunNatType:
BLOCKED = "Blocked"
OPEN_INTERNET = "Open Internet"
SYMMETRIC_UDP_FW = "Symmetric UDP Firewall"
FULL_CONE_NAT = "Full Cone NAT"
RESTRICTED_NAT = "Restricted NAT"
RESTRICTED_PORT_NAT = "Restricted Port NAT"
SYMMETRIC_NAT = "Symmetric NAT"
CHANGED_ADDR_ERROR = "Error when testing on Changed-IP and Port"
# ===== # =====
@ -50,33 +61,44 @@ class Stun:
retries_delay: float, retries_delay: float,
) -> None: ) -> None:
self.host = host self.__host = host
self.port = port self.__port = port
self.__timeout = timeout self.__timeout = timeout
self.__retries = retries self.__retries = retries
self.__retries_delay = retries_delay self.__retries_delay = retries_delay
self.__sock: (socket.socket | None) = None self.__sock: (socket.socket | None) = None
async def get_info(self, src_ip: str, src_port: int) -> tuple[str, str]: async def get_info(self, src_ip: str, src_port: int) -> StunInfo:
(family, _, _, _, addr) = socket.getaddrinfo(src_ip, src_port, type=socket.SOCK_DGRAM)[0] (family, _, _, _, addr) = socket.getaddrinfo(src_ip, src_port, type=socket.SOCK_DGRAM)[0]
nat_type = StunNatType.ERROR
ext_ip = ""
try: try:
with socket.socket(family, socket.SOCK_DGRAM) as self.__sock: with socket.socket(family, socket.SOCK_DGRAM) as self.__sock:
self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.__sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.__sock.settimeout(self.__timeout) self.__sock.settimeout(self.__timeout)
self.__sock.bind(addr) self.__sock.bind(addr)
(nat_type, response) = await self.__get_nat_type(src_ip) (nat_type, response) = await self.__get_nat_type(src_ip)
return (nat_type, (response.ext.ip if response.ext is not None else "")) ext_ip = (response.ext.ip if response.ext is not None else "")
except Exception as err:
get_logger(0).error("Can't get STUN info: %s", tools.efmt(err))
finally: finally:
self.__sock = None self.__sock = None
return StunInfo(
nat_type=nat_type,
src_ip=src_ip,
ext_ip=ext_ip,
stun_host=self.__host,
stun_port=self.__port,
)
async def __get_nat_type(self, src_ip: str) -> tuple[str, StunResponse]: # pylint: disable=too-many-return-statements async def __get_nat_type(self, src_ip: str) -> tuple[StunNatType, _StunResponse]: # pylint: disable=too-many-return-statements
first = await self.__make_request("First probe") first = await self.__make_request("First probe", self.__host, b"")
if not first.ok: if not first.ok:
return (StunNatType.BLOCKED, first) return (StunNatType.BLOCKED, first)
request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000006) # Change-Request
response = await self.__make_request("Change request [ext_ip == src_ip]", request) response = await self.__make_request("Change request [ext_ip == src_ip]", self.__host, request)
if first.ext is not None and first.ext.ip == src_ip: if first.ext is not None and first.ext.ip == src_ip:
if response.ok: if response.ok:
@ -88,20 +110,20 @@ class Stun:
if first.changed is None: if first.changed is None:
raise RuntimeError(f"Changed addr is None: {first}") raise RuntimeError(f"Changed addr is None: {first}")
response = await self.__make_request("Change request [ext_ip != src_ip]", addr=first.changed) response = await self.__make_request("Change request [ext_ip != src_ip]", first.changed, b"")
if not response.ok: if not response.ok:
return (StunNatType.CHANGED_ADDR_ERROR, response) return (StunNatType.CHANGED_ADDR_ERROR, response)
if response.ext == first.ext: if response.ext == first.ext:
request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002) request = struct.pack(">HHI", 0x0003, 0x0004, 0x00000002)
response = await self.__make_request("Change port", request, addr=first.changed.ip) response = await self.__make_request("Change port", first.changed.ip, request)
if response.ok: if response.ok:
return (StunNatType.RESTRICTED_NAT, response) return (StunNatType.RESTRICTED_NAT, response)
return (StunNatType.RESTRICTED_PORT_NAT, response) return (StunNatType.RESTRICTED_PORT_NAT, response)
return (StunNatType.SYMMETRIC_NAT, response) return (StunNatType.SYMMETRIC_NAT, response)
async def __make_request(self, ctx: str, request: bytes=b"", addr: (StunAddress | str | None)=None) -> StunResponse: async def __make_request(self, ctx: str, addr: (_StunAddress | str), request: bytes) -> _StunResponse:
# TODO: Support IPv6 and RFC 5389 # TODO: Support IPv6 and RFC 5389
# The first 4 bytes of the response are the Type (2) and Length (2) # The first 4 bytes of the response are the Type (2) and Length (2)
# The 5th byte is Reserved # The 5th byte is Reserved
@ -111,13 +133,10 @@ class Stun:
# More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1 # More info at: https://tools.ietf.org/html/rfc3489#section-11.2.1
# And at: https://tools.ietf.org/html/rfc5389#section-15.1 # And at: https://tools.ietf.org/html/rfc5389#section-15.1
if isinstance(addr, StunAddress): if isinstance(addr, _StunAddress):
addr_t = (addr.ip, addr.port) addr_t = (addr.ip, addr.port)
elif isinstance(addr, str): else: # str
addr_t = (addr, self.port) addr_t = (addr, self.__port)
else:
assert addr is None
addr_t = (self.host, self.port)
# https://datatracker.ietf.org/doc/html/rfc5389#section-6 # https://datatracker.ietf.org/doc/html/rfc5389#section-6
trans_id = b"\x21\x12\xA4\x42" + secrets.token_bytes(12) trans_id = b"\x21\x12\xA4\x42" + secrets.token_bytes(12)
@ -130,9 +149,9 @@ class Stun:
if error: if error:
get_logger(0).error("%s: Can't perform STUN request after %d retries; last error: %s", get_logger(0).error("%s: Can't perform STUN request after %d retries; last error: %s",
ctx, self.__retries, error) ctx, self.__retries, error)
return StunResponse(ok=False) return _StunResponse(ok=False)
parsed: dict[str, StunAddress] = {} parsed: dict[str, _StunAddress] = {}
offset = 0 offset = 0
remaining = len(response) remaining = len(response)
while remaining > 0: while remaining > 0:
@ -148,7 +167,7 @@ class Stun:
parsed[field] = self.__parse_address(response[offset:], (trans_id if attr_type == 0x0020 else b"")) parsed[field] = self.__parse_address(response[offset:], (trans_id if attr_type == 0x0020 else b""))
offset += attr_len offset += attr_len
remaining -= (4 + attr_len) remaining -= (4 + attr_len)
return StunResponse(ok=True, **parsed) return _StunResponse(ok=True, **parsed)
async def __inner_make_request(self, trans_id: bytes, request: bytes, addr: tuple[str, int]) -> tuple[bytes, str]: async def __inner_make_request(self, trans_id: bytes, request: bytes, addr: tuple[str, int]) -> tuple[bytes, str]:
assert self.__sock is not None assert self.__sock is not None
@ -172,13 +191,13 @@ class Stun:
return (response[20 : 20 + payload_len], "") # noqa: E203 return (response[20 : 20 + payload_len], "") # noqa: E203
def __parse_address(self, data: bytes, trans_id: bytes) -> StunAddress: def __parse_address(self, data: bytes, trans_id: bytes) -> _StunAddress:
family = data[1] family = data[1]
port = struct.unpack(">H", self.__trans_xor(data[2:4], trans_id))[0] port = struct.unpack(">H", self.__trans_xor(data[2:4], trans_id))[0]
if family == 0x01: if family == 0x01:
return StunAddress(str(ipaddress.IPv4Address(self.__trans_xor(data[4:8], trans_id))), port) return _StunAddress(str(ipaddress.IPv4Address(self.__trans_xor(data[4:8], trans_id))), port)
elif family == 0x02: elif family == 0x02:
return StunAddress(str(ipaddress.IPv6Address(self.__trans_xor(data[4:20], trans_id))), port) return _StunAddress(str(ipaddress.IPv6Address(self.__trans_xor(data[4:20], trans_id))), port)
raise RuntimeError(f"Unknown family; received: {family}") raise RuntimeError(f"Unknown family; received: {family}")
def __trans_xor(self, data: bytes, trans_id: bytes) -> bytes: def __trans_xor(self, data: bytes, trans_id: bytes) -> bytes:

View File

@ -1,8 +1,9 @@
[flake8] [flake8]
inline-quotes = double inline-quotes = double
max-line-length = 160 max-line-length = 160
ignore = W503, E227, E241, E252, Q003 ignore = W503, E221, E227, E241, E252, Q003
# W503 line break before binary operator # W503 line break before binary operator
# E221 multiple spaces before operator
# E227 missing whitespace around bitwise or shift operator # E227 missing whitespace around bitwise or shift operator
# E241 multiple spaces after # E241 multiple spaces after
# E252 missing whitespace around parameter equals # E252 missing whitespace around parameter equals