diff --git a/nbase/nbase_addrset.c b/nbase/nbase_addrset.c index 82aef9096..0ffb09214 100644 --- a/nbase/nbase_addrset.c +++ b/nbase/nbase_addrset.c @@ -187,37 +187,41 @@ void addrset_init(struct addrset *set) set->trie = (struct trie_node *) safe_zalloc(sizeof(struct trie_node)); } +void trie_free(struct trie_node *curr) +{ + /* Since we descend only down one side, we at most accumulate one tree's-depth, or 128. + * Add 4 for safety to account for special root node and special empty stack position 0. + */ + struct trie_node *stack[128+4]; + int i = 1; + + while (i > 0 && curr != NULL && curr != TRIE_NODE_TRUE) { + /* stash next_bit_one */ + if (curr->next_bit_one != NULL && curr->next_bit_one != TRIE_NODE_TRUE) { + stack[i++] = curr->next_bit_one; + } + /* if next_bit_zero is valid, descend */ + if (curr->next_bit_zero != NULL && curr->next_bit_zero != TRIE_NODE_TRUE) { + curr = curr->next_bit_zero; + } + else { + /* next_bit_one was stashed, next_bit_zero is invalid. Free it and move back up the stack. */ + free(curr); + curr = stack[--i]; + } + } +} + void addrset_free(struct addrset *set) { struct addrset_elem *elem, *next; - /* Since we descend only down one side, we at most accumulate one tree's-depth, or 128. - * Add 4 for safety to account for special root node and special empty stack position 0. - */ - struct trie_node *stack[128+4]; - struct trie_node *curr; - int i = 1; for (elem = set->head; elem != NULL; elem = next) { next = elem->next; free(elem); } - curr = set->trie; - while (i > 0) { - /* stash next_bit_one */ - if (curr->next_bit_one != NULL && curr->next_bit_one != TRIE_NODE_TRUE) { - stack[i++] = curr->next_bit_one; - } - /* if next_bit_zero is valid, descend */ - if (curr->next_bit_zero != NULL && curr->next_bit_zero != TRIE_NODE_TRUE) { - curr = curr->next_bit_zero; - } - else { - /* next_bit_one was stashed, next_bit_zero is invalid. Free it and move back up the stack. */ - free(curr); - curr = stack[--i]; - } - } + trie_free(set->trie); } @@ -315,6 +319,8 @@ struct trie_node *new_trie_node(const u32 *addr, const u32 *mask) new_node->addr[i] = addr[i]; new_node->mask[i] = mask[i]; } + /* New nodes default to matching true. Override if not. */ + new_node->next_bit_one = new_node->next_bit_zero = TRIE_NODE_TRUE; return new_node; } @@ -349,14 +355,10 @@ void trie_split (struct trie_node *this, const u32 *addr) } } -/* Convenient static mask used for new addresses. - * Will be replaced with user-supplied mask when we support netmasks */ -static const u32 MASK_ALL_ONES[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff }; - /* Tail-recursive helper for address insertion */ -void _trie_insert (struct trie_node *this, const u32 *addr) +void _trie_insert (struct trie_node *this, const u32 *addr, const u32 *mask) { - if (this == TRIE_NODE_TRUE) { + if (this == NULL || this == TRIE_NODE_TRUE) { /* Anything below this point always matches; no need to insert */ return; } @@ -370,24 +372,41 @@ void _trie_insert (struct trie_node *this, const u32 *addr) /* Split the netmask to ensure a match */ trie_split(this, addr); } + + for (size_t i=0; i < 4; i++) { + if (this->mask[i] > mask[i]) { + /* broader mask, truncate this one */ + this->mask[i] = mask[i]; + for (; i < 4; i++) { + this->mask[i] = 0; + } + /* The longer mask is superseded. Delete following nodes. */ + trie_free(this->next_bit_one); + trie_free(this->next_bit_zero); + /* Anything below here will always match. */ + this->next_bit_one = this->next_bit_zero = TRIE_NODE_TRUE; + return; + } + } + if (addr_next_bit_is_one(this->mask, addr)) { /* next bit is one: insert on the one branch */ if (this->next_bit_one == NULL) { /* Previously unmatching branch, always the case when splitting */ - this->next_bit_one = new_trie_node(addr, MASK_ALL_ONES); + this->next_bit_one = new_trie_node(addr, mask); } else { - return _trie_insert(this->next_bit_one, addr); + return _trie_insert(this->next_bit_one, addr, mask); } } else { /* next bit is zero: insert on the zero branch */ if (this->next_bit_zero == NULL) { /* Previously unmatching branch, always the case when splitting */ - this->next_bit_zero = new_trie_node(addr, MASK_ALL_ONES); + this->next_bit_zero = new_trie_node(addr, mask); } else { - return _trie_insert(this->next_bit_zero, addr); + return _trie_insert(this->next_bit_zero, addr, mask); } } } @@ -415,33 +434,68 @@ int sockaddr_to_addr(const struct sockaddr *sa, u32 *addr) return 1; } +int sockaddr_to_mask (const struct sockaddr *sa, int bits, u32 *mask) +{ + int unmasked_bits = 0; + if (bits >= 0) { + if (sa->sa_family == AF_INET) { + unmasked_bits = 32 - bits; + } +#ifdef HAVE_IPV6 + else if (sa->sa_family == AF_INET6) { + unmasked_bits = 128 - bits; + } +#endif + else { + return 0; + } + } + for (size_t i=0; i < 4; i++) { + if (unmasked_bits <= 32 * (3 - i)) { + mask[i] = 0xffffffff; + } + else if (unmasked_bits >= 32 * (4 - i)) { + mask[i] = 0; + } + else { + mask[i] = ~((1 << (unmasked_bits - (32 * (4 - i)))) - 1); + } + } + return 1; +} + /* Insert a sockaddr into the trie */ -void trie_insert (struct trie_node *this, const struct sockaddr *sa) +void trie_insert (struct trie_node *this, const struct sockaddr *sa, int bits) { u32 addr[4] = {0}; + u32 mask[4] = {0}; if (!sockaddr_to_addr(sa, addr)) { log_debug("Unknown address family %u, address not inserted.\n", sa->sa_family); return; } + if (!sockaddr_to_mask(sa, bits, mask)) { + log_debug("Bad netmask length %d for address family %u, address not inserted.\n", bits, sa->sa_family); + return; + } /* First node doesn't have a mask or address of its own; we have to check the * first bit manually. */ if (0x80000000 & addr[0]) { /* First bit is 1, so insert on ones branch */ if (this->next_bit_one == NULL) { /* Empty branch, just add it. */ - this->next_bit_one = new_trie_node(addr, MASK_ALL_ONES); + this->next_bit_one = new_trie_node(addr, mask); return; } - return _trie_insert(this->next_bit_one, addr); + return _trie_insert(this->next_bit_one, addr, mask); } else { /* First bit is 0, so insert on zeros branch */ if (this->next_bit_zero == NULL) { /* Empty branch, just add it. */ - this->next_bit_zero = new_trie_node(addr, MASK_ALL_ONES); + this->next_bit_zero = new_trie_node(addr, mask); return; } - return _trie_insert(this->next_bit_zero, addr); + return _trie_insert(this->next_bit_zero, addr, mask); } } @@ -615,20 +669,29 @@ int addrset_add_spec(struct addrset *set, const char *spec, int af, int dns) } } - if (netmask_bits < 0) { - /* See if it's a plain IP address */ - rc = resolve_name(local_spec, &addrs, af, 0); - if (rc == 0 && addrs != NULL) { - /* Add all addresses to the trie */ - for (addr = addrs; addr != NULL; addr = addr->ai_next) { - char addr_string[128]; - address_to_string(addr->ai_addr, addr->ai_addrlen, addr_string, sizeof(addr_string)); - trie_insert(set->trie, addr->ai_addr); - log_debug("Add IP %s to addrset (trie).\n", addr_string); + /* See if it's a plain IP address */ + rc = resolve_name(local_spec, &addrs, af, 0); + if (rc == 0 && addrs != NULL) { + /* Add all addresses to the trie */ + for (addr = addrs; addr != NULL; addr = addr->ai_next) { + char addr_string[128]; + if ((addr->ai_family == AF_INET && netmask_bits > 32) +#ifdef HAVE_IPV6 + || (addr->ai_family == AF_INET6 && netmask_bits > 128) +#endif + ) { + log_user("Illegal netmask in \"%s\". Must be smaller than address bit length.\n", spec); + free(local_spec); + freeaddrinfo(addrs); + return 0; } - freeaddrinfo(addrs); - return 1; + address_to_string(addr->ai_addr, addr->ai_addrlen, addr_string, sizeof(addr_string)); + trie_insert(set->trie, addr->ai_addr, netmask_bits); + log_debug("Add IP %s/%d to addrset (trie).\n", addr_string, netmask_bits); } + free(local_spec); + freeaddrinfo(addrs); + return 1; } elem = (struct addrset_elem *) safe_malloc(sizeof(*elem)); @@ -668,9 +731,6 @@ int addrset_add_spec(struct addrset *set, const char *spec, int af, int dns) for (addr = addrs; addr != NULL; addr = addr->ai_next) { char addr_string[128]; - elem = (struct addrset_elem *) safe_malloc(sizeof(*elem)); - memset(elem->u.ipv4.bits, 0, sizeof(elem->u.ipv4.bits)); - address_to_string(addr->ai_addr, addr->ai_addrlen, addr_string, sizeof(addr_string)); /* Note: it is possible that in this loop we are dealing with addresses @@ -680,49 +740,29 @@ int addrset_add_spec(struct addrset *set, const char *spec, int af, int dns) what you want if a /24 is applied to IPv6 and will cause an error if a /120 is applied to IPv4. */ if (addr->ai_family == AF_INET) { - const struct sockaddr_in *sin = (struct sockaddr_in *) addr->ai_addr; - uint8_t octets[4]; - - elem->type = ADDRSET_TYPE_IPV4_BITVECTOR; - - in_addr_to_octets(&sin->sin_addr, octets); - BIT_SET(elem->u.ipv4.bits[0], octets[0]); - BIT_SET(elem->u.ipv4.bits[1], octets[1]); - BIT_SET(elem->u.ipv4.bits[2], octets[2]); - BIT_SET(elem->u.ipv4.bits[3], octets[3]); if (netmask_bits > 32) { log_user("Illegal netmask in \"%s\". Must be between 0 and 32.\n", spec); - free(elem); + freeaddrinfo(addrs); return 0; } - apply_ipv4_netmask_bits(elem, netmask_bits); - log_debug("Add IPv4 %s/%ld to addrset.\n", addr_string, netmask_bits > 0 ? netmask_bits : 32); + log_debug("Add IPv4 %s/%ld to addrset (trie).\n", addr_string, netmask_bits > 0 ? netmask_bits : 32); #ifdef HAVE_IPV6 } else if (addr->ai_family == AF_INET6) { - const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *) addr->ai_addr; - - elem->type = ADDRSET_TYPE_IPV6_NETMASK; - - elem->u.ipv6.addr = sin6->sin6_addr; - if (netmask_bits > 128) { log_user("Illegal netmask in \"%s\". Must be between 0 and 128.\n", spec); - free(elem); + freeaddrinfo(addrs); return 0; } - make_ipv6_netmask(&elem->u.ipv6.mask, netmask_bits); - log_debug("Add IPv6 %s/%ld to addrset.\n", addr_string, netmask_bits > 0 ? netmask_bits : 128); + log_debug("Add IPv6 %s/%ld to addrset (trie).\n", addr_string, netmask_bits > 0 ? netmask_bits : 128); #endif } else { log_debug("ignoring address %s for %s. Family %d socktype %d protocol %d.\n", addr_string, spec, addr->ai_family, addr->ai_socktype, addr->ai_protocol); - free(elem); continue; } - elem->next = set->head; - set->head = elem; + trie_insert(set->trie, addr->ai_addr, netmask_bits); } if (addrs != NULL)