Commit c4a186b9 authored by Marek Vavrusa's avatar Marek Vavrusa

Moved net-related code for utils to common.

parent 1206c78d
/* Copyright (C) 2011 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include <stdlib.h>
#include <netdb.h> // addrinfo
#include <poll.h> // poll
#include <sys/socket.h> // AF_INET (BSD)
#include <netinet/in.h> // ntohl (BSD)
#include <fcntl.h>
#include "utils/common/netio.h"
#include "utils/common/msg.h"
#include "libknot/util/descriptor.h" // KNOT_CLASS_IN
#include "common/errcode.h"
int get_socktype(const params_t *params, const uint16_t qtype)
{
switch (params->protocol) {
case PROTO_TCP:
return SOCK_STREAM;
case PROTO_UDP:
if (qtype == KNOT_RRTYPE_AXFR || qtype == KNOT_RRTYPE_IXFR) {
WARN("using UDP for zone transfer\n");
}
return SOCK_DGRAM;
default:
if (qtype == KNOT_RRTYPE_AXFR || qtype == KNOT_RRTYPE_IXFR) {
return SOCK_STREAM;
} else {
return SOCK_DGRAM;
}
}
}
int send_query(const params_t *params,
const query_t *query,
const server_t *server,
const uint8_t *data,
const size_t data_len)
{
struct addrinfo hints, *res;
struct pollfd pfd;
int sockfd;
memset(&hints, 0, sizeof hints);
// Set IP type.
if (params->ip == IP_4) {
hints.ai_family = AF_INET;
} else if (params->ip == IP_6) {
hints.ai_family = AF_INET6;
} else {
hints.ai_family = AF_UNSPEC;
}
// Set TCP or UDP.
hints.ai_socktype = get_socktype(params, query->type);
// Get connection parameters.
if (getaddrinfo(server->name, server->service, &hints, &res) != 0) {
WARN("can't use nameserver %s port %s\n",
server->name, server->service);
return -1;
}
// Create socket.
sockfd = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (sockfd == -1) {
WARN("can't create socket for nameserver %s port %s\n",
server->name, server->service);
return -1;
}
// Initialize poll descriptor structure.
pfd.fd = sockfd;
pfd.events = POLLOUT;
pfd.revents = 0;
// Set non-blocking socket.
if (fcntl(sockfd, F_SETFL, O_NONBLOCK) == -1) {
WARN("can't create non-blocking socket\n");
}
// Connect using socket.
if (connect(sockfd, res->ai_addr, res->ai_addrlen) == -1 &&
errno != EINPROGRESS) {
WARN("can't connect to nameserver %s port %s\n",
server->name, server->service);
shutdown(sockfd, 2);
return -1;
}
// Check for connection timeout.
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for connection to nameserver %s port %s\n",
server->name, server->service);
shutdown(sockfd, 2);
return -1;
}
// For TCP add leading length bytes.
if (hints.ai_socktype == SOCK_STREAM) {
uint16_t pktsize = htons(data_len);
if (send(sockfd, &pktsize, sizeof(pktsize), 0) !=
sizeof(pktsize)) {
WARN("TCP packet leading lenght\n");
}
}
// Send data.
if (send(sockfd, data, data_len, 0) != data_len) {
WARN("can't send query\n");
}
return sockfd;
}
int receive_msg(const params_t *params,
const query_t *query,
int sockfd,
uint8_t *out,
size_t out_len)
{
struct pollfd pfd;
// Initialize poll descriptor structure.
pfd.fd = sockfd;
pfd.events = POLLIN;
pfd.revents = 0;
if (get_socktype(params, query->type) == SOCK_STREAM) {
uint16_t msg_len;
uint32_t total = 0;
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for TCP message length\n");
return KNOT_ERROR;
}
if (recv(sockfd, &msg_len, sizeof(msg_len), 0) !=
sizeof(msg_len)) {
WARN("can't receive TCP message length\n");
return KNOT_ERROR;
}
// Convert number to host format.
msg_len = ntohs(msg_len);
// Receive whole answer message.
while (total < msg_len) {
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for TCP answer\n");
return KNOT_ERROR;
}
total += recv(sockfd, out + total, out_len - total, 0);
}
return msg_len;
} else {
// Wait for datagram data.
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for UDP answer\n");
return KNOT_ERROR;
}
// Receive UDP datagram.
ssize_t len = recv(sockfd, out, out_len, 0);
if (len <= 0) {
WARN("can't receive UDP answer\n");
return KNOT_ERROR;
}
return len;
}
return KNOT_EOK;
}
/* Copyright (C) 2011 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*!
* \file netio.h
*
* \author Daniel Salzman <daniel.salzman@nic.cz>
*
* \brief Networking abstraction for utilities.
*
* \addtogroup utils
* @{
*/
#ifndef _UTILS__NETIO_H_
#define _UTILS__NETIO_H_
#include <arpa/inet.h> // inet_pton
#include <sys/socket.h> // AF_INET (BSD)
#include <netinet/in.h> // in_addr (BSD)
#include "utils/common/params.h"
#include "utils/common/resolv.h"
int get_socktype(const params_t *params, const uint16_t qtype);
int send_query(const params_t *params, const query_t *query,
const server_t *server, const uint8_t *data, size_t data_len);
int receive_msg(const params_t *params, const query_t *query,
int sockfd, uint8_t *out, size_t out_len);
#endif // _UTILS__NETIO_H_
......@@ -17,13 +17,11 @@
#include "utils/common/params.h"
#include <stdlib.h> // free
#include <arpa/inet.h> // inet_pton
#include <sys/socket.h> // AF_INET (BSD)
#include <netinet/in.h> // in_addr (BSD)
#include "common/errcode.h" // KNOT_EOK
#include "libknot/util/descriptor.h" // KNOT_RRTYPE
#include "utils/common/msg.h" // WARN
#include "utils/common/netio.h"
#define IPV4_REVERSE_DOMAIN "in-addr.arpa."
#define IPV6_REVERSE_DOMAIN "ip6.arpa."
......
......@@ -57,6 +57,49 @@ typedef enum {
PROTO_UDP
} protocol_t;
#define DEFAULT_WAIT_INTERVAL 1
/*! \brief Types of host operation mode. */
typedef enum {
/*!< Classic query for name-class-type. */
HOST_MODE_DEFAULT,
/*!< Query for NS and all authoritative SOA records. */
HOST_MODE_LIST_SERIALS,
} host_mode_t;
/*! \brief Structure containing parameters for host. */
typedef struct {
/*!< List of nameservers to query to. */
list servers;
/*!< List of DNS queries to process. */
list queries;
/*!< Operation mode. */
host_mode_t mode;
/*!< Version of ip protocol to use. */
ip_version_t ip;
/*!< Type (TCP, UDP) protocol to use. */
protocol_t protocol;
/*!< Default class number. */
uint16_t class_num;
/*!< Default type number (16unsigned + -1 uninitialized). */
int32_t type_num;
/*!< SOA serial for IXFR query (32unsigned + -1 uninitialized). */
int64_t ixfr_serial;
/*!< Use recursion. */
bool recursion;
/*!< UDP buffer size. */
uint32_t udp_size;
/*!< Number of UDP retries. */
uint32_t retries;
/*!< Wait for reply in seconds (-1 means forever). */
int32_t wait;
/*!< Stop quering if servfail. */
bool servfail_stop;
/*!< Verbose mode. */
bool verbose;
} params_t;
query_t* create_query(const char *name, const uint16_t type);
void query_free(query_t *query);
......
......@@ -18,10 +18,6 @@
#include <stdlib.h> // free
#include <fcntl.h> // fcntl
#include <netdb.h> // addrinfo
#include <poll.h> // poll
#include <sys/socket.h> // AF_INET (BSD)
#include <netinet/in.h> // ntohl (BSD)
#include "common/lists.h" // list
#include "common/errcode.h" // KNOT_EOK
......@@ -32,28 +28,10 @@
#include "utils/common/msg.h" // WARN
#include "utils/common/resolv.h" // server_t
#include "utils/host/host_params.h" // host_params_t
#include "utils/host/host_params.h" // params_t
#include "utils/common/netio.h"
static bool use_tcp(const host_params_t *params, const uint16_t type)
{
switch (params->protocol) {
case PROTO_TCP:
return true;
case PROTO_UDP:
if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
WARN("using UDP for zone transfer\n");
}
return false;
default:
if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
return true;
} else {
return false;
}
}
}
static bool use_recursion(const host_params_t *params, const uint16_t type)
static bool use_recursion(const params_t *params, const uint16_t type)
{
if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
return false;
......@@ -66,7 +44,7 @@ static bool use_recursion(const host_params_t *params, const uint16_t type)
}
}
static knot_packet_t* create_query_packet(const host_params_t *params,
static knot_packet_t* create_query_packet(const params_t *params,
const query_t *query,
uint8_t **data,
size_t *data_len)
......@@ -81,7 +59,7 @@ static knot_packet_t* create_query_packet(const host_params_t *params,
}
// Set packet buffer size.
if (use_tcp(params, query->type) == true) {
if (get_socktype(params, query->type) == SOCK_STREAM) {
// For TCP maximal dns packet size.
knot_packet_set_max_size(packet, MAX_PACKET_SIZE);
} else {
......@@ -187,160 +165,7 @@ static knot_packet_t* create_query_packet(const host_params_t *params,
return packet;
}
static int send_query(const host_params_t *params,
const query_t *query,
const server_t *server,
const uint8_t *data,
const size_t data_len)
{
struct addrinfo hints, *res;
struct pollfd pfd;
int sockfd;
memset(&hints, 0, sizeof hints);
// Set IP type.
if (params->ip == IP_4) {
hints.ai_family = AF_INET;
} else if (params->ip == IP_6) {
hints.ai_family = AF_INET6;
} else {
hints.ai_family = AF_UNSPEC;
}
// Set TCP or UDP.
if (use_tcp(params, query->type) == true) {
hints.ai_socktype = SOCK_STREAM;
} else {
hints.ai_socktype = SOCK_DGRAM;
}
// Get connection parameters.
if (getaddrinfo(server->name, server->service, &hints, &res) != 0) {
WARN("can't use nameserver %s port %s\n",
server->name, server->service);
return -1;
}
// Create socket.
sockfd = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (sockfd == -1) {
WARN("can't create socket for nameserver %s port %s\n",
server->name, server->service);
return -1;
}
// Initialize poll descriptor structure.
pfd.fd = sockfd;
pfd.events = POLLOUT;
pfd.revents = 0;
// Set non-blocking socket.
if (fcntl(sockfd, F_SETFL, O_NONBLOCK) == -1) {
WARN("can't create non-blocking socket\n");
}
// Connect using socket.
if (connect(sockfd, res->ai_addr, res->ai_addrlen) == -1 &&
errno != EINPROGRESS) {
WARN("can't connect to nameserver %s port %s\n",
server->name, server->service);
shutdown(sockfd, 2);
return -1;
}
// Check for connection timeout.
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for connection to nameserver %s port %s\n",
server->name, server->service);
shutdown(sockfd, 2);
return -1;
}
// For TCP add leading length bytes.
if (use_tcp(params, query->type) == true) {
uint16_t pktsize = htons(data_len);
if (send(sockfd, &pktsize, sizeof(pktsize), 0) !=
sizeof(pktsize)) {
WARN("TCP packet leading lenght\n");
}
}
// Send data.
if (send(sockfd, data, data_len, 0) != data_len) {
WARN("can't send query\n");
}
return sockfd;
}
static int receive_msg(const host_params_t *params,
const query_t *query,
const int sockfd,
uint8_t *out,
const size_t out_len)
{
struct pollfd pfd;
// Initialize poll descriptor structure.
pfd.fd = sockfd;
pfd.events = POLLIN;
pfd.revents = 0;
if (use_tcp(params, query->type) == true) {
uint16_t msg_len;
uint32_t total = 0;
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for TCP message length\n");
return KNOT_ERROR;
}
if (recv(sockfd, &msg_len, sizeof(msg_len), 0) !=
sizeof(msg_len)) {
WARN("can't receive TCP message length\n");
return KNOT_ERROR;
}
// Convert number to host format.
msg_len = ntohs(msg_len);
// Receive whole answer message.
while (total < msg_len) {
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for TCP answer\n");
return KNOT_ERROR;
}
total += recv(sockfd, out + total, out_len - total, 0);
}
return msg_len;
} else {
// Wait for datagram data.
if (poll(&pfd, 1, 1000 * params->wait) != 1) {
WARN("can't wait for UDP answer\n");
return KNOT_ERROR;
}
// Receive UDP datagram.
ssize_t len = recv(sockfd, out, out_len, 0);
if (len <= 0) {
WARN("can't receive UDP answer\n");
return KNOT_ERROR;
}
return len;
}
return KNOT_EOK;
}
static int process_query(const host_params_t *params, const query_t *query)
static int process_query(const params_t *params, const query_t *query)
{
const size_t out_len = MAX_PACKET_SIZE;
uint8_t out[out_len];
......@@ -381,7 +206,7 @@ static int process_query(const host_params_t *params, const query_t *query)
return KNOT_EOK;
}
int host_exec(const host_params_t *params)
int host_exec(const params_t *params)
{
node *query = NULL;
......
......@@ -32,7 +32,7 @@
#include "utils/host/host_params.h"
int host_exec(const host_params_t *params);
int host_exec(const params_t *params);
#endif // _HOST__HOST_EXEC_H_
......
......@@ -17,12 +17,12 @@
#include <stdlib.h> // EXIT_FAILURE
#include "common/errcode.h" // KNOT_EOK
#include "utils/host/host_params.h" // host_params_t
#include "utils/host/host_params.h" // params_t
#include "utils/host/host_exec.h" // host_exec
int main(int argc, char *argv[])
{
host_params_t params;
params_t params;
if (host_params_parse(&params, argc, argv) != KNOT_EOK) {
return EXIT_FAILURE;
......
......@@ -27,8 +27,9 @@
#include "utils/common/msg.h" // WARN
#include "utils/common/params.h" // parse_class
#include "utils/common/resolv.h" // get_nameservers
#include "utils/common/netio.h"
static void host_params_init(host_params_t *params)
static void host_params_init(params_t *params)
{
memset(params, 0, sizeof(*params));
......@@ -55,7 +56,7 @@ static void host_params_init(host_params_t *params)
params->verbose = false;
}
void host_params_clean(host_params_t *params)
void host_params_clean(params_t *params)
{
node *n = NULL, *nxt = NULL;
......@@ -77,54 +78,54 @@ void host_params_clean(host_params_t *params)
memset(params, 0, sizeof(*params));
}
static void host_params_flag_all(host_params_t *params)
static void host_params_flag_all(params_t *params)
{
params->type_num = KNOT_RRTYPE_ANY;
params->verbose = true;
}
static void host_params_flag_soa(host_params_t *params)
static void host_params_flag_soa(params_t *params)
{
params->type_num = KNOT_RRTYPE_SOA;
params->mode = HOST_MODE_LIST_SERIALS;
}
static void host_params_flag_axfr(host_params_t *params)
static void host_params_flag_axfr(params_t *params)
{
params->type_num = KNOT_RRTYPE_AXFR;
}
static void host_params_flag_nonrecursive(host_params_t *params)
static void host_params_flag_nonrecursive(params_t *params)
{
params->recursion = false;
}
static void host_params_flag_tcp(host_params_t *params)
static void host_params_flag_tcp(params_t *params)
{
params->protocol = PROTO_TCP;
}
static void host_params_flag_ipv4(host_params_t *params)
static void host_params_flag_ipv4(params_t *params)
{
params->ip = IP_4;
}
static void host_params_flag_ipv6(host_params_t *params)
static void host_params_flag_ipv6(params_t *params)
{
params->ip = IP_6;
}
static void host_params_flag_servfail(host_params_t *params)
static void host_params_flag_servfail(params_t *params)
{
params->servfail_stop = true;
}
static void host_params_flag_verbose(host_params_t *params)
static void host_params_flag_verbose(params_t *params)
{
params->verbose = true;
}
static void host_params_flag_nowait(host_params_t *params)
static void host_params_flag_nowait(params_t *params)
{
params->wait = -1;
}
......@@ -172,7 +173,7 @@ static int host_params_parse_wait(const char *value, int32_t *wait)
return KNOT_EOK;
}
static int host_params_parse_name(host_params_t *params, const char *name)
static int host_params_parse_name(params_t *params, const char *name)
{
char *reverse = get_reverse_name(name);
query_t *query;
......@@ -242,7 +243,7 @@ static int host_params_parse_name(host_params_t *params, const char *name)
return KNOT_EOK;
}
static int host_params_parse_server(host_params_t *params, const char *name)
static int host_params_parse_server(params_t *params, const char *name)
{
node *n = NULL, *nxt = NULL;
......@@ -270,7 +271,7 @@ static void host_params_help(int argc, char *argv[])
printf("Usage: %s [-aCdvlrT] [-4] [-6] [-c class] [-t type] {name} [server]\n", argv[0]);
}
int host_params_parse(host_params_t *params, int argc, char *argv[])
int host_params_parse(params_t *params, int argc, char *argv[])
{
int opt = 0;
......
......@@ -33,52 +33,11 @@
#include "common/lists.h" // list
#include "utils/common/params.h" // protocol_t
#define DEFAULT_WAIT_INTERVAL 1
/*! \brief Types of host operation mode. */
typedef enum {
/*!< Classic query for name-class-type. */
HOST_MODE_DEFAULT,
/*!< Query for NS and all authoritative SOA records. */
HOST_MODE_LIST_SERIALS,
} host_mode_t;
/*! \brief Structure containing parameters for host. */
typedef struct {
/*!< List of nameservers to query to. */
list servers;
/*!< List of DNS queries to process. */
list queries;
int host_params_parse(params_t *params, int argc, char *argv[]);
/*!< Operation mode. */
host_mode_t mode;
/*!< Version of ip protocol to use. */
ip_version_t ip;
/*!< Type (TCP, UDP) protocol to use. */
protocol_t protocol;
/*!< Default class number. */
uint16_t class_num;
/*!< Default type number (16unsigned + -1 uninitialized). */
int32_t type_num;
/*!< SOA serial for IXFR query (32unsigned + -1 uninitialized). */
int64_t ixfr_serial;
/*!< Use recursion. */
bool recursion;
/*!< UDP buffer size. */
uint32_t udp_size;
/*!< Number of UDP retries. */
uint32_t retries;
/*!< Wait for reply in seconds (-1 means forever). */
int32_t wait;
/*!< Stop quering if servfail. */
bool servfail_stop;
/*!< Verbose mode. */
bool verbose;
} host_params_t;
int host_params_parse(host_params_t *params, int argc, char *argv[]);
void host_params_clean(host_params_t *params);
void host_params_clean(params_t *params);
#endif // _HOST__HOST_PARAMS_H_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment