#!/usr/bin/env python


import logging

from lib.algorithm import Algorithm


version = "0.4-rc16"


class Common(Algorithm):
    """
    This class defines common methods used in many other parts
    of sqlmap

    @author: Bernardo Damele
    """

    def __init__(self):
        self.logger = logging.getLogger("sqlmapLog")


    def injectionStm(self, string, singleString=False, doubleString=False):
        """
        This method defines how the input string has to be escaped
        to perform the injection depending on the injection method
        identified as valid
        """

        if self.args.injectionMethod == "numeric":
            stm = "%s" % string
        elif self.args.injectionMethod == "stringsingle":
            if singleString:
                stm = "'%s" % singleString
            else:
                stm = "'%s" % string
        elif self.args.injectionMethod == "stringdouble":
            if doubleString:
                stm = "\"%s" % doubleString
            else:
                stm = "\"%s" % string

        return stm


    def urlReplace(self, parameter="", value="", newValue=""):
        """
        This method replaces the affected url parameter with the SQL
        injection statement to request
        """

        if not self.args.injParameter:
            return self.args.url.replace("%s=%s" % (parameter, value),
                                         "%s=%s" % (parameter, newValue))
        else:
            value = self.args.parameters[self.args.injParameter]
            return self.args.url.replace("%s=%s" % (self.args.injParameter, value),
                                         "%s=%s" % (self.args.injParameter, value + newValue))


    def roughParameters(self):
        """
        This method checks if the target has parameters then split the
        url from its parameters
        """

        if "?" in self.args.url:
            url, parameters = self.args.url.split("?")
        elif self.args.googleDork:
            warnMsg  = "the target url has not parameters "
            warnMsg += "so it is not possible to test SQL "
            warnMsg += "injection on it, skipping to next url"
            self.logger.warn(warnMsg)

            return None, None
        else:
            errMsg  = "you did not provide the parameters "
            errMsg += "in the target url"
            raise Exception, errMsg

        return url, parameters


    def paramDict(self, parameters):
        """
        This method split the parameter names from their value and save
        the data into a dictionary
        """

        testableParams = {}

        for element in parameters.split("&"):
            elem = element.split("=")
            if len(elem) == 2:
                parameter = elem[0]

                condition  = not self.args.urlParameter
                condition |= parameter in self.args.urlParameter

                if condition:
                    value = elem[1]
                    if value:
                        testableParams[parameter] = value

        if self.args.urlParameter and not testableParams:
            if len(self.args.urlParameter) > 1:
                warnMsg  = "the testable parameters you provided "
                warnMsg += "are not into the target url"
            else:
                parameter = self.args.urlParameter[0]

                warnMsg  = "the testable parameter '%s' " % parameter
                warnMsg += "you provided is not into the target url"

            if self.args.googleDork:
                warnMsg += ", skipping to next url"
                self.logger.warn(warnMsg)
            else:
                raise Exception, warnMsg
        elif len(self.args.urlParameter) != len(testableParams.keys()):
            for parameter in self.args.urlParameter:
                if parameter not in testableParams.keys():
                    warnMsg =  "the testable parameter '%s' " % parameter
                    warnMsg += "you provided is not into the target url"
                    self.logger.warn(warnMsg)

        return testableParams


    def parseFp(self, dbms, fingerprint):
        """
        This method format the remote DBMS fingerprint output
        """

        fp = dbms

        if len(fingerprint) == 0:
            return "%s" % fp
        elif len(fingerprint) == 1:
            return "%s %s" % (fp, fingerprint[0])
        else:
            for value in fingerprint:
                fp += " %s and" % value

            return fp[:-4]

