diff --git a/ndiff/ndiff b/ndiff/ndiff index f3186c54e..e8ee4e180 100755 --- a/ndiff/ndiff +++ b/ndiff/ndiff @@ -489,59 +489,78 @@ def host_pairs(a, b): j += 1 class ScanDiff(object): - """A complete diff of two scans. It is a container for two scans and a dict - mapping hosts to HostDiffs.""" - def __init__(self, scan_a, scan_b): + """An abtract class for different diff output types. Subclasses must define + various output methods.""" + def __init__(self, scan_a, scan_b, f = sys.stdout): """Create a ScanDiff from the "before" scan_a and the "after" scan_b.""" self.scan_a = scan_a self.scan_b = scan_b - self.hosts = [] - self.pre_script_result_diffs = [] - self.post_script_result_diffs = [] - self.host_diffs = {} + self.f = f - self.diff() - - def diff(self): + def output(self): self.scan_a.sort_hosts() self.scan_b.sort_hosts() + self.output_beginning() + + pre_script_result_diffs = ScriptResultDiff.diff_lists(self.scan_a.pre_script_results, self.scan_b.pre_script_results) + self.output_pre_scripts(pre_script_result_diffs) + + cost = 0 # Currently we never consider diffing hosts with a different id # (address or host name), which could lead to better diffs. for host_a, host_b in host_pairs(self.scan_a.hosts, self.scan_b.hosts): h_diff = HostDiff(host_a, host_b) + cost += h_diff.cost if h_diff.cost > 0 or verbose: host = host_a or host_b - self.hosts.append(host) - self.host_diffs[host] = h_diff + self.output_host_diff(h_diff) - self.pre_script_result_diffs = ScriptResultDiff.diff_lists(self.scan_a.pre_script_results, self.scan_b.pre_script_results) - self.post_script_result_diffs = ScriptResultDiff.diff_lists(self.scan_a.post_script_results, self.scan_b.post_script_results) + post_script_result_diffs = ScriptResultDiff.diff_lists(self.scan_a.post_script_results, self.scan_b.post_script_results) + self.output_post_scripts(post_script_result_diffs) - def print_text(self, f = sys.stdout): - """Print this diff in a human-readable text form.""" + self.output_ending() + + return cost + +class ScanDiffText(ScanDiff): + def __init__(self, scan_a, scan_b, f = sys.stdout): + ScanDiff.__init__(self, scan_a, scan_b, f) + + def output_beginning(self): banner_a = format_banner(self.scan_a) banner_b = format_banner(self.scan_b) - if self.cost > 0 or verbose: - if banner_a != banner_b: - print >> f, u"-%s" % banner_a - print >> f, u"+%s" % banner_b - elif verbose: - print >> f, u" %s" % banner_a + if banner_a != banner_b: + print >> self.f, u"-%s" % banner_a + print >> self.f, u"+%s" % banner_b + elif verbose: + print >> self.f, u" %s" % banner_a + def output_pre_scripts(self, pre_script_result_diffs): print_script_result_diffs_text("Pre-scan script results", self.scan_a.pre_script_results, self.scan_b.pre_script_results, - self.pre_script_result_diffs) - - for host in self.hosts: - print >> f - - h_diff = self.host_diffs[host] - h_diff.print_text(f) + pre_script_result_diffs, self.f) + def output_post_scripts(self, post_script_result_diffs): print_script_result_diffs_text("Post-scan script results", self.scan_a.post_script_results, self.scan_b.post_script_results, - self.post_script_result_diffs) + post_script_result_diffs, self.f) + + def output_host_diff(self, h_diff): + print >> self.f + h_diff.print_text(self.f) + + def output_ending(self): + pass + +class ScanDiffXML(ScanDiff): + def __init__(self, scan_a, scan_b, f = sys.stdout): + ScanDiff.__init__(self, scan_a, scan_b, f) + + impl = xml.dom.minidom.getDOMImplementation() + self.document = impl.createDocument(None, None, None) + + self.writer = XMLWriter(f) def nmaprun_differs(self): for attr in ("scanner", "version", "args", "start_date", "end_date"): @@ -549,53 +568,46 @@ class ScanDiff(object): return True return False - def print_xml(self, f = sys.stdout): - impl = xml.dom.minidom.getDOMImplementation() - document = impl.createDocument(None, None, None) - - writer = XMLWriter(f) - - writer.startDocument() - writer.startElement(u"nmapdiff", {u"version": NDIFF_XML_VERSION}) - writer.startElement(u"scandiff", {}) + def output_beginning(self): + self.writer.startDocument() + self.writer.startElement(u"nmapdiff", {u"version": NDIFF_XML_VERSION}) + self.writer.startElement(u"scandiff", {}) if self.nmaprun_differs(): - writer.frag_a(self.scan_a.nmaprun_to_dom_fragment(document)) - writer.frag_b(self.scan_b.nmaprun_to_dom_fragment(document)) + self.writer.frag_a(self.scan_a.nmaprun_to_dom_fragment(self.document)) + self.writer.frag_b(self.scan_b.nmaprun_to_dom_fragment(self.document)) elif verbose: - writer.frag(self.scan_a.nmaprun_to_dom_fragment(document)) + self.writer.frag(self.scan_a.nmaprun_to_dom_fragment(self.document)) - # prerule script changes. - if len(self.pre_script_result_diffs) > 0 or verbose: - prescript_elem = document.createElement(u"prescript") + def output_pre_scripts(self, pre_script_result_diffs): + if len(pre_script_result_diffs) > 0 or verbose: + prescript_elem = self.document.createElement(u"prescript") frag = script_result_diffs_to_dom_fragment( prescript_elem, self.scan_a.pre_script_results, - self.scan_b.pre_script_results, self.pre_script_result_diffs, - document) - writer.frag(frag) + self.scan_b.pre_script_results, pre_script_result_diffs, + self.document) + self.writer.frag(frag) frag.unlink() - for host in self.hosts: - h_diff = self.host_diffs[host] - frag = h_diff.to_dom_fragment(document) - writer.frag(frag) - frag.unlink() - - # postrule script changes. - if len(self.post_script_result_diffs) > 0 or verbose: - postscript_elem = document.createElement(u"postscript") + def output_post_scripts(self, post_script_result_diffs): + if len(post_script_result_diffs) > 0 or verbose: + postscript_elem = self.document.createElement(u"postscript") frag = script_result_diffs_to_dom_fragment( postscript_elem, self.scan_a.post_script_results, - self.scan_b.post_script_results, self.post_script_result_diffs, - document) - writer.frag(frag) + self.scan_b.post_script_results, post_script_result_diffs, + self.document) + self.writer.frag(frag) frag.unlink() - writer.endElement(u"scandiff") - writer.endElement(u"nmapdiff") - writer.endDocument() + def output_host_diff(self, h_diff): + frag = h_diff.to_dom_fragment(self.document) + self.writer.frag(frag) + frag.unlink() - cost = property(lambda self: sum([hd.cost for hd in self.host_diffs.values()])) + def output_ending(self): + self.writer.endElement(u"scandiff") + self.writer.endElement(u"nmapdiff") + self.writer.endDocument() class HostDiff(object): """A diff of two Hosts. It contains the two hosts, variables describing what @@ -1354,14 +1366,13 @@ def main(): print >> sys.stderr, u"Can't open file: %s" % str(e) sys.exit(EXIT_ERROR) - diff = ScanDiff(scan_a, scan_b) - if output_format == "text": - diff.print_text() + diff = ScanDiffText(scan_a, scan_b) elif output_format == "xml": - diff.print_xml() + diff = ScanDiffXML(scan_a, scan_b) + cost = diff.output() - if diff.cost == 0: + if cost == 0: return EXIT_EQUAL else: return EXIT_DIFFERENT diff --git a/ndiff/ndifftest.py b/ndiff/ndifftest.py index a875dd5dd..6294f0dc2 100755 --- a/ndiff/ndifftest.py +++ b/ndiff/ndifftest.py @@ -274,25 +274,58 @@ class service_test(unittest.TestCase): serv.extrainfo = u"misconfigured" self.assertEqual(serv.version_string(), u"%s %s (%s)" % (serv.product, serv.version, serv.extrainfo)) +class ScanDiffSub(ScanDiff): + """A subclass of ScanDiff that counts diffs for testing.""" + def __init__(self, scan_a, scan_b, f = sys.stdout): + ScanDiff.__init__(self, scan_a, scan_b, f) + self.pre_script_result_diffs = [] + self.post_script_result_diffs = [] + self.host_diffs = [] + + def output_beginning(self): + pass + + def output_pre_scripts(self, pre_script_result_diffs): + self.pre_script_result_diffs = pre_script_result_diffs + + def output_post_scripts(self, post_script_result_diffs): + self.post_script_result_diffs = post_script_result_diffs + + def output_host_diff(self, h_diff): + self.host_diffs.append(h_diff) + + def output_ending(self): + pass + class scan_diff_test(unittest.TestCase): """Test the ScanDiff class.""" + def setUp(self): + self.blackhole = open("/dev/null", "w") + + def tearDown(self): + self.blackhole.close() + def test_self(self): scan = Scan() scan.load_from_file("test-scans/complex.xml") - diff = ScanDiff(scan, scan) - self.assertEqual(len(diff.host_diffs), 0) - self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys())) + diff = ScanDiffText(scan, scan, self.blackhole) + cost = diff.output() + self.assertEqual(cost, 0) + diff = ScanDiffXML(scan, scan, self.blackhole) + cost = diff.output() + self.assertEqual(cost, 0) def test_unknown_up(self): a = Scan() a.load_from_file("test-scans/empty.xml") b = Scan() b.load_from_file("test-scans/simple.xml") - diff = ScanDiff(a, b) - self.assertTrue(len(diff.hosts) >= 1) + diff = ScanDiffSub(a, b, self.blackhole) + diff.output() + self.assertEqual(len(diff.pre_script_result_diffs), 0) + self.assertEqual(len(diff.post_script_result_diffs), 0) self.assertEqual(len(diff.host_diffs), 1) - self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys())) - h_diff = diff.host_diffs.values()[0] + h_diff = diff.host_diffs[0] self.assertEqual(h_diff.host_a.state, None) self.assertEqual(h_diff.host_b.state, "up") @@ -301,11 +334,12 @@ class scan_diff_test(unittest.TestCase): a.load_from_file("test-scans/simple.xml") b = Scan() b.load_from_file("test-scans/empty.xml") - diff = ScanDiff(a, b) - self.assertTrue(len(diff.hosts) >= 1) + diff = ScanDiffSub(a, b, self.blackhole) + diff.output() + self.assertEqual(len(diff.pre_script_result_diffs), 0) + self.assertEqual(len(diff.post_script_result_diffs), 0) self.assertEqual(len(diff.host_diffs), 1) - self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys())) - h_diff = diff.host_diffs.values()[0] + h_diff = diff.host_diffs[0] self.assertEqual(h_diff.host_a.state, "up") self.assertEqual(h_diff.host_b.state, None) @@ -327,11 +361,10 @@ class scan_diff_test(unittest.TestCase): a.load_from_file("test-scans/%s.xml" % pair[0]) b = Scan() b.load_from_file("test-scans/%s.xml" % pair[1]) - diff = ScanDiff(a, b) + diff = ScanDiffSub(a, b) scan_apply_diff(a, diff) - diff = ScanDiff(a, b) - self.assertEqual(diff.host_diffs, {}) - self.assertEqual(set(diff.hosts), set(diff.host_diffs.keys())) + diff = ScanDiffSub(a, b) + self.assertEqual(diff.host_diffs, []) class host_diff_test(unittest.TestCase): """Test the HostDiff class.""" @@ -641,9 +674,9 @@ class scan_diff_xml_test(unittest.TestCase): a.load_from_file("test-scans/empty.xml") b = Scan() b.load_from_file("test-scans/simple.xml") - self.scan_diff = ScanDiff(a, b) f = StringIO.StringIO() - self.scan_diff.print_xml(f) + self.scan_diff = ScanDiffXML(a, b, f) + self.scan_diff.output() self.xml = f.getvalue() f.close() @@ -655,10 +688,11 @@ class scan_diff_xml_test(unittest.TestCase): def scan_apply_diff(scan, diff): """Apply a scan diff to the given scan.""" - for host in diff.hosts: + for h_diff in diff.host_diffs: + host = h_diff.host_a or h_diff.host_b if host not in scan.hosts: scan.hosts.append(host) - host_apply_diff(host, diff.host_diffs[host]) + host_apply_diff(host, h_diff) def host_apply_diff(host, diff): """Apply a host diff to the given host."""