network.c 14.4 KB
Newer Older
1
/*  Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
2 3 4 5 6 7 8 9 10 11 12 13

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
14
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 16
 */

17
#include <unistd.h>
18
#include <assert.h>
19
#include "daemon/bindings/impl.h"
20 21 22
#include "daemon/network.h"
#include "daemon/worker.h"
#include "daemon/io.h"
23
#include "daemon/tls.h"
24

25
void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
26 27 28
{
	if (net != NULL) {
		net->loop = loop;
29
		net->endpoints = map_make(NULL);
30
		net->endpoint_kinds = trie_create(NULL);
31
		net->tls_client_params = NULL;
32
		net->tls_session_ticket_ctx = /* unsync. random, by default */
33
		tls_session_ticket_ctx_create(loop, NULL, 0);
34
		net->tcp.in_idle_timeout = 10000;
35
		net->tcp.tls_handshake_timeout = TLS_MAX_HANDSHAKE_TIME;
36
		net->tcp_backlog = tcp_backlog;
37 38 39
	}
}

40 41 42
/** Notify the registered function about endpoint getting open.
 * If log_port < 1, don't log it. */
static int endpoint_open_lua_cb(struct network *net, struct endpoint *ep,
43
				const char *log_addr)
44
{
45 46 47 48 49 50 51 52 53 54 55 56 57
	const bool ok = ep->flags.kind && !ep->handle && !ep->engaged && ep->fd != -1;
	if (!ok) {
		assert(!EINVAL);
		return kr_error(EINVAL);
	}
	/* First find callback in the endpoint registry. */
	struct worker_ctx *worker = net->loop->data; // LATER: the_worker
	lua_State *L = worker->engine->L;
	void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind,
				strlen(ep->flags.kind));
	if (!pp && net->missing_kind_is_error) {
		kr_log_error("warning: network socket kind '%s' not handled when opening '%s",
				ep->flags.kind, log_addr);
58 59
		if (ep->family != AF_UNIX)
			kr_log_error("#%d", ep->port);
60 61 62 63 64 65 66 67 68 69 70 71 72
		kr_log_error("'.  Likely causes: typo or not loading 'http' module.\n");
		/* No hard error, for now.  LATER: perhaps differentiate between
		 * explicit net.listen() calls and "just unused" systemd sockets.
		return kr_error(ENOENT);
		*/
	}
	if (!pp) return kr_ok();

	/* Now execute the callback. */
	const int fun_id = (char *)*pp - (char *)NULL;
	lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
	lua_pushboolean(L, true /* open */);
	lua_pushpointer(L, ep);
73
	if (ep->family == AF_UNIX) {
74 75
		lua_pushstring(L, log_addr);
	} else {
76
		lua_pushfstring(L, "%s#%d", log_addr, ep->port);
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
	}
	if (lua_pcall(L, 3, 0, 0)) {
		kr_log_error("error opening %s: %s\n", log_addr, lua_tostring(L, -1));
		return kr_error(ENOSYS); /* TODO: better value? */
	}
	ep->engaged = true;
	return kr_ok();
}

static int engage_endpoint_array(const char *key, void *endpoints, void *net)
{
	endpoint_array_t *eps = (endpoint_array_t *)endpoints;
	for (int i = 0; i < eps->len; ++i) {
		struct endpoint *ep = &eps->at[i];
		const bool match = !ep->engaged && ep->flags.kind;
		if (!match) continue;
93
		int ret = endpoint_open_lua_cb(net, ep, key);
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
		if (ret) return ret;
	}
	return 0;
}
int network_engage_endpoints(struct network *net)
{
	if (net->missing_kind_is_error)
		return kr_ok(); /* maybe weird, but let's make it idempotent */
	net->missing_kind_is_error = true;
	int ret = map_walk(&net->endpoints, engage_endpoint_array, net);
	if (ret) {
		net->missing_kind_is_error = false; /* avoid the same errors when closing */
		return ret;
	}
	return kr_ok();
}


/** Notify the registered function about endpoint about to be closed. */
static void endpoint_close_lua_cb(struct network *net, struct endpoint *ep)
{
	struct worker_ctx *worker = net->loop->data; // LATER: the_worker
	lua_State *L = worker->engine->L;
	void **pp = trie_get_try(net->endpoint_kinds, ep->flags.kind,
				strlen(ep->flags.kind));
	if (!pp && net->missing_kind_is_error) {
		kr_log_error("internal error: missing kind '%s' in endpoint registry\n",
				ep->flags.kind);
		return;
	}
	if (!pp) return;

	const int fun_id = (char *)*pp - (char *)NULL;
	lua_rawgeti(L, LUA_REGISTRYINDEX, fun_id);
	lua_pushboolean(L, false /* close */);
	lua_pushpointer(L, ep);
	lua_pushstring(L, "FIXME:endpoint-identifier");
	if (lua_pcall(L, 3, 0, 0)) {
		kr_log_error("failed to close FIXME:endpoint-identifier: %s\n",
				lua_tostring(L, -1));
	}
}

static void endpoint_close(struct network *net, struct endpoint *ep, bool force)
{
	assert(!ep->handle != !ep->flags.kind);
	if (ep->flags.kind) { /* Special endpoint. */
		if (ep->engaged) {
			endpoint_close_lua_cb(net, ep);
		}
		if (ep->fd > 0) {
			close(ep->fd); /* nothing to do with errors */
		}
		free_const(ep->flags.kind);
		return;
	}

151
	if (force) { /* Force close if event loop isn't running. */
152 153 154 155 156 157
		if (ep->fd >= 0) {
			close(ep->fd);
		}
		if (ep->handle) {
			ep->handle->loop = NULL;
			io_free(ep->handle);
158 159
		}
	} else { /* Asynchronous close */
160
		uv_close(ep->handle, io_free);
161 162 163
	}
}

164
/** Endpoint visitor (see @file map.h) */
165
static int close_key(const char *key, void *val, void *net)
166 167
{
	endpoint_array_t *ep_array = val;
168
	for (int i = 0; i < ep_array->len; ++i) {
169
		endpoint_close(net, &ep_array->at[i], true);
170 171 172 173 174 175 176 177 178 179 180 181
	}
	return 0;
}

static int free_key(const char *key, void *val, void *ext)
{
	endpoint_array_t *ep_array = val;
	array_clear(*ep_array);
	free(ep_array);
	return kr_ok();
}

182 183 184 185 186 187 188 189
int kind_unregister(trie_val_t *tv, void *L)
{
	int fun_id = (char *)*tv - (char *)NULL;
	luaL_unref(L, LUA_REGISTRYINDEX, fun_id);
	return 0;
}

void network_close_force(struct network *net)
190 191
{
	if (net != NULL) {
192
		map_walk(&net->endpoints, close_key, net);
193 194
		map_walk(&net->endpoints, free_key, 0);
		map_clear(&net->endpoints);
195 196 197 198 199 200 201 202 203 204 205
	}
}

void network_deinit(struct network *net)
{
	if (net != NULL) {
		network_close_force(net);
		struct worker_ctx *worker = net->loop->data; // LATER: the_worker
		trie_apply(net->endpoint_kinds, kind_unregister, worker->engine->L);
		trie_free(net->endpoint_kinds);

206
		tls_credentials_free(net->tls_credentials);
207
		tls_client_params_free(net->tls_client_params);
208
		tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx);
209 210 211
		#ifndef NDEBUG
			memset(net, 0, sizeof(*net));
		#endif
212 213 214
	}
}

215
/** Fetch or create endpoint array and insert endpoint (shallow memcpy). */
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep)
{
	/* Fetch or insert address into map */
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
	if (ep_array == NULL) {
		ep_array = malloc(sizeof(*ep_array));
		if (ep_array == NULL) {
			return kr_error(ENOMEM);
		}
		if (map_set(&net->endpoints, addr, ep_array) != 0) {
			free(ep_array);
			return kr_error(ENOMEM);
		}
		array_init(*ep_array);
	}

232
	if (array_reserve(*ep_array, ep_array->len + 1)) {
233 234
		return kr_error(ENOMEM);
	}
235
	memcpy(&ep_array->at[ep_array->len++], ep, sizeof(*ep));
236
	return kr_ok();
237 238
}

239 240
/** Open endpoint protocols.  ep->flags were pre-set. */
static int open_endpoint(struct network *net, struct endpoint *ep,
241
			 const struct sockaddr *sa, const char *log_addr)
242
{
243
	if ((sa != NULL) == (ep->fd != -1)) {
244 245
		assert(!EINVAL);
		return kr_error(EINVAL);
246
	}
247 248
	if (ep->handle) {
		return kr_error(EEXIST);
249 250
	}

251
	if (sa) {
252 253
		ep->fd = io_bind(sa, ep->flags.sock_type);
		if (ep->fd < 0) return ep->fd;
254
	}
255 256
	if (ep->flags.kind) {
		/* This EP isn't to be managed internally after binding. */
257
		return endpoint_open_lua_cb(net, ep, log_addr);
258 259 260 261
	} else {
		ep->engaged = true;
		/* .engaged seems not really meaningful with .kind == NULL, but... */
	}
262

263 264
	if (ep->flags.sock_type == SOCK_DGRAM) {
		if (ep->flags.tls) {
265 266
			assert(!EINVAL);
			return kr_error(EINVAL);
267
		}
268
		uv_udp_t *ep_handle = malloc(sizeof(uv_udp_t));
269 270
		ep->handle = (uv_handle_t *)ep_handle;
		if (!ep->handle) {
271 272
			return kr_error(ENOMEM);
		}
273
		return io_listen_udp(net->loop, ep_handle, ep->fd);
274 275
	} /* else */

276
	if (ep->flags.sock_type == SOCK_STREAM) {
277
		uv_tcp_t *ep_handle = malloc(sizeof(uv_tcp_t));
278 279
		ep->handle = (uv_handle_t *)ep_handle;
		if (!ep->handle) {
280 281
			return kr_error(ENOMEM);
		}
282 283
		return io_listen_tcp(net->loop, ep_handle, ep->fd,
					net->tcp_backlog, ep->flags.tls);
284 285 286
	} /* else */

	assert(!EINVAL);
287
	return kr_error(EINVAL);
288 289
}

290 291 292 293
/** @internal Fetch a pointer to endpoint of given parameters (or NULL).
 * Beware that there might be multiple matches, though that's not common. */
static struct endpoint * endpoint_get(struct network *net, const char *addr,
					uint16_t port, endpoint_flags_t flags)
294 295
{
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
296 297 298 299 300 301 302
	if (!ep_array) {
		return NULL;
	}
	for (int i = 0; i < ep_array->len; ++i) {
		struct endpoint *ep = &ep_array->at[i];
		if (ep->port == port && endpoint_flags_eq(ep->flags, flags)) {
			return ep;
303 304 305 306 307
		}
	}
	return NULL;
}

308 309
/** \note pass either sa != NULL xor ep.fd != -1;
 *  \note ownership of ep.flags.* is taken on success. */
310
static int create_endpoint(struct network *net, const char *addr_str,
311
				struct endpoint *ep, const struct sockaddr *sa)
312 313
{
	/* Bind interfaces */
314
	int ret = open_endpoint(net, ep, sa, addr_str);
315
	if (ret == 0) {
316
		ret = insert_endpoint(net, addr_str, ep);
317
	}
318 319
	if (ret != 0 && ep->handle) {
		endpoint_close(net, ep, false);
320 321 322 323
	}
	return ret;
}

