Major code refactoring - moved to one location only (getIdentifiedDBMS() in common.py) the retrieval of identified/fingerprinted DBMS.

Minor bug fixes thanks to previous refactoring too.
This commit is contained in:
Bernardo Damele
2011-01-13 17:36:54 +00:00
parent a1d1f69c3f
commit 2ac8debea0
37 changed files with 342 additions and 314 deletions

View File

@@ -13,6 +13,7 @@ from xml.etree import ElementTree as ET
from lib.core.common import getCompiledRegex
from lib.core.common import getErrorParsedDBMSes
from lib.core.common import getIdentifiedDBMS
from lib.core.common import isDBMSVersionAtLeast
from lib.core.common import isTechniqueAvailable
from lib.core.common import randomInt
@@ -33,13 +34,6 @@ class Agent:
This class defines the SQL agent methods.
"""
def __init__(self):
kb.misc = advancedDict()
kb.misc.delimiter = randomStr(length=6)
kb.misc.start = ":%s:" % randomStr(length=3, lowercase=True)
kb.misc.stop = ":%s:" % randomStr(length=3, lowercase=True)
kb.misc.space = ":%s:" % randomStr(length=1, lowercase=True)
def payloadDirect(self, query):
if query.startswith("AND "):
query = query.replace("AND ", "SELECT ", 1)
@@ -211,8 +205,8 @@ class Agent:
payload = payload.replace("[ORIGVALUE]", origvalue)
if "[INFERENCE]" in payload:
if kb.dbms is not None:
inference = queries[kb.dbms].inference
if getIdentifiedDBMS() is not None:
inference = queries[getIdentifiedDBMS()].inference
if "dbms_version" in inference:
if isDBMSVersionAtLeast(inference.dbms_version):
@@ -223,11 +217,6 @@ class Agent:
inferenceQuery = inference.query
payload = payload.replace("[INFERENCE]", inferenceQuery)
elif hasattr(kb.misc, "testedDbms") and kb.misc.testedDbms is not None:
inferenceQuery = queries[kb.misc.testedDbms].inference.query
payload = payload.replace("[INFERENCE]", inferenceQuery)
else:
errMsg = "invalid usage of inference payload without "
errMsg += "knowledge of underlying DBMS"
@@ -275,17 +264,17 @@ class Agent:
# SQLite version 2 does not support neither CAST() nor IFNULL(),
# introduced only in SQLite version 3
if kb.dbms == DBMS.SQLITE:
if getIdentifiedDBMS() == DBMS.SQLITE:
return field
if field.startswith("(CASE"):
nulledCastedField = field
else:
nulledCastedField = queries[kb.dbms].cast.query % field
if kb.dbms == DBMS.ACCESS:
nulledCastedField = queries[kb.dbms].isnull.query % (nulledCastedField, nulledCastedField)
nulledCastedField = queries[getIdentifiedDBMS()].cast.query % field
if getIdentifiedDBMS() == DBMS.ACCESS:
nulledCastedField = queries[getIdentifiedDBMS()].isnull.query % (nulledCastedField, nulledCastedField)
else:
nulledCastedField = queries[kb.dbms].isnull.query % nulledCastedField
nulledCastedField = queries[getIdentifiedDBMS()].isnull.query % nulledCastedField
return nulledCastedField
@@ -324,7 +313,7 @@ class Agent:
fields = fields.replace(", ", ",")
fieldsSplitted = fields.split(",")
dbmsDelimiter = queries[kb.dbms].delimiter.query
dbmsDelimiter = queries[getIdentifiedDBMS()].delimiter.query
nulledCastedFields = []
for field in fieldsSplitted:
@@ -383,13 +372,13 @@ class Agent:
def simpleConcatQuery(self, query1, query2):
concatenatedQuery = ""
if kb.dbms == DBMS.MYSQL:
if getIdentifiedDBMS() == DBMS.MYSQL:
concatenatedQuery = "CONCAT(%s,%s)" % (query1, query2)
elif kb.dbms in ( DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE ):
elif getIdentifiedDBMS() in ( DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE ):
concatenatedQuery = "%s||%s" % (query1, query2)
elif kb.dbms in (DBMS.MSSQL, DBMS.SYBASE):
elif getIdentifiedDBMS() in (DBMS.MSSQL, DBMS.SYBASE):
concatenatedQuery = "%s+%s" % (query1, query2)
return concatenatedQuery
@@ -431,7 +420,7 @@ class Agent:
concatenatedQuery = query
fieldsSelectFrom, fieldsSelect, fieldsNoSelect, fieldsSelectTop, fieldsSelectCase, _, fieldsToCastStr = self.getFields(query)
if kb.dbms == DBMS.MYSQL:
if getIdentifiedDBMS() == DBMS.MYSQL:
if fieldsSelectCase:
concatenatedQuery = concatenatedQuery.replace("SELECT ", "CONCAT('%s'," % kb.misc.start, 1)
concatenatedQuery += ",'%s')" % kb.misc.stop
@@ -444,7 +433,7 @@ class Agent:
elif fieldsNoSelect:
concatenatedQuery = "CONCAT('%s',%s,'%s')" % (kb.misc.start, concatenatedQuery, kb.misc.stop)
elif kb.dbms in ( DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE ):
elif getIdentifiedDBMS() in ( DBMS.PGSQL, DBMS.ORACLE, DBMS.SQLITE ):
if fieldsSelectCase:
concatenatedQuery = concatenatedQuery.replace("SELECT ", "'%s'||" % kb.misc.start, 1)
concatenatedQuery += "||'%s'" % kb.misc.stop
@@ -457,10 +446,10 @@ class Agent:
elif fieldsNoSelect:
concatenatedQuery = "'%s'||%s||'%s'" % (kb.misc.start, concatenatedQuery, kb.misc.stop)
if kb.dbms == DBMS.ORACLE and " FROM " not in concatenatedQuery and ( fieldsSelect or fieldsNoSelect ):
if getIdentifiedDBMS() == DBMS.ORACLE and " FROM " not in concatenatedQuery and ( fieldsSelect or fieldsNoSelect ):
concatenatedQuery += " FROM DUAL"
elif kb.dbms in (DBMS.MSSQL, DBMS.SYBASE):
elif getIdentifiedDBMS() in (DBMS.MSSQL, DBMS.SYBASE):
if fieldsSelectTop:
topNum = re.search("\ASELECT\s+TOP\s+([\d]+)\s+", concatenatedQuery, re.I).group(1)
concatenatedQuery = concatenatedQuery.replace("SELECT TOP %s " % topNum, "TOP %s '%s'+" % (topNum, kb.misc.start), 1)
@@ -511,13 +500,13 @@ class Agent:
"""
if query.startswith("SELECT "):
query = query[len("SELECT "):]
query = query[len("SELECT "):]
inbandQuery = self.prefixQuery("UNION ALL SELECT ", prefix=prefix)
if query.startswith("TOP"):
topNum = re.search("\ATOP\s+([\d]+)\s+", query, re.I).group(1)
query = query[len("TOP %s " % topNum):]
topNum = re.search("\ATOP\s+([\d]+)\s+", query, re.I).group(1)
query = query[len("TOP %s " % topNum):]
inbandQuery += "TOP %s " % topNum
intoRegExp = re.search("(\s+INTO (DUMP|OUT)FILE\s+\'(.+?)\')", query, re.I)
@@ -526,7 +515,7 @@ class Agent:
intoRegExp = intoRegExp.group(1)
query = query[:query.index(intoRegExp)]
if kb.dbms == DBMS.ORACLE and inbandQuery.endswith(" FROM DUAL"):
if getIdentifiedDBMS() == DBMS.ORACLE and inbandQuery.endswith(" FROM DUAL"):
inbandQuery = inbandQuery[:-len(" FROM DUAL")]
for element in range(count):
@@ -546,7 +535,7 @@ class Agent:
conditionIndex = query.index(" FROM ")
inbandQuery += query[conditionIndex:]
if kb.dbms == DBMS.ORACLE or DBMS.ORACLE in getErrorParsedDBMSes():
if getIdentifiedDBMS() == DBMS.ORACLE:
if " FROM " not in inbandQuery:
inbandQuery += " FROM DUAL"
@@ -565,7 +554,7 @@ class Agent:
else:
inbandQuery += char
if kb.dbms == DBMS.ORACLE:
if getIdentifiedDBMS() == DBMS.ORACLE:
inbandQuery += " FROM DUAL"
inbandQuery = self.suffixQuery(inbandQuery, comment, suffix)
@@ -595,21 +584,21 @@ class Agent:
"""
limitedQuery = query
limitStr = queries[kb.dbms].limit.query
limitStr = queries[getIdentifiedDBMS()].limit.query
fromIndex = limitedQuery.index(" FROM ")
untilFrom = limitedQuery[:fromIndex]
fromFrom = limitedQuery[fromIndex+1:]
orderBy = False
if kb.dbms in ( DBMS.MYSQL, DBMS.PGSQL, DBMS.SQLITE ):
limitStr = queries[kb.dbms].limit.query % (num, 1)
if getIdentifiedDBMS() in ( DBMS.MYSQL, DBMS.PGSQL, DBMS.SQLITE ):
limitStr = queries[getIdentifiedDBMS()].limit.query % (num, 1)
limitedQuery += " %s" % limitStr
elif kb.dbms == DBMS.FIREBIRD:
limitStr = queries[kb.dbms].limit.query % (num+1, num+1)
elif getIdentifiedDBMS() == DBMS.FIREBIRD:
limitStr = queries[getIdentifiedDBMS()].limit.query % (num+1, num+1)
limitedQuery += " %s" % limitStr
elif kb.dbms == DBMS.ORACLE:
elif getIdentifiedDBMS() == DBMS.ORACLE:
if " ORDER BY " in limitedQuery and "(SELECT " in limitedQuery:
orderBy = limitedQuery[limitedQuery.index(" ORDER BY "):]
limitedQuery = limitedQuery[:limitedQuery.index(" ORDER BY ")]
@@ -621,7 +610,7 @@ class Agent:
limitedQuery = limitedQuery % fromFrom
limitedQuery += "=%d" % (num + 1)
elif kb.dbms in (DBMS.MSSQL, DBMS.SYBASE):
elif getIdentifiedDBMS() in (DBMS.MSSQL, DBMS.SYBASE):
forgeNotIn = True
if " ORDER BY " in limitedQuery:
@@ -635,7 +624,7 @@ class Agent:
limitedQuery = limitedQuery.replace("DISTINCT %s" % notDistinct, notDistinct)
if limitedQuery.startswith("SELECT TOP ") or limitedQuery.startswith("TOP "):
topNums = re.search(queries[kb.dbms].limitregexp.query, limitedQuery, re.I)
topNums = re.search(queries[getIdentifiedDBMS()].limitregexp.query, limitedQuery, re.I)
if topNums:
topNums = topNums.groups()
@@ -681,7 +670,7 @@ class Agent:
@rtype: C{str}
"""
return queries[kb.dbms if kb.dbms else kb.misc.testedDbms].case.query % expression
return queries[getIdentifiedDBMS()].case.query % expression
def addPayloadDelimiters(self, inpStr):
"""

View File

@@ -218,15 +218,15 @@ def formatDBMSfp(versions=None):
versions = kb.dbmsVersion
if isinstance(versions, basestring):
return "%s %s" % (kb.dbms, versions)
return "%s %s" % (getIdentifiedDBMS(), versions)
elif isinstance(versions, (list, set, tuple)):
return "%s %s" % (kb.dbms, " and ".join([version for version in versions]))
return "%s %s" % (getIdentifiedDBMS(), " and ".join([version for version in versions]))
elif not versions:
warnMsg = "unable to extensively fingerprint the back-end "
warnMsg += "DBMS version"
logger.warn(warnMsg)
return kb.dbms
return getIdentifiedDBMS()
def formatFingerprintString(values, chain=" or "):
strJoin = "|".join([v for v in values])
@@ -627,7 +627,7 @@ def parsePasswordHash(password):
if not password or password == " ":
password = "NULL"
if kb.dbms == DBMS.MSSQL and password != "NULL" and isHexEncodedString(password):
if getIdentifiedDBMS() == DBMS.MSSQL and password != "NULL" and isHexEncodedString(password):
hexPassword = password
password = "%s\n" % hexPassword
password += "%sheader: %s\n" % (blank, hexPassword[:6])
@@ -928,25 +928,25 @@ def parseUnionPage(output, expression, partial=False, condition=None, sort=True)
def getDelayQuery(andCond=False):
query = None
if kb.dbms in (DBMS.MYSQL, DBMS.PGSQL):
if getIdentifiedDBMS() in (DBMS.MYSQL, DBMS.PGSQL):
if not kb.data.banner:
conf.dbmsHandler.getVersionFromBanner()
banVer = kb.bannerFp["dbmsVersion"] if 'dbmsVersion' in kb.bannerFp else None
if banVer is None or (kb.dbms == DBMS.MYSQL and banVer >= "5.0.12") or (kb.dbms == DBMS.PGSQL and banVer >= "8.2"):
query = queries[kb.dbms].timedelay.query % conf.timeSec
if banVer is None or (getIdentifiedDBMS() == DBMS.MYSQL and banVer >= "5.0.12") or (getIdentifiedDBMS() == DBMS.PGSQL and banVer >= "8.2"):
query = queries[getIdentifiedDBMS()].timedelay.query % conf.timeSec
else:
query = queries[kb.dbms].timedelay.query2 % conf.timeSec
elif kb.dbms == DBMS.FIREBIRD:
query = queries[kb.dbms].timedelay.query
query = queries[getIdentifiedDBMS()].timedelay.query2 % conf.timeSec
elif getIdentifiedDBMS() == DBMS.FIREBIRD:
query = queries[getIdentifiedDBMS()].timedelay.query
else:
query = queries[kb.dbms].timedelay.query % conf.timeSec
query = queries[getIdentifiedDBMS()].timedelay.query % conf.timeSec
if andCond:
if kb.dbms in ( DBMS.MYSQL, DBMS.SQLITE ):
if getIdentifiedDBMS() in ( DBMS.MYSQL, DBMS.SQLITE ):
query = query.replace("SELECT ", "")
elif kb.dbms == DBMS.FIREBIRD:
elif getIdentifiedDBMS() == DBMS.FIREBIRD:
query = "(%s)>0" % query
return query
@@ -1763,7 +1763,7 @@ def aliasToDbmsEnum(value):
retVal = None
for key, item in dbmsDict.items():
if value in item[0]:
if value.lower() in item[0]:
retVal = key
break
@@ -2040,6 +2040,18 @@ def getErrorParsedDBMSes():
return kb.htmlFp
def getIdentifiedDBMS():
dbms = None
if kb.dbms is not None:
dbms = kb.dbms
elif conf.dbms is not None:
dbms = conf.dbms
elif getErrorParsedDBMSes() is not None:
dbms = getErrorParsedDBMSes()[0]
return aliasToDbmsEnum(dbms)
def showHttpErrorCodes():
"""
Shows all HTTP error codes raised till now

View File

@@ -31,7 +31,7 @@ class DBMS:
MSSQL = "Microsoft SQL Server"
MYSQL = "MySQL"
ORACLE = "Oracle"
PGSQL = "PostgreSQL"
PGSQL = "PostgreSQL"
SQLITE = "SQLite"
SYBASE = "Sybase"

View File

@@ -34,6 +34,7 @@ from lib.core.common import parseTargetDirect
from lib.core.common import parseTargetUrl
from lib.core.common import paths
from lib.core.common import randomRange
from lib.core.common import randomStr
from lib.core.common import readCachedFileContent
from lib.core.common import readInput
from lib.core.common import runningAsAdmin
@@ -46,6 +47,7 @@ from lib.core.data import paths
from lib.core.data import queries
from lib.core.datatype import advancedDict
from lib.core.datatype import injectionDict
from lib.core.enums import DBMS
from lib.core.enums import HTTPMETHOD
from lib.core.enums import PAYLOAD
from lib.core.enums import PRIORITY
@@ -1165,6 +1167,12 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.threadException = False
kb.threadData = {}
kb.misc = advancedDict()
kb.misc.delimiter = randomStr(length=6)
kb.misc.start = ":%s:" % randomStr(length=3, lowercase=True)
kb.misc.stop = ":%s:" % randomStr(length=3, lowercase=True)
kb.misc.space = ":%s:" % randomStr(length=1, lowercase=True)
if flushAll:
kb.keywords = set(getFileItems(paths.SQL_KEYWORDS))
kb.tamperFunctions = []

View File

@@ -13,6 +13,7 @@ from lib.core.common import aliasToDbmsEnum
from lib.core.common import dataToSessionFile
from lib.core.common import formatFingerprintString
from lib.core.common import getFilteredPageContent
from lib.core.common import getIdentifiedDBMS
from lib.core.common import readInput
from lib.core.convert import base64pickle
from lib.core.convert import base64unpickle
@@ -140,7 +141,7 @@ def setDbms(dbms):
if dbmsRegExp:
dbms = dbmsRegExp.group(1)
kb.dbms = dbms
kb.dbms = aliasToDbmsEnum(dbms)
logger.info("the back-end DBMS is %s" % kb.dbms)
@@ -340,7 +341,7 @@ def resumeConfKb(expression, url, value):
if '.' in table:
db, table = table.split('.')
else:
db = "%s%s" % (kb.dbms, METADB_SUFFIX)
db = "%s%s" % (getIdentifiedDBMS(), METADB_SUFFIX)
logMsg = "resuming brute forced table name "
logMsg += "'%s' from session file" % table
@@ -355,7 +356,7 @@ def resumeConfKb(expression, url, value):
if '.' in table:
db, table = table.split('.')
else:
db = "%s%s" % (kb.dbms, METADB_SUFFIX)
db = "%s%s" % (getIdentifiedDBMS(), METADB_SUFFIX)
logMsg = "resuming brute forced column name "
logMsg += "'%s' for table '%s' from session file" % (colName, table)

