msgloop.py 4.42 KB
Newer Older
1
import sys
2 3
import logging
import inspect
4
import signal
5

6
from types import SimpleNamespace
7 8 9

import zmq

Robin Obůrka's avatar
Robin Obůrka committed
10 11 12 13
from .network import SN
from .network import get_arg_parser
from .messages import encode_msg, parse_msg
from .exceptions import *
14 15 16 17 18


logger = logging.getLogger("sn_main")


19
class LoopHardFail(Exception):
20 21 22
    pass


23
def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=None):
24 25 26 27
    zmq_ctx = zmq.Context.instance()
    sn_ctx = SN(zmq_ctx, argparser or get_arg_parser(), args=args)
    socket_recv = get_socket(sn_ctx, "in")
    socket_send = get_socket(sn_ctx, "out")
28

29
    if not socket_recv and not socket_send:
Robin Obůrka's avatar
Robin Obůrka committed
30
        raise LoopError("Neither input nor output socket provided")
31
    if not socket_recv and not inspect.isgeneratorfunction(process):
Robin Obůrka's avatar
Robin Obůrka committed
32
        raise LoopError("Generator is expected for output-only box")
33 34 35

    logger.info("SN main starting loop for %s box", box_name)

36 37 38 39 40 41 42
    data_for_context = {
        "zmq_ctx": zmq_ctx,
        "sn_ctx": sn_ctx,
        "args": sn_ctx.args,
        "socket_recv": socket_recv,
        "socket_send": socket_send,
    }
43

44
    context = None
45

46
    try:
47 48
        user_data = get_user_data(setup)
        context = build_context(box_name, data_for_context, user_data)
49

50
        register_signals(context)
51

52
        _sn_main_loop(context, process)
53

54 55 56
    except LoopHardFail as e:
        logger.error("Hard Fail of box: %s", context.name)
        sys.exit(1)
57

58 59
    except KeyboardInterrupt:
        pass
60 61

    finally:
62 63 64 65 66
        if context:
            # Is possible that context wasn't built yet (e.g. error in setup callback)
            if teardown:
                teardown(context)
            teardown_context(context)
67 68


69 70 71 72
def get_socket(context, sock_name):
    socket = None
    try:
        socket = context.get_socket(sock_name)
73

74 75
    except UndefinedSocketError as e:
        pass
76

77
    return socket
78 79


80 81 82 83 84 85 86
def get_user_data(setup):
    if setup:
        user_data = setup()
        if isinstance(user_data, dict):
            return user_data
        else:
            raise LoopError("Setup function didn't return a dictionary")
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    return {}


def build_context(box_name, context_data, user_data):
    ctx = {
        "name": box_name,
        "logger": logging.getLogger(box_name),
        "loop_continue": True,
        "errors_in_row": 0,
    }

    ctx.update(context_data)

    for k, v in user_data.items():
        if k in ctx:
            raise LoopError("Used reserved word in user_data: %s", k)
        else:
            ctx[k] = v

    return SimpleNamespace(**ctx)


def teardown_context(ctx):
    if ctx.socket_recv:
        ctx.socket_recv.close()
    if ctx.socket_send:
        ctx.socket_send.close()
    ctx.zmq_ctx.destroy()


def _sn_main_loop(context, process):
    if inspect.isgeneratorfunction(process):
        generate_output_message = process(context)
121
    else:
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
        generate_output_message = generate_processed_msg(context, process)

    while context.loop_continue:
        try:
            result = next(generate_output_message)
            process_result(context.socket_send, result)
            context.errors_in_row = 0

        except StopIteration as e:
            context.logger.warning("Box %s raised StopIteration - unexpected behavior", context.name)
            break

        except Exception as e:
            logger.error("Uncaught exception from loop")
            logger.exception(e)

            context.errors_in_row += 1
            if context.errors_in_row > 10:
                raise LoopHardFail("Many errors in row.")


def generate_processed_msg(context, process):
    while True:
        msg = context.socket_recv.recv_multipart()
        msg_type, payload = parse_msg(msg)

        result = process(context, msg_type, payload)

        yield result
151

152 153

def process_result(socket_send, result):
154
    if not result:
155 156 157
        # The box is output-only or it hasn't any reasonable answer
        return

158
    if not socket_send:
159
        # TODO: Hard fail?
Robin Obůrka's avatar
Robin Obůrka committed
160
        raise LoopError("Box generated output but there is any output socket. Bad configuration?")
161

162 163 164 165
    msg_type, payload = result
    msg_out = encode_msg(msg_type, payload)
    socket_send.send_multipart(msg_out)
    # TODO: Hard fail on InvalidMsgError in box output?
166 167


168 169 170 171
def register_signals(ctx):
    def signal_handler(signum, frame):
        ctx.logger.info("Signal %s received", signum)
        ctx.loop_continue = False
172

173 174
    for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]:
        signal.signal(sig, signal_handler)