mirror of
https://github.com/sqlmapproject/sqlmap.git
synced 2025-12-07 21:21:33 +00:00
refactoring and stabilization of multithreading
This commit is contained in:
@@ -20,6 +20,9 @@ import threading
|
||||
import urllib2
|
||||
import urlparse
|
||||
|
||||
import lib.core.common
|
||||
import lib.core.threads
|
||||
|
||||
from extra.clientform.clientform import ParseResponse
|
||||
from extra.clientform.clientform import ParseError
|
||||
from extra.keepalive import keepalive
|
||||
@@ -109,6 +112,7 @@ from lib.request.basicauthhandler import SmartHTTPBasicAuthHandler
|
||||
from lib.request.certhandler import HTTPSCertAuthHandler
|
||||
from lib.request.rangehandler import HTTPRangeHandler
|
||||
from lib.request.redirecthandler import SmartRedirectHandler
|
||||
from lib.request.templates import getPageTemplate
|
||||
from lib.utils.google import Google
|
||||
|
||||
authHandler = urllib2.BaseHandler()
|
||||
@@ -1360,8 +1364,10 @@ def __setKnowledgeBaseAttributes(flushAll=True):
|
||||
kb.locks.cacheLock = threading.Lock()
|
||||
kb.locks.logLock = threading.Lock()
|
||||
kb.locks.ioLock = threading.Lock()
|
||||
kb.locks.countLock = threading.Lock()
|
||||
|
||||
kb.matchRatio = None
|
||||
kb.multiThreadMode = False
|
||||
kb.nullConnection = None
|
||||
kb.pageTemplate = None
|
||||
kb.pageTemplates = dict()
|
||||
@@ -1701,6 +1707,10 @@ def __basicOptionValidation():
|
||||
errMsg += "to get the full list of supported charsets"
|
||||
raise sqlmapSyntaxException, errMsg
|
||||
|
||||
def __resolveCrossReferences():
|
||||
lib.core.threads.readInput = readInput
|
||||
lib.core.common.getPageTemplate = getPageTemplate
|
||||
|
||||
def init(inputOptions=advancedDict(), overrideOptions=False):
|
||||
"""
|
||||
Set attributes into both configuration and knowledge base singletons
|
||||
@@ -1720,6 +1730,7 @@ def init(inputOptions=advancedDict(), overrideOptions=False):
|
||||
__setMultipleTargets()
|
||||
__setTamperingFunctions()
|
||||
__setTrafficOutputFP()
|
||||
__resolveCrossReferences()
|
||||
|
||||
parseTargetUrl()
|
||||
parseTargetDirect()
|
||||
|
||||
@@ -14,6 +14,7 @@ from lib.core.data import kb
|
||||
from lib.core.data import logger
|
||||
from lib.core.datatype import advancedDict
|
||||
from lib.core.exception import sqlmapThreadException
|
||||
from lib.core.settings import MAX_NUMBER_OF_THREADS
|
||||
|
||||
shared = advancedDict()
|
||||
|
||||
@@ -39,6 +40,9 @@ class ThreadData():
|
||||
def getCurrentThreadUID():
|
||||
return hash(threading.currentThread())
|
||||
|
||||
def readInput(message, default=None):
|
||||
pass
|
||||
|
||||
def getCurrentThreadData():
|
||||
"""
|
||||
Returns current thread's dependent data
|
||||
@@ -49,12 +53,40 @@ def getCurrentThreadData():
|
||||
kb.threadData[threadUID] = ThreadData()
|
||||
return kb.threadData[threadUID]
|
||||
|
||||
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True):
|
||||
def exceptionHandledFunction(threadFunction):
|
||||
try:
|
||||
threadFunction()
|
||||
except KeyboardInterrupt:
|
||||
kb.threadContinue = False
|
||||
kb.threadException = True
|
||||
raise
|
||||
except:
|
||||
kb.threadContinue = False
|
||||
kb.threadException = True
|
||||
|
||||
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False):
|
||||
threads = []
|
||||
|
||||
kb.multiThreadMode = True
|
||||
kb.threadContinue = True
|
||||
kb.threadException = False
|
||||
|
||||
if threadChoice and numThreads == 1:
|
||||
while True:
|
||||
message = "please enter number of threads? [Enter for %d (current)] " % numThreads
|
||||
choice = readInput(message, default=str(numThreads))
|
||||
if choice and choice.isdigit():
|
||||
if int(choice) > MAX_NUMBER_OF_THREADS:
|
||||
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
|
||||
logger.critical(errMsg)
|
||||
else:
|
||||
numThreads = int(choice)
|
||||
break
|
||||
|
||||
if numThreads == 1:
|
||||
warnMsg = "running in a single-thread mode. This could take a while."
|
||||
logger.warn(warnMsg)
|
||||
|
||||
if numThreads > 1:
|
||||
infoMsg = "starting %d threads" % numThreads
|
||||
logger.info(infoMsg)
|
||||
@@ -64,7 +96,7 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
|
||||
|
||||
# Start the threads
|
||||
for numThread in range(numThreads):
|
||||
thread = threading.Thread(target=threadFunction, name=str(numThread))
|
||||
thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction])
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
@@ -98,6 +130,8 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
|
||||
raise
|
||||
|
||||
finally:
|
||||
kb.multiThreadMode = False
|
||||
kb.bruteMode = False
|
||||
kb.threadContinue = True
|
||||
kb.threadException = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user