msgloop.py 7.34 KB
Newer Older
1
import sys
2 3
import logging
import inspect
4
import signal
5

6
from types import SimpleNamespace
7 8 9

import zmq

Robin Obůrka's avatar
Robin Obůrka committed
10 11 12 13
from .network import SN
from .network import get_arg_parser
from .messages import encode_msg, parse_msg
from .exceptions import *
14 15


16
class LoopHardFail(Exception):
17 18 19
    pass


20 21
class LoopFail(Exception):
    pass
22 23


24 25
class SNBox():
    def __init__(self, box_name, argparser=None):
26
        # Local contexts for dependencies
27 28
        self.zmq_ctx = zmq.Context.instance()
        self.sn_ctx = SN(self.zmq_ctx, argparser or get_arg_parser())
29 30 31
        # Important values provided to box
        self.name = box_name
        self.logger = logging.getLogger(box_name)
32
        self.args = self.sn_ctx.args
33
        # Error management of the loop
34 35
        self.loop_continue = True
        self.errors_in_row =  0
36 37 38 39
        # User data
        # Data generated by setup function are placed into separate variable
        # Final box shouldn't use "self" - we want to isolate its values
        self.context = None
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

    # Core methods - Will be implemented in non-abstract boxes
    def check_configuration(self):
        raise NotImplementedError("check_configuration")

    def get_processed_message(self):
        raise NotImplementedError("get_processed_message")

    def process_result(self, result):
        raise NotImplementedError("process_result")

    # Public API for boxes - will be optionally implemented in final boxes
    def setup(self):
        return {}

    def teardown(self):
        pass
57

58 59
    def before_first_request(self):
        pass
60

61 62
    def process(self, msg_type, payload):
        raise NotImplementedError("process")
63

64 65
    # Provided functionality - should be final implementation
    def run(self):
66 67
        # This is the only way to be sure that check will be called.
        # Constructors will be overwritten in non-abstract boxes
68
        self.check_configuration()
69

70 71
        try:
            self.context = self.get_user_data()
72

73
            self.logger.info("SNBox starting loop for %s box", self.name)
74
            self.register_signals()
75

76
            self.before_loop()
77
            self.run_loop()
78

79 80 81 82 83
        except LoopHardFail as e:
            self.logger.error("Hard Fail of box: %s", self.name)
            self.logger.exception(e)
            # Finally will be called, because sys.exit() raises exception that will be uncaught.
            sys.exit(1)
84

85 86
        except KeyboardInterrupt:
            pass
87

88
        finally:
89 90
            self.teardown()  # Clean-up data generated by setup()
            self.teardown_box()  # Clean-up my local contexts
91

92 93
    def get_user_data(self):
        user_data = self.setup()
94

95
        if isinstance(user_data, dict):
96
            return SimpleNamespace(**user_data)
97
        else:
98
            raise SetupError("Setup function didn't return a dictionary")
99

100 101 102 103
    def register_signals(self):
        def signal_handler(signum, frame):
            self.logger.info("Signal %s received", signum)
            self.loop_continue = False
104

105
        for sig in [ signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]:
106
            signal.signal(sig, signal_handler)
107

108 109 110 111 112
    def before_loop(self):
        result = self.before_first_request()
        if result:
            self.process_result(result)

113 114
    def teardown_box(self):
        self.zmq_ctx.destroy()
115

116 117 118 119 120 121
    def run_loop(self):
        while self.loop_continue:
            try:
                result = self.get_processed_message()
                self.process_result(result)
                self.errors_in_row = 0
122

123 124 125
            except StopIteration:
                self.logger.info("Box %s raised StopIteration", self.name)
                break
126

127
            except (SetupError, NotImplementedError) as e:
128 129
                # These error are considered as show-stopper.
                # It means programmer error ans there is no reason for trying to recover
130
                raise e
131

132 133 134
            except Exception as e:
                self.logger.error("Uncaught exception from loop: %s", type(e).__name__)
                self.logger.exception(e)
135

136 137 138
                self.errors_in_row += 1
                if self.errors_in_row > 10:
                    raise LoopHardFail("Many errors in row.")
139

140 141 142 143 144
    # Helper methods
    def get_socket(self, sock_name):
        socket = None
        try:
            socket = self.sn_ctx.get_socket(sock_name)
145

146 147
        except UndefinedSocketError as e:
            pass
148

149
        return socket
150 151


152 153 154 155 156
class SNPipelineBox(SNBox):
    def __init__(self, box_name, argparser=None):
        super().__init__(box_name, argparser)
        self.socket_recv = self.get_socket("in")
        self.socket_send = self.get_socket("out")
157

158
    def check_configuration(self):
159 160 161 162
        if not self.socket_recv:
            raise SetupError("Input socket wasn't provided")
        if not self.socket_send:
            raise SetupError("Output socket wasn't provided")
163

164 165 166 167 168
    def teardown_box(self):
        self.socket_recv.close()
        self.socket_send.close()
        super().teardown_box()

169 170 171
    def get_processed_message(self):
        msg = self.socket_recv.recv_multipart()
        msg_type, payload = parse_msg(msg)
172

173
        return self.process(msg_type, payload)
174

175 176 177 178
    def process_result(self, result):
        if not result:
            # The box hasn't any reasonable answer
            return
179

180
        try:
181 182 183
            msg_type, payload = result
            msg_out = encode_msg(msg_type, payload)
            self.socket_send.send_multipart(msg_out)
184

185 186
        except (ValueError, InvalidMsgError):
            raise LoopFail("Generated broken output message. Possibly bug in box.")
187

188

189 190 191 192
class SNGeneratorBox(SNBox):
    def __init__(self, box_name, argparser=None):
        super().__init__(box_name, argparser)
        self.socket_send = self.get_socket("out")
193

194 195
        # Ensure about process() method before try to get iterator
        self.check_configuration()
196

197
        self.process_iterator = self.process()
198

199 200 201 202 203
    def check_configuration(self):
        if not self.socket_send:
            raise SetupError("Output socket wasn't provided")
        if not inspect.isgeneratorfunction(self.process):
            raise SetupError("Generator is expected for output-only box")
204

205 206 207 208
    def teardown_box(self):
        self.socket_send.close()
        super().teardown_box()

209 210
    def get_processed_message(self):
        return next(self.process_iterator)
211

212 213 214 215
    def process_result(self, result):
        if not result:
            # The box hasn't any reasonable answer
            return
216

217 218 219 220
        try:
            msg_type, payload = result
            msg_out = encode_msg(msg_type, payload)
            self.socket_send.send_multipart(msg_out)
221

222 223
        except (ValueError, InvalidMsgError):
            raise LoopFail("Generated broken output message. Possibly bug in box.")
224

225

226 227 228 229
class SNTerminationBox(SNBox):
    def __init__(self, box_name, argparser=None):
        super().__init__(box_name, argparser)
        self.socket_recv = self.get_socket("in")
230

231 232 233
    def check_configuration(self):
        if not self.socket_recv:
            raise SetupError("Input socket wasn't provided")
234

235 236 237 238
    def teardown_box(self):
        self.socket_recv.close()
        super().teardown_box()

239 240 241
    def get_processed_message(self):
        msg = self.socket_recv.recv_multipart()
        msg_type, payload = parse_msg(msg)
242

243
        return self.process(msg_type, payload)
244

245 246 247
    def process_result(self, result):
        if result:
            raise LoopFail("Input-only box generated output message. Possibly bug in box.")