testserver.py 9.74 KB
Newer Older
1
import argparse
2
import itertools
3
import logging
4
import os
5
import signal
6
import selectors
Marek Vavruša's avatar
Marek Vavruša committed
7
import socket
8
import sys
9
import threading
Marek Vavruša's avatar
Marek Vavruša committed
10
import time
11

Marek Vavruša's avatar
Marek Vavruša committed
12 13 14
import dns.message
import dns.rdatatype

15
from pydnstest import scenario, mock_client
16

17

Marek Vavruša's avatar
Marek Vavruša committed
18 19 20
class TestServer:
    """ This simulates UDP DNS server returning scripted or mirror DNS responses. """

21
    def __init__(self, test_scenario, root_addr, addr_family):
Marek Vavruša's avatar
Marek Vavruša committed
22 23 24 25
        """ Initialize server instance. """
        self.thread = None
        self.srv_socks = []
        self.client_socks = []
26
        self.connections = []
Marek Vavruša's avatar
Marek Vavruša committed
27
        self.active = False
28 29
        self.active_lock = threading.Lock()
        self.condition = threading.Condition()
30
        self.scenario = test_scenario
Marek Vavruša's avatar
Marek Vavruša committed
31 32 33
        self.addr_map = []
        self.start_iface = 2
        self.cur_iface = self.start_iface
34 35
        self.kroot_local = root_addr
        self.addr_family = addr_family
36
        self.undefined_answers = 0
Marek Vavruša's avatar
Marek Vavruša committed
37 38 39

    def __del__(self):
        """ Cleanup after deletion. """
40 41 42
        with self.active_lock:
            active = self.active
        if active:
Marek Vavruša's avatar
Marek Vavruša committed
43 44
            self.stop()

45
    def start(self, port=53):
Marek Vavruša's avatar
Marek Vavruša committed
46
        """ Synchronous start """
47 48 49 50 51
        with self.active_lock:
            if self.active:
                raise Exception('TestServer already started')
        with self.active_lock:
            self.active = True
52 53
        addr, _ = self.start_srv((self.kroot_local, port), self.addr_family)
        self.start_srv(addr, self.addr_family, socket.IPPROTO_TCP)
54
        self._bind_sockets()
Marek Vavruša's avatar
Marek Vavruša committed
55 56 57

    def stop(self):
        """ Stop socket server operation. """
58 59
        with self.active_lock:
            self.active = False
60 61
        if self.thread:
            self.thread.join()
62 63
        for conn in self.connections:
            conn.close()
Marek Vavruša's avatar
Marek Vavruša committed
64 65 66 67 68 69
        for srv_sock in self.srv_socks:
            srv_sock.close()
        for client_sock in self.client_socks:
            client_sock.close()
        self.client_socks = []
        self.srv_socks = []
70
        self.connections = []
Marek Vavruša's avatar
Marek Vavruša committed
71 72 73 74
        self.scenario = None

    def address(self):
        """ Returns opened sockets list """
75
        addrlist = []
Marek Vavruša's avatar
Marek Vavruša committed
76
        for s in self.srv_socks:
77 78
            addrlist.append(s.getsockname())
        return addrlist
Marek Vavruša's avatar
Marek Vavruša committed
79 80

    def handle_query(self, client):
81 82 83 84 85 86 87
        """
        Receive query from client socket and send an answer.

        Returns:
            True if client socket should be closed by caller
            False if client socket should be kept open
        """
88 89
        log = logging.getLogger('pydnstest.testserver.handle_query')
        server_addr = client.getsockname()[0]
90
        query, client_addr = mock_client.recvfrom_msg(client)
91 92 93 94
        if query is None:
            return False
        log.debug('server %s received query from %s: %s', server_addr, client_addr, query)

95 96 97 98 99 100 101 102 103 104 105 106
        message = self.scenario.reply(query, server_addr)
        if not message:
            log.debug('ignoring')
            return True
        elif isinstance(message, scenario.DNSReplyServfail):
            self.undefined_answers += 1
            self.scenario.current_step.log.error(
                'server %s has no response for question %s, answering with SERVFAIL',
                server_addr,
                '; '.join([str(rr) for rr in query.question]))
        else:
            log.debug('response: %s', message)
107

108
        mock_client.sendto_msg(client, message.to_wire(), client_addr)
109
        return True
Marek Vavruša's avatar
Marek Vavruša committed
110 111 112

    def query_io(self):
        """ Main server process """
113
        self.undefined_answers = 0
114 115
        with self.active_lock:
            if not self.active:
116
                raise Exception("[query_io] Test server not active")
117 118 119 120 121 122
        while True:
            with self.condition:
                self.condition.notify()
            with self.active_lock:
                if not self.active:
                    break
Petr Špaček's avatar
Petr Špaček committed
123
            objects = self.srv_socks + self.connections
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
            sel = selectors.DefaultSelector()
            for obj in objects:
                sel.register(obj, selectors.EVENT_READ)
            items = sel.select(0.1)
            for key, event in items:
                sock = key.fileobj
                if event & selectors.EVENT_READ:
                    if sock in self.srv_socks:
                        if sock.proto == socket.IPPROTO_TCP:
                            conn, _ = sock.accept()
                            self.connections.append(conn)
                        else:
                            self.handle_query(sock)
                    elif sock in self.connections:
                        if not self.handle_query(sock):
                            sock.close()
                            self.connections.remove(sock)
Petr Špaček's avatar
Petr Špaček committed
141
                    else:
142 143 144
                        raise Exception(
                            "[query_io] Socket IO internal error {}, exit"
                            .format(sock.getsockname()))
Petr Špaček's avatar
Petr Špaček committed
145
                else:
146 147
                    raise Exception("[query_io] Socket IO error {}, exit"
                                    .format(sock.getsockname()))
Marek Vavruša's avatar
Marek Vavruša committed
148

149
    def start_srv(self, address, family, proto=socket.IPPROTO_UDP):
Marek Vavruša's avatar
Marek Vavruša committed
150
        """ Starts listening thread if necessary """
151 152 153
        assert address
        assert address[0]  # host
        assert address[1]  # port
154 155
        assert family
        assert proto
156 157
        if family == socket.AF_INET6:
            if not socket.has_ipv6:
158 159
                raise NotImplementedError("[start_srv] IPv6 is not supported by socket {0}"
                                          .format(socket))
160
        elif family != socket.AF_INET:
161
            raise NotImplementedError("[start_srv] unsupported protocol family {0}".format(family))
162 163 164 165 166

        if proto == socket.IPPROTO_TCP:
            socktype = socket.SOCK_STREAM
        elif proto == socket.IPPROTO_UDP:
            socktype = socket.SOCK_DGRAM
Marek Vavruša's avatar
Marek Vavruša committed
167
        else:
168
            raise NotImplementedError("[start_srv] unsupported protocol {0}".format(proto))
169

170
        if self.thread is None:
Marek Vavruša's avatar
Marek Vavruša committed
171 172
            self.thread = threading.Thread(target=self.query_io)
            self.thread.start()
173 174
            with self.condition:
                self.condition.wait()
Marek Vavruša's avatar
Marek Vavruša committed
175 176

        for srv_sock in self.srv_socks:
177 178 179
            if (srv_sock.family == family
                    and srv_sock.getsockname() == address
                    and srv_sock.proto == proto):
Marek Vavruša's avatar
Marek Vavruša committed
180
                return srv_sock.getsockname()
181 182

        sock = socket.socket(family, socktype, proto)
183
        sock.bind(address)
Marek Vavruša's avatar
Marek Vavruša committed
184
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
185 186
        if proto == socket.IPPROTO_TCP:
            sock.listen(5)
Marek Vavruša's avatar
Marek Vavruša committed
187 188
        self.srv_socks.append(sock)
        sockname = sock.getsockname()
189
        return sockname, proto
Marek Vavruša's avatar
Marek Vavruša committed
190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    def _bind_sockets(self):
        """
        Bind test server to port 53 on all addresses referenced by test scenario.
        """
        # Bind to test servers
        for r in self.scenario.ranges:
            for addr in r.addresses:
                family = socket.AF_INET6 if ':' in addr else socket.AF_INET
                self.start_srv((addr, 53), family)

        # Bind addresses in ad-hoc REPLYs
        for s in self.scenario.steps:
            if s.type == 'REPLY':
                reply = s.data[0].message
                for rr in itertools.chain(reply.answer,
                                          reply.additional,
                                          reply.question,
                                          reply.authority):
                    for rd in rr:
                        if rd.rdtype == dns.rdatatype.A:
                            self.start_srv((rd.address, 53), socket.AF_INET)
                        elif rd.rdtype == dns.rdatatype.AAAA:
                            self.start_srv((rd.address, 53), socket.AF_INET6)

215
    def play(self, subject_addr):
216
        self.scenario.play({'': (subject_addr, 53)})
Marek Vavruša's avatar
Marek Vavruša committed
217

218 219 220 221 222 223

def empty_test_case():
    """
    Return (scenario, config) pair which answers to any query on 127.0.0.10.
    """
    # Mirror server
224
    empty_test_path = os.path.dirname(os.path.realpath(__file__)) + "/empty.rpl"
225 226
    test_config = {'ROOT_ADDR': '127.0.0.10',
                   '_SOCKET_FAMILY': socket.AF_INET}
227
    return scenario.parse_file(empty_test_path)[0], test_config
228

229 230 231 232 233 234 235 236

def standalone_self_test():
    """
    Self-test code

    Usage:
    LD_PRELOAD=libsocket_wrapper.so SOCKET_WRAPPER_DIR=/tmp $PYTHON -m pydnstest.testserver --help
    """
237
    logging.basicConfig(level=logging.DEBUG)
238
    argparser = argparse.ArgumentParser()
239
    argparser.add_argument('--scenario', help='absolute path to test scenario',
240
                           required=False)
241 242
    argparser.add_argument('--step', help='step # in the scenario (default: first)',
                           required=False, type=int)
243 244
    args = argparser.parse_args()
    if args.scenario:
245 246
        test_scenario, test_config_text = scenario.parse_file(args.scenario)
        test_config, _ = scenario.parse_config(test_config_text, True, os.getcwd())
247 248
    else:
        test_scenario, test_config = empty_test_case()
249 250 251 252 253 254 255 256 257

    if args.step:
        for step in test_scenario.steps:
            if step.id == args.step:
                test_scenario.current_step = step
        if not test_scenario.current_step:
            raise ValueError('step ID %s not found in scenario' % args.step)
    else:
        test_scenario.current_step = test_scenario.steps[0]
258

259
    server = TestServer(test_scenario, test_config['ROOT_ADDR'], test_config['_SOCKET_FAMILY'])
Marek Vavruša's avatar
Marek Vavruša committed
260
    server.start()
261

262
    logging.info("[==========] Mirror server running at %s", server.address())
263

264
    def kill(signum, frame):  # pylint: disable=unused-argument
265
        logging.info("[==========] Shutdown.")
266 267 268
        server.stop()
        sys.exit(128 + signum)

269 270
    signal.signal(signal.SIGINT, kill)
    signal.signal(signal.SIGTERM, kill)
271 272 273

    while True:
        time.sleep(0.5)
274 275 276 277 278


if __name__ == '__main__':
    # this is done to avoid creating global variables
    standalone_self_test()