#!/usr/bin/env python


import logging
import os
import random
import re
import time
import urllib2

from lib.common import version
from lib.common import Common
from lib.google import Google


class Option(Common):
    """
    This class defines methods to check, parse and set options
    based upon command line parameters values and extra file

    @author: Bernardo Damele
    """

    def __init__(self):
        self.logger = logging.getLogger("sqlmapLog")


    def __urllib2Opener(self):
        """
        This method creates and installs the urllib2 OpenerDirector
        globally
        """

        opener = urllib2.build_opener(self.proxyHandler, self.authHandler)
        opener.addheaders = self.httpHeaders
        urllib2.install_opener(opener)


    def __setGoogleDorking(self):
        """
        This method checks and set the Google dorking functionality
        """

        logMsg = "first request to Google to get the session cookie"
        self.logger.info(logMsg)

        googleObj = Google(self.proxyHandler)
        googleCookie = googleObj.getCookie()

        if googleCookie:
            googleCookie = googleCookie.groups()[0]

            logMsg = "sqlmap got '%s' as session cookie" % googleCookie
            self.logger.info(logMsg)
        else:
            errMsg  = "unable to parse Google response header "
            errMsg += "'set-cookie' to get the session cookie"
            raise Exception, errMsg

        matches = googleObj.search(self.args.googleDork)

        if not matches:
            errMsg  = "unable to find results for your "
            errMsg += "Google dork expression"
            raise Exception, errMsg

        self.args.testableHosts = googleObj.getTestableHosts()

        if self.args.testableHosts:
            logMsg  = "sqlmap got %d results for your " % len(matches)
            logMsg += "Google dork expression, "

            if len(matches) == len(self.args.testableHosts):
                logMsg += "all"
            else:
                logMsg += "%d " % len(self.args.testableHosts)

            logMsg += "of them are testable hosts"
            self.logger.info(logMsg)
        else:
            errMsg  = "sqlmap got %d results " % len(matches)
            errMsg += "for your Google dork expression, but none "
            errMsg += "of them has parameters to test for SQL "
            errMsg += "injection"
            raise Exception, errMsg


    def __setOutputResume(self):
        """
        This method checks and set the output text file and '--resume'
        command line option
        """

        if self.args.resume and not self.args.outputFile:
            warnMsg  = "you did not provide the text file to "
            warnMsg += "resume queries output from"
            self.logger.warn(warnMsg)
        elif self.args.outputFile:
            if os.path.exists(self.args.outputFile):
                self.args.writeFile = file(self.args.outputFile, "r")

                for line in self.args.writeFile.readlines():
                    if line.count("][") == 2:
                        line = line.split("][")

                        if len(line) != 3:
                            continue

                        url, expression, value = line

                        if not value:
                            continue

                        if value[-1] == "\n":
                            value = value[:-1]

                        if url not in self.args.resumedQueries.keys():
                            self.args.resumedQueries[url] = {}
                            self.args.resumedQueries[url][expression] = value
                        else:
                            if expression not in self.args.resumedQueries[url].keys():
                                self.args.resumedQueries[url][expression] = value
                            elif len(value) >= len(self.args.resumedQueries[url][expression]):
                                self.args.resumedQueries[url][expression] = value

                self.args.writeFile.close()

            try:
                self.args.writeFile = file(self.args.outputFile, "a")
                self.args.writeFile.write("\n[%s]\n" % time.strftime("%X %x"))
                self.args.writeFile.flush()
            except:
                pass


    def __setRemoteDBMS(self):
        """
        This method checks the command line argument '--remote-dbms'
        value
        """

        self.args.dbms = self.args.dbms.lower()
        self.args.MySQLVer = re.search("mysql ([\d\.]+)", self.args.dbms)

        if self.args.MySQLVer:
            self.args.dbms = "mysql"
            self.args.MySQLVer = self.args.MySQLVer.groups()[0]
        else:
            self.args.MySQLVer = None

        if self.args.dbms not in ( "mysql", "postgresql", "microsoft sql server" ):
            errMsg  = "you provided an unsupported remote database "
            errMsg += "management system. The supported DBMS are "
            errMsg += "'MySQL', 'PostgreSQL' and 'Microsoft SQL "
            errMsg += "Server'. If you do not know the remote "
            errMsg += "DBMS, do not provide it and sqlmap will "
            errMsg += "fingerprint it for you"
            raise Exception, errMsg


    def __setHTTPProxy(self):
        """
        This method defines the HTTP proxy to pass by all HTTP requests
        """

        if not re.search("http://[\w\:\/\.\-\_]+\:[\d]+", self.args.proxy, re.I):
            errMsg  = "proxy value must be in format "
            errMsg += "'http://url:port'"
            raise Exception, errMsg
        else:
            proxyRegExp = re.search("([\w\:\/\.\-\_\:]+)\:([\d]+)", self.args.proxy, re.I)
            proxyUrl = proxyRegExp.groups()[0]
            proxyPort = proxyRegExp.groups()[1]
            self.proxyHandler = urllib2.ProxyHandler({"http": "%s:%s" % (proxyUrl, proxyPort)})


    def __setHTTPCookies(self):
        """
        This method set the HTTP Cookie header
        """

        self.httpHeaders.append(("Connection", "Keep-Alive"))
        self.httpHeaders.append(("Cookie", self.args.cookie))


    def __setHTTPPasswordMgr(self, link, authRegExp):
        """
        This method initializes and returns an urllib2 HTTP
        authentication password manager
        """

        authUsername = authRegExp.groups()[0]
        authPassword = authRegExp.groups()[1]

        passwordMgr = urllib2.HTTPPasswordMgrWithDefaultRealm()
        passwordMgr.add_password(None, link, authUsername, authPassword)

        return passwordMgr


    def __setHTTPAuthentication(self):
        """
        This method checks and set the HTTP authentication method
        (Basic or Digest), username and password
        """

        if self.args.bAuth:
            bAuthRegExp = re.search("(.*)\:(.*)", self.args.bAuth)
        elif self.args.dAuth:
            dAuthRegExp = re.search("(.*)\:(.*)", self.args.dAuth)

        if ( self.args.bAuth and not bAuthRegExp ) or ( self.args.dAuth and not dAuthRegExp ):
            errMsg  = "HTTP Authentication value must be "
            errMsg += "in format 'username:password'"
            raise Exception, errMsg
        elif self.args.bAuth and bAuthRegExp:
            passwordMgr = self.__setHTTPPasswordMgr(link, bAuthRegExp)
            self.authHandler = urllib2.HTTPBasicAuthHandler(passwordMgr)
        elif self.args.dAuth and dAuthRegExp:
            passwordMgr = self.__setHTTPPasswordMgr(link, dAuthRegExp)
            self.authHandler = urllib2.HTTPDigestAuthHandler(passwordMgr)


    def __setURLParameters(self):
        """
        This method checks, set the url parameters and performs checks
        on 'data' command line parameter value for HTTP method POST
        """

        if self.args.url and self.args.httpMethod == "GET":
            url, parameters = self.roughParameters()
        elif self.args.url and self.args.httpMethod == "POST":
            # Perform checks on '--data' parameter value
            if self.args.data:
                parameters = self.args.data
            else:
                errMsg = "HTTP POST method depends on '--data' value"
                raise Exception, errMsg

        # Perform checks on testable parameters (-p)
        if self.args.urlParameter:
            self.args.urlParameter = self.args.urlParameter.split(",")
        else:
            self.args.urlParameter = []

        # Set the dictionary of parameters to perform SQL injection
        # in the format:
        #     { 'paramName': 'paramValue' }
        if self.args.url:
            self.args.parameters = self.paramDict(parameters)

        # Finalize target url value
        if self.args.url and self.args.httpMethod == "POST":
            if "?" in self.args.url:
                self.args.url = self.args.url.split("?")[0]
            self.args.url += "?" + parameters


    def __setHTTPMethod(self):
        """
        This method checks and set the HTTP method to perform HTTP
        requests with
        """

        self.args.httpMethod = self.args.httpMethod.upper()

        if self.args.httpMethod not in ("GET", "POST"):
            warnMsg  = "'%s' " % self.args.httpMethod
            warnMsg += "is an unsupported HTTP method, "
            warnMsg += "setting to default method, GET"
            self.logger.warn(warnMsg)

            self.args.httpMethod = "GET"


    def __defaultHTTPUserAgent(self):
        """
        This method returns the default sqlmap HTTP User-Agent
        header string
        """

        return "sqlmap/%s (http://sqlmap.sourceforge.net)" % version


    def __setHTTPUserAgent(self):
        """
        This method extracts a random HTTP User-Agent header string
        from a text file
        """

        logMsg  = "fetching random HTTP User-Agent header from "
        logMsg += "file '%s'" % self.args.userAgentsFile
        self.logger.info(logMsg)

        try:
            fd = open(self.args.userAgentsFile)
        except:
            warnMsg  = "unable to read HTTP User-Agent header "
            warnMsg += "file '%s'" % self.args.userAgentsFile
            self.logger.warn(warnMsg)

            self.httpHeaders.append(("User-Agent", self.__defaultHTTPUserAgent()))

            return

        count = 0
        userAgents = []

        while True:
            line = fd.readline()
            if not line:
                break
            userAgents.append(line)
            count += 1

        fd.close()

        userAgent = userAgents[random.randint(1, count - 1)]
        userAgent = userAgent.replace("\n", "").replace("\r", "")
        self.httpHeaders.append(("User-Agent", userAgent))

        logMsg  = "fetched random HTTP User-Agent header from "
        logMsg += "file '%s'" % self.args.userAgentsFile
        self.logger.info(logMsg)


    def __setURL(self):
        """
        This method checks and set the target url and port
        """

        if not re.search("^http[s]*://[\w\.\-\_]+", self.args.url, re.I):
            port = re.search("^[\w\.\-\_]+:([\d]+)", self.args.url)

            if port:
                if int(port.groups()[0]) == 443:
                    self.args.url = "https://" + self.args.url
                else:
                    self.args.url = "http://" + self.args.url
            else:
                self.args.url = "http://" + self.args.url

        link = re.search("^http[s]*://([\w\.\-\_]+)", self.args.url, re.I).groups()[0]
        transport = re.search("^(http[s]*)://[\w\.\-\_]+", self.args.url, re.I)
        port = re.search("^http[s]*://[\w\.\-\_]+:([\d]+)", self.args.url, re.I)

        if transport:
            transport = transport.groups()[0]
        elif port:
            if int(port.groups()[0]) == 443:
                transport = "https"
            else:
                transport = "http"
        else:
            transport = "http"

        if port:
            self.args.port = int(port.groups()[0])
        elif transport == "http":
            self.args.port = 80
        else:
            self.args.port = 443


    def __setVerbose(self):
        """
        This method set the verbosity of sqlmap output messages
        """

        verbose = int(self.args.verbose)

        if verbose == 1:
            self.logger.setLevel(logging.INFO)
        elif verbose > 1:
            self.logger.setLevel(logging.DEBUG)


    def __effectiveRun(self):
        """
        This method checks, parses and set options based upon command
        line parameters values and extra file
        """

        if self.args.verbose:
            self.__setVerbose()

        if self.args.url:
            self.__setURL()

        if self.args.userAgentsFile:
            self.__setHTTPUserAgent()
        else:
            self.httpHeaders.append(("User-Agent", self.__defaultHTTPUserAgent()))

        if self.args.httpMethod:
            self.__setHTTPMethod()
        else:
            self.args.httpMethod = "GET"

        self.__setURLParameters()

        if self.args.bAuth or self.args.dAuth:
            self.__setHTTPAuthentication()
        else:
            self.authHandler = ""

        if self.args.cookie:
            self.__setHTTPCookies()

        if self.args.proxy:
            self.__setHTTPProxy()
        else:
            self.proxyHandler = ""

        if self.args.dbms:
            self.__setRemoteDBMS()

        self.__setOutputResume()

        if self.args.googleDork:
            self.__setGoogleDorking()


    def __setVariables(self):
        """
        This method set some needed variables
        """

        self.args.fingerprint = None
        self.args.parameters = {}
        self.args.resumedQueries = {}
        self.args.unionComment = None
        self.args.unionCount = None
        self.args.writeFile = None
        self.httpHeaders = []


    def run(self, shellArgs):
        """
        This method is the core of this class and call three other
        main methods which perform the work
        """

        self.args = shellArgs

        self.__setVariables()
        self.__effectiveRun()
        self.__urllib2Opener()

        return self.args

