network.py 5.29 KB
Newer Older
Martin Prudek's avatar
Martin Prudek committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
#!/usr/bin/env python3
import argparse

import msgpack
import zmq


class InvalidMsgError(Exception):
    pass

def parse_msg(data):
    """ Gets a Sentinel-type ZMQ message and parses message type and its
    payload.
    """
    try:
        msg_type = str(data[0], encoding="UTF-8")
        payload = msgpack.unpackb(data[1], encoding="UTF-8")

    except IndexError:
        raise InvalidMsgError("Not enough parts in message")

    return msg_type, payload


def encode_msg(msg_type, data):
    b = bytes(msg_type, encoding="UTF-8")
    msg = msgpack.packb(data)

    return (b, msg)

def socket_builder(res_expected, config_list):
    """ Gets a tuple of command line arguments - each for one socket connection
    and returns a list of ZMQ sockets in the same order.
    """
    sockets = list()
    sock_configs = dict()
    res_avail = resource_parser(config_list)
    if set(res_avail).difference(res_expected):
        raise SockConfigError(
                "Unexpected resource provided: "
                + str(set(res_avail).difference(res_expected))
        )
    for res in res_expected:
        sc = None
        if res in res_avail:
            for config in res_avail[res]:
                if sc:
                    sc.add_connection(
                        config[1], config[0], config[2], config[3])
                else:
                    sc = SockConfig(config[1], config[0], config[2], config[3])
                    sockets.append(sc.socket)
        else:
           raise SockConfigError("Resource not provided: " + res)
    return tuple(sockets)


def resource_parser(config_list):
    """ Gets a tuple of command line arguments - each for one socket connection
    in the form {sockname,[conn/bind],SOCK_TYPE,IP,PORT}.
    Returns a dictionary filled with zmq socket configs in the form
    {name:[connection1, connection2,...]} as each ZMQ socket can handle
    multiple connections.
    """
    resources = dict()
    for config in config_list:
        splitted = config.split(",")
        if len(splitted) == 5:
            if not splitted[0] in resources:
                resources[splitted[0]] = list()
            resources[splitted[0]].append(splitted[1:])
        else:
            raise SockConfigError("Invalid resource: " + config)
    return resources


class SockConfigError(Exception):
    pass


class SockConfig:
    # a ZMQ feature: one socket can have a multiple connections
    class ZMQConnection:
        def __init__(self, addr, port):
            self.addr = addr
            self.port = port
            self.connection = self.get_connection()
        def get_connection(self):
            return "tcp://{}:{}".format(self.addr, self.port)


    SOCKET_TYPE_MAP = {
        "REQ": zmq.REQ,
        "REP": zmq.REP,
        "DEALER": zmq.DEALER,
        "ROUTER": zmq.ROUTER,
        "PUB": zmq.PUB,
        "SUB": zmq.SUB,
        "PUSH": zmq.PUSH,
        "PULL": zmq.PULL,
        "PAIR": zmq.PAIR,
    }

    DIRECTIONS = [
        "connect",
        "bind",
    ]

    def __init__(self, socktype, direction, addr, port):
        """ Initilizes ZMQ Context, Socket and its first connection. List
        of all connection is stored for further checking of duplicate
        connections.
        """
        self.check_params_validity(socktype, direction, addr, port)

        zmq_connection = self.ZMQConnection(addr, port)
        self.connections = list()
        self.connections.append(zmq_connection)

        ctx = zmq.Context.instance()
        self.socket = ctx.socket(self.socktype)
        self.socket.ipv6 = True

        if self.direction == "bind":
            self.socket.bind(zmq_connection.connection)
        elif self.direction == "connect":
            self.socket.connect(zmq_connection.connection)

    def add_connection(self, socktype, direction, addr, port):
        """ Adds another ZMQ connection to an existing ZMQ socket.
        """
        self.check_params_validity(socktype, direction, addr, port)

        if self.socktype != SockConfig.SOCKET_TYPE_MAP[socktype]:
            raise SockConfigError("Socket type does not match")

        if self.direction == "bind" or direction == "bind":
            raise SockConfigError("Socket direction mismatch")

        for con in self.connections:
            if con.addr == addr and con.port == port:
                raise SockConfigError("Creating duplicate connection")

        zmq_connection = self.ZMQConnection(addr, port)
        self.connections.append(zmq_connection)

        self.socket.connect(zmq_connection.connection)

    def check_params_validity(self, socktype, direction, addr, port):
        """ Checks whether all the params are present and ZMQ-compliant
        """
        if not socktype:
            raise SockConfigError("Missing socket type")
        if not direction:
            raise SockConfigError("Missing socket direction")
        if not addr:
            raise SockConfigError("Missing address")
        if not port:
            raise SockConfigError("Missing port")

        if socktype in SockConfig.SOCKET_TYPE_MAP:
            self.socktype = SockConfig.SOCKET_TYPE_MAP[socktype]
        else:
            raise SockConfigError("Unknown socket option", socktype)

        if direction in SockConfig.DIRECTIONS:
            self.direction = direction
        else:
            raise SockConfigError("Unknown direction option", direction)