324
int network_listen_fd(struct network *net, int fd, endpoint_flags_t flags)
325
{
326
	/* Extract fd's socket type. */
327 328
	socklen_t len = sizeof(flags.sock_type);
	int ret = getsockopt(fd, SOL_SOCKET, SO_TYPE, &flags.sock_type, &len);
329
	if (ret != 0) {
330 331
		return kr_error(errno);
	}
332 333
	if (flags.sock_type == SOCK_DGRAM && !flags.kind && flags.tls) {
		assert(!EINVAL); /* Perhaps DTLS some day. */
334 335 336
		return kr_error(EINVAL);
	}
	if (flags.sock_type != SOCK_DGRAM && flags.sock_type != SOCK_STREAM) {
337 338
		return kr_error(EBADF);
	}
339

340
	/* Extract local address for this socket. */
341
	struct sockaddr_storage ss = { .ss_family = AF_UNSPEC };
342 343 344
	socklen_t addr_len = sizeof(ss);
	ret = getsockname(fd, (struct sockaddr *)&ss, &addr_len);
	if (ret != 0) {
345
		return kr_error(errno);
346
	}
347 348 349 350 351 352 353

	struct endpoint ep = {
		.flags = flags,
		.family = ss.ss_family,
		.fd = fd,
	};
	/* Extract address string and port. */
354
	char addr_str[INET6_ADDRSTRLEN]; /* https://tools.ietf.org/html/rfc4291 */
355 356
	if (ss.ss_family == AF_INET) {
		uv_ip4_name((const struct sockaddr_in*)&ss, addr_str, sizeof(addr_str));
357
		ep.port = ntohs(((struct sockaddr_in *)&ss)->sin_port);
358 359
	} else if (ss.ss_family == AF_INET6) {
		uv_ip6_name((const struct sockaddr_in6*)&ss, addr_str, sizeof(addr_str));
360
		ep.port = ntohs(((struct sockaddr_in6 *)&ss)->sin6_port);
361 362 363
	} else {
		return kr_error(EAFNOSUPPORT);
	}
364 365 366

	/* always create endpoint for supervisor supplied fd
	 * even if addr+port is not unique */
367
	return create_endpoint(net, addr_str, &ep, NULL);
368 369
}

370 371
int network_listen(struct network *net, const char *addr, uint16_t port,
		   endpoint_flags_t flags)
372 373
{
	if (net == NULL || addr == 0 || port == 0) {
374
		assert(!EINVAL);
375 376
		return kr_error(EINVAL);
	}
377
	if (endpoint_get(net, addr, port, flags)) {
378
		return kr_error(EADDRINUSE); /* Already listening */
379 380
	}

381 382
	/* Parse address. */
	int ret = 0;
383
	union inaddr sa;
384
	if (strchr(addr, ':') != NULL) {
385
		ret = uv_ip6_addr(addr, port, &sa.ip6);
386
	} else {
387
		ret = uv_ip4_addr(addr, port, &sa.ip4);
388 389 390 391
	}
	if (ret != 0) {
		return ret;
	}
392 393 394 395 396 397 398
	struct endpoint ep = {
		.flags = flags,
		.fd = -1,
		.port = port,
		.family = sa.ip.sa_family,
	};
	return create_endpoint(net, addr, &ep, &sa.ip);
399 400
}

401
int network_close(struct network *net, const char *addr, int port)
402
{
403
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
404
	if (!ep_array) {
405 406 407
		return kr_error(ENOENT);
	}

408
	size_t i = 0;
409
	bool matched = false; /*< at least one match */
410
	while (i < ep_array->len) {
411
		struct endpoint *ep = &ep_array->at[i];
412
		if (port < 0 || ep->port == port) {
413
			endpoint_close(net, ep, false);
414 415 416 417 418 419 420 421 422 423
			array_del(*ep_array, i);
			matched = true;
			/* do not advance i */
		} else {
			++i;
		}
	}
	if (!matched) {
		return kr_error(ENOENT);
	}
424

425 426
	/* Collapse key if it has no endpoint. */
	if (ep_array->len == 0) {
427
		array_clear(*ep_array);
428 429 430 431 432 433
		free(ep_array);
		map_del(&net->endpoints, addr);
	}

	return kr_ok();
}
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449

