testserver.py 9.97 KB
Newer Older
1
import argparse
2
import itertools
3
import logging
4
import os
5
import signal
6
import selectors
Marek Vavruša's avatar
Marek Vavruša committed
7
import socket
8
import sys
9
import threading
Marek Vavruša's avatar
Marek Vavruša committed
10
import time
11

Marek Vavruša's avatar
Marek Vavruša committed
12 13 14
import dns.message
import dns.rdatatype

15
from pydnstest import scenario
16

17

Marek Vavruša's avatar
Marek Vavruša committed
18 19 20
class TestServer:
    """ This simulates UDP DNS server returning scripted or mirror DNS responses. """

21
    def __init__(self, test_scenario, root_addr, addr_family):
Marek Vavruša's avatar
Marek Vavruša committed
22 23 24 25
        """ Initialize server instance. """
        self.thread = None
        self.srv_socks = []
        self.client_socks = []
26
        self.connections = []
Marek Vavruša's avatar
Marek Vavruša committed
27
        self.active = False
28 29
        self.active_lock = threading.Lock()
        self.condition = threading.Condition()
30
        self.scenario = test_scenario
Marek Vavruša's avatar
Marek Vavruša committed
31 32 33
        self.addr_map = []
        self.start_iface = 2
        self.cur_iface = self.start_iface
34 35
        self.kroot_local = root_addr
        self.addr_family = addr_family
36
        self.undefined_answers = 0
Marek Vavruša's avatar
Marek Vavruša committed
37 38 39

    def __del__(self):
        """ Cleanup after deletion. """
40 41 42
        with self.active_lock:
            active = self.active
        if active:
Marek Vavruša's avatar
Marek Vavruša committed
43 44
            self.stop()

45
    def start(self, port=53):
Marek Vavruša's avatar
Marek Vavruša committed
46
        """ Synchronous start """
47 48 49 50 51
        with self.active_lock:
            if self.active:
                raise Exception('TestServer already started')
        with self.active_lock:
            self.active = True
52 53
        addr, _ = self.start_srv((self.kroot_local, port), self.addr_family)
        self.start_srv(addr, self.addr_family, socket.IPPROTO_TCP)
54
        self._bind_sockets()
Marek Vavruša's avatar
Marek Vavruša committed
55 56 57

    def stop(self):
        """ Stop socket server operation. """
58 59
        with self.active_lock:
            self.active = False
60 61
        if self.thread:
            self.thread.join()
62 63
        for conn in self.connections:
            conn.close()
Marek Vavruša's avatar
Marek Vavruša committed
64 65 66 67 68 69
        for srv_sock in self.srv_socks:
            srv_sock.close()
        for client_sock in self.client_socks:
            client_sock.close()
        self.client_socks = []
        self.srv_socks = []
70
        self.connections = []
Marek Vavruša's avatar
Marek Vavruša committed
71 72 73 74
        self.scenario = None

    def address(self):
        """ Returns opened sockets list """
75
        addrlist = []
Marek Vavruša's avatar
Marek Vavruša committed
76
        for s in self.srv_socks:
77 78
            addrlist.append(s.getsockname())
        return addrlist
Marek Vavruša's avatar
Marek Vavruša committed
79 80

    def handle_query(self, client):
81 82 83 84 85 86 87
        """
        Receive query from client socket and send an answer.

        Returns:
            True if client socket should be closed by caller
            False if client socket should be kept open
        """
88
        log = logging.getLogger('pydnstest.testserver.handle_query')
89
        server_addr = client.getsockname()[0]
90
        query, client_addr = scenario.recvfrom_msg(client)
Marek Vavruša's avatar
Marek Vavruša committed
91 92
        if query is None:
            return False
93
        log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
94
        response, is_raw_data = self.scenario.reply(query, server_addr)
Marek Vavruša's avatar
Marek Vavruša committed
95
        if response:
96
            if not is_raw_data:
97
                data_to_wire = response.to_wire(max_size=65535)
