network.c 8.73 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/*  Copyright (C) 2015 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
14
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 16
 */

17
#include <unistd.h>
18
#include <assert.h>
19 20 21
#include "daemon/network.h"
#include "daemon/worker.h"
#include "daemon/io.h"
22
#include "daemon/tls.h"
23

24
/* libuv 1.7.0+ is able to support SO_REUSEPORT for loadbalancing */
25 26
#if defined(UV_VERSION_HEX)
#if (__linux__ && SO_REUSEPORT)
27 28 29 30 31
  #define handle_init(type, loop, handle, family) do { \
	uv_ ## type ## _init_ex((loop), (handle), (family)); \
	uv_os_fd_t fd = 0; \
	if (uv_fileno((uv_handle_t *)(handle), &fd) == 0) { \
		int on = 1; \
32 33 34 35
		int ret = setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on)); \
		if (ret) { \
			return ret; \
		} \
36 37
	} \
  } while (0)
38 39 40 41 42 43
/* libuv 1.7.0+ is able to assign fd immediately */
#else
  #define handle_init(type, loop, handle, family) do { \
	uv_ ## type ## _init_ex((loop), (handle), (family)); \
  } while (0)
#endif
44 45 46 47 48
#else
  #define handle_init(type, loop, handle, family) \
	uv_ ## type ## _init((loop), (handle))
#endif

49 50 51 52 53 54 55 56
void network_init(struct network *net, uv_loop_t *loop)
{
	if (net != NULL) {
		net->loop = loop;
		net->endpoints = map_make();
	}
}

57 58 59 60 61 62 63
static void close_handle(uv_handle_t *handle, bool force)
{
	if (force) { /* Force close if event loop isn't running. */
		uv_os_fd_t fd = 0;
		if (uv_fileno(handle, &fd) == 0) {
			close(fd);
		}
64 65
		handle->loop = NULL;
		io_free(handle);
66
	} else { /* Asynchronous close */
67
		uv_close(handle, io_free);
68 69 70 71 72 73 74
	}
}

static int close_endpoint(struct endpoint *ep, bool force)
{
	if (ep->udp) {
		close_handle((uv_handle_t *)ep->udp, force);
75
	}
76 77
	if (ep->tcp) {
		close_handle((uv_handle_t *)ep->tcp, force);
78 79 80 81 82 83 84
	}

	free(ep);
	return kr_ok();
}

/** Endpoint visitor (see @file map.h) */
85
static int close_key(const char *key, void *val, void *ext)
86 87 88
{
	endpoint_array_t *ep_array = val;
	for (size_t i = ep_array->len; i--;) {
89
		close_endpoint(ep_array->at[i], true);
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
	}
	return 0;
}

static int free_key(const char *key, void *val, void *ext)
{
	endpoint_array_t *ep_array = val;
	array_clear(*ep_array);
	free(ep_array);
	return kr_ok();
}

void network_deinit(struct network *net)
{
	if (net != NULL) {
105
		map_walk(&net->endpoints, close_key, 0);
106 107
		map_walk(&net->endpoints, free_key, 0);
		map_clear(&net->endpoints);
108 109
		tls_credentials_free(net->tls_credentials);
		net->tls_credentials = NULL;
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
	}
}

/** Fetch or create endpoint array and insert endpoint. */
static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep)
{
	/* Fetch or insert address into map */
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
	if (ep_array == NULL) {
		ep_array = malloc(sizeof(*ep_array));
		if (ep_array == NULL) {
			return kr_error(ENOMEM);
		}
		if (map_set(&net->endpoints, addr, ep_array) != 0) {
			free(ep_array);
			return kr_error(ENOMEM);
		}
		array_init(*ep_array);
	}

130 131 132 133
	if (array_push(*ep_array, ep) < 0) {
		return kr_error(ENOMEM);
	}
	return kr_ok();
134 135 136 137 138
}

/** Open endpoint protocols. */
static int open_endpoint(struct network *net, struct endpoint *ep, struct sockaddr *sa, uint32_t flags)
{
139
	int ret = 0;
140
	if (flags & NET_UDP) {
141 142 143 144
		ep->udp = malloc(sizeof(*ep->udp));
		if (!ep->udp) {
			return kr_error(ENOMEM);
		}
145
		memset(ep->udp, 0, sizeof(*ep->udp));
146
		handle_init(udp, net->loop, ep->udp, sa->sa_family); /* can return! */
147
		ret = udp_bind(ep->udp, sa);
148 149 150 151 152 153
		if (ret != 0) {
			return ret;
		}
		ep->flags |= NET_UDP;
	}
	if (flags & NET_TCP) {
154 155 156 157
		ep->tcp = malloc(sizeof(*ep->tcp));
		if (!ep->tcp) {
			return kr_error(ENOMEM);
		}
158
		memset(ep->tcp, 0, sizeof(*ep->tcp));
159
		handle_init(tcp, net->loop, ep->tcp, sa->sa_family); /* can return! */
160 161 162 163 164 165
		if (flags & NET_TLS) {
			ret = tcp_bind_tls(ep->tcp, sa);
			ep->flags |= NET_TLS;
		} else {
			ret = tcp_bind(ep->tcp, sa);
		}
166 167 168 169 170
		if (ret != 0) {
			return ret;
		}
		ep->flags |= NET_TCP;
	}
171
	return ret;
172 173
}

