no pid limits

This commit is contained in:
Devaev Maxim 2020-03-01 20:50:50 +03:00
parent 44b0ab19bf
commit e855976f05

View File

@ -20,7 +20,6 @@
# ========================================================================== #
import os
import multiprocessing
import multiprocessing.queues
import multiprocessing.sharedctypes
@ -35,14 +34,11 @@ from . import aiotools
class AioProcessNotifier:
def __init__(self) -> None:
self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
self.__pid = os.getpid()
def notify(self) -> None:
assert os.getpid() != self.__pid, "Child only"
self.__queue.put(None)
async def wait(self) -> None:
assert os.getpid() == self.__pid, "Parent only"
while not (await aiotools.run_async(self.__inner_wait)):
pass
@ -64,40 +60,32 @@ class AioSharedFlags:
notifier: AioProcessNotifier,
) -> None:
self.__local_flags = dict(initial) # To fast comparsion
self.__notifier = notifier
self.__shared_flags: Dict[str, multiprocessing.sharedctypes.RawValue] = {
self.__flags: Dict[str, multiprocessing.sharedctypes.RawValue] = {
key: multiprocessing.RawValue("i", int(value)) # type: ignore
for (key, value) in initial.items()
}
self.__lock = multiprocessing.Lock()
self.__pid = os.getpid()
def update(self, **kwargs: bool) -> None:
assert os.getpid() != self.__pid, "Child only"
changed = False
try:
with self.__lock:
for (key, value) in kwargs.items():
value = bool(value)
if self.__local_flags[key] != value:
if not changed:
self.__lock.acquire()
self.__shared_flags[key].value = int(value)
self.__local_flags[key] = value
value = int(value) # type: ignore
if self.__flags[key].value != value:
self.__flags[key].value = value
changed = True
finally:
if changed:
self.__lock.release()
self.__notifier.notify()
if changed:
self.__notifier.notify()
async def get(self) -> Dict[str, bool]:
return (await aiotools.run_async(self.__inner_get))
def __inner_get(self) -> Dict[str, bool]:
assert os.getpid() == self.__pid, "Parent only"
with self.__lock:
return {
key: bool(shared.value)
for (key, shared) in self.__shared_flags.items()
for (key, shared) in self.__flags.items()
}