Commit 11663762 authored by Marek Vavruša's avatar Marek Vavruša

tcp: don't rely on the MSG_WAITALL, recv() loop instead

The MSG_WAITALL doesn't work as expected everywhere and
can return partial answer (for example with signal interrupt).
Also implemented timeouts for TCP receive.
parent 348e9908
......@@ -139,7 +139,7 @@ static int cmd_remote_reply(int c)
}
/* Read response packet. */
int n = tcp_recv(c, pkt->wire, pkt->max_size, NULL);
int n = tcp_recv_msg(c, pkt->wire, pkt->max_size, NULL);
if (n <= 0) {
dbg_server("remote: couldn't receive response = %s\n", knot_strerror(n));
knot_pkt_free(&pkt);
......@@ -258,7 +258,7 @@ static int cmd_remote(const char *cmd, uint16_t rrt, int argc, char *argv[])
}
/* Send and free packet. */
int ret = tcp_send(s, pkt->wire, pkt->size);
int ret = tcp_send_msg(s, pkt->wire, pkt->size);
knot_pkt_free(&pkt);
/* Evaluate and wait for reply. */
......
......@@ -399,7 +399,7 @@ static int remote_senderr(int c, uint8_t *qbuf, size_t buflen)
{
knot_wire_set_qr(qbuf);
knot_wire_set_rcode(qbuf, KNOT_RCODE_REFUSED);
return tcp_send(c, qbuf, buflen);
return tcp_send_msg(c, qbuf, buflen);
}
/* Public APIs. */
......@@ -472,7 +472,7 @@ int remote_recv(int sock, struct sockaddr *addr, uint8_t* buf, size_t *buflen)
}
/* Receive data. */
int n = tcp_recv(c, buf, *buflen, addr);
int n = tcp_recv_msg(c, buf, *buflen, NULL);
*buflen = n;
if (n <= 0) {
dbg_server("remote: failed to receive data\n");
......@@ -521,7 +521,7 @@ static int remote_send_chunk(int c, knot_pkt_t *query, const char* d, uint16_t l
goto failed;
}
ret = tcp_send(c, resp->wire, resp->size);
ret = tcp_send_msg(c, resp->wire, resp->size);
failed:
......
......@@ -46,12 +46,13 @@ static struct request *request_make(mm_ctx_t *mm)
return request;
}
static void request_close(mm_ctx_t *mm, struct request *request)
static int request_close(mm_ctx_t *mm, struct request *request)
{
/* Reset processing if didn't complete. */
if (request->state != NS_PROC_DONE) {
knot_process_reset(&request->process);
}
knot_process_finish(&request->process);
rem_node(&request->data.node);
......@@ -59,6 +60,8 @@ static void request_close(mm_ctx_t *mm, struct request *request)
knot_pkt_free(&request->data.query);
mm_free(mm, request->pkt_buf);
mm_free(mm, request);
return KNOT_EOK;
}
/*! \brief Wait for socket readiness. */
......@@ -78,10 +81,13 @@ static int request_wait(int fd, int state, struct timeval *timeout)
}
}
static int request_send(struct request *request, struct timeval *timeout)
static int request_send(struct request *request, const struct timeval *timeout)
{
/* Each request has unique timeout. */
struct timeval tv = { timeout->tv_sec, timeout->tv_usec };
/* Wait for writeability. */
int ret = request_wait(request->data.fd, NS_PROC_FULL, timeout);
int ret = request_wait(request->data.fd, NS_PROC_FULL, &tv);
if (ret <= 0) {
return KNOT_EAGAIN;
}
......@@ -96,26 +102,24 @@ static int request_send(struct request *request, struct timeval *timeout)
/* Send query. */
knot_pkt_t *query = request->data.query;
ret = tcp_send(request->data.fd, query->wire, query->size);
if (ret <= 0) {
ret = tcp_send_msg(request->data.fd, query->wire, query->size);
if (ret != query->size) {
return KNOT_ECONN;
}
return KNOT_EOK;
}
static int request_recv(struct request *request, struct timeval *timeout)
static int request_recv(struct request *request, const struct timeval *timeout)
{
/* Wait for response. */
int ret = request_wait(request->data.fd, NS_PROC_MORE, timeout);
if (ret <= 0) {
return NS_PROC_FAIL;
}
/* Each request has unique timeout. */
struct timeval tv = { timeout->tv_sec, timeout->tv_usec };
/* Receive it */
ret = tcp_recv(request->data.fd, request->pkt_buf, KNOT_WIRE_MAX_PKTSIZE, NULL);
if (ret <= 0) {
return NS_PROC_FAIL;
int ret = tcp_recv_msg(request->data.fd, request->pkt_buf,
KNOT_WIRE_MAX_PKTSIZE, &tv);
if (ret < 0) {
return ret;
}
return ret;
......@@ -195,8 +199,7 @@ int requestor_dequeue(struct requestor *requestor)
}
struct request *last = HEAD(requestor->pending);
request_close(requestor->mm, last);
return KNOT_EOK;
return request_close(requestor->mm, last);
}
static int exec_request(struct request *last, struct timeval *timeout)
......@@ -215,14 +218,16 @@ static int exec_request(struct request *last, struct timeval *timeout)
/* Receive and process expected answers. */
while (last->state == NS_PROC_MORE) {
int rcvd = request_recv(last, timeout);
if (rcvd <= 0) {
if (rcvd < 0) {
return rcvd;
}
last->state = knot_process_in(last->pkt_buf, rcvd, &last->process);
if (last->state == NS_PROC_FAIL) {
return KNOT_EMALF;
}
}
/* Expect complete request. */
if (last->state != NS_PROC_DONE) {
return KNOT_EMALF;
}
return KNOT_EOK;
......
......@@ -84,7 +84,6 @@ static enum fdset_sweep_state tcp_sweep(fdset_t *set, int i, void *data)
log_server_notice("Connection '%s' was terminated due to inactivity.\n",
addr_str);
close(fd);
return FDSET_SWEEP;
}
......@@ -105,8 +104,17 @@ static int tcp_handle(tcp_context_t *tcp, int fd,
rx->iov_len = KNOT_WIRE_MAX_PKTSIZE;
tx->iov_len = KNOT_WIRE_MAX_PKTSIZE;
/* Receive peer name. */
socklen_t addrlen = sizeof(struct sockaddr_storage);
if (getpeername(fd, (struct sockaddr *)&ss, &addrlen) < 0) {
;
}
/* Timeout. */
struct timeval tmout = { conf()->max_conn_reply, 0 };
/* Receive data. */
int ret = tcp_recv(fd, rx->iov_base, rx->iov_len, (struct sockaddr *)&ss);
int ret = tcp_recv_msg(fd, rx->iov_base, rx->iov_len, &tmout);
if (ret <= 0) {
dbg_net("tcp: client on fd=%d disconnected\n", fd);
if (ret == KNOT_EAGAIN) {
......@@ -137,7 +145,7 @@ static int tcp_handle(tcp_context_t *tcp, int fd,
/* If it has response, send it. */
if (tx_len > 0) {
if (tcp_send(fd, tx->iov_base, tx_len) != tx_len) {
if (tcp_send_msg(fd, tx->iov_base, tx_len) != tx_len) {
ret = KNOT_ECONNREFUSED;
break;
}
......@@ -193,7 +201,56 @@ int tcp_accept(int fd)
return incoming;
}
int tcp_send(int fd, const uint8_t *msg, size_t msglen)
/*! \brief Wait for data and return true if data arrived. */
static int tcp_wait_for_data(int fd, struct timeval *timeout)
{
fd_set set;
FD_ZERO(&set);
FD_SET(fd, &set);
return select(fd + 1, &set, NULL, NULL, timeout);
}
int tcp_recv_data(int fd, uint8_t *buf, int len, struct timeval *timeout)
{
int ret = 0;
int rcvd = 0;
int flags = 0;
#ifdef MSG_NOSIGNAL
flags |= MSG_NOSIGNAL;
#endif
while (rcvd < len) {
/* Receive data. */
ret = recv(fd, buf + rcvd, len - rcvd, flags);
if (ret > 0) {
rcvd += ret;
continue;
}
/* Check for disconnected socket. */
if (ret == 0) {
return KNOT_ECONNREFUSED;
}
/* Check for no data available. */
if (errno == EAGAIN || errno == EINTR) {
/* Continue only if timeout didn't expire. */
ret = tcp_wait_for_data(fd, timeout);
if (ret) {
continue;
} else {
return KNOT_EAGAIN;
}
} else {
return KNOT_ECONN;
}
}
return rcvd;
}
int tcp_send_msg(int fd, const uint8_t *msg, size_t msglen)
{
/* Create iovec for gathered write. */
struct iovec iov[2];
......@@ -207,62 +264,41 @@ int tcp_send(int fd, const uint8_t *msg, size_t msglen)
int total_len = iov[0].iov_len + iov[1].iov_len;
int sent = writev(fd, iov, 2);
if (sent != total_len) {
return KNOT_ERROR;
return KNOT_ECONN;
}
return msglen; /* Do not count the size prefix. */
}
int tcp_recv(int fd, uint8_t *buf, size_t len, struct sockaddr *addr)
int tcp_recv_msg(int fd, uint8_t *buf, size_t len, struct timeval *timeout)
{
/* Flags. */
int flags = MSG_WAITALL;
#ifdef MSG_NOSIGNAL
flags |= MSG_NOSIGNAL;
#endif
if (buf == NULL || fd < 0) {
return KNOT_EINVAL;
}
/* Receive size. */
unsigned short pktsize = 0;
int n = recv(fd, &pktsize, sizeof(unsigned short), flags);
if (n < 0) {
if (errno == EAGAIN) {
return KNOT_EAGAIN;
} else {
return KNOT_ERROR;
}
int ret = tcp_recv_data(fd, (uint8_t *)&pktsize, sizeof(pktsize), timeout);
if (ret != sizeof(pktsize)) {
return ret;
}
pktsize = ntohs(pktsize);
dbg_net("tcp: incoming packet size=%hu on fd=%d\n",
pktsize, fd);
dbg_net("tcp: incoming packet size=%hu on fd=%d\n", pktsize, fd);
// Check packet size
if (len < pktsize) {
return KNOT_ENOMEM;
}
/* Get peer name. */
if (addr) {
socklen_t addrlen = sizeof(struct sockaddr_storage);
if (getpeername(fd, addr, &addrlen) < 0) {
return KNOT_EMALF;
}
}
/* Receive payload. */
n = recv(fd, buf, pktsize, flags);
if (n < 0) {
if (errno == EAGAIN) {
return KNOT_EAGAIN;
} else {
return KNOT_ERROR;
}
ret = tcp_recv_data(fd, buf, pktsize, timeout);
if (ret != pktsize) {
return ret;
}
dbg_net("tcp: received packet size=%d on fd=%d\n",
n, fd);
return n;
dbg_net("tcp: received packet size=%d on fd=%d\n", ret, fd);
return ret;
}
static int tcp_event_accept(tcp_context_t *tcp, unsigned i)
......
......@@ -50,7 +50,19 @@
int tcp_accept(int fd);
/*!
* \brief Send TCP message.
* \brief Receive a block of data from TCP socket with wait.
*
* \param fd File descriptor.
* \param buf Data buffer.
* \param len Block length.
* \param timeout Timeout for the operation, NULL for infinite.
*
* \return number of bytes received or an error
*/
int tcp_recv_data(int fd, uint8_t *buf, int len, struct timeval *timeout);
/*!
* \brief Send a TCP message.
*
* \param fd Associated socket.
* \param msg Buffer for a query wireformat.
......@@ -59,21 +71,21 @@ int tcp_accept(int fd);
* \retval Number of sent data on success.
* \retval KNOT_ERROR on error.
*/
int tcp_send(int fd, const uint8_t *msg, size_t msglen);
int tcp_send_msg(int fd, const uint8_t *msg, size_t msglen);
/*!
* \brief Receive TCP message.
* \brief Receive a TCP message.
*
* \param fd Associated socket.
* \param buf Buffer for incoming bytestream.
* \param len Buffer maximum size.
* \param addr Source address.
* \param timeout Message receive timeout.
*
* \retval Number of read bytes on success.
* \retval KNOT_ERROR on error.
* \retval KNOT_ENOMEM on potential buffer overflow.
*/
int tcp_recv(int fd, uint8_t *buf, size_t len, struct sockaddr *addr);
int tcp_recv_msg(int fd, uint8_t *buf, size_t len, struct timeval *timeout);
/*!
* \brief TCP handler thread runnable.
......
......@@ -46,13 +46,13 @@ static void* responder_thread(void *arg)
if (client < 0) {
break;
}
int len = tcp_recv(client, buf, sizeof(buf), NULL);
int len = tcp_recv_msg(client, buf, sizeof(buf), NULL);
if (len < KNOT_WIRE_HEADER_SIZE) {
close(client);
break;
}
knot_wire_set_qr(buf);
tcp_send(client, buf, len);
tcp_send_msg(client, buf, len);
close(client);
}
return NULL;
......@@ -155,7 +155,7 @@ int main(int argc, char *argv[])
/* Terminate responder. */
int responder = net_connected_socket(SOCK_STREAM, &remote.addr, NULL, 0);
assert(responder > 0);
tcp_send(responder, (const uint8_t *)"", 1);
tcp_send_msg(responder, (const uint8_t *)"", 1);
(void) pthread_join(thread, 0);
close(responder);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment