requestor.c 7.61 KB
Newer Older
1
/*  Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

17 18
#include <assert.h>

19
#include "libknot/attribute.h"
20
#include "knot/query/requestor.h"
Daniel Salzman's avatar
Daniel Salzman committed
21
#include "libknot/errcode.h"
22
#include "contrib/mempattern.h"
23 24
#include "contrib/net.h"
#include "contrib/sockaddr.h"
25

26 27
static bool use_tcp(struct knot_request *request)
{
28
	return (request->flags & KNOT_RQ_UDP) == 0;
29 30
}

Jan Včelák's avatar
Jan Včelák committed
31 32 33 34 35
static bool is_answer_to_query(const knot_pkt_t *query, const knot_pkt_t *answer)
{
	return knot_wire_get_id(query->wire) == knot_wire_get_id(answer->wire);
}

36
/*! \brief Ensure a socket is connected. */
37
static int request_ensure_connected(struct knot_request *request)
38
{
39 40 41 42 43
	if (request->fd >= 0) {
		return KNOT_EOK;
	}

	int sock_type = use_tcp(request) ? SOCK_STREAM : SOCK_DGRAM;
44 45 46
	request->fd = net_connected_socket(sock_type,
	                                  (struct sockaddr *)&request->remote,
	                                  (struct sockaddr *)&request->source);
47
	if (request->fd < 0) {
48
		return KNOT_ECONN;
49
	}
50 51

	return KNOT_EOK;
52 53
}

54
static int request_send(struct knot_request *request, int timeout_ms)
55
{
56
	/* Initiate non-blocking connect if not connected. */
57
	int ret = request_ensure_connected(request);
58 59
	if (ret != KNOT_EOK) {
		return ret;
60 61
	}

62
	/* Send query, construct if not exists. */
63 64
	knot_pkt_t *query = request->query;
	uint8_t *wire = query->wire;
65
	size_t wire_len = query->size;
66

67 68
	/* Send query. */
	if (use_tcp(request)) {
69
		ret = net_dns_tcp_send(request->fd, wire, wire_len, timeout_ms);
70
	} else {
71
		ret = net_dgram_send(request->fd, wire, wire_len, NULL);
72
	}
73
	if (ret != wire_len) {
74 75 76 77 78 79
		return KNOT_ECONN;
	}

	return KNOT_EOK;
}

80
static int request_recv(struct knot_request *request, int timeout_ms)
81
{
82 83 84
	knot_pkt_t *resp = request->resp;
	knot_pkt_clear(resp);

85
	/* Wait for readability */
86
	int ret = request_ensure_connected(request);
87 88
	if (ret != KNOT_EOK) {
		return ret;
89 90
	}

91
	/* Receive it */
92
	if (use_tcp(request)) {
93
		ret = net_dns_tcp_recv(request->fd, resp->wire, resp->max_size, timeout_ms);
94
	} else {
95
		ret = net_dgram_recv(request->fd, resp->wire, resp->max_size, timeout_ms);
96
	}
97
	if (ret <= 0) {
98
		resp->size = 0;
99 100 101
		if (ret == 0) {
			return KNOT_ECONN;
		}
102
		return ret;
103 104
	}

105
	resp->size = ret;
106
	return ret;
107 108
}

109
struct knot_request *knot_request_make(knot_mm_t *mm,
110 111
                                       const struct sockaddr *remote,
                                       const struct sockaddr *source,
112
                                       knot_pkt_t *query,
113
                                       const knot_tsig_key_t *tsig_key,
114
                                       unsigned flags)
115
{
116
	if (remote == NULL || query == NULL) {
117 118 119
		return NULL;
	}

120
	struct knot_request *request = mm_alloc(mm, sizeof(*request));
121 122 123
	if (request == NULL) {
		return NULL;
	}
124
	memset(request, 0, sizeof(*request));
125

126 127 128 129
	request->resp = knot_pkt_new(NULL, KNOT_WIRE_MAX_PKTSIZE, mm);
	if (request->resp == NULL) {
		mm_free(mm, request);
		return NULL;
130 131 132
	}

	request->query = query;
133
	request->fd = -1;
134
	request->flags = flags;
135 136 137 138 139 140 141
	memcpy(&request->remote, remote, sockaddr_len(remote));
	if (source) {
		memcpy(&request->source, source, sockaddr_len(source));
	} else {
		request->source.ss_family = AF_UNSPEC;
	}

142 143 144
	if (tsig_key && tsig_key->algorithm == DNSSEC_TSIG_UNKNOWN) {
		tsig_key = NULL;
	}
145 146
	tsig_init(&request->tsig, tsig_key);

147 148 149
	return request;
}

150
void knot_request_free(struct knot_request *request, knot_mm_t *mm)
151
{
152
	if (request == NULL) {
153
		return;
154 155
	}

156 157 158
	if (request->fd >= 0) {
		close(request->fd);
	}
159 160
	knot_pkt_free(request->query);
	knot_pkt_free(request->resp);
161
	tsig_cleanup(&request->tsig);
162

163
	mm_free(mm, request);
164 165
}

