msgloop.py 4.48 KB
Newer Older
1
import sys
2 3
import logging
import inspect
4
import signal
5

6
from types import SimpleNamespace
7 8 9

import zmq

Robin Obůrka's avatar
Robin Obůrka committed
10 11 12 13
from .network import SN
from .network import get_arg_parser
from .messages import encode_msg, parse_msg
from .exceptions import *
14 15 16 17 18


logger = logging.getLogger("sn_main")


19
class LoopHardFail(Exception):
20 21 22
    pass


23 24


25 26
def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=None):
    sn_ctx = SN(zmq.Context.instance(), argparser or get_arg_parser(), args=args)
27

28
    context = None
29
    try:
30
        user_data = get_user_data(setup)
31

32 33
        context = build_context(box_name, sn_ctx,  user_data)
        check_configuration(context, process)
34

35 36
        logger.info("SN main starting loop for %s box", box_name)
        register_signals(context)
37
        _sn_main_loop(context, process)
38

39 40 41
    except LoopHardFail as e:
        logger.error("Hard Fail of box: %s", context.name)
        sys.exit(1)
42

43 44
    except KeyboardInterrupt:
        pass
45 46

    finally:
47 48 49 50 51
        if context:
            # Is possible that context wasn't built yet (e.g. error in setup callback)
            if teardown:
                teardown(context)
            teardown_context(context)
52 53


54 55 56 57 58 59
def get_user_data(setup):
    if setup:
        user_data = setup()
        if isinstance(user_data, dict):
            return user_data
        else:
60
            raise SetupError("Setup function didn't return a dictionary")
61

62 63 64
    return {}


65 66 67 68
def build_context(box_name, sn_ctx, user_data):
    socket_recv = get_socket(sn_ctx, "in")
    socket_send = get_socket(sn_ctx, "out")

69 70 71 72 73
    ctx = {
        "name": box_name,
        "logger": logging.getLogger(box_name),
        "loop_continue": True,
        "errors_in_row": 0,
74 75 76 77 78
        "sn_ctx": sn_ctx,
        "zmq_ctx": sn_ctx.context,
        "args": sn_ctx.args,
        "socket_recv": socket_recv,
        "socket_send": socket_send,
79 80 81 82
    }

    for k, v in user_data.items():
        if k in ctx:
83
            raise SetupError("Used reserved word in user_data: %s", k)
84 85 86 87 88 89
        else:
            ctx[k] = v

    return SimpleNamespace(**ctx)


90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
def get_socket(context, sock_name):
    socket = None
    try:
        socket = context.get_socket(sock_name)

    except UndefinedSocketError as e:
        pass

    return socket


def check_configuration(context, process):
    if not context.socket_recv and not context.socket_send:
        raise SetupError("Neither input nor output socket provided")
    if not context.socket_recv and not inspect.isgeneratorfunction(process):
        raise SetupError("Generator is expected for output-only box")


def teardown_context(context):
    if context.socket_recv:
        context.socket_recv.close()
    if context.socket_send:
        context.socket_send.close()
    context.zmq_ctx.destroy()
114 115 116 117 118


def _sn_main_loop(context, process):
    if inspect.isgeneratorfunction(process):
        generate_output_message = process(context)
119
    else:
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        generate_output_message = generate_processed_msg(context, process)

    while context.loop_continue:
        try:
            result = next(generate_output_message)
            process_result(context.socket_send, result)
            context.errors_in_row = 0

        except StopIteration as e:
            context.logger.warning("Box %s raised StopIteration - unexpected behavior", context.name)
            break

        except Exception as e:
            logger.error("Uncaught exception from loop")
            logger.exception(e)

            context.errors_in_row += 1
            if context.errors_in_row > 10:
                raise LoopHardFail("Many errors in row.")


def generate_processed_msg(context, process):
    while True:
        msg = context.socket_recv.recv_multipart()
        msg_type, payload = parse_msg(msg)

        result = process(context, msg_type, payload)

        yield result
149

150 151

def process_result(socket_send, result):
152
    if not result:
153 154 155
        # The box is output-only or it hasn't any reasonable answer
        return

156
    if not socket_send:
157
        # TODO: Hard fail?
Robin Obůrka's avatar
Robin Obůrka committed
158
        raise LoopError("Box generated output but there is any output socket. Bad configuration?")
159

160 161 162 163
    msg_type, payload = result
    msg_out = encode_msg(msg_type, payload)
    socket_send.send_multipart(msg_out)
    # TODO: Hard fail on InvalidMsgError in box output?
164 165


166
def register_signals(context):
167
    def signal_handler(signum, frame):
168 169
        context.logger.info("Signal %s received", signum)
        context.loop_continue = False
170

171 172
    for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]:
        signal.signal(sig, signal_handler)