98
                log.debug('response: %s', response)
Marek Vavruša's avatar
Marek Vavruša committed
99
            else:
100
                data_to_wire = response
101
                log.debug('raw response not printed')
Marek Vavruša's avatar
Marek Vavruša committed
102 103
        else:
            response = dns.message.make_response(query)
104
            response.set_rcode(dns.rcode.SERVFAIL)
105
            data_to_wire = response.to_wire()
106
            self.undefined_answers += 1
107 108 109 110
            self.scenario.current_step.log.error(
                'server %s has no response for question %s, answering with SERVFAIL',
                server_addr,
                '; '.join([str(rr) for rr in query.question]))
111

112
        scenario.sendto_msg(client, data_to_wire, client_addr)
113
        return True
Marek Vavruša's avatar
Marek Vavruša committed
114 115 116

    def query_io(self):
        """ Main server process """
117
        self.undefined_answers = 0
118 119
        with self.active_lock:
            if not self.active:
120
                raise Exception("[query_io] Test server not active")
121 122 123 124 125 126
        while True:
            with self.condition:
                self.condition.notify()
            with self.active_lock:
                if not self.active:
                    break
Petr Špaček's avatar
Petr Špaček committed
127
            objects = self.srv_socks + self.connections
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
            sel = selectors.DefaultSelector()
            for obj in objects:
                sel.register(obj, selectors.EVENT_READ)
            items = sel.select(0.1)
            for key, event in items:
                sock = key.fileobj
                if event & selectors.EVENT_READ:
                    if sock in self.srv_socks:
                        if sock.proto == socket.IPPROTO_TCP:
                            conn, _ = sock.accept()
                            self.connections.append(conn)
                        else:
                            self.handle_query(sock)
                    elif sock in self.connections:
                        if not self.handle_query(sock):
                            sock.close()
                            self.connections.remove(sock)
Petr Špaček's avatar
Petr Špaček committed
145
                    else:
146 147 148
                        raise Exception(
                            "[query_io] Socket IO internal error {}, exit"
                            .format(sock.getsockname()))
Petr Špaček's avatar
Petr Špaček committed
149
                else:
150 151
                    raise Exception("[query_io] Socket IO error {}, exit"
                                    .format(sock.getsockname()))
Marek Vavruša's avatar
Marek Vavruša committed
152

153
    def start_srv(self, address, family, proto=socket.IPPROTO_UDP):
Marek Vavruša's avatar
Marek Vavruša committed
154
        """ Starts listening thread if necessary """
155 156 157
        assert address
        assert address[0]  # host
        assert address[1]  # port
158 159
        assert family
        assert proto
160 161
        if family == socket.AF_INET6:
            if not socket.has_ipv6:
162 163
                raise NotImplementedError("[start_srv] IPv6 is not supported by socket {0}"
                                          .format(socket))
164
        elif family != socket.AF_INET:
165
            raise NotImplementedError("[start_srv] unsupported protocol family {0}".format(family))
166 167 168 169 170

        if proto == socket.IPPROTO_TCP:
            socktype = socket.SOCK_STREAM
        elif proto == socket.IPPROTO_UDP:
            socktype = socket.SOCK_DGRAM
Marek Vavruša's avatar
Marek Vavruša committed
171
        else:
172
            raise NotImplementedError("[start_srv] unsupported protocol {0}".format(proto))
173

174
        if self.thread is None:
Marek Vavruša's avatar
Marek Vavruša committed
175 176
            self.thread = threading.Thread(target=self.query_io)
            self.thread.start()
177 178
            with self.condition:
                self.condition.wait()
Marek Vavruša's avatar
Marek Vavruša committed
179 180

        for srv_sock in self.srv_socks:
181 182 183
            if (srv_sock.family == family
                    and srv_sock.getsockname() == address
                    and srv_sock.proto == proto):
Marek Vavruša's avatar
Marek Vavruša committed
184
                return srv_sock.getsockname()
185 186

        sock = socket.socket(family, socktype, proto)
