testserver.py 10 KB
Newer Older
1 2
from __future__ import absolute_import

3
import argparse
4
import itertools
5
import logging
6
import os
7
import signal
8
import selectors
Marek Vavruša's avatar
Marek Vavruša committed
9
import socket
10
import sys
11
import threading
Marek Vavruša's avatar
Marek Vavruša committed
12
import time
13

Marek Vavruša's avatar
Marek Vavruša committed
14 15 16
import dns.message
import dns.rdatatype

17
from pydnstest import scenario
18

19

Marek Vavruša's avatar
Marek Vavruša committed
20 21 22
class TestServer:
    """ This simulates UDP DNS server returning scripted or mirror DNS responses. """

23
    def __init__(self, test_scenario, root_addr, addr_family):
Marek Vavruša's avatar
Marek Vavruša committed
24 25 26 27
        """ Initialize server instance. """
        self.thread = None
        self.srv_socks = []
        self.client_socks = []
28
        self.connections = []
Marek Vavruša's avatar
Marek Vavruša committed
29
        self.active = False
30 31
        self.active_lock = threading.Lock()
        self.condition = threading.Condition()
32
        self.scenario = test_scenario
Marek Vavruša's avatar
Marek Vavruša committed
33 34 35
        self.addr_map = []
        self.start_iface = 2
        self.cur_iface = self.start_iface
36 37
        self.kroot_local = root_addr
        self.addr_family = addr_family
38
        self.undefined_answers = 0
Marek Vavruša's avatar
Marek Vavruša committed
39 40 41

    def __del__(self):
        """ Cleanup after deletion. """
42 43 44
        with self.active_lock:
            active = self.active
        if active:
Marek Vavruša's avatar
Marek Vavruša committed
45 46
            self.stop()

47
    def start(self, port=53):
Marek Vavruša's avatar
Marek Vavruša committed
48
        """ Synchronous start """
49 50 51 52 53
        with self.active_lock:
            if self.active:
                raise Exception('TestServer already started')
        with self.active_lock:
            self.active = True
54 55
        addr, _ = self.start_srv((self.kroot_local, port), self.addr_family)
        self.start_srv(addr, self.addr_family, socket.IPPROTO_TCP)
56
        self._bind_sockets()
Marek Vavruša's avatar
Marek Vavruša committed
57 58 59

    def stop(self):
        """ Stop socket server operation. """
60 61
        with self.active_lock:
            self.active = False
62 63
        if self.thread:
            self.thread.join()
64 65
        for conn in self.connections:
            conn.close()
Marek Vavruša's avatar
Marek Vavruša committed
66 67 68 69 70 71
        for srv_sock in self.srv_socks:
            srv_sock.close()
        for client_sock in self.client_socks:
            client_sock.close()
        self.client_socks = []
        self.srv_socks = []
72
        self.connections = []
Marek Vavruša's avatar
Marek Vavruša committed
73 74 75 76
        self.scenario = None

    def address(self):
        """ Returns opened sockets list """
77
        addrlist = []
Marek Vavruša's avatar
Marek Vavruša committed
78
        for s in self.srv_socks:
79 80
            addrlist.append(s.getsockname())
        return addrlist
Marek Vavruša's avatar
Marek Vavruša committed
81 82

    def handle_query(self, client):
83 84 85 86 87 88 89
        """
        Receive query from client socket and send an answer.

        Returns:
            True if client socket should be closed by caller
            False if client socket should be kept open
        """
90
        log = logging.getLogger('pydnstest.testserver.handle_query')
91
        server_addr = client.getsockname()[0]
92
        query, client_addr = scenario.recvfrom_msg(client)
Marek Vavruša's avatar
Marek Vavruša committed
93 94
        if query is None:
            return False
95
        log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
96
        response, is_raw_data = self.scenario.reply(query, server_addr)
Marek Vavruša's avatar
Marek Vavruša committed
97
        if response:
98
            if not is_raw_data:
99
                data_to_wire = response.to_wire(max_size=65535)
100
                log.debug('response: %s', response)
Marek Vavruša's avatar
Marek Vavruša committed
101
            else:
102
                data_to_wire = response
103
                log.debug('raw response not printed')
Marek Vavruša's avatar
Marek Vavruša committed
104 105
        else:
            response = dns.message.make_response(query)
106
            response.set_rcode(dns.rcode.SERVFAIL)
107
            data_to_wire = response.to_wire()
108
            self.undefined_answers += 1
109 110 111 112
            self.scenario.current_step.log.error(
                'server %s has no response for question %s, answering with SERVFAIL',
                server_addr,
                '; '.join([str(rr) for rr in query.question]))
113

114
        scenario.sendto_msg(client, data_to_wire, client_addr)
115
        return True
Marek Vavruša's avatar
Marek Vavruša committed
116 117 118

    def query_io(self):
        """ Main server process """
119
        self.undefined_answers = 0
120 121
        with self.active_lock:
            if not self.active:
122
                raise Exception("[query_io] Test server not active")
123 124 125 126 127 128
        while True:
            with self.condition:
                self.condition.notify()
            with self.active_lock:
                if not self.active:
                    break
Petr Špaček's avatar
Petr Špaček committed
129
            objects = self.srv_socks + self.connections
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
            sel = selectors.DefaultSelector()
            for obj in objects:
                sel.register(obj, selectors.EVENT_READ)
            items = sel.select(0.1)
            for key, event in items:
                sock = key.fileobj
                if event & selectors.EVENT_READ:
                    if sock in self.srv_socks:
                        if sock.proto == socket.IPPROTO_TCP:
                            conn, _ = sock.accept()
                            self.connections.append(conn)
                        else:
                            self.handle_query(sock)
                    elif sock in self.connections:
                        if not self.handle_query(sock):
                            sock.close()
                            self.connections.remove(sock)
Petr Špaček's avatar
Petr Špaček committed
147
                    else:
148 149 150
                        raise Exception(
                            "[query_io] Socket IO internal error {}, exit"
                            .format(sock.getsockname()))
Petr Špaček's avatar
Petr Špaček committed
151
                else:
152 153
                    raise Exception("[query_io] Socket IO error {}, exit"
                                    .format(sock.getsockname()))
Marek Vavruša's avatar
Marek Vavruša committed
154

155
    def start_srv(self, address, family, proto=socket.IPPROTO_UDP):
Marek Vavruša's avatar
Marek Vavruša committed
156
        """ Starts listening thread if necessary """
157 158 159
        assert address
        assert address[0]  # host
        assert address[1]  # port
160 161
        assert family
        assert proto
162 163
        if family == socket.AF_INET6:
            if not socket.has_ipv6:
164 165
                raise NotImplementedError("[start_srv] IPv6 is not supported by socket {0}"
                                          .format(socket))
166
        elif family != socket.AF_INET:
167
            raise NotImplementedError("[start_srv] unsupported protocol family {0}".format(family))
168 169 170 171 172

        if proto == socket.IPPROTO_TCP:
            socktype = socket.SOCK_STREAM
        elif proto == socket.IPPROTO_UDP:
            socktype = socket.SOCK_DGRAM
Marek Vavruša's avatar
Marek Vavruša committed
173
        else:
174
            raise NotImplementedError("[start_srv] unsupported protocol {0}".format(proto))
175

176
        if self.thread is None:
Marek Vavruša's avatar
Marek Vavruša committed
177 178
            self.thread = threading.Thread(target=self.query_io)
            self.thread.start()
179 180
            with self.condition:
                self.condition.wait()
Marek Vavruša's avatar
Marek Vavruša committed
181 182

        for srv_sock in self.srv_socks:
183 184 185
            if (srv_sock.family == family
                    and srv_sock.getsockname() == address
                    and srv_sock.proto == proto):
Marek Vavruša's avatar
Marek Vavruša committed
186
                return srv_sock.getsockname()
187 188

        sock = socket.socket(family, socktype, proto)
189
        sock.bind(address)
Marek Vavruša's avatar
Marek Vavruša committed
190
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
191 192
        if proto == socket.IPPROTO_TCP:
            sock.listen(5)
Marek Vavruša's avatar
Marek Vavruša committed
193 194
        self.srv_socks.append(sock)
        sockname = sock.getsockname()
195
        return sockname, proto
Marek Vavruša's avatar
Marek Vavruša committed
196

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    def _bind_sockets(self):
        """
        Bind test server to port 53 on all addresses referenced by test scenario.
        """
        # Bind to test servers
        for r in self.scenario.ranges:
            for addr in r.addresses:
                family = socket.AF_INET6 if ':' in addr else socket.AF_INET
                self.start_srv((addr, 53), family)

        # Bind addresses in ad-hoc REPLYs
        for s in self.scenario.steps:
            if s.type == 'REPLY':
                reply = s.data[0].message
                for rr in itertools.chain(reply.answer,
                                          reply.additional,
                                          reply.question,
                                          reply.authority):
                    for rd in rr:
                        if rd.rdtype == dns.rdatatype.A:
                            self.start_srv((rd.address, 53), socket.AF_INET)
                        elif rd.rdtype == dns.rdatatype.AAAA:
                            self.start_srv((rd.address, 53), socket.AF_INET6)

221
    def play(self, subject_addr):
222
        self.scenario.play({'': (subject_addr, 53)})
Marek Vavruša's avatar
Marek Vavruša committed
223

224 225 226 227 228 229

def empty_test_case():
    """
    Return (scenario, config) pair which answers to any query on 127.0.0.10.
    """
    # Mirror server
230
    empty_test_path = os.path.dirname(os.path.realpath(__file__)) + "/empty.rpl"
231 232
    test_config = {'ROOT_ADDR': '127.0.0.10',
                   '_SOCKET_FAMILY': socket.AF_INET}
233
    return scenario.parse_file(empty_test_path)[0], test_config
234

235 236 237 238 239 240 241 242

def standalone_self_test():
    """
    Self-test code

    Usage:
    LD_PRELOAD=libsocket_wrapper.so SOCKET_WRAPPER_DIR=/tmp $PYTHON -m pydnstest.testserver --help
    """
243
    logging.basicConfig(level=logging.DEBUG)
244
    argparser = argparse.ArgumentParser()
245
    argparser.add_argument('--scenario', help='absolute path to test scenario',
246
                           required=False)
247 248
    argparser.add_argument('--step', help='step # in the scenario (default: first)',
                           required=False, type=int)
249 250
    args = argparser.parse_args()
    if args.scenario:
251 252
        test_scenario, test_config_text = scenario.parse_file(args.scenario)
        test_config, _ = scenario.parse_config(test_config_text, True, os.getcwd())
253 254
    else:
        test_scenario, test_config = empty_test_case()
255 256 257 258 259 260 261 262 263

    if args.step:
        for step in test_scenario.steps:
            if step.id == args.step:
                test_scenario.current_step = step
        if not test_scenario.current_step:
            raise ValueError('step ID %s not found in scenario' % args.step)
    else:
        test_scenario.current_step = test_scenario.steps[0]
264

265
    server = TestServer(test_scenario, test_config['ROOT_ADDR'], test_config['_SOCKET_FAMILY'])
Marek Vavruša's avatar
Marek Vavruša committed
266
    server.start()
267

268
    logging.info("[==========] Mirror server running at %s", server.address())
269

270
    def kill(signum, frame):  # pylint: disable=unused-argument
271
        logging.info("[==========] Shutdown.")
272 273 274
        server.stop()
        sys.exit(128 + signum)

275 276
    signal.signal(signal.SIGINT, kill)
    signal.signal(signal.SIGTERM, kill)
277 278 279

    while True:
        time.sleep(0.5)
280 281 282 283 284


if __name__ == '__main__':
    # this is done to avoid creating global variables
    standalone_self_test()