From 188a3acade8a7b0055f1efcde5497bef1a32cfff Mon Sep 17 00:00:00 2001 From: dmiller Date: Mon, 31 Oct 2022 17:27:56 +0000 Subject: [PATCH] Clarify and optimize top-ports checking --- services.cc | 199 ++++++++++++++++++++++++++-------------------------- 1 file changed, 100 insertions(+), 99 deletions(-) diff --git a/services.cc b/services.cc index ff7629075..d96c4ef19 100644 --- a/services.cc +++ b/services.cc @@ -72,6 +72,8 @@ #include #include +#include +#include /* This structure is the key for looking up services in the port/proto -> service map. */ @@ -332,30 +334,32 @@ static int port_compare(const void *a, const void *b) { } +template +class C_array_iterator: public std::iterator { + T *ptr; + public: + C_array_iterator(T *_ptr=NULL) : ptr(_ptr) {} + C_array_iterator(const C_array_iterator &other) : ptr(other.ptr) {} + C_array_iterator& operator=(T *_ptr) {ptr = _ptr; return *this;} + C_array_iterator& operator++() {ptr++; return *this;} + C_array_iterator operator++(int) {C_array_iterator retval = *this; ++(*this); return retval;} + C_array_iterator& operator--() {ptr--; return *this;} + C_array_iterator operator--(int) {C_array_iterator retval = *this; --(*this); return retval;} + bool operator==(C_array_iterator &other) const {return ptr == other.ptr;} + bool operator!=(C_array_iterator &other) const {return !(*this == other);} + bool operator<(C_array_iterator &other) const {return ptr < other.ptr;} + C_array_iterator& operator+=(std::ptrdiff_t n) {ptr += n; return *this;} + C_array_iterator& operator-=(std::ptrdiff_t n) {ptr -= n; return *this;} + std::ptrdiff_t operator+(const C_array_iterator &other) {return ptr + other.ptr;} + std::ptrdiff_t operator-(const C_array_iterator &other) {return ptr - other.ptr;} + T& operator*() const {return *ptr;} +}; -// is_port_member() returns true if serv is an element of ptsdata. -// This could be implemented MUCH more efficiently but it should only be -// called when you use a non-default top-ports or port-ratio value TOGETHER WITH -// a -p portlist. - -static bool is_port_member(const struct scan_lists *ptsdata, const struct service_node *serv) { - int i; - - if (strcmp(serv->s_proto, "tcp") == 0) { - for (i=0; itcp_count; i++) - if (serv->s_port == ptsdata->tcp_ports[i]) - return true; - } else if (strcmp(serv->s_proto, "udp") == 0) { - for (i=0; iudp_count; i++) - if (serv->s_port == ptsdata->udp_ports[i]) - return true; - } else if (strcmp(serv->s_proto, "sctp") == 0) { - for (i=0; isctp_count; i++) - if (serv->s_port == ptsdata->sctp_ports[i]) - return true; - } - - return false; +// is_port_member() returns true if serv->s_port is an element of pts. +static bool is_port_member(unsigned short *pts, int count, const struct service_node *serv) { + C_array_iterator begin = pts; + C_array_iterator end = pts + count; + return std::binary_search(begin, end, serv->s_port); } // gettoppts() sets its third parameter, a scan_list, with the most @@ -375,7 +379,6 @@ static bool is_port_member(const struct scan_lists *ptsdata, const struct servic // function if o.TCPScan() || o.UDPScan() || o.SCTPScan() void gettoppts(double level, const char *portlist, struct scan_lists * ports, const char *exclude_ports) { - int ti=0, ui=0, si=0; struct scan_lists ptsdata = { 0 }; bool ptsdata_initialized = false; const struct service_node *current; @@ -421,85 +424,83 @@ void gettoppts(double level, const char *portlist, struct scan_lists * ports, co if (ptsdata_initialized && exclude_ports) removepts(exclude_ports, &ptsdata); - if (level < 1) { - for (i = services_by_ratio.begin(); i != services_by_ratio.end(); i++) { - current = &(*i); - if (ptsdata_initialized && !is_port_member(&ptsdata, current)) - continue; - if (current->ratio >= level) { - if (o.TCPScan() && strcmp(current->s_proto, "tcp") == 0) - ports->tcp_count++; - else if (o.UDPScan() && strcmp(current->s_proto, "udp") == 0) - ports->udp_count++; - else if (o.SCTPScan() && strcmp(current->s_proto, "sctp") == 0) - ports->sctp_count++; - } else { - break; - } - } + /* Max number of ports for each protocol cannot be more than the minimum of: + * 1. all of them (65536) + * 2. requested ports (ptsdata) + * 3. the number in services db (numXXXports) + */ + int tcpmax = o.TCPScan() ? (ptsdata_initialized ? ptsdata.tcp_count : 65536) : 0; + tcpmax = MIN(tcpmax, numtcpports); + int udpmax = o.UDPScan() ? (ptsdata_initialized ? ptsdata.udp_count : 65536) : 0; + udpmax = MIN(udpmax, numudpports); + int sctpmax = o.SCTPScan() ? (ptsdata_initialized ? ptsdata.sctp_count : 65536) : 0; + sctpmax = MIN(sctpmax, numsctpports); - if (ports->tcp_count) - ports->tcp_ports = (unsigned short *)safe_zalloc(ports->tcp_count * sizeof(unsigned short)); - - if (ports->udp_count) - ports->udp_ports = (unsigned short *)safe_zalloc(ports->udp_count * sizeof(unsigned short)); - - if (ports->sctp_count) - ports->sctp_ports = (unsigned short *)safe_zalloc(ports->sctp_count * sizeof(unsigned short)); - - ports->prots = NULL; - - for (i = services_by_ratio.begin(); i != services_by_ratio.end(); i++) { - current = &(*i); - if (ptsdata_initialized && !is_port_member(&ptsdata, current)) - continue; - if (current->ratio >= level) { - if (o.TCPScan() && strcmp(current->s_proto, "tcp") == 0) - ports->tcp_ports[ti++] = current->s_port; - else if (o.UDPScan() && strcmp(current->s_proto, "udp") == 0) - ports->udp_ports[ui++] = current->s_port; - else if (o.SCTPScan() && strcmp(current->s_proto, "sctp") == 0) - ports->sctp_ports[si++] = current->s_port; - } else { - break; - } - } - } else if (level >= 1) { + // If level is positive integer, it's the max number of ports. + if (level >= 1) { if (level > 65536) fatal("Level argument to gettoppts (%g) is too large", level); - - if (o.TCPScan()) { - ports->tcp_count = MIN((int) level, numtcpports); - ports->tcp_ports = (unsigned short *)safe_zalloc(ports->tcp_count * sizeof(unsigned short)); - } - if (o.UDPScan()) { - ports->udp_count = MIN((int) level, numudpports); - ports->udp_ports = (unsigned short *)safe_zalloc(ports->udp_count * sizeof(unsigned short)); - } - if (o.SCTPScan()) { - ports->sctp_count = MIN((int) level, numsctpports); - ports->sctp_ports = (unsigned short *)safe_zalloc(ports->sctp_count * sizeof(unsigned short)); - } - - ports->prots = NULL; - - for (i = services_by_ratio.begin(); i != services_by_ratio.end(); i++) { - current = &(*i); - if (ptsdata_initialized && !is_port_member(&ptsdata, current)) - continue; - if (o.TCPScan() && strcmp(current->s_proto, "tcp") == 0 && ti < ports->tcp_count) - ports->tcp_ports[ti++] = current->s_port; - else if (o.UDPScan() && strcmp(current->s_proto, "udp") == 0 && ui < ports->udp_count) - ports->udp_ports[ui++] = current->s_port; - else if (o.SCTPScan() && strcmp(current->s_proto, "sctp") == 0 && si < ports->sctp_count) - ports->sctp_ports[si++] = current->s_port; - } - - if (ti < ports->tcp_count) ports->tcp_count = ti; - if (ui < ports->udp_count) ports->udp_count = ui; - if (si < ports->sctp_count) ports->sctp_count = si; - } else + tcpmax = MIN((int) level, tcpmax); + udpmax = MIN((int) level, udpmax); + sctpmax = MIN((int) level, sctpmax); + // Now force the ratio comparison to always be true: + level = 0; + } + else if (level <= 0) { fatal("Argument to gettoppts (%g) should be a positive ratio below 1 or an integer of 1 or higher", level); + } + // else level is a ratio between 0 and 1 + + // These could be 0/false if the scan type was not requested. + if (tcpmax) { + ports->tcp_ports = (unsigned short *)safe_zalloc(tcpmax * sizeof(unsigned short)); + } + if (udpmax) { + ports->udp_ports = (unsigned short *)safe_zalloc(udpmax * sizeof(unsigned short)); + } + if (sctpmax) { + ports->sctp_ports = (unsigned short *)safe_zalloc(sctpmax * sizeof(unsigned short)); + } + + ports->prots = NULL; + + // Loop until we get enough or run out of candidates + for (i = services_by_ratio.begin(); i != services_by_ratio.end() && (tcpmax || udpmax || sctpmax); i++) { + current = &(*i); + if (current->ratio < level) { + break; + } + switch (current->s_proto[0]) { + case 't': + if (tcpmax && strcmp(current->s_proto, "tcp") == 0 + && (!ptsdata_initialized || + is_port_member(ptsdata.tcp_ports, ptsdata.tcp_count, current)) + ) { + ports->tcp_ports[ports->tcp_count++] = current->s_port; + tcpmax--; + } + break; + case 'u': + if (udpmax && strcmp(current->s_proto, "udp") == 0 + && (!ptsdata_initialized || + is_port_member(ptsdata.udp_ports, ptsdata.udp_count, current)) + ) { + ports->udp_ports[ports->udp_count++] = current->s_port; + udpmax--; + } + break; + case 's': + if (sctpmax && strcmp(current->s_proto, "sctp") == 0 + && (!ptsdata_initialized || + is_port_member(ptsdata.sctp_ports, ptsdata.sctp_count, current)) + ) + ports->sctp_ports[ports->sctp_count++] = current->s_port; + sctpmax--; + break; + default: + break; + } + } if (ptsdata_initialized) { free_scan_lists(&ptsdata);