diff --git a/nselib/dns.lua b/nselib/dns.lua index c0acf4f57..88e827f1b 100644 --- a/nselib/dns.lua +++ b/nselib/dns.lua @@ -307,7 +307,7 @@ end -- * id: numeric value to use for the DNS transaction id -- * nsid: If true, queries the server for the nameserver identifier (RFC 5001) -- * subnet: table, if set perform a edns-client-subnet lookup. The table should contain the fields: --- family - string can be either inet or inet6 +-- family - IPv4: "inet" or 1 (default), IPv6: "inet6" or 2 -- address - string containing the originating subnet IP address -- mask - number containing the number of subnet bits -- @return true if a dns response was received and contained an answer of the requested type, @@ -360,9 +360,6 @@ function query(dname, options) if ( options.nsid ) then addNSID(pkt, dnssec) elseif ( options.subnet ) then - local family = { ["inet"] = 1, ["inet6"] = 2 } - assert( family[options.subnet.family], "Unsupported subnet family") - options.subnet.family = family[options.subnet.family] addClientSubnet(pkt, dnssec, options.subnet ) elseif ( dnssec.DO ) then addOPT(pkt, {DO = true}) @@ -1402,13 +1399,18 @@ end -- @param pkt Table representing DNS packet. -- @param Z Table of Z flags. Only DO is supported. -- @param client_subnet table containing the following fields --- family - 1 IPv4, 2 - IPv6 +-- family - IPv4: "inet" or 1 (default), IPv6: "inet6" or 2 -- mask - byte containing the length of the subnet mask -- address - string containing the IP address function addClientSubnet(pkt,Z,subnet) + local family = subnet.family or 1 + if type(family) == "string" then + family = ({inet=1,inet6=2})[family] + end + assert(family == 1 or family == 2, "Unsupported subnet family") local code = 8 -- https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-11 local scope_mask = 0 -- In requests, it MUST be set to 0 see draft - local data = bin.pack(">SCCA",subnet.family or 1,subnet.mask,scope_mask,ipOps.ip_to_str(subnet.address)) + local data = bin.pack(">SCCA",family,subnet.mask,scope_mask,ipOps.ip_to_str(subnet.address)) local opt = bin.pack(">SS",code, #data) .. data addOPT(pkt,Z,opt) end