diff --git a/nbase/nbase.h b/nbase/nbase.h index 3869b2d8a..f6baf1091 100644 --- a/nbase/nbase.h +++ b/nbase/nbase.h @@ -562,11 +562,15 @@ char *executable_path(const char *argv0); /* addrset management functions and definitions */ /* A set of addresses. Used to match against allow/deny lists. */ struct addrset_elem; +/* A radix tree (trie) used to match quickly against allow/deny lists. */ +struct trie_node; /* A set of addresses. Used to match against allow/deny lists. */ struct addrset { /* Linked list of struct addset_elem. */ struct addrset_elem *head; + /* Radix tree for faster matching of certain cases */ + struct trie_node *trie; }; void nbase_set_log(void (*log_user_func)(const char *, ...),void (*log_debug_func)(const char *, ...)); diff --git a/nbase/nbase_addrset.c b/nbase/nbase_addrset.c index c9985317d..82aef9096 100644 --- a/nbase/nbase_addrset.c +++ b/nbase/nbase_addrset.c @@ -156,19 +156,334 @@ void nbase_set_log(void (*log_user_func)(const char *, ...),void (*log_debug_fun log_debug = log_debug_func; } +/* Node for a radix tree (trie) used to match certain addresses. + * Currently, only individual numeric IP and IPv6 addresses are matched using + * the trie. */ +struct trie_node { + /* The address prefix that this node represents. */ + u32 addr[4]; + /* The prefix mask. Bits in addr that are not within this mask are ignored. */ + u32 mask[4]; + /* Addresses with the next bit after the mask equal to 1 are on this branch. */ + struct trie_node *next_bit_one; + /* Addresses with the next bit after the mask equal to 0 are on this branch. */ + struct trie_node *next_bit_zero; +}; + +/* Special node pointer to represent "all possible addresses" + * This will be used to represent netmask specifications. */ +static struct trie_node *TRIE_NODE_TRUE = NULL; + void addrset_init(struct addrset *set) { set->head = NULL; + /* We could simply allocate one byte to get a unique address, but this + * feels safer and is not too large. */ + if (TRIE_NODE_TRUE == NULL) { + TRIE_NODE_TRUE = (struct trie_node *) safe_zalloc(sizeof(struct trie_node)); + } + + /* Allocate the first node of the IPv4 trie */ + set->trie = (struct trie_node *) safe_zalloc(sizeof(struct trie_node)); } void addrset_free(struct addrset *set) { struct addrset_elem *elem, *next; + /* Since we descend only down one side, we at most accumulate one tree's-depth, or 128. + * Add 4 for safety to account for special root node and special empty stack position 0. + */ + struct trie_node *stack[128+4]; + struct trie_node *curr; + int i = 1; for (elem = set->head; elem != NULL; elem = next) { next = elem->next; free(elem); } + + curr = set->trie; + while (i > 0) { + /* stash next_bit_one */ + if (curr->next_bit_one != NULL && curr->next_bit_one != TRIE_NODE_TRUE) { + stack[i++] = curr->next_bit_one; + } + /* if next_bit_zero is valid, descend */ + if (curr->next_bit_zero != NULL && curr->next_bit_zero != TRIE_NODE_TRUE) { + curr = curr->next_bit_zero; + } + else { + /* next_bit_one was stashed, next_bit_zero is invalid. Free it and move back up the stack. */ + free(curr); + curr = stack[--i]; + } + } +} + + +/* Public domain log2 function. https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogLookup */ +static const char LogTable256[256] = { +#define LT(n) n, n, n, n, n, n, n, n, n, n, n, n, n, n, n, n + -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + LT(4), LT(5), LT(5), LT(6), LT(6), LT(6), LT(6), + LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7) +}; + +/* Returns a mask representing the common prefix between 2 values. */ +u32 common_mask(u32 a, u32 b) +{ + u8 r; // r will be lg(v) + u32 t, tt; // temporaries + u32 v = a ^ b; + + if ((tt = v >> 16)) + { + r = (t = tt >> 8) ? 24 + LogTable256[t] : 16 + LogTable256[tt]; + } + else + { + r = (t = v >> 8) ? 8 + LogTable256[t] : LogTable256[v]; + } + if (r + 1 >= 32) { + /* shifting this many bits would overflow. Just return max mask */ + return 0xffffffff; + } + else { + return ~((1 << (r + 1)) - 1); + } +} + +/* Given a mask and a value, return the value of the bit immediately following + * the masked bits. */ +u32 next_bit_is_one(u32 mask, u32 value) { + if (mask == 0) { + /* no masked bits, check the first bit. */ + return (0x80000000 & value); + } + else if (mask == 0xffffffff) { + /* Imaginary bit off the end we will say is 0 */ + return 0; + } + /* isolate the bit by overlapping the mask with its inverse */ + return ((mask >> 1) & ~mask) & value; +} + +/* Given a mask and an address, return true if the first unmasked bit is one */ +u32 addr_next_bit_is_one(const u32 *mask, const u32 *addr) { + u32 curr_mask; + for (size_t i = 0; i < 4; i++) { + curr_mask = mask[i]; + if (curr_mask < 0xffffffff) { + /* Only bother checking the first not-completely-masked portion of the address */ + return next_bit_is_one(curr_mask, addr[i]); + } + } + /* Mask must be all ones, meaning that the next bit is off the end, and clearly not 1. */ + return 0; +} + +/* Return true if the masked portion of a and b is identical */ +int mask_matches(u32 mask, u32 a, u32 b) +{ + return !(mask & (a ^ b)); +} + +/* Apply a mask and check if 2 addresses are equal */ +int addr_matches(const u32 *mask, const u32 *sa, const u32 *sb) +{ + u32 curr_mask; + for (size_t i = 0; i < 4; i++) { + curr_mask = mask[i]; + if (curr_mask == 0) { + /* No more applicable bits */ + break; + } + else if (!mask_matches(curr_mask, sa[i], sb[i])) { + /* Doesn't match. */ + return 0; + } + } + /* All applicable bits match. */ + return 1; +} + +/* Helper function to allocate and initialize a new node */ +struct trie_node *new_trie_node(const u32 *addr, const u32 *mask) +{ + struct trie_node *new_node = (struct trie_node *) safe_zalloc(sizeof(struct trie_node)); + for (size_t i=0; i < 4; i++) { + new_node->addr[i] = addr[i]; + new_node->mask[i] = mask[i]; + } + return new_node; +} + +/* Split a node into 2: one that matches the greatest common prefix with addr + * and one that does not. */ +void trie_split (struct trie_node *this, const u32 *addr) +{ + u32 new_mask[4] = {0,0,0,0}; + /* Calculate the mask of the common prefix */ + for (size_t i=0; i < 4; i++) { + new_mask[i] = common_mask(this->addr[i], addr[i]); + if (new_mask[i] < 0xffffffff) { + break; + } + } + /* Make a copy of this node to continue matching what it has been */ + struct trie_node *new_node = new_trie_node(this->addr, this->mask); + new_node->next_bit_one = this->next_bit_one; + new_node->next_bit_zero = this->next_bit_zero; + /* Adjust this node to the smaller mask */ + for (size_t i=0; i < 4; i++) { + this->mask[i] = new_mask[i]; + } + /* Put the new node on the appropriate branch */ + if (addr_next_bit_is_one(this->mask, this->addr)) { + this->next_bit_one = new_node; + this->next_bit_zero = NULL; + } + else { + this->next_bit_zero = new_node; + this->next_bit_one = NULL; + } +} + +/* Convenient static mask used for new addresses. + * Will be replaced with user-supplied mask when we support netmasks */ +static const u32 MASK_ALL_ONES[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff }; + +/* Tail-recursive helper for address insertion */ +void _trie_insert (struct trie_node *this, const u32 *addr) +{ + if (this == TRIE_NODE_TRUE) { + /* Anything below this point always matches; no need to insert */ + return; + } + if (addr_matches(this->mask, this->addr, addr)) { + if (1 & this->mask[3]) { + /* 1. end of address: duplicate. return; */ + return; + } + } + else { + /* Split the netmask to ensure a match */ + trie_split(this, addr); + } + if (addr_next_bit_is_one(this->mask, addr)) { + /* next bit is one: insert on the one branch */ + if (this->next_bit_one == NULL) { + /* Previously unmatching branch, always the case when splitting */ + this->next_bit_one = new_trie_node(addr, MASK_ALL_ONES); + } + else { + return _trie_insert(this->next_bit_one, addr); + } + } + else { + /* next bit is zero: insert on the zero branch */ + if (this->next_bit_zero == NULL) { + /* Previously unmatching branch, always the case when splitting */ + this->next_bit_zero = new_trie_node(addr, MASK_ALL_ONES); + } + else { + return _trie_insert(this->next_bit_zero, addr); + } + } +} + +/* Helper function to turn a sockaddr into an array of u32, used internally */ +int sockaddr_to_addr(const struct sockaddr *sa, u32 *addr) +{ + if (sa->sa_family == AF_INET) { + /* IPv4-mapped IPv6 address */ + addr[0] = addr[1] = 0; + addr[2] = 0xffff; + addr[3] = ntohl(((struct sockaddr_in *) sa)->sin_addr.s_addr); + } +#ifdef HAVE_IPV6 + else if (sa->sa_family == AF_INET6) { + unsigned char *addr6 = ((struct sockaddr_in6 *) sa)->sin6_addr.s6_addr; + for (size_t i=0; i < 4; i++) { + addr[i] = (addr6[i*4] << 24) + (addr6[i*4+1] << 16) + (addr6[i*4+2] << 8) + addr6[i*4+3]; + } + } +#endif + else { + return 0; + } + return 1; +} + +/* Insert a sockaddr into the trie */ +void trie_insert (struct trie_node *this, const struct sockaddr *sa) +{ + u32 addr[4] = {0}; + if (!sockaddr_to_addr(sa, addr)) { + log_debug("Unknown address family %u, address not inserted.\n", sa->sa_family); + return; + } + /* First node doesn't have a mask or address of its own; we have to check the + * first bit manually. */ + if (0x80000000 & addr[0]) { + /* First bit is 1, so insert on ones branch */ + if (this->next_bit_one == NULL) { + /* Empty branch, just add it. */ + this->next_bit_one = new_trie_node(addr, MASK_ALL_ONES); + return; + } + return _trie_insert(this->next_bit_one, addr); + } + else { + /* First bit is 0, so insert on zeros branch */ + if (this->next_bit_zero == NULL) { + /* Empty branch, just add it. */ + this->next_bit_zero = new_trie_node(addr, MASK_ALL_ONES); + return; + } + return _trie_insert(this->next_bit_zero, addr); + } +} + +/* Tail-recursive helper for matching addresses */ +int _trie_match (const struct trie_node *this, const u32 *addr) +{ + if (this == TRIE_NODE_TRUE) { + return 1; + } + if (this == NULL) { + return 0; + } + if (addr_matches(this->mask, this->addr, addr)) { + if (1 & this->mask[3]) { + /* We've matched all possible bits! Yay! */ + return 1; + } + else if (addr_next_bit_is_one(this->mask, addr)) { + return _trie_match(this->next_bit_one, addr); + } + else { + return _trie_match(this->next_bit_zero, addr); + } + } + return 0; +} + +int trie_match (const struct trie_node *this, const struct sockaddr *sa) +{ + u32 addr[4] = {0}; + if (!sockaddr_to_addr(sa, addr)) { + log_debug("Unknown address family %u, cannot match.\n", sa->sa_family); + return 0; + } + /* Manually check first bit to decide which branch to match against */ + if (0x80000000 & addr[0]) { + return _trie_match(this->next_bit_one, addr); + } + else { + return _trie_match(this->next_bit_zero, addr); + } + return 0; } /* A debugging function to print out the contents of an addrset_elem. For IPv4 @@ -300,6 +615,22 @@ int addrset_add_spec(struct addrset *set, const char *spec, int af, int dns) } } + if (netmask_bits < 0) { + /* See if it's a plain IP address */ + rc = resolve_name(local_spec, &addrs, af, 0); + if (rc == 0 && addrs != NULL) { + /* Add all addresses to the trie */ + for (addr = addrs; addr != NULL; addr = addr->ai_next) { + char addr_string[128]; + address_to_string(addr->ai_addr, addr->ai_addrlen, addr_string, sizeof(addr_string)); + trie_insert(set->trie, addr->ai_addr); + log_debug("Add IP %s to addrset (trie).\n", addr_string); + } + freeaddrinfo(addrs); + return 1; + } + } + elem = (struct addrset_elem *) safe_malloc(sizeof(*elem)); memset(elem->u.ipv4.bits, 0, sizeof(elem->u.ipv4.bits)); @@ -648,6 +979,11 @@ int addrset_contains(const struct addrset *set, const struct sockaddr *sa) { struct addrset_elem *elem; + /* First check the trie. */ + if (trie_match(set->trie, sa)) + return 1; + + /* If that didn't match, check the rest of the addrset_elem in order */ for (elem = set->head; elem != NULL; elem = elem->next) { if (addrset_elem_match(elem, sa)) return 1;