mirror of
https://github.com/mofeng-git/One-KVM.git
synced 2025-12-13 01:30:31 +08:00
aiotools.shield_fg()
This commit is contained in:
parent
d995349b63
commit
aa630988cc
@ -84,6 +84,47 @@ _RetvalT = TypeVar("_RetvalT")
|
|||||||
|
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
|
class _ArmoredFuture(asyncio.Future):
|
||||||
|
def cancel(self, *_, **__) -> bool: # type: ignore
|
||||||
|
# FIXME: Выяснить, почему это работает
|
||||||
|
return False
|
||||||
|
|
||||||
|
def forced_cancel(self) -> bool:
|
||||||
|
return super().cancel()
|
||||||
|
|
||||||
|
|
||||||
|
def shield_fg(aw: Awaitable): # type: ignore
|
||||||
|
# XXX: Копия asyncio.shield() с небольшими изменениями
|
||||||
|
|
||||||
|
inner = asyncio.ensure_future(aw)
|
||||||
|
if inner.done():
|
||||||
|
return inner
|
||||||
|
outer = _ArmoredFuture(loop=inner.get_loop())
|
||||||
|
|
||||||
|
def inner_done(_) -> None: # type: ignore
|
||||||
|
if outer.cancelled():
|
||||||
|
if not inner.cancelled():
|
||||||
|
inner.exception() # Mark inner's result as retrieved
|
||||||
|
return
|
||||||
|
|
||||||
|
if inner.cancelled():
|
||||||
|
outer.forced_cancel()
|
||||||
|
else:
|
||||||
|
err = inner.exception()
|
||||||
|
if err is not None:
|
||||||
|
outer.set_exception(err)
|
||||||
|
else:
|
||||||
|
outer.set_result(inner.result())
|
||||||
|
|
||||||
|
def outer_done(_) -> None: # type: ignore
|
||||||
|
if not inner.done():
|
||||||
|
inner.remove_done_callback(inner_done)
|
||||||
|
|
||||||
|
inner.add_done_callback(inner_done)
|
||||||
|
outer.add_done_callback(outer_done)
|
||||||
|
return outer
|
||||||
|
|
||||||
|
|
||||||
def atomic(func: _FunctionT) -> _FunctionT:
|
def atomic(func: _FunctionT) -> _FunctionT:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
|||||||
@ -22,9 +22,12 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from kvmd.aiotools import AioExclusiveRegion
|
from kvmd.aiotools import AioExclusiveRegion
|
||||||
|
from kvmd.aiotools import shield_fg
|
||||||
|
|
||||||
|
|
||||||
# =====
|
# =====
|
||||||
@ -115,3 +118,34 @@ async def test_fail__region__access_two() -> None:
|
|||||||
assert not region.is_busy()
|
assert not region.is_busy()
|
||||||
await region.exit()
|
await region.exit()
|
||||||
assert not region.is_busy()
|
assert not region.is_busy()
|
||||||
|
|
||||||
|
|
||||||
|
# =====
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ok__shield_fg() -> None:
|
||||||
|
ops: List[str] = []
|
||||||
|
|
||||||
|
async def foo(op: str, delay: float) -> None: # pylint: disable=disallowed-name
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
ops.append(op)
|
||||||
|
|
||||||
|
async def bar() -> None: # pylint: disable=disallowed-name
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
raise RuntimeError()
|
||||||
|
finally:
|
||||||
|
await shield_fg(foo("foo1", 2.0))
|
||||||
|
ops.append("foo1-noexc")
|
||||||
|
finally:
|
||||||
|
await shield_fg(foo("foo2", 1.0))
|
||||||
|
ops.append("foo2-noexc")
|
||||||
|
finally:
|
||||||
|
ops.append("done")
|
||||||
|
|
||||||
|
task = asyncio.create_task(bar())
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
task.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
assert ops == ["foo1", "foo2", "foo2-noexc", "done"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user