diff --git a/nselib/http.lua b/nselib/http.lua index c6c8ca56a..5e1935577 100644 --- a/nselib/http.lua +++ b/nselib/http.lua @@ -170,15 +170,31 @@ local get_default_port = url.get_default_port --- Get a value suitable for the Host header field. -- See RFC 2616 sections 14.23 and 5.2. -local function get_host_field(host, port) +local function get_host_field(host, port, scheme) + -- If the global header is set by script-arg, use that. if host_header then return host_header end + -- If there's no host, we can't invent a name. if not host then return nil end local number = (type(port) == "number") and port or port.number - if number == 443 or number == 80 then - return stdnse.get_hostname(host) + if scheme then + -- Caller provided scheme. If it's default, return just the hostname. + if number == get_default_port(scheme) then + return stdnse.get_hostname(host) + end else - return stdnse.get_hostname(host) .. ":" .. number + scheme = url.get_default_scheme(port) + if scheme then + -- Caller did not provide scheme, and this port has a default scheme. + local ssl_port = shortport.ssl(host, port) + if (ssl_port and scheme == 'https') or + (not ssl_port and scheme == 'http') then + -- If it's SSL and https, or if it's plaintext and http, return just the hostname. + return stdnse.get_hostname(host) + end + end end + -- No special cases matched, so include the port number in the host header + return stdnse.get_hostname(host) .. ":" .. number end -- Skip *( SP | HT ) starting at offset. See RFC 2616, section 2.2. @@ -366,6 +382,11 @@ local function validate_options(options) stdnse.debug1("http: options.redirect_ok must be a function or boolean or number") bad = true end + elseif(key == 'scheme') then + if type(value) ~= 'string' then + stdnse.debug1("http: options.scheme must be a string") + bad = true + end else stdnse.debug1("http: Unknown key in the options table: %s", key) end @@ -1086,7 +1107,7 @@ local function build_request(host, port, method, path, options) local mod_options = { header = { Connection = "close", - Host = get_host_field(host, port), + Host = get_host_field(host, port, options.scheme), ["User-Agent"] = USER_AGENT } } @@ -1604,6 +1625,7 @@ function get(host, port, path, options) if(not(validate_options(options))) then return http_error("Options failed to validate.") end + options = options or {} local redir_check = get_redirect_ok(host, port, options) local response, state, location local u = { host = host, port = port, path = path } @@ -1617,6 +1639,8 @@ function get(host, port, path, options) if ( not(u) ) then break end + -- Allow redirect to change scheme (e.g. redirect to https) + options.scheme = u.scheme or options.scheme location = location or {} table.insert(location, response.header.location) until( not(redir_check(u)) ) @@ -1640,6 +1664,7 @@ function get_url( u, options ) port.service = parsed.scheme port.number = parsed.port or get_default_port(parsed.scheme) or 80 + options.scheme = options.scheme or parsed.scheme local path = parsed.path or "/" if parsed.query then @@ -1678,6 +1703,7 @@ function head(host, port, path, options) if(not(validate_options(options))) then return http_error("Options failed to validate.") end + options = options or {} local redir_check = get_redirect_ok(host, port, options) local response, state, location local u = { host = host, port = port, path = path } @@ -1691,6 +1717,8 @@ function head(host, port, path, options) if ( not(u) ) then break end + -- Allow redirect to change scheme (e.g. redirect to https) + options.scheme = u.scheme or options.scheme location = location or {} table.insert(location, response.header.location) until( not(redir_check(u)) ) diff --git a/nselib/url.lua b/nselib/url.lua index 31a2905e1..24b2c3ae8 100644 --- a/nselib/url.lua +++ b/nselib/url.lua @@ -415,11 +415,31 @@ local get_default_port_ports = {http=80, https=443} -- @param scheme for determining the port, such as "http" or "https". -- @return A port number as an integer, such as 443 for scheme "https", -- or nil in case of an undefined scheme ------------------------------------------------------------------------------ function get_default_port (scheme) return get_default_port_ports[(scheme or ""):lower()] end +local function invert(t) + local out = {} + for k, v in pairs(t) do + out[v] = k + end + return out +end + +get_default_scheme_schemes = invert(get_default_port_ports) + +--- +-- Provides the default URI scheme for a given port. +-- +-- @param port A port number as a number or port table +-- @return scheme for addressing the port, such as "http" or "https". +----------------------------------------------------------------------------- +function get_default_scheme (port) + local number = (type(port) == "number") and port or port.number + return get_default_scheme_schemes[port] +end + if not unittest.testing() then return _ENV end