Commit b8f47276 authored by Robin Obůrka's avatar Robin Obůrka

msgloop: New approach to message loop

parent eac0bfb2
import sys
import logging import logging
import inspect import inspect
import signal import signal
from collections import namedtuple from types import SimpleNamespace
import zmq import zmq
...@@ -14,23 +15,16 @@ from .exceptions import * ...@@ -14,23 +15,16 @@ from .exceptions import *
logger = logging.getLogger("sn_main") logger = logging.getLogger("sn_main")
EnvData = namedtuple("EnvData", [
"name",
"logger",
])
class LoopHardFail(Exception):
class SignalReceived(Exception):
pass pass
def signal_handler(signum, frame):
raise SignalReceived()
def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=None): def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=None):
ctx = SN(zmq.Context.instance(), argparser or get_arg_parser(), args=args) zmq_ctx = zmq.Context.instance()
socket_recv, socket_send = detect_and_get_sockets(ctx) sn_ctx = SN(zmq_ctx, argparser or get_arg_parser(), args=args)
socket_recv = get_socket(sn_ctx, "in")
socket_send = get_socket(sn_ctx, "out")
if not socket_recv and not socket_send: if not socket_recv and not socket_send:
raise LoopError("Neither input nor output socket provided") raise LoopError("Neither input nor output socket provided")
...@@ -39,58 +33,121 @@ def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=N ...@@ -39,58 +33,121 @@ def sn_main(box_name, process, setup=None, teardown=None, argparser=None, args=N
logger.info("SN main starting loop for %s box", box_name) logger.info("SN main starting loop for %s box", box_name)
env_data = init_env_data(box_name) data_for_context = {
"zmq_ctx": zmq_ctx,
"sn_ctx": sn_ctx,
"args": sn_ctx.args,
"socket_recv": socket_recv,
"socket_send": socket_send,
}
for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]: context = None
signal.signal(sig, signal_handler)
try: try:
user_data = setup() if setup else None user_data = get_user_data(setup)
_sn_main_loop(env_data, user_data, socket_recv, socket_send, setup, process, teardown) context = build_context(box_name, data_for_context, user_data)
except SignalReceived as e: register_signals(context)
logger.info("Box %s stopped by signal", box_name)
except LoopError as e: _sn_main_loop(context, process)
raise e
except AssertionError as e: except LoopHardFail as e:
# For pytest logger.error("Hard Fail of box: %s", context.name)
raise e sys.exit(1)
except Exception as e: except KeyboardInterrupt:
logger.error("Uncaught exception from loop") pass
logger.exception(e)
finally: finally:
if context:
# Is possible that context wasn't built yet (e.g. error in setup callback)
if teardown: if teardown:
teardown(user_data) teardown(context)
ctx.context.destroy() teardown_context(context)
def init_env_data(box_name): def get_socket(context, sock_name):
return EnvData( socket = None
name = box_name, try:
logger = logging.getLogger(box_name) socket = context.get_socket(sock_name)
)
except UndefinedSocketError as e:
pass
return socket
def get_user_data(setup):
if setup:
user_data = setup()
if isinstance(user_data, dict):
return user_data
else:
raise LoopError("Setup function didn't return a dictionary")
return {}
def build_context(box_name, context_data, user_data):
ctx = {
"name": box_name,
"logger": logging.getLogger(box_name),
"loop_continue": True,
"errors_in_row": 0,
}
ctx.update(context_data)
for k, v in user_data.items():
if k in ctx:
raise LoopError("Used reserved word in user_data: %s", k)
else:
ctx[k] = v
return SimpleNamespace(**ctx)
def _sn_main_loop(env_data, user_data, socket_recv, socket_send, setup=None, process=None, teardown=None):
if socket_recv:
try:
while True:
msg_in = socket_recv.recv_multipart()
msg_type, payload = parse_msg(msg_in)
result = process(env_data, user_data, msg_type, payload) def teardown_context(ctx):
process_result(socket_send, result) if ctx.socket_recv:
ctx.socket_recv.close()
if ctx.socket_send:
ctx.socket_send.close()
ctx.zmq_ctx.destroy()
except InvalidMsgError as e:
logger.error("Received broken message")
def _sn_main_loop(context, process):
if inspect.isgeneratorfunction(process):
generate_output_message = process(context)
else: else:
for result in process(env_data, user_data): generate_output_message = generate_processed_msg(context, process)
process_result(socket_send, result)
while context.loop_continue:
try:
result = next(generate_output_message)
process_result(context.socket_send, result)
context.errors_in_row = 0
except StopIteration as e:
context.logger.warning("Box %s raised StopIteration - unexpected behavior", context.name)
break
except Exception as e:
logger.error("Uncaught exception from loop")
logger.exception(e)
context.errors_in_row += 1
if context.errors_in_row > 10:
raise LoopHardFail("Many errors in row.")
def generate_processed_msg(context, process):
while True:
msg = context.socket_recv.recv_multipart()
msg_type, payload = parse_msg(msg)
result = process(context, msg_type, payload)
yield result
def process_result(socket_send, result): def process_result(socket_send, result):
...@@ -99,33 +156,19 @@ def process_result(socket_send, result): ...@@ -99,33 +156,19 @@ def process_result(socket_send, result):
return return
if not socket_send: if not socket_send:
# TODO: Hard fail?
raise LoopError("Box generated output but there is any output socket. Bad configuration?") raise LoopError("Box generated output but there is any output socket. Bad configuration?")
try:
msg_type, payload = result msg_type, payload = result
msg_out = encode_msg(msg_type, payload) msg_out = encode_msg(msg_type, payload)
socket_send.send_multipart(msg_out) socket_send.send_multipart(msg_out)
# TODO: Hard fail on InvalidMsgError in box output?
except (ValueError, InvalidMsgError) as e:
# Invalid message on input means that a received some bad message and I
# just want to not fail. Invalid message on output means a
# programmer error of the box author and I need to distinguish between
# them.
raise LoopError("Box generates broken messages")
def detect_and_get_sockets(context):
socket_recv = None
socket_send = None
try:
socket_recv = context.get_socket("in")
except UndefinedSocketError as e:
pass
try: def register_signals(ctx):
socket_send = context.get_socket("out") def signal_handler(signum, frame):
except UndefinedSocketError as e: ctx.logger.info("Signal %s received", signum)
pass ctx.loop_continue = False
return socket_recv, socket_send for sig in [ signal.SIGHUP, signal.SIGTERM, signal.SIGQUIT, signal.SIGABRT ]:
signal.signal(sig, signal_handler)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment