some more refactorings

This commit is contained in:
Miroslav Stampar
2012-02-16 14:42:28 +00:00
parent 6632aa7308
commit dcf7277a0f
12 changed files with 245 additions and 237 deletions

View File

@@ -1192,11 +1192,14 @@ def expandAsteriskForColumns(expression):
return expression
def getRange(count, dump=False, plusOne=False):
def getLimitRange(count, dump=False, plusOne=False):
"""
Returns range of values used in limit/offset constructs
"""
retVal = None
count = int(count)
indexRange = None
limitStart = 1
limitStop = count
limitStart, limitStop = 1, count
if dump:
if isinstance(conf.limitStop, int) and conf.limitStop > 0 and conf.limitStop < limitStop:
@@ -1205,11 +1208,15 @@ def getRange(count, dump=False, plusOne=False):
if isinstance(conf.limitStart, int) and conf.limitStart > 0 and conf.limitStart <= limitStop:
limitStart = conf.limitStart
indexRange = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop)
retVal = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop)
return indexRange
return retVal
def parseUnionPage(output, unique=True):
"""
Returns resulting items from inband query inside provided page content
"""
if output is None:
return None
@@ -1250,7 +1257,7 @@ def parseUnionPage(output, unique=True):
def parseFilePaths(page):
"""
Detect (possible) absolute system paths inside the provided page content
Detects (possible) absolute system paths inside the provided page content
"""
if page:
@@ -1265,32 +1272,6 @@ def parseFilePaths(page):
if absFilePath not in kb.absFilePaths:
kb.absFilePaths.add(absFilePath)
def getDelayQuery(andCond=False):
query = None
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL):
if not kb.data.banner:
conf.dbmsHandler.getVersionFromBanner()
banVer = kb.bannerFp["dbmsVersion"] if 'dbmsVersion' in kb.bannerFp else None
if banVer is None or (Backend.isDbms(DBMS.MYSQL) and banVer >= "5.0.12") or (Backend.isDbms(DBMS.PGSQL) and banVer >= "8.2"):
query = queries[Backend.getIdentifiedDbms()].timedelay.query % conf.timeSec
else:
query = queries[Backend.getIdentifiedDbms()].timedelay.query2 % conf.timeSec
elif Backend.isDbms(DBMS.FIREBIRD):
query = queries[Backend.getIdentifiedDbms()].timedelay.query
else:
query = queries[Backend.getIdentifiedDbms()].timedelay.query % conf.timeSec
if andCond:
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.SQLITE ):
query = query.replace("SELECT ", "")
elif Backend.isDbms(DBMS.FIREBIRD):
query = "(%s)>0" % query
return query
def getLocalIP():
retVal = None
try:
@@ -1310,11 +1291,11 @@ def getRemoteIP():
def getFileType(filePath):
try:
magicFileType = magic.from_file(filePath)
_ = magic.from_file(filePath)
except:
return "unknown"
return "text" if "ASCII" in magicFileType or "text" in magicFileType else "binary"
return "text" if "ASCII" in _ or "text" in _ else "binary"
def getCharset(charsetType=None):
asciiTbl = []
@@ -1354,15 +1335,14 @@ def getCharset(charsetType=None):
return asciiTbl
def searchEnvPath(fileName):
envPaths = os.environ["PATH"]
def searchEnvPath(filename):
result = None
path = os.environ.get("PATH", "")
paths = path.split(";") if IS_WIN else path.split(":")
envPaths = envPaths.split(";") if IS_WIN else envPaths.split(":")
for envPath in envPaths:
envPath = envPath.replace(";", "")
result = os.path.exists(os.path.normpath(os.path.join(envPath, fileName)))
for _ in paths:
_ = _.replace(";", "")
result = os.path.exists(os.path.normpath(os.path.join(_, filename)))
if result:
break
@@ -1394,28 +1374,40 @@ def urlEncodeCookieValues(cookieStr):
else:
return None
def directoryPath(path):
def directoryPath(filepath):
"""
Returns directory path for a given filepath
"""
retVal = None
if isWindowsDriveLetterPath(path):
retVal = ntpath.dirname(path)
if isWindowsDriveLetterPath(filepath):
retVal = ntpath.dirname(filepath)
else:
retVal = posixpath.dirname(path)
retVal = posixpath.dirname(filepath)
return retVal
def normalizePath(path):
def normalizePath(filepath):
"""
Returns normalized string representation of a given filepath
"""
retVal = None
if isWindowsDriveLetterPath(path):
retVal = ntpath.normpath(path)
if isWindowsDriveLetterPath(filepath):
retVal = ntpath.normpath(filepath)
else:
retVal = posixpath.normpath(path)
retVal = posixpath.normpath(filepath)
return retVal
def safeStringFormat(formatStr, params):
retVal = formatStr.replace("%d", "%s")
def safeStringFormat(format_, params):
"""
Avoids problems with inappropriate string format strings
"""
retVal = format_.replace("%d", "%s")
if isinstance(params, basestring):
retVal = retVal.replace("%s", params)
@@ -1435,23 +1427,12 @@ def safeStringFormat(formatStr, params):
return retVal
def sanitizeAsciiString(subject):
if subject:
index = None
for i in xrange(len(subject)):
if ord(subject[i]) >= 128:
index = i
break
if index is None:
return subject
else:
return subject[:index] + "".join(subject[i] if ord(subject[i]) < 128 else '?' for i in xrange(index, len(subject)))
else:
return None
def getFilteredPageContent(page, onlyText=True):
"""
Returns filtered page content without script, style and/or comments
or all HTML tags
"""
retVal = page
# only if the page's charset has been successfully identified
@@ -2402,6 +2383,10 @@ def isTechniqueAvailable(technique):
return getTechniqueData(technique) is not None
def isInferenceAvailable():
"""
Returns True whether techniques using inference technique are available
"""
return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME))
def setOptimize():
@@ -2619,7 +2604,7 @@ def listToStrValue(value):
def getExceptionFrameLocals():
"""
Returns dictionary with local variable content from frame
where exception was raised
where exception has been raised
"""
retVal = {}
@@ -2793,7 +2778,7 @@ def isNullValue(value):
def expandMnemonics(mnemonics, parser, args):
"""
Expand mnemonic options
Expands mnemonic options
"""
class MnemonicNode:
@@ -2876,7 +2861,7 @@ def expandMnemonics(mnemonics, parser, args):
def safeCSValue(value):
"""
Returns value safe for CSV dumping.
Returns value safe for CSV dumping
Reference: http://tools.ietf.org/html/rfc4180
"""
@@ -2890,6 +2875,10 @@ def safeCSValue(value):
return retVal
def filterPairValues(values):
"""
Returns only list-like values with length 2
"""
retVal = []
if not isNoneValue(values) and hasattr(values, '__iter__'):
@@ -2973,6 +2962,10 @@ def asciifyUrl(url, forceQuote=False):
return urlparse.urlunsplit([parts.scheme, netloc, path, query, parts.fragment])
def findPageForms(content, url, raise_=False, addToTargets=False):
"""
Parses given page content for possible forms
"""
class _(StringIO):
def __init__(self, content, url):
StringIO.__init__(self, unicodeencode(content, kb.pageEncoding) if isinstance(content, unicode) else content)
@@ -3016,15 +3009,18 @@ def findPageForms(content, url, raise_=False, addToTargets=False):
if not item.selected:
item.selected = True
break
request = form.click()
url = urldecode(request.get_full_url(), kb.pageEncoding)
method = request.get_method()
data = request.get_data() if request.has_data() else None
data = urldecode(data, kb.pageEncoding) if data and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in data else data
if not data and method and method.upper() == HTTPMETHOD.POST:
debugMsg = "invalid POST form with blank data detected"
logger.debug(debugMsg)
continue
target = (url, method, data, conf.cookie)
retVal.add(target)
else:
@@ -3041,6 +3037,10 @@ def findPageForms(content, url, raise_=False, addToTargets=False):
return retVal
def getHostHeader(url):
"""
Returns proper Host header value for a given target URL
"""
retVal = urlparse.urlparse(url).netloc
if any(retVal.endswith(':%d' % _) for _ in [80, 443]):
@@ -3048,7 +3048,11 @@ def getHostHeader(url):
return retVal
def executeCode(code, variables=None):
def evaluateCode(code, variables=None):
"""
Executes given python code given in a string form
"""
try:
exec(code, variables)
except Exception, ex:
@@ -3056,21 +3060,39 @@ def executeCode(code, variables=None):
raise sqlmapGenericException, errMsg
def serializeObject(object_):
"""
Serializes given object
"""
return pickle.dumps(object_)
def unserializeObject(value):
"""
Unserializes object from given serialized form
"""
retVal = None
if value:
retVal = pickle.loads(value.encode(UNICODE_ENCODING)) # pickle has problems with Unicode
return retVal
def resetCounter(counter):
kb.counters[counter] = 0
def resetCounter(technique):
"""
Resets query counter for a given technique
"""
def incrementCounter(counter):
if counter not in kb.counters:
resetCounter(counter)
kb.counters[counter] += 1
kb.counters[technique] = 0
def getCounter(counter):
return kb.counters.get(counter, 0)
def incrementCounter(technique):
"""
Increments query counter for a given technique
"""
kb.counters[technique] = getCounter(technique) + 1
def getCounter(technique):
"""
Returns query counter for a given technique
"""
return kb.counters.get(technique, 0)

