bulk commit for safe/unsafe identificator naming (done and tested for all 4 major DBMSes) and one bug fix for --search-column on MSSQL (inside queries)

This commit is contained in:
Miroslav Stampar
2011-03-28 11:01:55 +00:00
parent 19a6f86954
commit 73e5d20ade
4 changed files with 114 additions and 63 deletions

View File

@@ -807,12 +807,12 @@ class Enumeration:
if "," in conf.db:
dbs = conf.db.split(",")
query += " WHERE "
query += " OR ".join("%s = '%s'" % (condition, db) for db in dbs)
query += " OR ".join("%s = '%s'" % (condition, self.__unsafeSQLIdentificatorNaming(db)) for db in dbs)
else:
query += " WHERE %s='%s'" % (condition, conf.db)
query += " WHERE %s='%s'" % (condition, self.__unsafeSQLIdentificatorNaming(conf.db))
elif conf.excludeSysDbs:
query += " WHERE "
query += " AND ".join("%s != '%s'" % (condition, db) for db in self.excludeDbsList)
query += " AND ".join("%s != '%s'" % (condition, self.__unsafeSQLIdentificatorNaming(db)) for db in self.excludeDbsList)
infoMsg = "skipping system databases '%s'" % ", ".join(db for db in self.excludeDbsList)
logger.info(infoMsg)
@@ -835,6 +835,8 @@ class Enumeration:
value = newValue
for db, table in value:
db = self.__safeSQLIdentificatorNaming(db)
table = self.__safeSQLIdentificatorNaming(table, True)
if not kb.data.cachedTables.has_key(db):
kb.data.cachedTables[db] = [table]
else:
@@ -855,7 +857,7 @@ class Enumeration:
if Backend.getIdentifiedDbms() in (DBMS.SQLITE, DBMS.FIREBIRD, DBMS.MAXDB, DBMS.ACCESS):
query = rootQuery.blind.count
else:
query = rootQuery.blind.count % db
query = rootQuery.blind.count % self.__unsafeSQLIdentificatorNaming(db)
count = inject.getValue(query, inband=False, error=False, expected=EXPECTED.INT, charsetType=2)
if not isNumPosStrValue(count):
@@ -880,10 +882,11 @@ class Enumeration:
elif Backend.getIdentifiedDbms() in (DBMS.SQLITE, DBMS.FIREBIRD):
query = rootQuery.blind.query % index
else:
query = rootQuery.blind.query % (db, index)
query = rootQuery.blind.query % (self.__unsafeSQLIdentificatorNaming(db), index)
table = inject.getValue(query, inband=False, error=False)
tables.append(table)
kb.hintValue = table
table = self.__safeSQLIdentificatorNaming(table, True)
tables.append(table)
if tables:
kb.data.cachedTables[db] = tables
@@ -908,8 +911,6 @@ class Enumeration:
if "." in conf.tbl:
if not conf.db:
conf.db, conf.tbl = conf.tbl.split(".")
elif Backend.getIdentifiedDbms() == DBMS.MSSQL:
conf.tbl = "%s.%s" % (DEFAULT_MSSQL_SCHEMA, conf.tbl)
self.forceDbmsEnum()
@@ -933,7 +934,7 @@ class Enumeration:
logger.error(errMsg)
bruteForce = True
conf.tbl = self.__safeSQLIdentificatorNaming(conf.tbl)
conf.tbl = self.__safeSQLIdentificatorNaming(conf.tbl, True)
conf.db = self.__safeSQLIdentificatorNaming(conf.db)
if bruteForce:
@@ -973,8 +974,8 @@ class Enumeration:
if Backend.getIdentifiedDbms() == DBMS.ORACLE:
conf.col = conf.col.upper()
colList = conf.col.split(",")
condQuery = " AND (" + " OR ".join("%s LIKE '%s'" % (condition, "%" + col + "%") for col in colList) + ")"
infoMsg += "like '%s' " % ", ".join(col for col in colList)
condQuery = " AND (" + " OR ".join("%s LIKE '%s'" % (condition, "%" + self.__unsafeSQLIdentificatorNaming(col) + "%") for col in colList) + ")"
infoMsg += "like '%s' " % ", ".join(self.__unsafeSQLIdentificatorNaming(col) for col in colList)
else:
condQuery = ""
@@ -984,16 +985,16 @@ class Enumeration:
if isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION) or isTechniqueAvailable(PAYLOAD.TECHNIQUE.ERROR) or conf.direct:
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ):
query = rootQuery.inband.query % (conf.tbl, conf.db)
query = rootQuery.inband.query % (self.__unsafeSQLIdentificatorNaming(conf.tbl), self.__unsafeSQLIdentificatorNaming(conf.db))
query += condQuery
elif Backend.getIdentifiedDbms() == DBMS.ORACLE:
query = rootQuery.inband.query % conf.tbl.upper()
query = rootQuery.inband.query % self.__unsafeSQLIdentificatorNaming(conf.tbl.upper())
query += condQuery
elif Backend.getIdentifiedDbms() == DBMS.MSSQL:
query = rootQuery.inband.query % (conf.db, conf.db,
conf.db, conf.db,
conf.db, conf.db,
conf.db, conf.tbl if '.' not in conf.tbl else conf.tbl.split('.')[1])
conf.db, self.__unsafeSQLIdentificatorNaming(conf.tbl))
query += condQuery.replace("[DB]", conf.db)
elif Backend.getIdentifiedDbms() == DBMS.SQLITE:
query = rootQuery.inband.query % conf.tbl
@@ -1024,16 +1025,16 @@ class Enumeration:
logger.info(infoMsg)
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ):
query = rootQuery.blind.count % (conf.tbl, conf.db)
query = rootQuery.blind.count % (self.__unsafeSQLIdentificatorNaming(conf.tbl), self.__unsafeSQLIdentificatorNaming(conf.db))
query += condQuery
elif Backend.getIdentifiedDbms() == DBMS.ORACLE:
query = rootQuery.blind.count % conf.tbl.upper()
query = rootQuery.blind.count % self.__unsafeSQLIdentificatorNaming(conf.tbl.upper())
query += condQuery
elif Backend.getIdentifiedDbms() in DBMS.MSSQL:
query = rootQuery.blind.count % (conf.db, conf.db, \
conf.tbl if '.' not in conf.tbl else conf.tbl.split('.')[1])
self.__unsafeSQLIdentificatorNaming(conf.tbl))
query += condQuery.replace("[DB]", conf.db)
elif Backend.getIdentifiedDbms() == DBMS.FIREBIRD:
@@ -1061,18 +1062,18 @@ class Enumeration:
for index in indexRange:
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ):
query = rootQuery.blind.query % (conf.tbl, conf.db)
query = rootQuery.blind.query % (self.__unsafeSQLIdentificatorNaming(conf.tbl), self.__unsafeSQLIdentificatorNaming(conf.db))
query += condQuery
field = None
elif Backend.getIdentifiedDbms() == DBMS.ORACLE:
query = rootQuery.blind.query % (conf.tbl.upper())
query = rootQuery.blind.query % self.__unsafeSQLIdentificatorNaming(conf.tbl.upper())
query += condQuery
field = None
elif Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
query = rootQuery.blind.query % (conf.db, conf.db,
conf.db, conf.db,
conf.db, conf.db,
conf.tbl if '.' not in conf.tbl else conf.tbl.split('.')[1])
self.__unsafeSQLIdentificatorNaming(conf.tbl))
query += condQuery.replace("[DB]", conf.db)
field = condition.replace("[DB]", conf.db)
elif Backend.getIdentifiedDbms() == DBMS.FIREBIRD:
@@ -1083,17 +1084,15 @@ class Enumeration:
query = agent.limitQuery(index, query, field)
column = inject.getValue(query, inband=False, error=False)
column = self.__safeSQLIdentificatorNaming(column)
if not onlyColNames:
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ):
query = rootQuery.blind.query2 % (conf.tbl, column, conf.db)
query = rootQuery.blind.query2 % (self.__unsafeSQLIdentificatorNaming(conf.tbl), column, self.__unsafeSQLIdentificatorNaming(conf.db))
elif Backend.getIdentifiedDbms() == DBMS.ORACLE:
query = rootQuery.blind.query2 % (conf.tbl.upper(), column)
query = rootQuery.blind.query2 % (self.__unsafeSQLIdentificatorNaming(conf.tbl.upper()), column)
elif Backend.getIdentifiedDbms() == DBMS.MSSQL:
query = rootQuery.blind.query2 % (conf.db, conf.db, conf.db,
conf.db, column, conf.db,
conf.db, conf.db, conf.tbl if '.' not in conf.tbl else conf.tbl.split('.')[1])
conf.db, conf.db, self.__unsafeSQLIdentificatorNaming(conf.tbl))
elif Backend.getIdentifiedDbms() == DBMS.FIREBIRD:
query = rootQuery.blind.query2 % (conf.tbl, column)
@@ -1102,8 +1101,10 @@ class Enumeration:
if Backend.getIdentifiedDbms() == DBMS.FIREBIRD:
colType = firebirdTypes[colType] if colType in firebirdTypes else colType
column = self.__safeSQLIdentificatorNaming(column)
columns[column] = colType
else:
column = self.__safeSQLIdentificatorNaming(column)
columns[column] = None
if columns:
@@ -1208,12 +1209,15 @@ class Enumeration:
return entries, lengths
def __safeSQLIdentificatorNaming(self, value):
def __safeSQLIdentificatorNaming(self, value, isTable=False):
"""
Returns a safe representation of SQL identificator name
"""
retVal = value
if isinstance(value, basestring):
if isTable and Backend.getIdentifiedDbms() == DBMS.MSSQL and '.' not in value:
value = "%s.%s" % (DEFAULT_MSSQL_SCHEMA, value)
parts = value.split('.')
for i in range(len(parts)):
if not re.match(r"\A[A-Za-z0-9_]+\Z", parts[i]):
@@ -1222,6 +1226,7 @@ class Enumeration:
elif Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.ORACLE, DBMS.PGSQL):
parts[i] = "\"%s\"" % parts[i].strip("\"")
retVal = ".".join(parts)
return retVal
def __unsafeSQLIdentificatorNaming(self, value):
@@ -1255,8 +1260,6 @@ class Enumeration:
if "." in conf.tbl:
if not conf.db:
conf.db, conf.tbl = conf.tbl.split(".")
elif Backend.getIdentifiedDbms() == DBMS.MSSQL:
conf.tbl = "%s.%s" % (DEFAULT_MSSQL_SCHEMA, conf.tbl)
self.forceDbmsEnum()
@@ -1270,7 +1273,7 @@ class Enumeration:
rootQuery = queries[Backend.getIdentifiedDbms()].dump_table
conf.tbl = self.__safeSQLIdentificatorNaming(conf.tbl)
conf.tbl = self.__safeSQLIdentificatorNaming(conf.tbl, True)
conf.db = self.__safeSQLIdentificatorNaming(conf.db)
if conf.col:
@@ -1602,21 +1605,23 @@ class Enumeration:
dbConsider, dbCondParam = self.likeOrExact("database")
for db in dbList:
db = self.__safeSQLIdentificatorNaming(db)
infoMsg = "searching database"
if dbConsider == "1":
infoMsg += "s like"
infoMsg += " '%s'" % db
infoMsg += " '%s'" % self.__unsafeSQLIdentificatorNaming(db)
logger.info(infoMsg)
if conf.excludeSysDbs:
exclDbsQuery = "".join(" AND '%s' != %s" % (db, dbCond) for db in self.excludeDbsList)
exclDbsQuery = "".join(" AND '%s' != %s" % (self.__unsafeSQLIdentificatorNaming(db), dbCond) for db in self.excludeDbsList)
infoMsg = "skipping system databases '%s'" % ", ".join(db for db in self.excludeDbsList)
logger.info(infoMsg)
else:
exclDbsQuery = ""
dbQuery = "%s%s" % (dbCond, dbCondParam)
dbQuery = dbQuery % db
dbQuery = dbQuery % self.__unsafeSQLIdentificatorNaming(db)
if isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION) or isTechniqueAvailable(PAYLOAD.TECHNIQUE.ERROR) or conf.direct:
if Backend.getIdentifiedDbms() == DBMS.MYSQL and not kb.data.has_information_schema:
@@ -1632,12 +1637,13 @@ class Enumeration:
values = [ values ]
for value in values:
value = self.__safeSQLIdentificatorNaming(value)
foundDbs.append(value)
else:
infoMsg = "fetching number of databases"
if dbConsider == "1":
infoMsg += "s like"
infoMsg += " '%s'" % db
infoMsg += " '%s'" % self.__unsafeSQLIdentificatorNaming(db)
logger.info(infoMsg)
if Backend.getIdentifiedDbms() == DBMS.MYSQL and not kb.data.has_information_schema:
@@ -1652,7 +1658,7 @@ class Enumeration:
warnMsg = "no database"
if dbConsider == "1":
warnMsg += "s like"
warnMsg += " '%s' found" % db
warnMsg += " '%s' found" % self.__unsafeSQLIdentificatorNaming(db)
logger.warn(warnMsg)
continue
@@ -1668,7 +1674,9 @@ class Enumeration:
query += exclDbsQuery
query = agent.limitQuery(index, query, dbCond)
foundDbs.append(inject.getValue(query, inband=False, error=False))
value = inject.getValue(query, inband=False, error=False)
value = self.__safeSQLIdentificatorNaming(value)
foundDbs.append(value)
return foundDbs
@@ -1707,17 +1715,19 @@ class Enumeration:
tblConsider, tblCondParam = self.likeOrExact("table")
for tbl in tblList:
tbl = self.__safeSQLIdentificatorNaming(tbl, True)
if Backend.getIdentifiedDbms() == DBMS.ORACLE:
tbl = tbl.upper()
infoMsg = "searching table"
if tblConsider == "1":
infoMsg += "s like"
infoMsg += " '%s'" % tbl
infoMsg += " '%s'" % self.__unsafeSQLIdentificatorNaming(tbl)
logger.info(infoMsg)
if conf.excludeSysDbs:
exclDbsQuery = "".join(" AND '%s' != %s" % (db, dbCond) for db in self.excludeDbsList)
exclDbsQuery = "".join(" AND '%s' != %s" % (self.__unsafeSQLIdentificatorNaming(db), dbCond) for db in self.excludeDbsList)
infoMsg = "skipping system databases '%s'" % ", ".join(db for db in self.excludeDbsList)
logger.info(infoMsg)
else:
@@ -1737,6 +1747,9 @@ class Enumeration:
values = [ values ]
for foundDb, foundTbl in values:
foundDb = self.__safeSQLIdentificatorNaming(foundDb)
foundTbl = self.__safeSQLIdentificatorNaming(foundTbl, True)
if foundDb is None or foundTbl is None:
continue
@@ -1748,7 +1761,7 @@ class Enumeration:
infoMsg = "fetching number of databases with table"
if tblConsider == "1":
infoMsg += "s like"
infoMsg += " '%s'" % tbl
infoMsg += " '%s'" % self.__unsafeSQLIdentificatorNaming(tbl)
logger.info(infoMsg)
query = rootQuery.blind.count
@@ -1760,7 +1773,7 @@ class Enumeration:
warnMsg = "no databases have table"
if tblConsider == "1":
warnMsg += "s like"
warnMsg += " '%s'" % tbl
warnMsg += " '%s'" % self.__unsafeSQLIdentificatorNaming(tbl)
logger.warn(warnMsg)
continue
@@ -1773,6 +1786,7 @@ class Enumeration:
query += exclDbsQuery
query = agent.limitQuery(index, query)
foundDb = inject.getValue(query, inband=False, error=False)
foundDb = self.__safeSQLIdentificatorNaming(foundDb)
if foundDb not in foundTbls:
foundTbls[foundDb] = []
@@ -1784,14 +1798,16 @@ class Enumeration:
continue
for db in foundTbls.keys():
db = self.__safeSQLIdentificatorNaming(db)
infoMsg = "fetching number of table"
if tblConsider == "1":
infoMsg += "s like"
infoMsg += " '%s' in database '%s'" % (tbl, db)
infoMsg += " '%s' in database '%s'" % (self.__unsafeSQLIdentificatorNaming(tbl), db)
logger.info(infoMsg)
query = rootQuery.blind.count2
query = query % db
query = query % self.__unsafeSQLIdentificatorNaming(db)
query += " AND %s" % tblQuery
count = inject.getValue(query, inband=False, error=False, expected=EXPECTED.INT, charsetType=2)
@@ -1799,7 +1815,7 @@ class Enumeration:
warnMsg = "no table"
if tblConsider == "1":
warnMsg += "s like"
warnMsg += " '%s' " % tbl
warnMsg += " '%s' " % self.__unsafeSQLIdentificatorNaming(tbl)
warnMsg += "in database '%s'" % db
logger.warn(warnMsg)
@@ -1809,11 +1825,12 @@ class Enumeration:
for index in indexRange:
query = rootQuery.blind.query2
query = query % db
query = query % self.__unsafeSQLIdentificatorNaming(db)
query += " AND %s" % tblQuery
query = agent.limitQuery(index, query)
foundTbl = inject.getValue(query, inband=False, error=False)
kb.hintValue = foundTbl
foundTbl = self.__safeSQLIdentificatorNaming(foundTbl, True)
foundTbls[db].append(foundTbl)
return foundTbls
@@ -1862,10 +1879,12 @@ class Enumeration:
colConsider, colCondParam = self.likeOrExact("column")
for column in colList:
column = self.__safeSQLIdentificatorNaming(column)
infoMsg = "searching column"
if colConsider == "1":
infoMsg += "s like"
infoMsg += " '%s'" % column
infoMsg += " '%s'" % self.__unsafeSQLIdentificatorNaming(column)
logger.info(infoMsg)
foundCols[column] = {}
@@ -1878,7 +1897,7 @@ class Enumeration:
exclDbsQuery = ""
colQuery = "%s%s" % (colCond, colCondParam)
colQuery = colQuery % column
colQuery = colQuery % self.__unsafeSQLIdentificatorNaming(column)
if isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION) or isTechniqueAvailable(PAYLOAD.TECHNIQUE.ERROR) or conf.direct:
query = rootQuery.inband.query
@@ -1891,6 +1910,9 @@ class Enumeration:
values = [ values ]
for foundDb, foundTbl in values:
foundDb = self.__safeSQLIdentificatorNaming(foundDb)
foundTbl = self.__safeSQLIdentificatorNaming(foundTbl, True)
if foundDb is None or foundTbl is None:
continue
@@ -1945,6 +1967,7 @@ class Enumeration:
query += exclDbsQuery
query = agent.limitQuery(index, query)
db = inject.getValue(query, inband=False, error=False)
db = self.__safeSQLIdentificatorNaming(db)
if db not in dbs:
dbs[db] = {}
@@ -1957,10 +1980,12 @@ class Enumeration:
colQuery = colQuery % column
for db in dbData:
db = self.__safeSQLIdentificatorNaming(db)
infoMsg = "fetching number of tables containing column"
if colConsider == "1":
infoMsg += "s like"
infoMsg += " '%s' in database '%s'" % (column, db)
infoMsg += " '%s' in database '%s'" % (self.__unsafeSQLIdentificatorNaming(column), db)
logger.info(infoMsg)
query = rootQuery.blind.count2
@@ -1988,6 +2013,8 @@ class Enumeration:
tbl = inject.getValue(query, inband=False, error=False)
kb.hintValue = tbl
tbl = self.__safeSQLIdentificatorNaming(tbl, True)
if tbl not in dbs[db]:
dbs[db][tbl] = {}