diff --git a/protocols.cc b/protocols.cc index fd1cc89bc..216b13342 100644 --- a/protocols.cc +++ b/protocols.cc @@ -64,16 +64,29 @@ #include "protocols.h" #include "NmapOps.h" -#include "charpool.h" +#include "string_pool.h" #include "nmap_error.h" #include "utils.h" +#include + extern NmapOps o; -static int numipprots = 0; -static struct protocol_list *protocol_table[PROTOCOL_TABLE_SIZE]; -static int protocols_initialized = 0; + +struct strcmp_comparator { + bool operator()(const char *a, const char *b) const { + return strcmp(a, b) < 0; + } +}; + +// IP Protocol number is 8 bits wide +// protocol_table[IPPROTO_TCP] == {"tcp", 6} +static struct nprotoent *protocol_table[UCHAR_MAX]; +// proto_map["tcp"] = {"tcp", 6} +typedef std::map ProtoMap; +static ProtoMap proto_map; static int nmap_protocols_init() { + static int protocols_initialized = 0; if (protocols_initialized) return 0; char filename[512]; @@ -83,7 +96,6 @@ static int nmap_protocols_init() { char *p; char line[1024]; int lineno = 0; - struct protocol_list *current, *previous; int res; if (nmap_fetchfile(filename, sizeof(filename), "nmap-protocols") != 1) { @@ -108,35 +120,27 @@ static int nmap_protocols_init() { if (*p == '#') continue; res = sscanf(line, "%127s %hu", protocolname, &protno); - if (res !=2) + if (res !=2 || protno > UCHAR_MAX) { + error("Parse error in protocols file %s line %d", filename, lineno); continue; + } + + struct nprotoent ent; + // Using string_pool means we don't have to copy this data; the pointer is unique! + ent.p_name = string_pool_insert(protocolname); + ent.p_proto = protno; + std::pair status = proto_map.insert(std::pair(ent.p_name, ent)); /* Now we make sure our protocols don't have duplicates */ - for(current = protocol_table[protno % PROTOCOL_TABLE_SIZE], previous = NULL; - current; current = current->next) { - if (protno == current->protoent->p_proto) { - if (o.debugging) { - error("Protocol %d is duplicated in protocols file %s", ntohs(protno), filename); - } - break; + if (!status.second) { + if (o.debugging > 1) { + error("Protocol %d is duplicated in protocols file %s", protno, filename); } - previous = current; - } - if (current) continue; - - numipprots++; - - current = (struct protocol_list *) cp_alloc(sizeof(struct protocol_list)); - current->protoent = (struct nprotoent *) cp_alloc(sizeof(struct nprotoent)); - current->next = NULL; - if (previous == NULL) { - protocol_table[protno % PROTOCOL_TABLE_SIZE] = current; - } else { - previous->next = current; } - current->protoent->p_name = cp_strdup(protocolname); - current->protoent->p_proto = protno; + + assert(!protocol_table[protno]); + protocol_table[protno] = &status.first->second; } fclose(fp); protocols_initialized = 1; @@ -151,18 +155,24 @@ static int nmap_protocols_init() { int addprotocolsfromservmask(char *mask, u8 *porttbl) { - struct protocol_list *current; - int bucket, t=0; + ProtoMap::const_iterator it; + int t=0; - if (!protocols_initialized && nmap_protocols_init() == -1) + if (nmap_protocols_init() != 0) fatal("%s: Couldn't get protocol numbers", __func__); - for(bucket = 0; bucket < PROTOCOL_TABLE_SIZE; bucket++) { - for(current = protocol_table[bucket % PROTOCOL_TABLE_SIZE]; current; current = current->next) { - if (wildtest(mask, current->protoent->p_name)) { - porttbl[ntohs(current->protoent->p_proto)] |= SCAN_PROTOCOLS; - t++; - } + // Check for easy ones: plain string match. + it = proto_map.find(mask); + if (it != proto_map.end()) { + // Matched! No need to try wildtest on everything. + porttbl[it->second.p_proto] |= SCAN_PROTOCOLS; + return 1; + } + // No match? iterate and use wildtest. + for(it = proto_map.begin(); it != proto_map.end(); it++) { + if (wildtest(mask, it->second.p_name)) { + porttbl[it->second.p_proto] |= SCAN_PROTOCOLS; + t++; } } @@ -172,17 +182,10 @@ int addprotocolsfromservmask(char *mask, u8 *porttbl) { struct nprotoent *nmap_getprotbynum(int num) { - struct protocol_list *current; if (nmap_protocols_init() == -1) return NULL; - for(current = protocol_table[num % PROTOCOL_TABLE_SIZE]; - current; current = current->next) { - if (num == current->protoent->p_proto) - return current->protoent; - } - - /* Couldn't find it ... oh well. */ - return NULL; + assert(num >= 0 && num < UCHAR_MAX); + return protocol_table[num]; } diff --git a/protocols.h b/protocols.h index 26c6a2b5b..9ecee451d 100644 --- a/protocols.h +++ b/protocols.h @@ -75,18 +75,11 @@ #include "libnetutil/netutil.h" #endif -#define PROTOCOL_TABLE_SIZE 256 - struct nprotoent { const char *p_name; short p_proto; }; -struct protocol_list { - struct nprotoent *protoent; - struct protocol_list *next; -}; - int addprotocolsfromservmask(char *mask, u8 *porttbl); struct nprotoent *nmap_getprotbynum(int num);