View File

@@ -91,34 +91,32 @@ def urlencode(value, safe="%&=", convall=False, limit=False):
return value
count = 0
result = None
result = None if value is None else ""
if value is None:
return result
if value:
if convall or safe is None:
safe = ""
if convall or safe is None:
safe = ""
# corner case when character % really needs to be
# encoded (when not representing url encoded char)
# except in cases when tampering scripts are used
if all(map(lambda x: '%' in x, [safe, value])) and not kb.tamperFunctions:
value = re.sub("%(?![0-9a-fA-F]{2})", "%25", value, re.DOTALL | re.IGNORECASE)
# corner case when character % really needs to be
# encoded (when not representing url encoded char)
# except in cases when tampering scripts are used
if all(map(lambda x: '%' in x, [safe, value])) and not kb.tamperFunctions:
value = re.sub("%(?![0-9a-fA-F]{2})", "%25", value, re.DOTALL | re.IGNORECASE)
while True:
result = urllib.quote(utf8encode(value), safe)
while True:
result = urllib.quote(utf8encode(value), safe)
if limit and len(result) > URLENCODE_CHAR_LIMIT:
if count >= len(URLENCODE_FAILSAFE_CHARS):
break
while count < len(URLENCODE_FAILSAFE_CHARS):
safe += URLENCODE_FAILSAFE_CHARS[count]
count += 1
if safe[-1] in value:
if limit and len(result) > URLENCODE_CHAR_LIMIT:
if count >= len(URLENCODE_FAILSAFE_CHARS):
break
else:
break
while count < len(URLENCODE_FAILSAFE_CHARS):
safe += URLENCODE_FAILSAFE_CHARS[count]
count += 1
if safe[-1] in value:
break
else:
break
return result

View File

@@ -41,45 +41,45 @@ class Dump:
"""
def __init__(self):
self.__outputFile = None
self.__outputFP = None
self.__outputBP = None
self.__lock = threading.Lock()
self._outputFile = None
self._outputFP = None
self._outputBP = None
self._lock = threading.Lock()
def __write(self, data, n=True, console=True):
def _write(self, data, n=True, console=True):
text = "%s%s" % (data, "\n" if n else " ")
if console:
dataToStdout(text)
if kb.get("multiThreadMode"):
self.__lock.acquire()
self._lock.acquire()
self.__outputBP.write(text)
self._outputBP.write(text)
if self.__outputBP.tell() > BUFFERED_LOG_SIZE:
if self._outputBP.tell() > BUFFERED_LOG_SIZE:
self.flush()
if kb.get("multiThreadMode"):
self.__lock.release()
self._lock.release()
kb.dataOutputFlag = True
def flush(self):
if self.__outputBP and self.__outputFP and self.__outputBP.tell() > 0:
_ = self.__outputBP.getvalue()
self.__outputBP.truncate(0)
self.__outputFP.write(_)
if self._outputBP and self._outputFP and self._outputBP.tell() > 0:
_ = self._outputBP.getvalue()
self._outputBP.truncate(0)
self._outputFP.write(_)
def __formatString(self, inpStr):
def _formatString(self, inpStr):
return restoreDumpMarkedChars(getUnicode(inpStr))
def setOutputFile(self):
self.__outputFile = "%s%slog" % (conf.outputPath, os.sep)
self.__outputFP = codecs.open(self.__outputFile, "ab", UNICODE_ENCODING)
self.__outputBP = StringIO.StringIO()
self._outputFile = "%s%slog" % (conf.outputPath, os.sep)
self._outputFP = codecs.open(self._outputFile, "ab", UNICODE_ENCODING)
self._outputBP = StringIO.StringIO()
def getOutputFile(self):
return self.__outputFile
return self._outputFile
def string(self, header, data, sort=True):
if isinstance(data, (list, tuple, set)):
@@ -90,21 +90,21 @@ class Dump:
data = getUnicode(data)
if data:
data = self.__formatString(data)
data = self._formatString(data)
if data[-1] == '\n':
data = data[:-1]
if "\n" in data:
self.__write("%s:\n---\n%s\n---\n" % (header, data))
self._write("%s:\n---\n%s\n---\n" % (header, data))
else:
self.__write("%s: '%s'\n" % (header, data))
self._write("%s: '%s'\n" % (header, data))
else:
self.__write("%s:\tNone\n" % header)
self._write("%s:\tNone\n" % header)
def lister(self, header, elements, sort=True):
if elements:
self.__write("%s [%d]:" % (header, len(elements)))
self._write("%s [%d]:" % (header, len(elements)))
if sort:
try:
@@ -116,12 +116,12 @@ class Dump:
for element in elements:
if isinstance(element, basestring):
self.__write("[*] %s" % element)
self._write("[*] %s" % element)
elif isinstance(element, (list, tuple, set)):
self.__write("[*] " + ", ".join(getUnicode(e) for e in element))
self._write("[*] " + ", ".join(getUnicode(e) for e in element))
if elements:
self.__write("")
self._write("")
def technic(self, header, data):
self.string(header, data)
@@ -147,13 +147,13 @@ class Dump:
self.lister("database management system users", users)
def userSettings(self, header, userSettings, subHeader):
self.__areAdmins = set()
self._areAdmins = set()
if userSettings:
self.__write("%s:" % header)
self._write("%s:" % header)
if isinstance(userSettings, (tuple, list, set)):
self.__areAdmins = userSettings[1]
self._areAdmins = userSettings[1]
userSettings = userSettings[0]
users = userSettings.keys()
@@ -167,16 +167,16 @@ class Dump:
else:
stringSettings = " [%d]:" % len(settings)
if user in self.__areAdmins:
self.__write("[*] %s (administrator)%s" % (user, stringSettings))
if user in self._areAdmins:
self._write("[*] %s (administrator)%s" % (user, stringSettings))
else:
self.__write("[*] %s%s" % (user, stringSettings))
self._write("[*] %s%s" % (user, stringSettings))
if settings:
settings.sort()
for setting in settings:
self.__write(" %s: %s" % (subHeader, setting))
self._write(" %s: %s" % (subHeader, setting))
print
def dbs(self,dbs):
@@ -198,23 +198,23 @@ class Dump:
for db, tables in dbTables.items():
tables.sort()
self.__write("Database: %s" % db if db else "Current database")
self._write("Database: %s" % db if db else "Current database")
if len(tables) == 1:
self.__write("[1 table]")
self._write("[1 table]")
else:
self.__write("[%d tables]" % len(tables))
self._write("[%d tables]" % len(tables))
self.__write("+%s+" % lines)
self._write("+%s+" % lines)
for table in tables:
if isinstance(table, (list, tuple, set)):
table = table[0]
blank = " " * (maxlength - len(normalizeUnicode(table) or str(table)))
self.__write("| %s%s |" % (table, blank))
self._write("| %s%s |" % (table, blank))
self.__write("+%s+\n" % lines)
self._write("+%s+\n" % lines)
else:
self.string("tables", dbTables)
@@ -246,17 +246,17 @@ class Dump:
maxlength2 = max(maxlength2, len("TYPE"))
lines2 = "-" * (maxlength2 + 2)
self.__write("Database: %s\nTable: %s" % (db if db else "Current database", table))
self._write("Database: %s\nTable: %s" % (db if db else "Current database", table))
if len(columns) == 1:
self.__write("[1 column]")
self._write("[1 column]")
else:
self.__write("[%d columns]" % len(columns))
self._write("[%d columns]" % len(columns))
if colType is not None:
self.__write("+%s+%s+" % (lines1, lines2))
self._write("+%s+%s+" % (lines1, lines2))
else:
self.__write("+%s+" % lines1)
self._write("+%s+" % lines1)
blank1 = " " * (maxlength1 - len("COLUMN"))
@@ -264,11 +264,11 @@ class Dump:
blank2 = " " * (maxlength2 - len("TYPE"))
if colType is not None:
self.__write("| Column%s | Type%s |" % (blank1, blank2))
self.__write("+%s+%s+" % (lines1, lines2))
self._write("| Column%s | Type%s |" % (blank1, blank2))
self._write("+%s+%s+" % (lines1, lines2))
else:
self.__write("| Column%s |" % blank1)
self.__write("+%s+" % lines1)
self._write("| Column%s |" % blank1)
self._write("+%s+" % lines1)
for column in colList:
colType = columns[column]
@@ -276,14 +276,14 @@ class Dump:
if colType is not None:
blank2 = " " * (maxlength2 - len(colType))
self.__write("| %s%s | %s%s |" % (column, blank1, colType, blank2))
self._write("| %s%s | %s%s |" % (column, blank1, colType, blank2))
else:
self.__write("| %s%s |" % (column, blank1))
self._write("| %s%s |" % (column, blank1))
if colType is not None:
self.__write("+%s+%s+\n" % (lines1, lines2))
self._write("+%s+%s+\n" % (lines1, lines2))
else:
self.__write("+%s+\n" % lines1)
self._write("+%s+\n" % lines1)
def dbTablesCount(self, dbTables):
if isinstance(dbTables, dict) and len(dbTables) > 0:
@@ -296,16 +296,16 @@ class Dump:
maxlength1 = max(maxlength1, len(normalizeUnicode(table) or str(table)))
for db, counts in dbTables.items():
self.__write("Database: %s" % db if db else "Current database")
self._write("Database: %s" % db if db else "Current database")
lines1 = "-" * (maxlength1 + 2)
blank1 = " " * (maxlength1 - len("Table"))
lines2 = "-" * (maxlength2 + 2)
blank2 = " " * (maxlength2 - len("Entries"))
self.__write("+%s+%s+" % (lines1, lines2))
self.__write("| Table%s | Entries%s |" % (blank1, blank2))
self.__write("+%s+%s+" % (lines1, lines2))
self._write("+%s+%s+" % (lines1, lines2))
self._write("| Table%s | Entries%s |" % (blank1, blank2))
self._write("+%s+%s+" % (lines1, lines2))
sortedCounts = counts.keys()
sortedCounts.sort(reverse=True)
@@ -321,9 +321,9 @@ class Dump:
for table in tables:
blank1 = " " * (maxlength1 - len(normalizeUnicode(table) or str(table)))
blank2 = " " * (maxlength2 - len(str(count)))
self.__write("| %s%s | %d%s |" % (table, blank1, count, blank2))
self._write("| %s%s | %d%s |" % (table, blank1, count, blank2))
self.__write("+%s+%s+\n" % (lines1, lines2))
self._write("+%s+%s+\n" % (lines1, lines2))
else:
logger.error("unable to retrieve the number of entries for any table")
@@ -365,7 +365,7 @@ class Dump:
separator += "+%s" % lines
separator += "+"
self.__write("Database: %s\nTable: %s" % (db if db else "Current database", table))
self._write("Database: %s\nTable: %s" % (db if db else "Current database", table))
if conf.replicate:
cols = []
@@ -402,11 +402,11 @@ class Dump:
rtable = replication.createTable(table, cols)
if count == 1:
self.__write("[1 entry]")
self._write("[1 entry]")
else:
self.__write("[%d entries]" % count)
self._write("[%d entries]" % count)
self.__write(separator)
self._write(separator)
for column in columns:
if column != "__infos__":
@@ -414,7 +414,7 @@ class Dump:
maxlength = int(info["length"])
blank = " " * (maxlength - len(column))
self.__write("| %s%s" % (column, blank), n=False)
self._write("| %s%s" % (column, blank), n=False)
if not conf.replicate:
if field == fields:
@@ -424,7 +424,7 @@ class Dump:
field += 1
self.__write("|\n%s" % separator)
self._write("|\n%s" % separator)
if not conf.replicate:
dataToDumpFile(dumpFP, "\n")
@@ -461,7 +461,7 @@ class Dump:
values.append(value)
maxlength = int(info["length"])
blank = " " * (maxlength - len(value))
self.__write("| %s%s" % (value, blank), n=False, console=console)
self._write("| %s%s" % (value, blank), n=False, console=console)
if not conf.replicate:
if field == fields:
@@ -477,12 +477,12 @@ class Dump:
except sqlmapValueException:
pass
self.__write("|", console=console)
self._write("|", console=console)
if not conf.replicate:
dataToDumpFile(dumpFP, "\n")
self.__write("%s\n" % separator)
self._write("%s\n" % separator)
if conf.replicate:
rtable.endTransaction()
@@ -502,26 +502,26 @@ class Dump:
msg = "Column%s found in the " % colConsiderStr
msg += "following databases:"
self.__write(msg)
self._write(msg)
printDbs = {}
_ = {}
for db, tblData in dbs.items():
for tbl, colData in tblData.items():
for col, dataType in colData.items():
if column.lower() in col.lower():
if db in printDbs:
if tbl in printDbs[db]:
printDbs[db][tbl][col] = dataType
if db in _:
if tbl in _[db]:
_[db][tbl][col] = dataType
else:
printDbs[db][tbl] = { col: dataType }
_[db][tbl] = { col: dataType }
else:
printDbs[db] = {}
printDbs[db][tbl] = { col: dataType }
_[db] = {}
_[db][tbl] = { col: dataType }
continue
self.dbTableColumns(printDbs)
self.dbTableColumns(_)
def query(self, query, queryRes):
self.string(query, queryRes)

View File

@@ -249,6 +249,9 @@ SQL_STATEMENTS = {
# string representation for NULL value
NULL = "NULL"
# string representation for current database
CURRENT_DB = "CD"
# Regular expressions used for parsing error messages (--parse-errors)
ERROR_PARSING_REGEXES = (
r"<b>[^<]*(fatal|error|warning|exception)[^<]*</b>:?\s*(?P<result>.+?)<br\s*/?\s*>",