proxy.py 5.14 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
from contextlib import contextmanager, ContextDecorator
import os
import subprocess
from typing import Any, Dict, Optional

import dns
import dns.rcode
import pytest

from kresd import CERTS_DIR, Forward, Kresd, make_kresd, make_port, PYTESTS_DIR
import utils


HINTS = {
    '0.foo.': '127.0.0.1',
    '1.foo.': '127.0.0.1',
    '2.foo.': '127.0.0.1',
    '3.foo.': '127.0.0.1',
}


def resolve_hint(sock, qname):
    buff, msgid = utils.get_msgbuff(qname)
    sock.sendall(buff)
    answer = utils.receive_parse_answer(sock)
    assert answer.id == msgid
    assert answer.rcode() == dns.rcode.NOERROR
    assert answer.answer[0][0].address == HINTS[qname]


class Proxy(ContextDecorator):
    PATH = ''

    def __init__(
                self,
                local_ip: str = '127.0.0.1',
                local_port: Optional[int] = None,
                upstream_ip: str = '127.0.0.1',
                upstream_port: Optional[int] = None
            ) -> None:
        self.local_ip = local_ip
        self.local_port = local_port
        self.upstream_ip = upstream_ip
        self.upstream_port = upstream_port
        self.proxy = None

    def get_args(self):
        args = []
        args.append('--local')
        args.append(self.local_ip)
        if self.local_port is not None:
            args.append('--lport')
            args.append(str(self.local_port))
        args.append('--upstream')
        args.append(self.upstream_ip)
        if self.upstream_port is not None:
            args.append('--uport')
            args.append(str(self.upstream_port))
        return args

    def __enter__(self):
        if not os.path.exists(self.PATH):
            pytest.skip("proxy executable '{}' not found (did you compile it?)".format(self.PATH))

        cwd, cmd = os.path.split(self.PATH)
        cmd = './' + cmd
        args = [cmd] + self.get_args()
        print(' '.join(args))

        self.proxy = subprocess.Popen(
            args, cwd=cwd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.proxy is not None:
            self.proxy.terminate()
            self.proxy = None


class TLSProxy(Proxy):
    PATH = os.path.join(PYTESTS_DIR, 'proxy', 'tlsproxy')

    def __init__(
                self,
                local_ip: str = '127.0.0.1',
                local_port: Optional[int] = None,
                upstream_ip: str = '127.0.0.1',
                upstream_port: Optional[int] = None,
                certname: Optional[str] = 'tt',
                close: Optional[int] = None,
92 93
                rehandshake: bool = False,
                force_tls13: bool = False
94 95 96 97 98 99 100 101 102 103
            ) -> None:
        super().__init__(local_ip, local_port, upstream_ip, upstream_port)
        if certname is not None:
            self.cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
            self.key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
        else:
            self.cert_path = None
            self.key_path = None
        self.close = close
        self.rehandshake = rehandshake
104
        self.force_tls13 = force_tls13
105 106 107 108 109 110 111 112 113 114 115 116 117 118

    def get_args(self):
        args = super().get_args()
        if self.cert_path is not None:
            args.append('--cert')
            args.append(self.cert_path)
        if self.key_path is not None:
            args.append('--key')
            args.append(self.key_path)
        if self.close is not None:
            args.append('--close')
            args.append(str(self.close))
        if self.rehandshake:
            args.append('--rehandshake')
119 120
        if self.force_tls13:
            args.append('--tls13')
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        return args


@contextmanager
def kresd_tls_client(
            workdir: str,
            proxy: TLSProxy,
            kresd_tls_client_kwargs: Optional[Dict[Any, Any]] = None,
            kresd_fwd_target_kwargs: Optional[Dict[Any, Any]] = None
        ) -> Kresd:
    """kresd_tls_client --(tls)--> tlsproxy --(tcp)--> kresd_fwd_target"""
    ALLOWED_IPS = {'127.0.0.1', '::1'}
    assert proxy.local_ip in ALLOWED_IPS, "only localhost IPs supported for proxy"
    assert proxy.upstream_ip in ALLOWED_IPS, "only localhost IPs are supported for proxy"

    if kresd_tls_client_kwargs is None:
        kresd_tls_client_kwargs = dict()
    if kresd_fwd_target_kwargs is None:
        kresd_fwd_target_kwargs = dict()

    # run forward target instance
    dir1 = os.path.join(workdir, 'kresd_fwd_target')
    os.makedirs(dir1)

    with make_kresd(dir1, hints=HINTS, **kresd_fwd_target_kwargs) as kresd_fwd_target:
        sock = kresd_fwd_target.ip_tcp_socket()
        resolve_hint(sock, list(HINTS.keys())[0])

        proxy.local_port = make_port('127.0.0.1', '::1')
        proxy.upstream_port = kresd_fwd_target.port

        with proxy:
            # run test kresd instance
            dir2 = os.path.join(workdir, 'kresd_tls_client')
            os.makedirs(dir2)
            forward = Forward(
                proto='tls', ip=proxy.local_ip, port=proxy.local_port,
                hostname='transport-test-server.com', ca_file=proxy.cert_path)
            with make_kresd(dir2, forward=forward, **kresd_tls_client_kwargs) as kresd:
                yield kresd