aiotools.shield_fg()

This commit is contained in:
Maxim Devaev 2022-08-07 18:42:00 +03:00
parent d995349b63
commit aa630988cc
2 changed files with 75 additions and 0 deletions

View File

@ -84,6 +84,47 @@ _RetvalT = TypeVar("_RetvalT")
# ===== # =====
class _ArmoredFuture(asyncio.Future):
def cancel(self, *_, **__) -> bool: # type: ignore
# FIXME: Выяснить, почему это работает
return False
def forced_cancel(self) -> bool:
return super().cancel()
def shield_fg(aw: Awaitable): # type: ignore
# XXX: Копия asyncio.shield() с небольшими изменениями
inner = asyncio.ensure_future(aw)
if inner.done():
return inner
outer = _ArmoredFuture(loop=inner.get_loop())
def inner_done(_) -> None: # type: ignore
if outer.cancelled():
if not inner.cancelled():
inner.exception() # Mark inner's result as retrieved
return
if inner.cancelled():
outer.forced_cancel()
else:
err = inner.exception()
if err is not None:
outer.set_exception(err)
else:
outer.set_result(inner.result())
def outer_done(_) -> None: # type: ignore
if not inner.done():
inner.remove_done_callback(inner_done)
inner.add_done_callback(inner_done)
outer.add_done_callback(outer_done)
return outer
def atomic(func: _FunctionT) -> _FunctionT: def atomic(func: _FunctionT) -> _FunctionT:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: async def wrapper(*args: Any, **kwargs: Any) -> Any:

View File

@ -22,9 +22,12 @@
import asyncio import asyncio
from typing import List
import pytest import pytest
from kvmd.aiotools import AioExclusiveRegion from kvmd.aiotools import AioExclusiveRegion
from kvmd.aiotools import shield_fg
# ===== # =====
@ -115,3 +118,34 @@ async def test_fail__region__access_two() -> None:
assert not region.is_busy() assert not region.is_busy()
await region.exit() await region.exit()
assert not region.is_busy() assert not region.is_busy()
# =====
@pytest.mark.asyncio
async def test_ok__shield_fg() -> None:
ops: List[str] = []
async def foo(op: str, delay: float) -> None: # pylint: disable=disallowed-name
await asyncio.sleep(delay)
ops.append(op)
async def bar() -> None: # pylint: disable=disallowed-name
try:
try:
try:
raise RuntimeError()
finally:
await shield_fg(foo("foo1", 2.0))
ops.append("foo1-noexc")
finally:
await shield_fg(foo("foo2", 1.0))
ops.append("foo2-noexc")
finally:
ops.append("done")
task = asyncio.create_task(bar())
await asyncio.sleep(0.1)
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
assert ops == ["foo1", "foo2", "foo2-noexc", "done"]