no pid limits

This commit is contained in:
Devaev Maxim 2020-03-01 20:50:50 +03:00
parent 44b0ab19bf
commit e855976f05

View File

@ -20,7 +20,6 @@
# ========================================================================== # # ========================================================================== #
import os
import multiprocessing import multiprocessing
import multiprocessing.queues import multiprocessing.queues
import multiprocessing.sharedctypes import multiprocessing.sharedctypes
@ -35,14 +34,11 @@ from . import aiotools
class AioProcessNotifier: class AioProcessNotifier:
def __init__(self) -> None: def __init__(self) -> None:
self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue() self.__queue: multiprocessing.queues.Queue = multiprocessing.Queue()
self.__pid = os.getpid()
def notify(self) -> None: def notify(self) -> None:
assert os.getpid() != self.__pid, "Child only"
self.__queue.put(None) self.__queue.put(None)
async def wait(self) -> None: async def wait(self) -> None:
assert os.getpid() == self.__pid, "Parent only"
while not (await aiotools.run_async(self.__inner_wait)): while not (await aiotools.run_async(self.__inner_wait)):
pass pass
@ -64,40 +60,32 @@ class AioSharedFlags:
notifier: AioProcessNotifier, notifier: AioProcessNotifier,
) -> None: ) -> None:
self.__local_flags = dict(initial) # To fast comparsion
self.__notifier = notifier self.__notifier = notifier
self.__shared_flags: Dict[str, multiprocessing.sharedctypes.RawValue] = { self.__flags: Dict[str, multiprocessing.sharedctypes.RawValue] = {
key: multiprocessing.RawValue("i", int(value)) # type: ignore key: multiprocessing.RawValue("i", int(value)) # type: ignore
for (key, value) in initial.items() for (key, value) in initial.items()
} }
self.__lock = multiprocessing.Lock() self.__lock = multiprocessing.Lock()
self.__pid = os.getpid()
def update(self, **kwargs: bool) -> None: def update(self, **kwargs: bool) -> None:
assert os.getpid() != self.__pid, "Child only"
changed = False changed = False
try: with self.__lock:
for (key, value) in kwargs.items(): for (key, value) in kwargs.items():
value = bool(value) value = int(value) # type: ignore
if self.__local_flags[key] != value: if self.__flags[key].value != value:
if not changed: self.__flags[key].value = value
self.__lock.acquire()
self.__shared_flags[key].value = int(value)
self.__local_flags[key] = value
changed = True changed = True
finally: if changed:
if changed: self.__notifier.notify()
self.__lock.release()
self.__notifier.notify()
async def get(self) -> Dict[str, bool]: async def get(self) -> Dict[str, bool]:
return (await aiotools.run_async(self.__inner_get)) return (await aiotools.run_async(self.__inner_get))
def __inner_get(self) -> Dict[str, bool]: def __inner_get(self) -> Dict[str, bool]:
assert os.getpid() == self.__pid, "Parent only"
with self.__lock: with self.__lock:
return { return {
key: bool(shared.value) key: bool(shared.value)
for (key, shared) in self.__shared_flags.items() for (key, shared) in self.__flags.items()
} }