network.c 5.58 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*  Copyright (C) 2015 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "daemon/network.h"
#include "daemon/worker.h"
#include "daemon/io.h"

21
/* libuv 1.7.0+ is able to support SO_REUSEPORT for loadbalancing */
22
#if (defined(ENABLE_REUSEPORT) || defined(UV_VERSION_HEX)) && (__linux__ && SO_REUSEPORT)
23 24 25 26 27 28 29 30 31 32 33 34 35
  #define handle_init(type, loop, handle, family) do { \
	uv_ ## type ## _init_ex((loop), (handle), (family)); \
	uv_os_fd_t fd = 0; \
	if (uv_fileno((uv_handle_t *)(handle), &fd) == 0) { \
		int on = 1; \
		setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on)); \
	} \
  } while (0)
#else
  #define handle_init(type, loop, handle, family) \
	uv_ ## type ## _init((loop), (handle))
#endif

36 37 38 39 40 41 42 43
void network_init(struct network *net, uv_loop_t *loop)
{
	if (net != NULL) {
		net->loop = loop;
		net->endpoints = map_make();
	}
}

44
static void free_handle(uv_handle_t *handle)
45
{
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
	free(handle);
}

static void close_handle(uv_handle_t *handle, bool force)
{
	if (force) { /* Force close if event loop isn't running. */
		uv_os_fd_t fd = 0;
		if (uv_fileno(handle, &fd) == 0) {
			close(fd);
		}
		free_handle(handle);
	} else { /* Asynchronous close */
		uv_close(handle, free_handle);
	}
}

static int close_endpoint(struct endpoint *ep, bool force)
{
	if (ep->udp) {
		close_handle((uv_handle_t *)ep->udp, force);
66
	}
67 68
	if (ep->tcp) {
		close_handle((uv_handle_t *)ep->tcp, force);
69 70 71 72 73 74 75
	}

	free(ep);
	return kr_ok();
}

/** Endpoint visitor (see @file map.h) */
76
static int close_key(const char *key, void *val, void *ext)
77 78 79
{
	endpoint_array_t *ep_array = val;
	for (size_t i = ep_array->len; i--;) {
80
		close_endpoint(ep_array->at[i], true);
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
	}
	return 0;
}

static int free_key(const char *key, void *val, void *ext)
{
	endpoint_array_t *ep_array = val;
	array_clear(*ep_array);
	free(ep_array);
	return kr_ok();
}

void network_deinit(struct network *net)
{
	if (net != NULL) {
96
		map_walk(&net->endpoints, close_key, 0);
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
		map_walk(&net->endpoints, free_key, 0);
		map_clear(&net->endpoints);
	}
}

/** Fetch or create endpoint array and insert endpoint. */
static int insert_endpoint(struct network *net, const char *addr, struct endpoint *ep)
{
	/* Fetch or insert address into map */
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
	if (ep_array == NULL) {
		ep_array = malloc(sizeof(*ep_array));
		if (ep_array == NULL) {
			return kr_error(ENOMEM);
		}
		if (map_set(&net->endpoints, addr, ep_array) != 0) {
			free(ep_array);
			return kr_error(ENOMEM);
		}
		array_init(*ep_array);
	}

119 120 121 122
	if (array_push(*ep_array, ep) < 0) {
		return kr_error(ENOMEM);
	}
	return kr_ok();
123 124 125 126 127 128
}

/** Open endpoint protocols. */
static int open_endpoint(struct network *net, struct endpoint *ep, struct sockaddr *sa, uint32_t flags)
{
	if (flags & NET_UDP) {
129 130 131 132 133 134
		ep->udp = malloc(sizeof(*ep->udp));
		if (!ep->udp) {
			return kr_error(ENOMEM);
		}
		handle_init(udp, net->loop, ep->udp, sa->sa_family);
		int ret = udp_bind(ep->udp, sa);
135 136 137 138 139 140
		if (ret != 0) {
			return ret;
		}
		ep->flags |= NET_UDP;
	}
	if (flags & NET_TCP) {
141 142 143 144 145 146
		ep->tcp = malloc(sizeof(*ep->tcp));
		if (!ep->tcp) {
			return kr_error(ENOMEM);
		}
		handle_init(tcp, net->loop, ep->tcp, sa->sa_family);
		int ret = tcp_bind(ep->tcp, sa);
147 148 149 150 151 152 153 154
		if (ret != 0) {
			return ret;
		}
		ep->flags |= NET_TCP;
	}
	return kr_ok();
}

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
/** @internal Fetch endpoint array and offset of the address/port query. */
static endpoint_array_t *network_get(struct network *net, const char *addr, uint16_t port, size_t *index)
{
	endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
	if (ep_array) {
		for (size_t i = ep_array->len; i--;) {
			struct endpoint *ep = ep_array->at[i];
			if (ep->port == port) {
				*index = i;
				return ep_array;
			}
		}
	}
	return NULL;
}

171 172 173 174 175 176
int network_listen(struct network *net, const char *addr, uint16_t port, uint32_t flags)
{
	if (net == NULL || addr == 0 || port == 0) {
		return kr_error(EINVAL);
	}

177
	/* Already listening */
178 179
	size_t index = 0;
	if (network_get(net, addr, port, &index)) {
180 181 182
		return kr_ok();
	}

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
	/* Parse address. */
	int ret = 0;
	struct sockaddr_storage sa;
	if (strchr(addr, ':') != NULL) {
		ret = uv_ip6_addr(addr, port, (struct sockaddr_in6 *)&sa);
	} else {
		ret = uv_ip4_addr(addr, port, (struct sockaddr_in *)&sa);
	}
	if (ret != 0) {
		return ret;
	}

	/* Bind interfaces */
	struct endpoint *ep = malloc(sizeof(*ep));
	memset(ep, 0, sizeof(*ep));
	ep->flags = NET_DOWN;
	ep->port = port;
	ret = open_endpoint(net, ep, (struct sockaddr *)&sa, flags);
	if (ret == 0) {
		ret = insert_endpoint(net, addr, ep);
	}
	if (ret != 0) {
205
		close_endpoint(ep, false);
206 207 208 209 210 211 212
	}

	return ret;
}

int network_close(struct network *net, const char *addr, uint16_t port)
{
213 214 215
	size_t index = 0;
	endpoint_array_t *ep_array = network_get(net, addr, port, &index);
	if (!ep_array) {
216 217 218 219
		return kr_error(ENOENT);
	}

	/* Close endpoint in array. */
220 221
	close_endpoint(ep_array->at[index], false);
	array_del(*ep_array, index);
222

223 224 225 226 227 228 229 230
	/* Collapse key if it has no endpoint. */
	if (ep_array->len == 0) {
		free(ep_array);
		map_del(&net->endpoints, addr);
	}

	return kr_ok();
}