Commit b8f47276 authored by Robin Obůrka's avatar Robin Obůrka

msgloop: New approach to message loop

parent eac0bfb2
import sys
import logging import logging
import inspect import inspect
import signal import signal
from collections import namedtuple from types import SimpleNamespace
import zmq import zmq
...@@ -14,23 +15,16 @@ from .exceptions import * ...@@ -14,23 +15,16 @@ from .exceptions import *
logger = logging.getLogger("sn_main") logger = logging.getLogger("sn_main")
EnvData = namedtuple("EnvData", [
"name",
"logger",
])
class LoopHardFail(Exception):
class SignalReceived(Exception):
pass pass
def signal_handler(signum, frame):
raise SignalReceived()
def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=None): def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=None):
ctx = SN(zmq.Context.instance(), argparser or get_arg_parser(), args=args) zmq_ctx = zmq.Context.instance()
socket_recv, socket_send = detect_and_get_sockets(ctx) sn_ctx = SN(zmq_ctx, argparser or get_arg_parser(), args=args)
socket_recv = get_socket(sn_ctx, "in")
socket_send = get_socket(sn_ctx, "out")
if not socket_recv and not socket_send: if not socket_recv and not socket_send:
raise LoopError("Neither input nor output socket provided") raise LoopError("Neither input nor output socket provided")
...@@ -39,58 +33,121 @@ def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=N ...@@ -39,58 +33,121 @@ def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=N
logger.info("SN main starting loop for %s box", box_name) logger.info("SN main starting loop for %s box", box_name)
env_data = init_env_data(box_name) data_for_context = {
"zmq_ctx": zmq_ctx,
"sn_ctx": sn_ctx,
"args": sn_ctx.args,
"socket_recv": socket_recv,
"socket_send": socket_send,
}
for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]: context = None
signal.signal(sig, signal_handler)
try: try:
user_data = setup() if setup else None user_data = get_user_data(setup)
_sn_main_loop(env_data, user_data, socket_recv, socket_send, setup, process, teardown) context = build_context(box_name, data_for_context, user_data)
except SignalReceived as e: register_signals(context)
logger.info("Box %s stopped by signal", box_name)
except LoopError as e: _sn_main_loop(context, process)
raise e
except AssertionError as e: except LoopHardFail as e:
# For pytest logger.error("Hard Fail of box: %s", context.name)
raise e sys.exit(1)
except Exception as e: except KeyboardInterrupt:
logger.error("Uncaught exception from loop") pass
logger.exception(e)
finally: finally:
if teardown: if context:
teardown(user_data) # Is possible that context wasn't built yet (e.g. error in setup callback)
ctx.context.destroy() if teardown:
teardown(context)
teardown_context(context)
def init_env_data(box_name): def get_socket(context, sock_name):
return EnvData( socket = None
name = box_name, try:
logger = logging.getLogger(box_name) socket = context.get_socket(sock_name)
)
except UndefinedSocketError as e:
pass
def _sn_main_loop(env_data, user_data, socket_recv, socket_send, setup=None, process=None, teardown=None): return socket
if socket_recv:
try:
while True:
msg_in = socket_recv.recv_multipart()
msg_type, payload = parse_msg(msg_in)
result = process(env_data, user_data, msg_type, payload)
process_result(socket_send, result)
except InvalidMsgError as e: def get_user_data(setup):
logger.error("Received broken message") if setup:
user_data = setup()
if isinstance(user_data, dict):
return user_data
else:
raise LoopError("Setup function didn't return a dictionary")
return {}
def build_context(box_name, context_data, user_data):
ctx = {
"name": box_name,
"logger": logging.getLogger(box_name),
"loop_continue": True,
"errors_in_row": 0,
}
ctx.update(context_data)
for k, v in user_data.items():
if k in ctx:
raise LoopError("Used reserved word in user_data: %s", k)
else:
ctx[k] = v
return SimpleNamespace(**ctx)
def teardown_context(ctx):
if ctx.socket_recv:
ctx.socket_recv.close()
if ctx.socket_send:
ctx.socket_send.close()
ctx.zmq_ctx.destroy()
def _sn_main_loop(context, process):
if inspect.isgeneratorfunction(process):
generate_output_message = process(context)
else: else:
for result in process(env_data, user_data): generate_output_message = generate_processed_msg(context, process)
process_result(socket_send, result)
while context.loop_continue:
try:
result = next(generate_output_message)
process_result(context.socket_send, result)
context.errors_in_row = 0
except StopIteration as e:
context.logger.warning("Box %s raised StopIteration - unexpected behavior", context.name)
break
except Exception as e:
logger.error("Uncaught exception from loop")
logger.exception(e)
context.errors_in_row += 1
if context.errors_in_row > 10:
raise LoopHardFail("Many errors in row.")
def generate_processed_msg(context, process):
while True:
msg = context.socket_recv.recv_multipart()
msg_type, payload = parse_msg(msg)
result = process(context, msg_type, payload)
yield result
def process_result(socket_send, result): def process_result(socket_send, result):
...@@ -99,33 +156,19 @@ def process_result(socket_send, result): ...@@ -99,33 +156,19 @@ def process_result(socket_send, result):
return return
if not socket_send: if not socket_send:
# TODO: Hard fail?
raise LoopError("Box generated output but there is any output socket. Bad configuration?") raise LoopError("Box generated output but there is any output socket. Bad configuration?")
try: msg_type, payload = result
msg_type, payload = result msg_out = encode_msg(msg_type, payload)
msg_out = encode_msg(msg_type, payload) socket_send.send_multipart(msg_out)
socket_send.send_multipart(msg_out) # TODO: Hard fail on InvalidMsgError in box output?
except (ValueError, InvalidMsgError) as e:
# Invalid message on input means that a received some bad message and I
# just want to not fail. Invalid message on output means a
# programmer error of the box author and I need to distinguish between
# them.
raise LoopError("Box generates broken messages")
def detect_and_get_sockets(context): def register_signals(ctx):
socket_recv = None def signal_handler(signum, frame):
socket_send = None ctx.logger.info("Signal %s received", signum)
ctx.loop_continue = False
try: for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]:
socket_recv = context.get_socket("in") signal.signal(sig, signal_handler)
except UndefinedSocketError as e:
pass
try:
socket_send = context.get_socket("out")
except UndefinedSocketError as e:
pass
return socket_recv, socket_send
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment