Commit e7c5c102 authored by Grigorii Demidov's avatar Grigorii Demidov Committed by Petr Špaček

daemon: reuse outbound TCP connections if possible; TLS over outbound TCP connection

parent 5aea7a1f
......@@ -398,6 +398,107 @@ static int net_tls(lua_State *L)
return 1;
}
static int print_tls_param(const char *key, void *val, void *data)
{
if (!val) {
return 0;
}
struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)val;
lua_State *L = (lua_State *)data;
lua_newtable(L);
lua_newtable(L);
lua_newtable(L);
for (size_t i = 0; i < entry->pins.len; ++i) {
lua_pushnumber(L, i + 1);
lua_pushstring(L, entry->pins.at[i]);
lua_settable(L, -3);
}
lua_setfield(L, -2, "pins");
lua_newtable(L);
for (size_t i = 0; i < entry->ca_files.len; ++i) {
lua_pushnumber(L, i + 1);
lua_pushstring(L, entry->ca_files.at[i]);
lua_settable(L, -3);
}
lua_setfield(L, -2, "ca files");
lua_setfield(L, -2, key);
return 0;
}
static int print_tls_client_params(lua_State *L)
{
struct engine *engine = engine_luaget(L);
if (!engine) {
return 0;
}
struct network *net = &engine->net;
if (!net) {
return 0;
}
if (net->tls_client_params.root == 0 ) {
return 0;
}
map_walk(&net->tls_client_params, print_tls_param, (void *)L);
return 1;
}
static int net_tls_client(lua_State *L)
{
struct engine *engine = engine_luaget(L);
if (!engine) {
return 0;
}
struct network *net = &engine->net;
if (!net) {
return 0;
}
/* Only return current credentials. */
if (lua_gettop(L) == 0) {
return print_tls_client_params(L);
}
const char *full_addr = NULL;
const char *ca_file = NULL;
const char *pin = NULL;
if ((lua_gettop(L) == 1) && lua_isstring(L, 1)) {
full_addr = lua_tostring(L, 1);
} else if ((lua_gettop(L) == 3) && lua_isstring(L, 1) && lua_isstring(L, 2) && lua_isstring(L, 3)) {
full_addr = lua_tostring(L, 1);
ca_file = lua_tostring(L, 2);
pin = lua_tostring(L, 3);
} else {
format_error(L, "net.tls_client either takes one parameter (\"address\") either takes three ones: (\"address\", \"ca_file\", \"pin\")");
lua_error(L);
}
char addr[INET6_ADDRSTRLEN];
uint16_t port = 0;
if (kr_straddr_split(full_addr, addr, sizeof(addr), &port) != kr_ok()) {
format_error(L, "invalid IP address");
lua_error(L);
}
if (port == 0) {
port = 53;
}
int r = tls_client_params_set(&net->tls_client_params, addr, port, ca_file, pin);
if (r != 0) {
lua_pushstring(L, strerror(ENOMEM));
lua_error(L);
}
lua_pushboolean(L, true);
return 1;
}
static int net_tls_padding(lua_State *L)
{
struct engine *engine = engine_luaget(L);
......@@ -508,6 +609,8 @@ int lib_net(lua_State *L)
{ "bufsize", net_bufsize },
{ "tcp_pipeline", net_pipeline },
{ "tls", net_tls },
{ "tls_server", net_tls },
{ "tls_client", net_tls_client },
{ "tls_padding", net_tls_padding },
{ "outgoing_v4", net_outgoing_v4 },
{ "outgoing_v6", net_outgoing_v6 },
......
......@@ -48,15 +48,18 @@ static void check_bufsize(uv_handle_t* handle)
static void session_clear(struct session *s)
{
assert(s->outgoing || s->tasks.len == 0);
assert(s->tasks.len == 0 && s->waiting.len == 0);
array_clear(s->tasks);
array_clear(s->waiting);
tls_free(s->tls_ctx);
tls_client_ctx_free(s->tls_client_ctx);
memset(s, 0, sizeof(*s));
}
void session_free(struct session *s)
{
if (s) {
assert(s->tasks.len == 0 && s->waiting.len == 0);
session_clear(s);
free(s);
}
......@@ -89,6 +92,8 @@ static void session_release(struct worker_ctx *worker, uv_handle_t *handle)
if (!s) {
return;
}
assert(s->waiting.len == 0 && s->tasks.len == 0);
assert(s->buffering == NULL);
if (!s->outgoing && handle->type == UV_TCP) {
worker_end_tcp(worker, handle); /* to free the buffering task */
}
......@@ -158,8 +163,10 @@ static int udp_bind_finalize(uv_handle_t *handle)
{
check_bufsize((uv_handle_t *)handle);
/* Handle is already created, just create context. */
handle->data = session_new();
assert(handle->data);
struct session *session = session_new();
assert(session);
session->handle = handle;
handle->data = session;
return io_start_read((uv_handle_t *)handle);
}
......@@ -189,20 +196,14 @@ int udp_bindfd(uv_udp_t *handle, int fd)
return udp_bind_finalize((uv_handle_t *)handle);
}
static void tcp_timeout(uv_handle_t *timer)
{
uv_handle_t *handle = timer->data;
uv_close(handle, io_free);
}
static void tcp_timeout_trigger(uv_timer_t *timer)
{
uv_handle_t *handle = timer->data;
struct session *session = handle->data;
struct session *session = timer->data;
if (session->tasks.len > 0) {
uv_timer_again(timer);
} else {
uv_close((uv_handle_t *)timer, tcp_timeout);
uv_timer_stop(timer);
worker_session_close(session);
}
}
......@@ -210,12 +211,16 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
{
uv_loop_t *loop = handle->loop;
struct session *s = handle->data;
if (s->closing) {
return;
}
struct worker_ctx *worker = loop->data;
/* TCP pipelining is rather complicated and requires cooperation from the worker
* so the whole message reassembly and demuxing logic is inside worker */
int ret = 0;
if (s->has_tls) {
ret = tls_process(worker, handle, (const uint8_t *)buf->base, nread);
ret = s->outgoing ? tls_client_process(worker, handle, (const uint8_t *)buf->base, nread) :
tls_process(worker, handle, (const uint8_t *)buf->base, nread);
} else {
ret = worker_process_tcp(worker, handle, (const uint8_t *)buf->base, nread);
}
......@@ -226,7 +231,7 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
if (!s->outgoing && !uv_is_closing((uv_handle_t *)&s->timeout)) {
uv_timer_stop(&s->timeout);
if (s->tasks.len == 0) {
uv_close((uv_handle_t *)&s->timeout, tcp_timeout);
worker_session_close(s);
} else { /* If there are tasks running, defer until they finish. */
uv_timer_start(&s->timeout, tcp_timeout_trigger, 1, KR_CONN_RTT_MAX/2);
}
......@@ -265,7 +270,7 @@ static void _tcp_accept(uv_stream_t *master, int status, bool tls)
}
uv_timer_t *timer = &session->timeout;
uv_timer_init(master->loop, timer);
timer->data = client;
timer->data = session;
uv_timer_start(timer, tcp_timeout_trigger, KR_CONN_RTT_MAX/2, KR_CONN_RTT_MAX/2);
io_start_read((uv_handle_t *)client);
}
......@@ -379,8 +384,12 @@ void io_create(uv_loop_t *loop, uv_handle_t *handle, int type)
}
struct worker_ctx *worker = loop->data;
handle->data = session_borrow(worker);
assert(handle->data);
struct session *session = session_borrow(worker);
assert(session);
session->handle = handle;
handle->data = session;
session->timeout.data = session;
uv_timer_init(worker->loop, &session->timeout);
}
void io_deinit(uv_handle_t *handle)
......@@ -388,6 +397,7 @@ void io_deinit(uv_handle_t *handle)
if (!handle) {
return;
}
struct session *session = handle->data;
uv_loop_t *loop = handle->loop;
if (loop && loop->data) {
struct worker_ctx *worker = loop->data;
......
......@@ -18,22 +18,35 @@
#include <uv.h>
#include <libknot/packet/pkt.h>
#include <gnutls/gnutls.h>
#include "lib/generic/array.h"
#include "daemon/worker.h"
struct qr_task;
struct tls_ctx_t;
struct tls_client_ctx_t;
/* Per-session (TCP or UDP) persistent structure,
* that exists between remote counterpart and a local socket.
*/
struct session {
bool outgoing;
bool outgoing; /**< True: to upstream; false: from a client. */
bool throttled;
bool has_tls;
bool connected;
bool closing;
union inaddr peer;
uv_handle_t *handle;
uv_timer_t timeout;
struct qr_task *buffering; /**< Worker buffers the incomplete TCP query here. */
struct tls_ctx_t *tls_ctx;
array_t(struct qr_task *) tasks;
struct tls_client_ctx_t *tls_client_ctx;
uint8_t msg_hdr[4]; /**< Buffer for DNS message header. */
ssize_t msg_hdr_idx; /**< The number of bytes in msg_hdr filled so far. */
qr_tasklist_t tasks;
qr_tasklist_t waiting;
ssize_t bytes_to_skip;
};
void session_free(struct session *s);
......
......@@ -18,6 +18,7 @@
#include <signal.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <getopt.h>
#include <libgen.h>
#include <uv.h>
......
......@@ -51,6 +51,7 @@ void network_init(struct network *net, uv_loop_t *loop)
if (net != NULL) {
net->loop = loop;
net->endpoints = map_make();
net->tls_client_params = map_make();
}
}
......@@ -106,6 +107,7 @@ void network_deinit(struct network *net)
map_walk(&net->endpoints, free_key, 0);
map_clear(&net->endpoints);
tls_credentials_free(net->tls_credentials);
tls_client_params_free(&net->tls_client_params);
net->tls_credentials = NULL;
}
}
......
......@@ -46,6 +46,7 @@ struct network {
uv_loop_t *loop;
map_t endpoints;
struct tls_credentials *tls_credentials;
map_t tls_client_params;
};
void network_init(struct network *net, uv_loop_t *loop);
......
This diff is collapsed.
......@@ -20,11 +20,13 @@
#include <gnutls/gnutls.h>
#include <libknot/packet/pkt.h>
#include "lib/defines.h"
#include "lib/generic/array.h"
#include "lib/generic/map.h"
#define MAX_TLS_PADDING KR_EDNS_PAYLOAD
struct tls_ctx_t;
struct tls_credentials;
struct tls_client_ctx_t;
struct tls_credentials {
int count;
char *tls_cert;
......@@ -34,6 +36,20 @@ struct tls_credentials {
char *ephemeral_servicename;
};
struct tls_client_paramlist_entry {
array_t(const char *) ca_files;
array_t(const char *) pins;
gnutls_certificate_credentials_t credentials;
};
typedef enum tls_client_hs_state {
TLS_HS_NOT_STARTED = 0,
TLS_HS_IN_PROGRESS,
TLS_HS_DONE
} tls_client_hs_state_t;
typedef int (*tls_handshake_cb) (struct session *session, int status);
/*! Create an empty TLS context in query context */
struct tls_ctx_t* tls_new(struct worker_ctx *worker);
......@@ -66,3 +82,32 @@ void tls_credentials_log_pins(struct tls_credentials *tls_credentials);
/*! Generate new ephemeral TLS credentials. */
struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine);
/*! Set TLS authentication parameters for given address. */
int tls_client_params_set(map_t *tls_client_paramlist,
const char *addr, uint16_t port,
const char *ca_file, const char *pin);
/*! Free TLS authentication parameters. */
int tls_client_params_free(map_t *tls_client_paramlist);
/*! Allocate new client TLS context */
struct tls_client_ctx_t *tls_client_ctx_new(const struct tls_client_paramlist_entry *entry);
int tls_client_process(struct worker_ctx *worker, uv_stream_t *handle,
const uint8_t *buf, ssize_t nread);
/*! Free client TLS context */
void tls_client_ctx_free(struct tls_client_ctx_t *ctx);
int tls_client_connect_start(struct tls_client_ctx_t *ctx, struct session *session,
tls_handshake_cb handshake_cb);
void tls_client_close(struct tls_client_ctx_t *ctx);
int tls_client_push(struct qr_task *task, uv_handle_t *handle, knot_pkt_t *pkt);
tls_client_hs_state_t tls_client_get_hs_state(const struct tls_client_ctx_t *ctx);
int tls_client_ctx_set_params(struct tls_client_ctx_t *ctx,
const struct tls_client_paramlist_entry *entry);
\ No newline at end of file
This diff is collapsed.
......@@ -21,8 +21,12 @@
#include "lib/generic/map.h"
/** Query resolution task (opaque). */
struct qr_task;
/** Worker state (opaque). */
struct worker_ctx;
/** Transport session (opaque). */
struct session;
/** Worker callback */
typedef void (*worker_cb_t)(struct worker_ctx *worker, struct kr_request *req, void *baton);
......@@ -31,14 +35,20 @@ struct worker_ctx *worker_create(struct engine *engine, knot_mm_t *pool,
int worker_id, int worker_count);
/**
* Process incoming packet (query or answer to subrequest).
* Process an incoming packet (query from a client or answer from upstream).
*
* @param worker the singleton worker
* @param handle socket through which the request came
* @param query the packet, or NULL on an error from the transport layer
* @param addr the address from which the packet came (or NULL, possibly, on error)
* @return 0 or an error code
*/
int worker_submit(struct worker_ctx *worker, uv_handle_t *handle, knot_pkt_t *query,
const struct sockaddr* addr);
/**
* Process incoming DNS/TCP message fragment(s).
* Process incoming DNS message fragment(s) that arrived over a stream (TCP, TLS).
*
* If the fragment contains only a partial message, it is buffered.
* If the fragment contains a complete query or completes current fragment, execute it.
* @return the number of newly-completed requests (>=0) or an error code
......@@ -55,6 +65,7 @@ int worker_end_tcp(struct worker_ctx *worker, uv_handle_t *handle);
/**
* Schedule query for resolution.
*
* After resolution finishes, invoke on_complete with baton.
* @return 0 or an error code
*
* @note the options passed are |-combined with struct kr_context::options
......@@ -66,15 +77,27 @@ int worker_resolve(struct worker_ctx *worker, knot_pkt_t *query, struct kr_qflag
/** Collect worker mempools */
void worker_reclaim(struct worker_ctx *worker);
/** Closes given session */
void worker_session_close(struct session *session);
/** @cond internal */
/** Number of request within timeout window. */
#define MAX_PENDING KR_NSREP_MAXADDR
/** Maximum response time from TCP upstream, milliseconds */
#define MAX_TCP_INACTIVITY 10000
/** Freelist of available mempools. */
typedef array_t(void *) mp_freelist_t;
/** List of query resolution tasks. */
typedef array_t(struct qr_task *) qr_tasklist_t;
/** Session list. */
typedef array_t(struct session *) qr_sessionlist_t;
/** \details Worker state is meant to persist during the whole life of daemon. */
struct worker_ctx {
struct engine *engine;
......@@ -94,6 +117,7 @@ struct worker_ctx {
#endif
struct {
size_t concurrent;
size_t rconcurrent;
size_t udp;
size_t tcp;
size_t ipv4;
......@@ -103,6 +127,12 @@ struct worker_ctx {
size_t timeout;
} stats;
bool too_many_open;
size_t rconcurrent_highwatermark;
/* List of active outbound TCP sessions */
map_t tcp_connected;
/* List of outbound TCP sessions waiting to be accepted */
map_t tcp_waiting;
map_t outgoing;
mp_freelist_t pool_mp;
mp_freelist_t pool_ioreq;
......@@ -110,34 +140,6 @@ struct worker_ctx {
knot_mm_t pkt_pool;
};
/** Query resolution task. */
struct qr_task
{
struct kr_request req;
struct worker_ctx *worker;
struct session *session;
knot_pkt_t *pktbuf;
array_t(struct qr_task *) waiting;
uv_handle_t *pending[MAX_PENDING];
uint16_t pending_count;
uint16_t addrlist_count;
uint16_t addrlist_turn;
uint16_t timeouts;
uint16_t iter_count;
uint16_t bytes_remaining;
struct sockaddr *addrlist;
uv_timer_t *timeout;
worker_cb_t on_complete;
void *baton;
struct {
union inaddr addr;
union inaddr dst_addr;
uv_handle_t *handle;
} source;
uint32_t refs;
bool finished : 1;
bool leading : 1;
};
/** @endcond */
......@@ -823,7 +823,7 @@ int kr_make_query(struct kr_query *query, knot_pkt_t *pkt)
char name_str[KNOT_DNAME_MAXLEN], type_str[16];
knot_dname_to_str(name_str, query->sname, sizeof(name_str));
knot_rrtype_to_string(query->stype, type_str, sizeof(type_str));
QVERBOSE_MSG(query, "'%s' type '%s' id was assigned, parent id %hu\n",
QVERBOSE_MSG(query, "'%s' type '%s' id was assigned, parent id %u\n",
name_str, type_str, query->parent ? query->parent->id : 0);
}
return kr_ok();
......
......@@ -1520,7 +1520,6 @@ int kr_resolve_checkout(struct kr_request *request, struct sockaddr *src,
if (ret != 0) {
return kr_error(EINVAL);
}
WITH_VERBOSE {
char qname_str[KNOT_DNAME_MAXLEN], zonecut_str[KNOT_DNAME_MAXLEN], ns_str[INET6_ADDRSTRLEN], type_str[16];
knot_dname_to_str(qname_str, knot_pkt_qname(packet), sizeof(qname_str));
......
......@@ -318,6 +318,31 @@ uint16_t kr_inaddr_port(const struct sockaddr *addr)
}
}
int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
{
int ret = kr_ok();
if (!addr || !buf || !buflen) {
return kr_error(EINVAL);
}
char str[INET6_ADDRSTRLEN + 6];
if (!inet_ntop(addr->sa_family, kr_inaddr(addr), str, sizeof(str))) {
return kr_error(errno);
}
int len = strlen(str);
str[len] = '#';
u16tostr((uint8_t *)&str[len + 1], kr_inaddr_port(addr));
len += 6;
str[len] = 0;
if (len >= *buflen) {
ret = kr_error(ENOSPC);
} else {
memcpy(buf, str, len + 1);
}
*buflen = len;
return ret;
}
int kr_straddr_family(const char *addr)
{
if (!addr) {
......@@ -396,6 +421,84 @@ int kr_straddr_subnet(void *dst, const char *addr)
return bit_len;
}
int kr_straddr_split(const char *addr, char *buf, size_t buflen, uint16_t *port)
{
const int base = 10;
long p = 0;
size_t addrlen = strlen(addr);
char *p_start = strchr(addr, '@');
char *p_end;
if (!p_start) {
p_start = strchr(addr, '#');
}
if (p_start) {
if (p_start[1] != '\0'){
p = strtol(p_start + 1, &p_end, base);
if (*p_end != '\0' || p <= 0 || p > UINT16_MAX) {
return kr_error(EINVAL);
}
}
addrlen = p_start - addr;
}
/* Check if address is valid. */
if (addrlen >= INET6_ADDRSTRLEN) {
return kr_error(EINVAL);
}
char str[INET6_ADDRSTRLEN];
struct sockaddr_storage ss;
memcpy(str, addr, addrlen); str[addrlen] = '\0';
int family = kr_straddr_family(str);
if (family == kr_error(EINVAL) || !inet_pton(family, str, &ss)) {
return kr_error(EINVAL);
}
/* Address and port contains valid values, return it to caller */
if (buf) {
if (addrlen >= buflen) {
return kr_error(ENOSPC);
}
memcpy(buf, addr, addrlen); buf[addrlen] = '\0';
}
if (port) {
*port = (uint16_t)p;
}
return kr_ok();
}
int kr_straddr_join(const char *addr, uint16_t port, char *buf, size_t *buflen)
{
if (!addr || !buf || !buflen) {
return kr_error(EINVAL);
}
struct sockaddr_storage ss;
int family = kr_straddr_family(addr);
if (family == kr_error(EINVAL) || !inet_pton(family, addr, &ss)) {
return kr_error(EINVAL);
}
int len = strlen(addr);
if (len + 6 >= *buflen) {
return kr_error(ENOSPC);
}
memcpy(buf, addr, len + 1);
buf[len] = '#';
u16tostr((uint8_t *)&buf[len + 1], port);
len += 6;
buf[len] = 0;
*buflen = len;
return kr_ok();
}
int kr_bitcmp(const char *a, const char *b, int bits)
{
/* We're using the function from lua directly, so at least for now
......
......@@ -197,10 +197,13 @@ int kr_inaddr_len(const struct sockaddr *addr);
/** Port. */
KR_EXPORT KR_PURE
uint16_t kr_inaddr_port(const struct sockaddr *addr);
/** String representation for given address as "<addr>#<port>" */
KR_EXPORT
int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen);
/** Return address type for string. */
KR_EXPORT KR_PURE
int kr_straddr_family(const char *addr);
/** Return address length in given family. */
/** Return address length in given family (struct in*_addr). */
KR_EXPORT KR_CONST
int kr_family_len(int family);
/** Create a sockaddr* from string+port representation (also accepts IPv6 link-local). */
......@@ -211,6 +214,26 @@ struct sockaddr * kr_straddr_socket(const char *addr, int port);
KR_EXPORT
int kr_straddr_subnet(void *dst, const char *addr);
/** Splits ip address specified as "addr@port" or "addr#port" into addr and port
* and performs validation.
* @note if #port part isn't present, then port will be set to 0.
* buf and\or port can be set to NULL.
* @return kr_error(EINVAL) - addr part doesn't contains valid ip address or
* #port part is out-of-range (either < 0 either > UINT16_MAX)
* kr_error(ENOSP) - buflen is too small
*/
KR_EXPORT
int kr_straddr_split(const char *addr, char *buf, size_t buflen, uint16_t *port);
/** Formats ip address and port in "addr#port" format.
* and performs validation.
* @note Port always formatted as five-character string with leading zeros.
* @return kr_error(EINVAL) - addr or buf is NULL or buflen is 0 or
* addr doesn't contain a valid ip address
* kr_error(ENOSP) - buflen is too small
*/
KR_EXPORT
int kr_straddr_join(const char *addr, uint16_t port, char *buf, size_t *buflen);
/** Compare memory bitwise. The semantics is "the same" as for memcmp().
* The partial byte is considered with more-significant bits first,
* so this is e.g. suitable for comparing IP prefixes. */
......@@ -300,6 +323,19 @@ static inline const char *lua_push_printf(lua_State *L, const char *fmt, ...)
return ret;
}
/** @internal Return string representation of addr.
* @note return pointer to static string
*/
static inline char *kr_straddr(const struct sockaddr *addr)
{
assert(addr != NULL);
/* We are the sinle-threaded application */
static char str[INET6_ADDRSTRLEN + 6];
size_t len = sizeof(str);
int ret = kr_inaddr_str(addr, str, &len);
return ret != kr_ok() || len == 0 ? NULL : str;
}
/** The current time in monotonic milliseconds.
*
* \note it may be outdated in case of long callbacks; see uv_now().
......
......@@ -120,6 +120,32 @@ local function forward(target)
end
end
-- Forward request and all subrequests to upstream over TCP; validate answers
local function tcp_forward(target)
local list = {}
if type(target) == 'table' then
for _, v in pairs(target) do
table.insert(list, addr2sock(v))
assert(#list <= 4, 'at most 4 TCP_FORWARD targets are supported')
end
else
table.insert(list, addr2sock(target))
end
return function(state, req)
local qry = req:current()
req.options.FORWARD = true
req.options.NO_MINIMIZE = true
qry.flags.FORWARD = true
qry.flags.ALWAYS_CUT = false
qry.flags.NO_MINIMIZE = true
qry.flags.AWAIT_CUT = true
req.options.TCP = true
qry.flags.TCP = true
set_nslist(qry, list)
return state
end
end
-- Rewrite records in packet
local function reroute(tbl, names)
-- Import renumbering rules
......@@ -236,7 +262,8 @@ end
local policy = {
-- Policies
PASS = 1, DENY = 2, DROP = 3, TC = 4, QTRACE = 5,
FORWARD = forward, STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags,
FORWARD = forward, TCP_FORWARD = tcp_forward,
STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags,
-- Special values