187
        sock.bind(address)
Marek Vavruša's avatar
Marek Vavruša committed
188
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
189 190
        if proto == socket.IPPROTO_TCP:
            sock.listen(5)
Marek Vavruša's avatar
Marek Vavruša committed
191 192
        self.srv_socks.append(sock)
        sockname = sock.getsockname()
193
        return sockname, proto
Marek Vavruša's avatar
Marek Vavruša committed
194

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    def _bind_sockets(self):
        """
        Bind test server to port 53 on all addresses referenced by test scenario.
        """
        # Bind to test servers
        for r in self.scenario.ranges:
            for addr in r.addresses:
                family = socket.AF_INET6 if ':' in addr else socket.AF_INET
                self.start_srv((addr, 53), family)

        # Bind addresses in ad-hoc REPLYs
        for s in self.scenario.steps:
            if s.type == 'REPLY':
                reply = s.data[0].message
                for rr in itertools.chain(reply.answer,
                                          reply.additional,
                                          reply.question,
                                          reply.authority):
                    for rd in rr:
                        if rd.rdtype == dns.rdatatype.A:
                            self.start_srv((rd.address, 53), socket.AF_INET)
                        elif rd.rdtype == dns.rdatatype.AAAA:
                            self.start_srv((rd.address, 53), socket.AF_INET6)

219
    def play(self, subject_addr):
220
        self.scenario.play({'': (subject_addr, 53)})
Marek Vavruša's avatar
Marek Vavruša committed
221

222 223 224 225 226 227

def empty_test_case():
    """
    Return (scenario, config) pair which answers to any query on 127.0.0.10.
    """
    # Mirror server
228
    empty_test_path = os.path.dirname(os.path.realpath(__file__)) + "/empty.rpl"
229 230
    test_config = {'ROOT_ADDR': '127.0.0.10',
                   '_SOCKET_FAMILY': socket.AF_INET}
231
    return scenario.parse_file(empty_test_path)[0], test_config
232

233 234 235 236 237 238 239 240

def standalone_self_test():
    """
    Self-test code

    Usage:
    LD_PRELOAD=libsocket_wrapper.so SOCKET_WRAPPER_DIR=/tmp $PYTHON -m pydnstest.testserver --help
    """
241
    logging.basicConfig(level=logging.DEBUG)
242
    argparser = argparse.ArgumentParser()
243
    argparser.add_argument('--scenario', help='absolute path to test scenario',
244
                           required=False)
245 246
    argparser.add_argument('--step', help='step # in the scenario (default: first)',
                           required=False, type=int)
247 248
    args = argparser.parse_args()
    if args.scenario:
249 250
        test_scenario, test_config_text = scenario.parse_file(args.scenario)
        test_config, _ = scenario.parse_config(test_config_text, True, os.getcwd())
251 252
    else:
        test_scenario, test_config = empty_test_case()
253 254 255 256 257 258 259 260 261

    if args.step:
        for step in test_scenario.steps:
            if step.id == args.step:
                test_scenario.current_step = step
        if not test_scenario.current_step:
            raise ValueError('step ID %s not found in scenario' % args.step)
    else:
        test_scenario.current_step = test_scenario.steps[0]
262

263
    server = TestServer(test_scenario, test_config['ROOT_ADDR'], test_config['_SOCKET_FAMILY'])
Marek Vavruša's avatar
Marek Vavruša committed
264
    server.start()
265

266
    logging.info("[==========] Mirror server running at %s", server.address())
267

268
    def kill(signum, frame):  # pylint: disable=unused-argument
269
        logging.info("[==========] Shutdown.")
270 271 272
        server.stop()
        sys.exit(128 + signum)

273 274
    signal.signal(signal.SIGINT, kill)
    signal.signal(signal.SIGTERM, kill)
275 276 277

    while True:
        time.sleep(0.5)
278 279 280 281 282


if __name__ == '__main__':
    # this is done to avoid creating global variables
    standalone_self_test()