166 167 168
int knot_requestor_init(struct knot_requestor *requestor,
                        const knot_layer_api_t *proc, void *proc_param,
                        knot_mm_t *mm)
169
{
170
	if (requestor == NULL || proc == NULL) {
171
		return KNOT_EINVAL;
172 173
	}

174
	memset(requestor, 0, sizeof(*requestor));
175

176
	requestor->mm = mm;
177
	knot_layer_init(&requestor->layer, mm, proc);
178
	knot_layer_begin(&requestor->layer, proc_param);
179 180

	return KNOT_EOK;
181 182
}

183
void knot_requestor_clear(struct knot_requestor *requestor)
184
{
185 186 187 188
	if (requestor == NULL) {
		return;
	}

189
	knot_layer_finish(&requestor->layer);
190

191
	memset(requestor, 0, sizeof(*requestor));
192 193
}

194 195
static int request_reset(struct knot_requestor *req,
                         struct knot_request *last)
196
{
197 198 199
	knot_layer_reset(&req->layer);
	tsig_reset(&last->tsig);

200 201 202 203 204 205 206 207
	if (req->layer.flags & KNOT_RQ_LAYER_CLOSE) {
		req->layer.flags &= ~KNOT_RQ_LAYER_CLOSE;
		if (last->fd >= 0) {
			close(last->fd);
			last->fd = -1;
		}
	}

208 209 210
	if (req->layer.state == KNOT_STATE_RESET) {
		return KNOT_LAYER_ERROR;
	}
211

212 213
	return KNOT_EOK;
}
214

215 216 217 218 219
static int request_produce(struct knot_requestor *req,
                           struct knot_request *last,
                           int timeout_ms)
{
	knot_layer_produce(&req->layer, last->query);
220

221 222 223
	int ret = tsig_sign_packet(&last->tsig, last->query);
	if (ret != KNOT_EOK) {
		return ret;
224 225
	}

226
	// TODO: verify condition
227
	if (req->layer.state == KNOT_STATE_CONSUME) {
228 229 230 231 232
		ret = request_send(last, timeout_ms);
	}

	return ret;
}
233

234 235 236 237 238 239 240
static int request_consume(struct knot_requestor *req,
                           struct knot_request *last,
                           int timeout_ms)
{
	int ret = request_recv(last, timeout_ms);
	if (ret < 0) {
		return ret;
241 242
	}

243 244 245 246 247
	ret = knot_pkt_parse(last->resp, 0);
	if (ret != KNOT_EOK) {
		return ret;
	}

Jan Včelák's avatar
Jan Včelák committed
248 249 250 251
	if (!is_answer_to_query(last->query, last->resp)) {
		return KNOT_EMALF;
	}

252 253 254 255 256
	ret = tsig_verify_packet(&last->tsig, last->resp);
	if (ret != KNOT_EOK) {
		return ret;
	}

257 258 259 260
	if (tsig_unsigned_count(&last->tsig) >= 100) {
		return KNOT_TSIG_EBADSIG;
	}

261 262
	knot_layer_consume(&req->layer, last->resp);

263 264 265
	return KNOT_EOK;
}

266
static bool layer_active(knot_layer_state_t state)
267 268 269
{
	switch (state) {
	case KNOT_STATE_CONSUME:
270 271
	case KNOT_STATE_PRODUCE:
	case KNOT_STATE_RESET:
272 273 274 275 276 277
		return true;
	default:
		return false;
	}
}

278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
static int request_io(struct knot_requestor *req, struct knot_request *last,
                      int timeout_ms)
{
	switch (req->layer.state) {
	case KNOT_STATE_CONSUME:
		return request_consume(req, last, timeout_ms);
	case KNOT_STATE_PRODUCE:
		return request_produce(req, last, timeout_ms);
	case KNOT_STATE_RESET:
		return request_reset(req, last);
	default:
		return KNOT_EINVAL;
	}
}

293 294
int knot_requestor_exec(struct knot_requestor *requestor,
                        struct knot_request *request,
295
                        int timeout_ms)
296
{
297 298 299 300
	if (!requestor || !request) {
		return KNOT_EINVAL;
	}

301 302
	int ret = KNOT_EOK;

303 304
	requestor->layer.tsig = &request->tsig;

305
	/* Do I/O until the processing is satisifed or fails. */
306
	while (layer_active(requestor->layer.state)) {
307
		ret = request_io(requestor, request, timeout_ms);
308
		if (ret != KNOT_EOK) {
309
			knot_layer_finish(&requestor->layer);
310 311
			return ret;
		}
312 313 314
	}

	/* Expect complete request. */
315
	if (requestor->layer.state != KNOT_STATE_DONE) {
316
		ret = KNOT_LAYER_ERROR;
317 318
	}

319 320
	/* Verify last TSIG */
	if (tsig_unsigned_count(&request->tsig) != 0) {
321
		ret = KNOT_TSIG_EBADSIG;
322 323
	}

324
	/* Finish current query processing. */
325
	knot_layer_finish(&requestor->layer);
326 327 328

	return ret;
}