174
/** Open fd as endpoint. */
175
static int open_endpoint_fd(struct network *net, struct endpoint *ep, int fd, int sock_type, bool use_tls)
176
{
177
	int ret = kr_ok();
178
	if (sock_type == SOCK_DGRAM) {
179 180 181 182
		if (use_tls) {
			/* we do not support TLS over UDP */
			return kr_error(EBADF);
		}
183 184 185 186 187 188 189 190
		if (ep->udp) {
			return kr_error(EEXIST);
		}
		ep->udp = malloc(sizeof(*ep->udp));
		if (!ep->udp) {
			return kr_error(ENOMEM);
		}
		uv_udp_init(net->loop, ep->udp);
191
		ret = udp_bindfd(ep->udp, fd);
192 193 194 195 196
		if (ret != 0) {
			close_handle((uv_handle_t *)ep->udp, false);
			return ret;
		}
		ep->flags |= NET_UDP;
197
		return kr_ok();
198 199 200 201 202 203 204 205 206 207
	}
	if (sock_type == SOCK_STREAM) {
		if (ep->tcp) {
			return kr_error(EEXIST);
		}
		ep->tcp = malloc(sizeof(*ep->tcp));
		if (!ep->tcp) {
			return kr_error(ENOMEM);
		}
		uv_tcp_init(net->loop, ep->tcp);
208 209 210 211 212 213
		if (use_tls) {
			ret = tcp_bindfd_tls(ep->tcp, fd);
			ep->flags |= NET_TLS;
		} else {
			ret = tcp_bindfd(ep->tcp, fd);
		}
214 215 216 217 218
		if (ret != 0) {
			close_handle((uv_handle_t *)ep->tcp, false);
			return ret;
		}
		ep->flags |= NET_TCP;
219
		return kr_ok();
220
	}
221
	return kr_error(EINVAL);
222 223
}

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
/** @internal Fetch endpoint array and offset of the address/port query. */
static endpoint_array_t *network_get(struct network *net, const char *addr, uint16_t port, size_t *index)
{
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
	if (ep_array) {
		for (size_t i = ep_array->len; i--;) {
			struct endpoint *ep = ep_array->at[i];
			if (ep->port == port) {
				*index = i;
				return ep_array;
			}
		}
	}
	return NULL;
}

240
int network_listen_fd(struct network *net, int fd, bool use_tls)
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
{
	/* Extract local address and socket type. */
	int sock_type = SOCK_DGRAM;
	socklen_t len = sizeof(sock_type);
	int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sock_type, &len);	
	if (ret != 0) {
		return kr_error(EBADF);
	}
	/* Extract local address for this socket. */
	struct sockaddr_storage ss;
	socklen_t addr_len = sizeof(ss);
	ret = getsockname(fd, (struct sockaddr *)&ss, &addr_len);
	if (ret != 0) {
		return kr_error(EBADF);
	}
	int port = 0;
257
	char addr_str[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
	if (ss.ss_family == AF_INET) {
		uv_ip4_name((const struct sockaddr_in*)&ss, addr_str, sizeof(addr_str));
		port = ntohs(((struct sockaddr_in *)&ss)->sin_port);
	} else if (ss.ss_family == AF_INET6) {
		uv_ip6_name((const struct sockaddr_in6*)&ss, addr_str, sizeof(addr_str));
		port = ntohs(((struct sockaddr_in6 *)&ss)->sin6_port);
	} else {
		return kr_error(EAFNOSUPPORT);
	}
	/* Fetch or create endpoint for this fd */
	size_t index = 0;
	endpoint_array_t *ep_array = network_get(net, addr_str, port, &index);
	if (!ep_array) {
		struct endpoint *ep = malloc(sizeof(*ep));
		memset(ep, 0, sizeof(*ep));
		ep->flags = NET_DOWN;
		ep->port = port;
		ret = insert_endpoint(net, addr_str, ep);
		if (ret != 0) {
			return ret;
		}
		ep_array = network_get(net, addr_str, port, &index);
	}
	/* Open fd in found/created endpoint. */
	struct endpoint *ep = ep_array->at[index];
	assert(ep != NULL);
	/* Create a libuv struct for this socket. */
285
	return open_endpoint_fd(net, ep, fd, sock_type, use_tls);
286 287
}

288 289 290 291 292 293
int network_listen(struct network *net, const char *addr, uint16_t port, uint32_t flags)
{
	if (net == NULL || addr == 0 || port == 0) {
		return kr_error(EINVAL);
	}

294
	/* Already listening */
295 296
	size_t index = 0;
	if (network_get(net, addr, port, &index)) {
297 298 299
		return kr_ok();
	}

300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
	/* Parse address. */
	int ret = 0;
	struct sockaddr_storage sa;
	if (strchr(addr, ':') != NULL) {
		ret = uv_ip6_addr(addr, port, (struct sockaddr_in6 *)&sa);
	} else {
		ret = uv_ip4_addr(addr, port, (struct sockaddr_in *)&sa);
	}
	if (ret != 0) {
		return ret;
	}

	/* Bind interfaces */
	struct endpoint *ep = malloc(sizeof(*ep));
	memset(ep, 0, sizeof(*ep));
	ep->flags = NET_DOWN;
	ep->port = port;
	ret = open_endpoint(net, ep, (struct sockaddr *)&sa, flags);
	if (ret == 0) {
		ret = insert_endpoint(net, addr, ep);
	}
	if (ret != 0) {
322
		close_endpoint(ep, false);
323 324 325 326 327 328 329
	}

	return ret;
}

int network_close(struct network *net, const char *addr, uint16_t port)
{
330 331 332
	size_t index = 0;
	endpoint_array_t *ep_array = network_get(net, addr, port, &index);
	if (!ep_array) {
333 334 335 336
		return kr_error(ENOENT);
	}

	/* Close endpoint in array. */
337 338
	close_endpoint(ep_array->at[index], false);
	array_del(*ep_array, index);
339

340 341 342 343 344 345 346 347
	/* Collapse key if it has no endpoint. */
	if (ep_array->len == 0) {
		free(ep_array);
		map_del(&net->endpoints, addr);
	}

	return kr_ok();
}