void network_new_hostname(struct network *net, struct engine *engine)
{
	if (net->tls_credentials &&
	    net->tls_credentials->ephemeral_servicename) {
		struct tls_credentials *newcreds;
		newcreds = tls_get_ephemeral_credentials(engine);
		if (newcreds) {
			tls_credentials_release(net->tls_credentials);
			net->tls_credentials = newcreds;
			kr_log_info("[tls] Updated ephemeral X.509 cert with new hostname\n");
		} else {
			kr_log_error("[tls] Failed to update ephemeral X.509 cert with new hostname, using existing one\n");
		}
	}
}
450

451
#ifdef SO_ATTACH_BPF
452 453 454 455 456 457 458 459
static int set_bpf_cb(const char *key, void *val, void *ext)
{
	endpoint_array_t *endpoints = (endpoint_array_t *)val;
	assert(endpoints != NULL);
	int *bpffd = (int *)ext;
	assert(bpffd != NULL);

	for (size_t i = 0; i < endpoints->len; i++) {
460
		struct endpoint *endpoint = &endpoints->at[i];
461
		uv_os_fd_t sockfd = -1;
462 463
		if (endpoint->handle != NULL)
			uv_fileno(endpoint->handle, &sockfd);
464 465 466
		assert(sockfd != -1);

		if (setsockopt(sockfd, SOL_SOCKET, SO_ATTACH_BPF, bpffd, sizeof(int)) != 0) {
467
			return 1; /* return error (and stop iterating over net->endpoints) */
468 469
		}
	}
470
	return 0; /* OK */
471
}
472
#endif
473

474
int network_set_bpf(struct network *net, int bpf_fd)
475
{
476 477 478
#ifdef SO_ATTACH_BPF
	if (map_walk(&net->endpoints, set_bpf_cb, &bpf_fd) != 0) {
		/* set_bpf_cb() has returned error. */
479 480 481
		network_clear_bpf(net);
		return 0;
	}
482 483 484 485 486 487
#else
	kr_log_error("[network] SO_ATTACH_BPF socket option doesn't supported\n");
	(void)net;
	(void)bpf_fd;
	return 0;
#endif
488 489 490
	return 1;
}

491
#ifdef SO_DETACH_BPF
492 493 494 495 496 497
static int clear_bpf_cb(const char *key, void *val, void *ext)
{
	endpoint_array_t *endpoints = (endpoint_array_t *)val;
	assert(endpoints != NULL);

	for (size_t i = 0; i < endpoints->len; i++) {
498
		struct endpoint *endpoint = &endpoints->at[i];
499
		uv_os_fd_t sockfd = -1;
500 501
		if (endpoint->handle != NULL)
			uv_fileno(endpoint->handle, &sockfd);
502 503
		assert(sockfd != -1);

504 505 506 507 508
		if (setsockopt(sockfd, SOL_SOCKET, SO_DETACH_BPF, NULL, 0) != 0) {
			kr_log_error("[network] failed to clear SO_DETACH_BPF socket option\n");
		}
		/* Proceed even if setsockopt() failed,
		 * as we want to process all opened sockets. */
509
	}
510
	return 0;
511
}
512
#endif
513 514 515

void network_clear_bpf(struct network *net)
{
516 517 518 519 520 521
#ifdef SO_DETACH_BPF
	map_walk(&net->endpoints, clear_bpf_cb, NULL);
#else
	kr_log_error("[network] SO_DETACH_BPF socket option doesn't supported\n");
	(void)net;
#endif
522
}