Rewrite integration test suite in Python
[tinc] / test / integration / testlib / notification.py
1 """Support for receiving notifications from tincd scripts."""
2
3 import os
4 import signal
5 import threading
6 import queue
7 import multiprocessing.connection as mp
8 import typing as T
9
10 from .log import log
11 from .event import Notification
12 from .const import MPC_FAMILY
13
14
15 def _get_key(name, script) -> str:
16     return f"{name}/{script}"
17
18
19 class NotificationServer:
20     """Receive event notifications from tincd scripts."""
21
22     address: T.Union[str, bytes]
23     authkey: bytes  # only to prevent accidental connections to wrong servers
24     _lock: threading.Lock
25     _ready: threading.Event
26     _worker: T.Optional[threading.Thread]
27     _notifications: T.Dict[str, queue.Queue]
28
29     def __init__(self) -> None:
30         self.address = ""
31         self.authkey = os.urandom(8)
32         self._lock = threading.Lock()
33         self._ready = threading.Event()
34         self._worker = threading.Thread(target=self._recv, daemon=True)
35         self._notifications = {}
36
37         log.debug("using authkey %s", self.authkey)
38
39         self._worker.start()
40         log.debug("waiting for notification worker to become ready")
41
42         self._ready.wait()
43         log.debug("notification worker is ready")
44
45     @T.overload
46     def get(self, node: str, script: str) -> Notification:
47         """Receive notification from the specified node and script without a timeout.
48         Doesn't return until a notification arrives.
49         """
50         return self.get(node, script)
51
52     @T.overload
53     def get(self, node: str, script: str, timeout: float) -> T.Optional[Notification]:
54         """Receive notification from the specified node and script with a timeout.
55         If nothing arrives before it expires, None is returned.
56         """
57         return self.get(node, script, timeout)
58
59     def get(
60         self, node: str, script: str, timeout: T.Optional[float] = None
61     ) -> T.Optional[Notification]:
62         """Receive notification from specified node and script. See overloads above."""
63
64         key = _get_key(node, script)
65         with self._lock:
66             que = self._notifications.get(key, queue.Queue())
67             self._notifications[key] = que
68         try:
69             return que.get(timeout=timeout)
70         except queue.Empty:
71             return None
72
73     def _recv(self) -> None:
74         try:
75             self._listen()
76         except (OSError, AssertionError) as ex:
77             log.error("recv notifications failed", exc_info=ex)
78             os.kill(0, signal.SIGTERM)
79
80     def _listen(self) -> None:
81         with mp.Listener(family=MPC_FAMILY, authkey=self.authkey) as listener:
82             assert not isinstance(listener.address, tuple)
83             self.address = listener.address
84             self._ready.set()
85             while True:
86                 with listener.accept() as conn:
87                     self._handle_conn(conn)
88
89     def _handle_conn(self, conn: mp.Connection) -> None:
90         log.debug("accepted connection")
91
92         data: Notification = conn.recv()
93         assert isinstance(data, Notification)
94         data.update_time()
95
96         key = _get_key(data.node, data.script)
97         log.debug('from "%s" received data "%s"', key, data)
98
99         with self._lock:
100             que = self._notifications.get(key, queue.Queue())
101             self._notifications[key] = que
102         que.put(data)
103
104
105 notifications = NotificationServer()