View File

@@ -12,6 +12,7 @@ import os
import rlcompleter
from lib.core import readlineng as readline
from lib.core.common import getIdentifiedDBMS
from lib.core.data import kb
from lib.core.data import paths
from lib.core.data import queries
@@ -29,7 +30,7 @@ def loadHistory():
def queriesForAutoCompletion():
autoComplQueries = {}
for item in queries[kb.dbms]._toflat():
for item in queries[getIdentifiedDBMS()]._toflat():
if item._has_key('query') and len(item.query) > 1 and item._name != 'blind':
autoComplQueries[item.query] = None

View File

@@ -7,18 +7,15 @@ Copyright (c) 2006-2010 sqlmap developers (http://sqlmap.sourceforge.net/)
See the file 'doc/COPYING' for copying permission
"""
from lib.core.common import getErrorParsedDBMSes
from lib.core.data import kb
from lib.core.common import getIdentifiedDBMS
from lib.core.datatype import advancedDict
class Unescaper(advancedDict):
def unescape(self, expression, quote=True, dbms=None):
if hasattr(kb, "dbms") and kb.dbms is not None:
return self[kb.dbms](expression, quote=quote)
elif hasattr(kb.misc, "testedDbms") and kb.misc.testedDbms is not None:
return self[kb.misc.testedDbms](expression, quote=quote)
elif getErrorParsedDBMSes():
return self[getErrorParsedDBMSes()[0]](expression, quote=quote)
identifiedDbms = getIdentifiedDBMS()
if identifiedDbms is not None:
return self[identifiedDbms](expression, quote=quote)
elif dbms is not None:
return self[dbms](expression, quote=quote)
else: