#pragma once
//////////////////////////////////////////////////////////////////////
// Copyright (c) 2010, Oliver 'kfs1' Smith <oliver@kfs.org>
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// - Neither the name of KingFisher Software nor the names of its contributors
// may be used to endorse or promote products derived from this software
// without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
// TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//////////////////////////////////////////////////////////////////////
//
// Abstraction layer for interacting with database.
// Currently provides APIs for MySQL and SQLite3.
//
// Components:
//
// DBA
// {
//   Connection :- Abstract base class encapsulating a connection
//                 to a database server.
//   MySQLConnection :- MySQL specific API
//   SQLiteConnection :- SQLite specific API
//
//   ResultSet :- Interface class that provides access to the
//                results of a query from a database, and
//                ensures that resources are freed when it
//                leaves scope.
// }
//
// CAUTION:
//  These interfaces do NOT provide asynchronous operation.

/*

DBA::Connection and DBA::ResultSet.

1. Errors.

    All DBA::Connection/DBA::ResultSet functions handle non-trivial errors
    by throwing an exception derived from std::exception.

2. Connecting to a database.

    a. Create an instance of one of the derived Connection classes,
    b. Construct an instance of the "DBA::Credentials" structure with your
      database connection parameters.
    c. Pass the Credentials to the Connect() member of your connection.

3. Executing SQL.

    If you are not interested in the results of a statement, you can
    call
        connection.Do(size_t length, const char* statement) ;
    If you are not interested in the success/failure of the statement,
        connection.DoQuietly(size_t length, const char* statement) ;

    In most cases, however, you will want the results from the database.
    For this you must use the interface class, DBA::ResultSet.

        DBA::ResultSet rs(connection, size_t length, const char* statement) ;
    or
        DBA::ResultSet rs(connection, "format", size_t maxStatementLength, parameters...) ;

    The second variant will perform snprintf style formatting of "format"
    into a temporary string large enough to hold "maxStatementLength" bytes.
    maxStatementLength must be at least as long as the length of format.

    If the SQL query fails or no results are returned, an exception will be
    thrown.

    CAVEAT: Result sets are stored in memory, and any locks held by the query
    may persist, until the query object destructor is called, so "static"
    DBA::ResultSet objects are strongly discouraged.

3. Result Details.

    Once you have obtained a ResultSet object from a connection, the following
    members are provided to query the status of the ResultSet:

        int resultset.Error()
            Returns a database-specific error if an error occurred.

        bool resultset.HasRows()
            'true' or 'false' as to whether any rows were returned.

        unsigned int resultset.Rows()
            MySQL: Returns the number of rows in the current result set.
            Others: Returns the current row + 1 (so that CurrentRow() < Rows()).
            Call conn.HasAccurateRowCount() to determine if your DB has an exact
            row count.

        unsigned int resultset.Cols()
            Returns the number of columns-per-row in the current result set,
            if available.

        unsigned int resultset.CurrentRow()
            Returns the current row index within the result set.

    If the query was a data-modifying statement (INSERT, UPDATE, etc)

        unsigned int resultset.AffectedRows()
            Returns the number of rows affected by the statement.

        unsigned int resultset.LastInsertID()
            If the statement generated a new automatic row ID, returns
            the id. Zero if no id was generated or the database does not
            support automatic automatic ids.

    To clear the current result set and perform another query with the
    same ResultSet instance, use

        void resultset.Do(int statementLength, const char* statement) ;
    or
        void resultset.Do(const char* format, size_t maxStatementLen, ...) ;

4. Retrieving rows of data.

    After executing a query statement (SELECT, CALL etc) and obtaining
    a valid ResultSet object, you will be able to retrieve the data
    from the SQL server row-at-a-time.

        bool resultset.FetchRow()
            Advances to the next row in the result set and returns true.
            Returns false when the last row has been retrieved.

5. Retrieving values (columns).

    You can retrieve the columns of a row as strings with hand-indexing
    by using
        const char* resultset.operator[](index)
    e.g.
        const char* id = resultset[0] ;

    Alternatively, the ResultSet class provides operator>> for most
    base types: const char*, const unsigned char*, char, unsigned char,
    short, unsigned short, int, unsigned int, long long, unsigned long long,
    float and double.

    e.g.
        DBA::ResultSet rs(connection, "SELECT playerID, custID, name, last_login FROM players") ;
        int playerID, customerID ;
        const char *name, *lastLogin ;
        while ( rs.FetchRow() )
        {
            rs >> playerID >> customerID >> name >> lastLogin ;
            printf("%u: player#%u cust#%u %s (%s)\n", rs.CurrentRow(), playerID, customerID, name, lastLogin) ;
        }

    CAVEAT: operator[] does /not/ interact with the column cursor.

6. Examples

    #include <stdexcept>
    #define DBA_ENABLE_MYSQL
    #include "dbaConn.h"

    // Declares the database connection as a static global.
    static DBA::MySQLConnection dbc ;

    ...

    try
    {
        // We want the 'test' database on localhost (127.0.0.1) and use the default port (0)
        DBA::Credentials creds = { "127.0.0.1", 0, "test", "testuser", "testpass" } ;

        // Connect with the above credentials, report errors and enable autoreconnect.
        dbc.Connect(creds, true, true) ;

        // Select the current database time. If the SQL is
        // invalid or the database handle is invalid or there
        // is a problem with the query, an exception will be
        // thrown.
        char select[] = "SELECT NOW()" ;
        DBA::ResultSet rs(dbc, sizeof(select) - 1, select) ;

        // We're pretty sure this should return a row.
        if ( rs.HasRows() == false )
            throw std::logic_error("Database returned nothing");

        // We must retrieve the row
        if ( rs.FetchRow() == false )
            throw std::logic_error("Unable to retrieve date");

        // column 0 should be the time field.
        printf("Database time is: %s\n", rs[0]) ;

        // Now we'll re-use the result set for a second query.

        // Assuming "t_foo" is a table containing several rows,
        //  1, "one"; 2, "two", 3, "three", 4, "four"
        const char* statement = "SELECT id, label FROM t_foo";
        rs.Do(strlen(statement), statement);

        if ( rs.HasRows() == false )
            throw std::logic_error("Nothing in t_foo") ;

        UINT32 id ;
        const char* label ;
        while ( rs.FetchRow() == true )
        {
            // Retrieve the columns as native types.
            rs >> id >> label ;
            printf("Native: id %u, label %s\n", id, label);

            // Retrieve them as char*s using operator[]
            printf("char*: id %s, label %s\n", row[0], row[1]);
        }

        // When we exit scope, the ResultSet object is destructed
        // and any resources it was using are cleaned up.
    }
    catch (std::exception& e)
    {
        printf("ERROR: %s\n", e.what());
    }

    ...
*/

#include <string.h>
#include <stdexcept>
#include <string>

namespace DBA
{
    struct Credentials
    {
        // Full constructor.
        Credentials(const char* host_, unsigned int port_, const char* database_
                    , const char* username_, const char* password_
                    , const char* driver_=""
                    )
            : host(host_), port(port_), database(database_)
            , username(username_), password(password_)
            , driver(driver_)
            {}
        // Partial constructor omitting remote host/port
        Credentials(const char* database_, const char* driver_="", const char* username_="", const char* password_="")
            : host(""), port(0), database(database_)
            , username(username_), password(password_)
            , driver(driver_)
            {}

        const char* host ;      // Host on which database resides, if applicable
        unsigned int port ;     // Port to connect to, if applicable
        const char* database ;  // Name of database/file
        const char* username ;  // Username to authenticate as, if applicable
        const char* password ;  // Password to authenticate with, if applicable
        const char* driver ;    // Driver to use (e.g. vfs for SQLite), if applicable
    } ;

    // Connection abstraction
    class Connection
    {
    public:
        Connection() ;
        virtual ~Connection() ;

        // Connect to a database.
        virtual void Connect(const Credentials&, unsigned int flags=0, bool autoReconnect=false) = 0 ;
        // Disconnect from a database.
        virtual void Disconnect() = 0 ;
        // true if a connection is established.
        virtual bool IsConnected() const = 0 ;
        void CheckConnected() const { if ( IsConnected() == false ) throw std::logic_error("No database connected") ; }

        // Implementation details
        virtual bool HasAccurateRowCount() const = 0 ;

        // Ensure a string is escaped so it won't mess up your SQL.
        virtual void EscapeString(const char* src, char* dest, size_t destSize) = 0 ;

        // Execute an SQL statement but don't care about resources.
        void Do(int stlen, const char* statement)
        {
            CheckConnected() ;
            Execute(stlen, statement) ;
            ReleaseResult() ;
        }
        inline void Do(const std::string& statement)
        {
            Do(statement.length(), statement.c_str()) ;
        }

        // Execute an SQL statement but don't care about resources or errors.
        void DoQuietly(int queryLen, const char* query)
        {
            try { Do(queryLen, query) ; }
            catch (...) { }
        }
        inline void DoQuietly(const std::string& statement)
        {
            DoQuietly(statement.length(), statement.c_str()) ;
        }

        virtual const char* GetError() = 0 ;
        void ThrowError(const char* message = NULL)
        {
            std::string error = (message ? message : "") ;
            error += GetError() ;
            throw std::logic_error(error) ;
        }

    protected:
        int NextColumn()
        {
            if ( AtEndOfRow() )
                throw std::logic_error("Passed end of row") ;
            return m_currentColumn++ ;
        }

    protected:
        // Execute a static SQL statement
        void Execute(size_t stlen, const char* query) ;
        // Implementation specific execution steps
            // Execute SQL, does not try to retrieve results
            virtual void _execute(size_t stlen, const char* statement) = 0 ;
            // Retrieve/process results from execution.
            virtual void _processResultSet() = 0 ;
            // Clean up the result set
            virtual void _releaseResultSet() = 0 ;

        // Release resources used by a query.
        void ReleaseResult() ;

        // Fetch the next row.
        virtual bool FetchRow() = 0 ;

        // Some databases (e.g. MySQL) support multiple result sets being
        // returned by one query (e.g. for a "CALL" statement).
        virtual bool NextResultSet() { ReleaseResult() ; return false ; }

        virtual const char* GetColumn(unsigned int index) const = 0 ;
        const char* GetNextString()
        {
            const char* source = GetColumn(NextColumn()) ;
            return ( source != NULL ? source : "" ) ;
        }
        virtual Connection& operator >> (const void*& into) = 0 ;
        virtual Connection& operator >> (const char*& into) = 0 ;
        virtual Connection& operator >> (const unsigned char*& into) = 0 ;
        virtual Connection& operator >> (char& into) = 0 ;
        virtual Connection& operator >> (unsigned char& into) = 0 ;
        virtual Connection& operator >> (short& into) = 0 ;
        virtual Connection& operator >> (unsigned short& into) = 0 ;
        virtual Connection& operator >> (int& into) = 0 ;
        virtual Connection& operator >> (unsigned int& into) = 0 ;
        virtual Connection& operator >> (long long& into) = 0 ;
        virtual Connection& operator >> (unsigned long long& into) = 0 ;
        virtual Connection& operator >> (float& into) = 0 ;
        virtual Connection& operator >> (double& into) = 0 ;

        size_t AffectedRows() const { return m_affectedRows ; }
        bool HasRows() const { return m_rows > m_currentRow ; }
        size_t Rows() const { return m_rows ; }
        size_t CurrentRow() const { return m_currentRow ; }
        size_t Cols() const { return m_cols ; }
        size_t CurrentCol() const { return m_currentColumn ; }
        size_t LastInsertID() const { return m_lastInsertID ; }
        bool AtEndOfRow() const { return m_currentColumn >= m_cols ; }

    protected:
        // Current data
        bool        m_executed ;            // Have we executed a statement?
        int         m_error ;               // Last error

        // For DML statements.
        size_t      m_affectedRows ;
        size_t      m_lastInsertID ;

        // For queries
        size_t      m_rows ;                // Not all interfaces know this.
        size_t      m_currentRow ;          // 0-based cursor.
        size_t      m_cols ;                // Number of columns in result set.
        size_t      m_currentColumn ;       // Column cursor

        class ResultSet* m_currentResult ;  // To prevent multiple queries.
        friend class ResultSet ;
    } ;

    // The result class is merely a wrapper for the query
    // operations. Its purpose is to ensure scoping; when
    // the result object is destroyed, all of the database
    // specific functionality for releasing the results
    // is executed automatically.

    //////////////////////////////////////////////////////////////////////
    // ResultSet interface class
    // Provides scoping of resources associated with a given connection,
    // and ensures that they are freed when the object goes out of scope.
    // Constructor will fail if there are no results available on the
    // current connection.
    class ResultSet
    {
    private:
        ResultSet() {}

    public:
        // Most SQL interfaces want to know the length of the statement
        // being supplied. This requirement is preserved to make the
        // user conscious of the possible performance overheads if they
        // know the length of the string in advance.
        ResultSet(Connection& conn, size_t stLen, const char* statement) ;

        // Varadic, snprintf style constructor. maxStatementLen is the maximum
        // length of the resulting statement you wish to allow. The inclusion of
        // this parameter helps (a) efficiency, (b) disambiguate the resulting
        // function fingerprint from that of the future ResultSet(conn, statement) ;
        ResultSet(Connection& conn, const char* format, size_t maxStatementLen, ...) ;

#ifdef _CONSTRUCTOR_DELEGATION_IMPLEMENTED_
        ResultSet(Connection& conn, const char* statement)
            : ResultSet(conn, strlen(statement), statement)
            {}

        ResultSet(Connection& conn, const std::string& statement)
            : ResultSet(conn, statement.length(), statement)
            {}
#endif

        // Destructor.
        ~ResultSet() ;

        // Execute additional SQL statements:
        void Do(size_t stlen, const char* statement) ;

        // Varadic argument variation of Do.
        void Do(const char* format, size_t maxStatementLen, ...) ;

        // ResultSet is primarily an interface class: that is, it
        // provides control over access to protected members of the
        // Connection class, thus ensuring these fields can only be
        // accessed when you have properly acquired a ResultSet
        // object.

        inline size_t AffectedRows() const { return m_conn->AffectedRows() ; }
        inline bool HasRows() const { return m_conn->HasRows() ; }
        inline size_t Rows() const { return m_conn->Rows() ; }
        inline size_t CurrentRow() const { return m_conn->CurrentRow() ; }
        inline size_t Cols() const { return m_conn->Cols() ; }
        inline size_t CurrentCol() const { return m_conn->CurrentCol() ; }
        inline size_t LastInsertID() const { return m_conn->LastInsertID() ; }
        inline bool AtEndOfRow() const { return m_conn->AtEndOfRow() ; }

        inline bool FetchRow() { return m_conn->FetchRow() ; }
        inline bool NextResultSet() { return m_conn->NextResultSet() ; }

        // Retrieve columns using "operator >>", uses friend status with
        // the Connection class to access the implementation-specific
        // operator >>s.
        template<typename _Type> ResultSet& operator >> (_Type& into)
        {
#ifndef NDEBUG
            if ( m_conn->m_currentRow == 0 )
                throw std::domain_error("Attempted to access row before calling FetchRow") ;
#endif
            m_conn->operator>>(into) ;
            return *this ;
        }

        // Retrieve a pointer to the string version of the indexed column.
        inline const char* operator[](unsigned int index) const
        {
#ifndef NDEBUG
            if ( m_conn->m_currentRow == 0 )
                throw std::domain_error("Attempted to access row before calling FetchRow") ;
#endif
            return m_conn->GetColumn(index) ;
        }

        // Release the resources I'm using and free up the dbconn.
        void Release() ;

    private:
        Connection* m_conn ;
    } ;

}

// Include the implementation specific members; only those you have enabled
// will actually be compiled.
#include "dbaConnMySQL.h"       // Requires DBA_ENABLE_MYSQL
#include "dbaConnSQLite.h"      // Requires DBA_ENABLE_SQLITE

// In most cases, not definining DBA_ENABLE_xxx is probably an error,
// so we should notify the user.
// If you actually want just the core routines, define DBA_WITHOUT_DATABASES
#if !defined(_DBA_HAVE_DATABASES) && !defined(DBA_WITHOT_DATABASES)
# error("You have not defined any databases (DBA_ENABLE_MYSQL, DBA_ENABLE_SQLITE, etc) or DBA_WITHOUT_DATABASES")
#endif