msgloop.py 4.7 KB
Newer Older
1
import sys
2 3
import logging
import inspect
4
import signal
5

6
from types import SimpleNamespace
7 8 9

import zmq

Robin Obůrka's avatar
Robin Obůrka committed
10 11 12 13
from .network import SN
from .network import get_arg_parser
from .messages import encode_msg, parse_msg
from .exceptions import *
14 15 16 17 18


logger = logging.getLogger("sn_main")


19
class LoopHardFail(Exception):
20 21 22
    pass


23 24
class LoopFail(Exception):
    pass
25 26


27 28
def sn_main(box_name, process, setup=None, teardown=None, argparser=None):
    sn_ctx = SN(zmq.Context.instance(), argparser or get_arg_parser())
29

30
    context = None
31
    try:
32
        user_data = get_user_data(setup)
33

34 35
        context = build_context(box_name, sn_ctx,  user_data)
        check_configuration(context, process)
36

37 38
        logger.info("SN main starting loop for %s box", box_name)
        register_signals(context)
39
        _sn_main_loop(context, process)
40

41 42
    except LoopHardFail as e:
        logger.error("Hard Fail of box: %s", context.name)
43
        logger.exception(e)
44
        # Finally will be called, because sys.exit() raises exception that will be uncaught.
45
        sys.exit(1)
46

47 48
    except KeyboardInterrupt:
        pass
49 50

    finally:
51 52 53 54 55
        if context:
            # Is possible that context wasn't built yet (e.g. error in setup callback)
            if teardown:
                teardown(context)
            teardown_context(context)
56 57


58 59 60 61 62 63
def get_user_data(setup):
    if setup:
        user_data = setup()
        if isinstance(user_data, dict):
            return user_data
        else:
64
            raise SetupError("Setup function didn't return a dictionary")
65

66 67 68
    return {}


69 70 71 72
def build_context(box_name, sn_ctx, user_data):
    socket_recv = get_socket(sn_ctx, "in")
    socket_send = get_socket(sn_ctx, "out")

73 74 75 76 77
    ctx = {
        "name": box_name,
        "logger": logging.getLogger(box_name),
        "loop_continue": True,
        "errors_in_row": 0,
78 79 80 81 82
        "sn_ctx": sn_ctx,
        "zmq_ctx": sn_ctx.context,
        "args": sn_ctx.args,
        "socket_recv": socket_recv,
        "socket_send": socket_send,
83 84 85 86
    }

    for k, v in user_data.items():
        if k in ctx:
87
            raise SetupError("Used reserved word in user_data: %s", k)
88 89 90 91 92 93
        else:
            ctx[k] = v

    return SimpleNamespace(**ctx)


94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
def get_socket(context, sock_name):
    socket = None
    try:
        socket = context.get_socket(sock_name)

    except UndefinedSocketError as e:
        pass

    return socket


def check_configuration(context, process):
    if not context.socket_recv and not context.socket_send:
        raise SetupError("Neither input nor output socket provided")
    if not context.socket_recv and not inspect.isgeneratorfunction(process):
        raise SetupError("Generator is expected for output-only box")


def teardown_context(context):
    if context.socket_recv:
        context.socket_recv.close()
    if context.socket_send:
        context.socket_send.close()
    context.zmq_ctx.destroy()
118 119 120 121 122


def _sn_main_loop(context, process):
    if inspect.isgeneratorfunction(process):
        generate_output_message = process(context)
123
    else:
124 125 126 127 128 129 130 131
        generate_output_message = generate_processed_msg(context, process)

    while context.loop_continue:
        try:
            result = next(generate_output_message)
            process_result(context.socket_send, result)
            context.errors_in_row = 0

132
        except StopIteration:
133
            context.logger.info("Box %s raised StopIteration", context.name)
134 135
            break

136 137 138
        except SetupError as e:
            raise e

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        except Exception as e:
            logger.error("Uncaught exception from loop")
            logger.exception(e)

            context.errors_in_row += 1
            if context.errors_in_row > 10:
                raise LoopHardFail("Many errors in row.")


def generate_processed_msg(context, process):
    while True:
        msg = context.socket_recv.recv_multipart()
        msg_type, payload = parse_msg(msg)

        result = process(context, msg_type, payload)

        yield result
156

157 158

def process_result(socket_send, result):
159
    if not result:
160 161 162
        # The box is output-only or it hasn't any reasonable answer
        return

163
    if not socket_send:
164
        raise SetupError("Box generated output but there is any output socket. Bad configuration?")
165

166 167 168 169 170
    try:
        msg_type, payload = result
        msg_out = encode_msg(msg_type, payload)
        socket_send.send_multipart(msg_out)

171
    except (ValueError, InvalidMsgError):
172
        raise LoopFail("Generated broken output message. Possibly bug in box.")
173 174


175
def register_signals(context):
176
    def signal_handler(signum, frame):
177 178
        context.logger.info("Signal %s received", signum)
        context.loop_continue = False
179

180 181
    for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]:
        signal.signal(sig, signal_handler)