1
0
mirror of https://github.com/nmap/nmap.git synced 2025-12-27 09:59:04 +00:00

Do output as a byproduct of calculating the diff.

This doesn't require keeping the whole diff in memory until the end.
This commit is contained in:
david
2011-12-21 06:59:46 +00:00
parent d08bb02073
commit 5be27e7aad
2 changed files with 133 additions and 88 deletions

View File

@@ -274,25 +274,58 @@ class service_test(unittest.TestCase):
serv.extrainfo = u"misconfigured"
self.assertEqual(serv.version_string(), u"%s %s (%s)" % (serv.product, serv.version, serv.extrainfo))
class ScanDiffSub(ScanDiff):
"""A subclass of ScanDiff that counts diffs for testing."""
def __init__(self, scan_a, scan_b, f = sys.stdout):
ScanDiff.__init__(self, scan_a, scan_b, f)
self.pre_script_result_diffs = []
self.post_script_result_diffs = []
self.host_diffs = []
def output_beginning(self):
pass
def output_pre_scripts(self, pre_script_result_diffs):
self.pre_script_result_diffs = pre_script_result_diffs
def output_post_scripts(self, post_script_result_diffs):
self.post_script_result_diffs = post_script_result_diffs
def output_host_diff(self, h_diff):
self.host_diffs.append(h_diff)
def output_ending(self):
pass
class scan_diff_test(unittest.TestCase):
"""Test the ScanDiff class."""
def setUp(self):
self.blackhole = open("/dev/null", "w")
def tearDown(self):
self.blackhole.close()
def test_self(self):
scan = Scan()
scan.load_from_file("test-scans/complex.xml")
diff = ScanDiff(scan, scan)
self.assertEqual(len(diff.host_diffs), 0)
self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys()))
diff = ScanDiffText(scan, scan, self.blackhole)
cost = diff.output()
self.assertEqual(cost, 0)
diff = ScanDiffXML(scan, scan, self.blackhole)
cost = diff.output()
self.assertEqual(cost, 0)
def test_unknown_up(self):
a = Scan()
a.load_from_file("test-scans/empty.xml")
b = Scan()
b.load_from_file("test-scans/simple.xml")
diff = ScanDiff(a, b)
self.assertTrue(len(diff.hosts) >= 1)
diff = ScanDiffSub(a, b, self.blackhole)
diff.output()
self.assertEqual(len(diff.pre_script_result_diffs), 0)
self.assertEqual(len(diff.post_script_result_diffs), 0)
self.assertEqual(len(diff.host_diffs), 1)
self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys()))
h_diff = diff.host_diffs.values()[0]
h_diff = diff.host_diffs[0]
self.assertEqual(h_diff.host_a.state, None)
self.assertEqual(h_diff.host_b.state, "up")
@@ -301,11 +334,12 @@ class scan_diff_test(unittest.TestCase):
a.load_from_file("test-scans/simple.xml")
b = Scan()
b.load_from_file("test-scans/empty.xml")
diff = ScanDiff(a, b)
self.assertTrue(len(diff.hosts) >= 1)
diff = ScanDiffSub(a, b, self.blackhole)
diff.output()
self.assertEqual(len(diff.pre_script_result_diffs), 0)
self.assertEqual(len(diff.post_script_result_diffs), 0)
self.assertEqual(len(diff.host_diffs), 1)
self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys()))
h_diff = diff.host_diffs.values()[0]
h_diff = diff.host_diffs[0]
self.assertEqual(h_diff.host_a.state, "up")
self.assertEqual(h_diff.host_b.state, None)
@@ -327,11 +361,10 @@ class scan_diff_test(unittest.TestCase):
a.load_from_file("test-scans/%s.xml" % pair[0])
b = Scan()
b.load_from_file("test-scans/%s.xml" % pair[1])
diff = ScanDiff(a, b)
diff = ScanDiffSub(a, b)
scan_apply_diff(a, diff)
diff = ScanDiff(a, b)
self.assertEqual(diff.host_diffs, {})
self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys()))
diff = ScanDiffSub(a, b)
self.assertEqual(diff.host_diffs, [])
class host_diff_test(unittest.TestCase):
"""Test the HostDiff class."""
@@ -641,9 +674,9 @@ class scan_diff_xml_test(unittest.TestCase):
a.load_from_file("test-scans/empty.xml")
b = Scan()
b.load_from_file("test-scans/simple.xml")
self.scan_diff = ScanDiff(a, b)
f = StringIO.StringIO()
self.scan_diff.print_xml(f)
self.scan_diff = ScanDiffXML(a, b, f)
self.scan_diff.output()
self.xml = f.getvalue()
f.close()
@@ -655,10 +688,11 @@ class scan_diff_xml_test(unittest.TestCase):
def scan_apply_diff(scan, diff):
"""Apply a scan diff to the given scan."""
for host in diff.hosts:
for h_diff in diff.host_diffs:
host = h_diff.host_a or h_diff.host_b
if host not in scan.hosts:
scan.hosts.append(host)
host_apply_diff(host, diff.host_diffs[host])
host_apply_diff(host, h_diff)
def host_apply_diff(host, diff):
"""Apply a host diff to the given host."""