region: notify about enter/exit, unregion on exception

This commit is contained in:
Devaev Maxim
2020-03-02 02:13:47 +03:00
parent 8972357dbc
commit 3b16242cfa

View File

@@ -35,6 +35,7 @@ from typing import Coroutine
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Type from typing import Type
from typing import TypeVar from typing import TypeVar
from typing import Optional
from typing import Any from typing import Any
import aiofiles import aiofiles
@@ -126,10 +127,31 @@ async def afile_write_now(afile: aiofiles.base.AiofilesContextManager, data: byt
await run_async(os.fsync, afile.fileno()) await run_async(os.fsync, afile.fileno())
# =====
class AioNotifier:
def __init__(self) -> None:
self.__queue: asyncio.queues.Queue = asyncio.Queue()
async def notify(self) -> None:
await self.__queue.put(None)
async def wait(self) -> None:
await self.__queue.get()
while not self.__queue.empty():
await self.__queue.get()
# ===== # =====
class AioExclusiveRegion: class AioExclusiveRegion:
def __init__(self, exc_type: Type[Exception]) -> None: def __init__(
self,
exc_type: Type[Exception],
notifier: Optional[AioNotifier]=None,
) -> None:
self.__exc_type = exc_type self.__exc_type = exc_type
self.__notifier = notifier
self.__busy = False self.__busy = False
def is_busy(self) -> bool: def is_busy(self) -> bool:
@@ -138,11 +160,19 @@ class AioExclusiveRegion:
async def enter(self) -> None: async def enter(self) -> None:
if not self.__busy: if not self.__busy:
self.__busy = True self.__busy = True
try:
if self.__notifier:
await self.__notifier.notify()
except: # noqa: E722
self.__busy = False
raise
return return
raise self.__exc_type() raise self.__exc_type()
async def exit(self) -> None: async def exit(self) -> None:
self.__busy = False self.__busy = False
if self.__notifier:
await self.__notifier.notify()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def exit_only_on_exception(self) -> AsyncGenerator[None, None]: async def exit_only_on_exception(self) -> AsyncGenerator[None, None]:
@@ -162,18 +192,5 @@ class AioExclusiveRegion:
_exc: BaseException, _exc: BaseException,
_tb: types.TracebackType, _tb: types.TracebackType,
) -> None: ) -> None:
await self.exit() await self.exit()
# =====
class AioNotifier:
def __init__(self) -> None:
self.__queue: asyncio.queues.Queue = asyncio.Queue()
async def notify(self) -> None:
await self.__queue.put(None)
async def wait(self) -> None:
await self.__queue.get()
while not self.__queue.empty():
await self.__queue.get()