feat(Core/Database): port TrinityCore database API (#5611)

This commit is contained in:
Kargatum
2021-06-22 11:21:07 +07:00
committed by GitHub
parent 2a2e54d8c5
commit 9ac6fddcae
155 changed files with 5818 additions and 4321 deletions

View File

@@ -6,6 +6,7 @@ CollectSourceFiles(
${CMAKE_CURRENT_SOURCE_DIR}
PRIVATE_SOURCES
# Exclude
${CMAKE_CURRENT_SOURCE_DIR}/Updater
${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders)
if(USE_COREPCH)
@@ -40,12 +41,11 @@ target_include_directories(database
${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(database
# PRIVATE
PRIVATE
# acore-core-interface
# mysql
mysql
PUBLIC
common
mysql)
common)
set_target_properties(database
PROPERTIES

View File

@@ -1,29 +1,30 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "AdhocStatement.h"
#include "Errors.h"
#include "MySQLConnection.h"
#include "QueryResult.h"
#include <cstdlib>
#include <cstring>
/*! Basic, ad-hoc queries. */
BasicStatementTask::BasicStatementTask(const char* sql) :
m_has_result(false)
{
m_sql = strdup(sql);
}
BasicStatementTask::BasicStatementTask(const char* sql, QueryResultFuture result) :
m_has_result(true),
m_result(result)
BasicStatementTask::BasicStatementTask(char const* sql, bool async) :
m_result(nullptr)
{
m_sql = strdup(sql);
m_has_result = async; // If the operation is async, then there's a result
if (async)
m_result = new QueryResultPromise();
}
BasicStatementTask::~BasicStatementTask()
{
free((void*)m_sql);
if (m_has_result && m_result != nullptr)
delete m_result;
}
bool BasicStatementTask::Execute()
@@ -31,14 +32,14 @@ bool BasicStatementTask::Execute()
if (m_has_result)
{
ResultSet* result = m_conn->Query(m_sql);
if (!result || !result->GetRowCount())
if (!result || !result->GetRowCount() || !result->NextRow())
{
delete result;
m_result.set(QueryResult(nullptr));
m_result->set_value(QueryResult(nullptr));
return false;
}
result->NextRow();
m_result.set(QueryResult(result));
m_result->set_value(QueryResult(result));
return true;
}

View File

@@ -1,30 +1,29 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _ADHOCSTATEMENT_H
#define _ADHOCSTATEMENT_H
#include <ace/Future.h>
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include "SQLOperation.h"
typedef ACE_Future<QueryResult> QueryResultFuture;
/*! Raw, ad-hoc query. */
class BasicStatementTask : public SQLOperation
class AC_DATABASE_API BasicStatementTask : public SQLOperation
{
public:
BasicStatementTask(const char* sql);
BasicStatementTask(const char* sql, QueryResultFuture result);
~BasicStatementTask() override;
BasicStatementTask(char const* sql, bool async = false);
~BasicStatementTask();
bool Execute() override;
QueryResultFuture GetFuture() const { return m_result->get_future(); }
private:
const char* m_sql; //- Raw query to be executed
char const* m_sql; //- Raw query to be executed
bool m_has_result;
QueryResultFuture m_result;
QueryResultPromise* m_result;
};
#endif

View File

@@ -1,10 +1,10 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "DatabaseEnv.h"
WorldDatabaseWorkerPool WorldDatabase;
CharacterDatabaseWorkerPool CharacterDatabase;
LoginDatabaseWorkerPool LoginDatabase;
DatabaseWorkerPool<WorldDatabaseConnection> WorldDatabase;
DatabaseWorkerPool<CharacterDatabaseConnection> CharacterDatabase;
DatabaseWorkerPool<LoginDatabaseConnection> LoginDatabase;

View File

@@ -1,38 +1,29 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef DATABASEENV_H
#define DATABASEENV_H
#include "Common.h"
#include "Errors.h"
#include "Log.h"
#include "DatabaseWorkerPool.h"
#include "Define.h"
#include "Implementation/CharacterDatabase.h"
#include "Implementation/LoginDatabase.h"
#include "Implementation/WorldDatabase.h"
#include "Field.h"
#include "PreparedStatement.h"
#include "QueryCallback.h"
#include "QueryResult.h"
#include "MySQLThreading.h"
#include "Transaction.h"
#define _LIKE_ "LIKE"
#define _TABLE_SIM_ "`"
#define _CONCAT3_(A, B, C) "CONCAT( " A ", " B ", " C " )"
#define _OFFSET_ "LIMIT %d, 1"
#include "LoginDatabase.h"
#include "CharacterDatabase.h"
#include "WorldDatabase.h"
/// Accessor to the world database
extern WorldDatabaseWorkerPool WorldDatabase;
AC_DATABASE_API extern DatabaseWorkerPool<WorldDatabaseConnection> WorldDatabase;
/// Accessor to the character database
extern CharacterDatabaseWorkerPool CharacterDatabase;
AC_DATABASE_API extern DatabaseWorkerPool<CharacterDatabaseConnection> CharacterDatabase;
/// Accessor to the realm/login database
extern LoginDatabaseWorkerPool LoginDatabase;
AC_DATABASE_API extern DatabaseWorkerPool<LoginDatabaseConnection> LoginDatabase;
#endif

View File

@@ -0,0 +1,82 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef DatabaseEnvFwd_h__
#define DatabaseEnvFwd_h__
#include <future>
#include <memory>
struct QueryResultFieldMetadata;
class Field;
class ResultSet;
using QueryResult = std::shared_ptr<ResultSet>;
using QueryResultFuture = std::future<QueryResult>;
using QueryResultPromise = std::promise<QueryResult>;
class CharacterDatabaseConnection;
class LoginDatabaseConnection;
class WorldDatabaseConnection;
class PreparedStatementBase;
template<typename T>
class PreparedStatement;
using CharacterDatabasePreparedStatement = PreparedStatement<CharacterDatabaseConnection>;
using LoginDatabasePreparedStatement = PreparedStatement<LoginDatabaseConnection>;
using WorldDatabasePreparedStatement = PreparedStatement<WorldDatabaseConnection>;
class PreparedResultSet;
using PreparedQueryResult = std::shared_ptr<PreparedResultSet>;
using PreparedQueryResultFuture = std::future<PreparedQueryResult>;
using PreparedQueryResultPromise = std::promise<PreparedQueryResult>;
class QueryCallback;
template<typename T>
class AsyncCallbackProcessor;
using QueryCallbackProcessor = AsyncCallbackProcessor<QueryCallback>;
class TransactionBase;
using TransactionFuture = std::future<bool>;
using TransactionPromise = std::promise<bool>;
template<typename T>
class Transaction;
class TransactionCallback;
template<typename T>
using SQLTransaction = std::shared_ptr<Transaction<T>>;
using CharacterDatabaseTransaction = SQLTransaction<CharacterDatabaseConnection>;
using LoginDatabaseTransaction = SQLTransaction<LoginDatabaseConnection>;
using WorldDatabaseTransaction = SQLTransaction<WorldDatabaseConnection>;
class SQLQueryHolderBase;
using QueryResultHolderFuture = std::future<void>;
using QueryResultHolderPromise = std::promise<void>;
template<typename T>
class SQLQueryHolder;
using CharacterDatabaseQueryHolder = SQLQueryHolder<CharacterDatabaseConnection>;
using LoginDatabaseQueryHolder = SQLQueryHolder<LoginDatabaseConnection>;
using WorldDatabaseQueryHolder = SQLQueryHolder<WorldDatabaseConnection>;
class SQLQueryHolderCallback;
// mysql
struct MySQLHandle;
struct MySQLResult;
struct MySQLField;
struct MySQLBind;
struct MySQLStmt;
#endif // DatabaseEnvFwd_h__

View File

@@ -5,12 +5,18 @@
#include "DatabaseLoader.h"
#include "Config.h"
// #include "DBUpdater.h" not implement
#include "DatabaseEnv.h"
#include "Duration.h"
#include "Log.h"
#include "Duration.h"
#include <errmsg.h>
#include <mysqld_error.h>
#include <thread>
DatabaseLoader::DatabaseLoader(std::string const& logger)
: _logger(logger) { }
template <class T>
DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::string const& name)
{
@@ -19,14 +25,15 @@ DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::st
std::string const dbString = sConfigMgr->GetOption<std::string>(name + "DatabaseInfo", "");
if (dbString.empty())
{
LOG_INFO("sql.driver", "Database %s not specified in configuration file!", name.c_str());
LOG_ERROR(_logger, "Database %s not specified in configuration file!", name.c_str());
return false;
}
uint8 const asyncThreads = sConfigMgr->GetOption<uint8>(name + "Database.WorkerThreads", 1);
if (asyncThreads < 1 || asyncThreads > 32)
{
LOG_INFO("sql.driver", "%s database: invalid number of worker threads specified. Please pick a value between 1 and 32.", name.c_str());
LOG_ERROR(_logger, "%s database: invalid number of worker threads specified. "
"Please pick a value between 1 and 32.", name.c_str());
return false;
}
@@ -39,36 +46,36 @@ DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::st
// Try reconnect
if (error == CR_CONNECTION_ERROR)
{
uint8 const ATTEMPTS = sConfigMgr->GetOption<uint8>("Database.Reconnect.Attempts", 20);
Seconds RECONNECT_SECONDS = Seconds(sConfigMgr->GetOption<uint8>("Database.Reconnect.Seconds", 15));
uint8 count = 0;
uint8 const attempts = sConfigMgr->GetOption<uint8>("Database.Reconnect.Attempts", 20);
Seconds reconnectSeconds = Seconds(sConfigMgr->GetOption<uint8>("Database.Reconnect.Seconds", 15));
uint8 reconnectCount = 0;
while (count < ATTEMPTS)
while (reconnectCount < attempts)
{
LOG_INFO("sql.driver", "> Retrying after %u seconds", static_cast<uint32>(RECONNECT_SECONDS.count()));
std::this_thread::sleep_for(RECONNECT_SECONDS);
LOG_INFO(_logger, "> Retrying after %u seconds", static_cast<uint32>(reconnectSeconds.count()));
std::this_thread::sleep_for(reconnectSeconds);
error = pool.Open();
if (error == CR_CONNECTION_ERROR)
{
count++;
reconnectCount++;
}
else
{
break;
}
}
}
// If the error wasn't handled quit
if (error)
{
LOG_ERROR("sql.driver", "DatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile for specific errors", name.c_str());
LOG_ERROR(_logger, "DatabasePool %s NOT opened. There were errors opening the MySQL connections. "
"Check your log file for specific errors", name.c_str());
return false;
}
}
// Add the close operation
_close.push([&pool]
{
@@ -78,11 +85,11 @@ DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::st
return true;
});
_prepare.push([name, &pool]() -> bool
_prepare.push([this, name, &pool]() -> bool
{
if (!pool.PrepareStatements())
{
LOG_ERROR("sql.driver", "Could not prepare statements of the %s database, see log for details.", name.c_str());
LOG_ERROR(_logger, "Could not prepare statements of the %s database, see log for details.", name.c_str());
return false;
}
@@ -129,6 +136,9 @@ bool DatabaseLoader::Process(std::queue<Predicate>& queue)
return true;
}
template DatabaseLoader& DatabaseLoader::AddDatabase<LoginDatabaseConnection>(DatabaseWorkerPool<LoginDatabaseConnection>&, std::string const&);
template DatabaseLoader& DatabaseLoader::AddDatabase<CharacterDatabaseConnection>(DatabaseWorkerPool<CharacterDatabaseConnection>&, std::string const&);
template DatabaseLoader& DatabaseLoader::AddDatabase<WorldDatabaseConnection>(DatabaseWorkerPool<WorldDatabaseConnection>&, std::string const&);
template AC_DATABASE_API
DatabaseLoader& DatabaseLoader::AddDatabase<LoginDatabaseConnection>(DatabaseWorkerPool<LoginDatabaseConnection>&, std::string const&);
template AC_DATABASE_API
DatabaseLoader& DatabaseLoader::AddDatabase<CharacterDatabaseConnection>(DatabaseWorkerPool<CharacterDatabaseConnection>&, std::string const&);
template AC_DATABASE_API
DatabaseLoader& DatabaseLoader::AddDatabase<WorldDatabaseConnection>(DatabaseWorkerPool<WorldDatabaseConnection>&, std::string const&);

View File

@@ -17,9 +17,11 @@ class DatabaseWorkerPool;
// A helper class to initiate all database worker pools,
// handles updating, delays preparing of statements and cleans up on failure.
class DatabaseLoader
class AC_DATABASE_API DatabaseLoader
{
public:
DatabaseLoader(std::string const& logger);
// Register a database to the loader (lazy implemented)
template <class T>
DatabaseLoader& AddDatabase(DatabaseWorkerPool<T>& pool, std::string const& name);
@@ -49,6 +51,8 @@ private:
// Returns false when there was an error.
bool Process(std::queue<Predicate>& queue);
std::string const _logger;
std::queue<Predicate> _open, _prepare;
std::stack<Closer> _close;
};

View File

@@ -1,39 +1,46 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "DatabaseWorker.h"
#include "ProducerConsumerQueue.h"
#include "SQLOperation.h"
#include "MySQLConnection.h"
#include "MySQLThreading.h"
DatabaseWorker::DatabaseWorker(ACE_Activation_Queue* new_queue, MySQLConnection* con) :
m_queue(new_queue),
m_conn(con)
DatabaseWorker::DatabaseWorker(ProducerConsumerQueue<SQLOperation*>* newQueue, MySQLConnection* connection)
{
/// Assign thread to task
activate();
_connection = connection;
_queue = newQueue;
_cancelationToken = false;
_workerThread = std::thread(&DatabaseWorker::WorkerThread, this);
}
int DatabaseWorker::svc()
DatabaseWorker::~DatabaseWorker()
{
if (!m_queue)
return -1;
_cancelationToken = true;
SQLOperation* request = nullptr;
while (1)
_queue->Cancel();
_workerThread.join();
}
void DatabaseWorker::WorkerThread()
{
if (!_queue)
return;
for (;;)
{
request = (SQLOperation*)(m_queue->dequeue());
if (!request)
break;
SQLOperation* operation = nullptr;
request->SetConnection(m_conn);
request->call();
_queue->WaitAndPop(operation);
delete request;
if (_cancelationToken || !operation)
return;
operation->SetConnection(_connection);
operation->call();
delete operation;
}
return 0;
}

View File

@@ -1,30 +1,38 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _WORKERTHREAD_H
#define _WORKERTHREAD_H
#include <ace/Task.h>
#include <ace/Activation_Queue.h>
#include "Define.h"
#include <atomic>
#include <thread>
template <typename T>
class ProducerConsumerQueue;
class MySQLConnection;
class SQLOperation;
class DatabaseWorker : protected ACE_Task_Base
class AC_DATABASE_API DatabaseWorker
{
public:
DatabaseWorker(ACE_Activation_Queue* new_queue, MySQLConnection* con);
///- Inherited from ACE_Task_Base
int svc() override;
int wait() override { return ACE_Task_Base::wait(); }
DatabaseWorker(ProducerConsumerQueue<SQLOperation*>* newQueue, MySQLConnection* connection);
~DatabaseWorker();
private:
DatabaseWorker() : ACE_Task_Base() { }
ACE_Activation_Queue* m_queue;
MySQLConnection* m_conn;
ProducerConsumerQueue<SQLOperation*>* _queue;
MySQLConnection* _connection;
void WorkerThread();
std::thread _workerThread;
std::atomic<bool> _cancelationToken;
DatabaseWorker(DatabaseWorker const& right) = delete;
DatabaseWorker& operator=(DatabaseWorker const& right) = delete;
};
#endif

View File

@@ -1,33 +1,65 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "DatabaseWorkerPool.h"
#include "DatabaseEnv.h"
#include "AdhocStatement.h"
#include "Common.h"
#include "Errors.h"
#include "Implementation/CharacterDatabase.h"
#include "Implementation/LoginDatabase.h"
#include "Implementation/WorldDatabase.h"
#include "Log.h"
#include "MySQLPreparedStatement.h"
#include "MySQLWorkaround.h"
#include "PreparedStatement.h"
#include "ProducerConsumerQueue.h"
#include "QueryCallback.h"
#include "QueryHolder.h"
#include "QueryResult.h"
#include "SQLOperation.h"
#include "Transaction.h"
#include <mysqld_error.h>
#ifdef ACORE_DEBUG
#include <boost/stacktrace.hpp>
#include <sstream>
#endif
#define MIN_MYSQL_SERVER_VERSION 50700u
#define MIN_MYSQL_CLIENT_VERSION 50700u
template <class T> DatabaseWorkerPool<T>::DatabaseWorkerPool() :
_mqueue(new ACE_Message_Queue<ACE_SYNCH>(2 * 1024 * 1024, 2 * 1024 * 1024)),
_queue(new ACE_Activation_Queue(_mqueue)),
_async_threads(0),
_synch_threads(0)
class PingOperation : public SQLOperation
{
memset(_connectionCount, 0, sizeof(_connectionCount));
_connections.resize(IDX_SIZE);
//! Operation for idle delaythreads
bool Execute() override
{
m_conn->Ping();
return true;
}
};
template <class T>
DatabaseWorkerPool<T>::DatabaseWorkerPool()
: _queue(new ProducerConsumerQueue<SQLOperation*>()),
_async_threads(0), _synch_threads(0)
{
WPFatal(mysql_thread_safe(), "Used MySQL library isn't thread-safe.");
WPFatal(mysql_get_client_version() >= MIN_MYSQL_CLIENT_VERSION, "AzerothCore does not support MySQL versions below 5.7");
WPFatal(mysql_get_client_version() == MYSQL_VERSION_ID, "Used MySQL library version (%s id %lu) does not match the version id used to compile AzerothCore (id %u)",
mysql_get_client_info(), mysql_get_client_version(), MYSQL_VERSION_ID);
}
template <class T>
DatabaseWorkerPool<T>::~DatabaseWorkerPool()
{
_queue->Cancel();
}
template <class T>
void DatabaseWorkerPool<T>::SetConnectionInfo(std::string const& infoString,
uint8 const asyncThreads, uint8 const synchThreads)
uint8 const asyncThreads, uint8 const synchThreads)
{
_connectionInfo = std::make_unique<MySQLConnectionInfo>(infoString);
@@ -40,22 +72,22 @@ uint32 DatabaseWorkerPool<T>::Open()
{
WPFatal(_connectionInfo.get(), "Connection info was not set!");
LOG_INFO("sql.driver", "Opening DatabasePool '%s'. Asynchronous connections: %u, synchronous connections: %u.",
LOG_INFO("sql.driver", "Opening DatabasePool '%s'. "
"Asynchronous connections: %u, synchronous connections: %u.",
GetDatabaseName(), _async_threads, _synch_threads);
uint32 error = OpenConnections(IDX_ASYNC, _async_threads);
if (error)
{
return error;
}
error = OpenConnections(IDX_SYNCH, _synch_threads);
if (!error)
{
LOG_INFO("sql.driver", "DatabasePool '%s' opened successfully. %u total connections running.",
GetDatabaseName(), (_connectionCount[IDX_SYNCH] + _connectionCount[IDX_ASYNC]));
LOG_INFO("sql.driver", "DatabasePool '%s' opened successfully. " SZFMTD
" total connections running.", GetDatabaseName(),
(_connections[IDX_SYNCH].size() + _connections[IDX_ASYNC].size()));
}
LOG_INFO("sql.driver", " ");
@@ -68,109 +100,59 @@ void DatabaseWorkerPool<T>::Close()
{
LOG_INFO("sql.driver", "Closing down DatabasePool '%s'.", GetDatabaseName());
//! Shuts down delaythreads for this connection pool by underlying deactivate().
//! The next dequeue attempt in the worker thread tasks will result in an error,
//! ultimately ending the worker thread task.
_queue->queue()->close();
//! Closes the actualy MySQL connection.
_connections[IDX_ASYNC].clear();
for (uint8 i = 0; i < _connectionCount[IDX_ASYNC]; ++i)
{
T* t = _connections[IDX_ASYNC][i];
DatabaseWorker* worker = t->m_worker;
worker->wait(); //! Block until no more threads are running this task.
delete worker;
t->Close(); //! Closes the actualy MySQL connection.
}
LOG_INFO("sql.driver", "Asynchronous connections on DatabasePool '%s' terminated. Proceeding with synchronous connections.",
LOG_INFO("sql.driver", "Asynchronous connections on DatabasePool '%s' terminated. "
"Proceeding with synchronous connections.",
GetDatabaseName());
//! Shut down the synchronous connections
//! There's no need for locking the connection, because DatabaseWorkerPool<>::Close
//! should only be called after any other thread tasks in the core have exited,
//! meaning there can be no concurrent access at this point.
for (uint8 i = 0; i < _connectionCount[IDX_SYNCH]; ++i)
_connections[IDX_SYNCH][i]->Close();
//! Deletes the ACE_Activation_Queue object and its underlying ACE_Message_Queue
delete _queue;
delete _mqueue;
_connections[IDX_SYNCH].clear();
LOG_INFO("sql.driver", "All connections on DatabasePool '%s' closed.", GetDatabaseName());
}
template <class T>
uint32 DatabaseWorkerPool<T>::OpenConnections(InternalIndex type, uint8 numConnections)
{
_connections[type].resize(numConnections);
for (uint8 i = 0; i < numConnections; ++i)
{
T* t;
if (type == IDX_ASYNC)
{
t = new T(_queue, *_connectionInfo);
}
else if (type == IDX_SYNCH)
{
t = new T(*_connectionInfo);
}
else
{
ASSERT(false, "> Incorrect InternalIndex (%u)", static_cast<uint32>(type));
}
_connections[type][i] = t;
++_connectionCount[type];
uint32 error = t->Open();
if (!error)
{
if (mysql_get_server_version(t->GetHandle()) < MIN_MYSQL_SERVER_VERSION)
{
LOG_ERROR("sql.driver", "Not support MySQL versions below 5.7");
error = 1;
}
}
// Failed to open a connection or invalid version, abort and cleanup
if (error)
{
while (_connectionCount[type] != 0)
{
T* t = _connections[type][i--];
delete t;
--_connectionCount[type];
}
return error;
}
}
// Everything is fine
return 0;
}
template <class T>
bool DatabaseWorkerPool<T>::PrepareStatements()
{
for (uint8 i = 0; i < IDX_SIZE; ++i)
for (auto& connections : _connections)
{
for (uint32 c = 0; c < _connectionCount[i]; ++c)
for (auto& connection : connections)
{
T* t = _connections[i][c];
t->LockIfReady();
if (!t->PrepareStatements())
connection->LockIfReady();
if (!connection->PrepareStatements())
{
t->Unlock();
connection->Unlock();
Close();
return false;
}
else
connection->Unlock();
size_t const preparedSize = connection->m_stmts.size();
if (_preparedStatementSize.size() < preparedSize)
_preparedStatementSize.resize(preparedSize);
for (size_t i = 0; i < preparedSize; ++i)
{
t->Unlock();
// already set by another connection
// (each connection only has prepared statements of it's own type sync/async)
if (_preparedStatementSize[i] > 0)
continue;
if (MySQLPreparedStatement* stmt = connection->m_stmts[i].get())
{
uint32 const paramCount = stmt->GetParameterCount();
// TC only supports uint8 indices.
ASSERT(paramCount < std::numeric_limits<uint8>::max());
_preparedStatementSize[i] = static_cast<uint8>(paramCount);
}
}
}
}
@@ -179,74 +161,28 @@ bool DatabaseWorkerPool<T>::PrepareStatements()
}
template <class T>
char const* DatabaseWorkerPool<T>::GetDatabaseName() const
QueryResult DatabaseWorkerPool<T>::Query(char const* sql, T* connection /*= nullptr*/)
{
return _connectionInfo->database.c_str();
}
if (!connection)
connection = GetFreeConnection();
template <class T>
void DatabaseWorkerPool<T>::Execute(const char* sql)
{
if (!sql)
return;
BasicStatementTask* task = new BasicStatementTask(sql);
Enqueue(task);
}
template <class T>
void DatabaseWorkerPool<T>::Execute(PreparedStatement* stmt)
{
PreparedStatementTask* task = new PreparedStatementTask(stmt);
Enqueue(task);
}
template <class T>
void DatabaseWorkerPool<T>::DirectExecute(const char* sql)
{
if (!sql)
return;
T* t = GetFreeConnection();
t->Execute(sql);
t->Unlock();
}
template <class T>
void DatabaseWorkerPool<T>::DirectExecute(PreparedStatement* stmt)
{
T* t = GetFreeConnection();
t->Execute(stmt);
t->Unlock();
//! Delete proxy-class. Not needed anymore
delete stmt;
}
template <class T>
QueryResult DatabaseWorkerPool<T>::Query(const char* sql, T* conn /* = nullptr*/)
{
if (!conn)
conn = GetFreeConnection();
ResultSet* result = conn->Query(sql);
conn->Unlock();
if (!result || !result->GetRowCount())
ResultSet* result = connection->Query(sql);
connection->Unlock();
if (!result || !result->GetRowCount() || !result->NextRow())
{
delete result;
return QueryResult(nullptr);
}
result->NextRow();
return QueryResult(result);
}
template <class T>
PreparedQueryResult DatabaseWorkerPool<T>::Query(PreparedStatement* stmt)
PreparedQueryResult DatabaseWorkerPool<T>::Query(PreparedStatement<T>* stmt)
{
T* t = GetFreeConnection();
PreparedResultSet* ret = t->Query(stmt);
t->Unlock();
auto connection = GetFreeConnection();
PreparedResultSet* ret = connection->Query(stmt);
connection->Unlock();
//! Delete proxy-class. Not needed anymore
delete stmt;
@@ -261,40 +197,66 @@ PreparedQueryResult DatabaseWorkerPool<T>::Query(PreparedStatement* stmt)
}
template <class T>
QueryResultFuture DatabaseWorkerPool<T>::AsyncQuery(const char* sql)
QueryCallback DatabaseWorkerPool<T>::AsyncQuery(char const* sql)
{
QueryResultFuture res;
BasicStatementTask* task = new BasicStatementTask(sql, res);
BasicStatementTask* task = new BasicStatementTask(sql, true);
// Store future result before enqueueing - task might get already processed and deleted before returning from this method
QueryResultFuture result = task->GetFuture();
Enqueue(task);
return res; //! Actual return value has no use yet
return QueryCallback(std::move(result));
}
template <class T>
PreparedQueryResultFuture DatabaseWorkerPool<T>::AsyncQuery(PreparedStatement* stmt)
QueryCallback DatabaseWorkerPool<T>::AsyncQuery(PreparedStatement<T>* stmt)
{
PreparedQueryResultFuture res;
PreparedStatementTask* task = new PreparedStatementTask(stmt, res);
PreparedStatementTask* task = new PreparedStatementTask(stmt, true);
// Store future result before enqueueing - task might get already processed and deleted before returning from this method
PreparedQueryResultFuture result = task->GetFuture();
Enqueue(task);
return res;
return QueryCallback(std::move(result));
}
template <class T>
QueryResultHolderFuture DatabaseWorkerPool<T>::DelayQueryHolder(SQLQueryHolder* holder)
SQLQueryHolderCallback DatabaseWorkerPool<T>::DelayQueryHolder(std::shared_ptr<SQLQueryHolder<T>> holder)
{
QueryResultHolderFuture res;
SQLQueryHolderTask* task = new SQLQueryHolderTask(holder, res);
SQLQueryHolderTask* task = new SQLQueryHolderTask(holder);
// Store future result before enqueueing - task might get already processed and deleted before returning from this method
QueryResultHolderFuture result = task->GetFuture();
Enqueue(task);
return res; //! Fool compiler, has no use yet
return { std::move(holder), std::move(result) };
}
template <class T>
SQLTransaction DatabaseWorkerPool<T>::BeginTransaction()
SQLTransaction<T> DatabaseWorkerPool<T>::BeginTransaction()
{
return SQLTransaction(new Transaction);
return std::make_shared<Transaction<T>>();
}
template <class T>
void DatabaseWorkerPool<T>::CommitTransaction(SQLTransaction transaction)
void DatabaseWorkerPool<T>::CommitTransaction(SQLTransaction<T> transaction)
{
#ifdef ACORE_DEBUG
//! Only analyze transaction weaknesses in Debug mode.
//! Ideally we catch the faults in Debug mode and then correct them,
//! so there's no need to waste these CPU cycles in Release mode.
switch (transaction->GetSize())
{
case 0:
LOG_DEBUG("sql.driver", "Transaction contains 0 queries. Not executing.");
return;
case 1:
LOG_DEBUG("sql.driver", "Warning: Transaction only holds 1 query, consider removing Transaction context in code.");
break;
default:
break;
}
#endif // ACORE_DEBUG
Enqueue(new TransactionTask(transaction));
}
template <class T>
TransactionCallback DatabaseWorkerPool<T>::AsyncCommitTransaction(SQLTransaction<T> transaction)
{
#ifdef ACORE_DEBUG
//! Only analyze transaction weaknesses in Debug mode.
@@ -303,38 +265,42 @@ void DatabaseWorkerPool<T>::CommitTransaction(SQLTransaction transaction)
switch (transaction->GetSize())
{
case 0:
LOG_INFO("sql.driver", "Transaction contains 0 queries. Not executing.");
return;
LOG_DEBUG("sql.driver", "Transaction contains 0 queries. Not executing.");
break;
case 1:
LOG_INFO("sql.driver", "Warning: Transaction only holds 1 query, consider removing Transaction context in code.");
LOG_DEBUG("sql.driver", "Warning: Transaction only holds 1 query, consider removing Transaction context in code.");
break;
default:
break;
}
#endif // ACORE_DEBUG
Enqueue(new TransactionTask(transaction));
TransactionWithResultTask* task = new TransactionWithResultTask(transaction);
TransactionFuture result = task->GetFuture();
Enqueue(task);
return TransactionCallback(std::move(result));
}
template <class T>
void DatabaseWorkerPool<T>::DirectCommitTransaction(SQLTransaction& transaction)
void DatabaseWorkerPool<T>::DirectCommitTransaction(SQLTransaction<T>& transaction)
{
T* con = GetFreeConnection();
int errorCode = con->ExecuteTransaction(transaction);
T* connection = GetFreeConnection();
int errorCode = connection->ExecuteTransaction(transaction);
if (!errorCode)
{
con->Unlock(); // OK, operation succesful
connection->Unlock(); // OK, operation succesful
return;
}
//! Handle MySQL Errno 1213 without extending deadlock to the core itself
//! TODO: More elegant way
/// @todo More elegant way
if (errorCode == ER_LOCK_DEADLOCK)
{
//todo: handle multiple sync threads deadlocking in a similar way as async threads
uint8 loopBreaker = 5;
for (uint8 i = 0; i < loopBreaker; ++i)
{
if (!con->ExecuteTransaction(transaction))
if (!connection->ExecuteTransaction(transaction))
break;
}
}
@@ -342,20 +308,177 @@ void DatabaseWorkerPool<T>::DirectCommitTransaction(SQLTransaction& transaction)
//! Clean up now.
transaction->Cleanup();
con->Unlock();
connection->Unlock();
}
template <class T>
void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction& trans, PreparedStatement* stmt)
PreparedStatement<T>* DatabaseWorkerPool<T>::GetPreparedStatement(PreparedStatementIndex index)
{
if (!trans)
Execute(stmt);
else
trans->Append(stmt);
return new PreparedStatement<T>(index, _preparedStatementSize[index]);
}
template <class T>
void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction& trans, const char* sql)
void DatabaseWorkerPool<T>::EscapeString(std::string& str)
{
if (str.empty())
return;
char* buf = new char[str.size() * 2 + 1];
EscapeString(buf, str.c_str(), uint32(str.size()));
str = buf;
delete[] buf;
}
template <class T>
void DatabaseWorkerPool<T>::KeepAlive()
{
//! Ping synchronous connections
for (auto& connection : _connections[IDX_SYNCH])
{
if (connection->LockIfReady())
{
connection->Ping();
connection->Unlock();
}
}
//! Assuming all worker threads are free, every worker thread will receive 1 ping operation request
//! If one or more worker threads are busy, the ping operations will not be split evenly, but this doesn't matter
//! as the sole purpose is to prevent connections from idling.
auto const count = _connections[IDX_ASYNC].size();
for (uint8 i = 0; i < count; ++i)
Enqueue(new PingOperation);
}
template <class T>
uint32 DatabaseWorkerPool<T>::OpenConnections(InternalIndex type, uint8 numConnections)
{
for (uint8 i = 0; i < numConnections; ++i)
{
// Create the connection
auto connection = [&] {
switch (type)
{
case IDX_ASYNC:
return std::make_unique<T>(_queue.get(), *_connectionInfo);
case IDX_SYNCH:
return std::make_unique<T>(*_connectionInfo);
default:
ABORT();
}
}();
if (uint32 error = connection->Open())
{
// Failed to open a connection or invalid version, abort and cleanup
_connections[type].clear();
return error;
}
else if (connection->GetServerVersion() < MIN_MYSQL_SERVER_VERSION)
{
LOG_ERROR("sql.driver", "AzerothCore does not support MySQL versions below 5.7");
return 1;
}
else
{
_connections[type].push_back(std::move(connection));
}
}
// Everything is fine
return 0;
}
template <class T>
unsigned long DatabaseWorkerPool<T>::EscapeString(char* to, char const* from, unsigned long length)
{
if (!to || !from || !length)
return 0;
return _connections[IDX_SYNCH].front()->EscapeString(to, from, length);
}
template <class T>
void DatabaseWorkerPool<T>::Enqueue(SQLOperation* op)
{
_queue->Push(op);
}
template <class T>
T* DatabaseWorkerPool<T>::GetFreeConnection()
{
#ifdef ACORE_DEBUG
if (_warnSyncQueries)
{
std::ostringstream ss;
ss << boost::stacktrace::stacktrace();
LOG_WARN("sql.performances", "Sync query at:\n%s", ss.str().c_str());
}
#endif
uint8 i = 0;
auto const num_cons = _connections[IDX_SYNCH].size();
T* connection = nullptr;
//! Block forever until a connection is free
for (;;)
{
connection = _connections[IDX_SYNCH][++i % num_cons].get();
//! Must be matched with t->Unlock() or you will get deadlocks
if (connection->LockIfReady())
break;
}
return connection;
}
template <class T>
char const* DatabaseWorkerPool<T>::GetDatabaseName() const
{
return _connectionInfo->database.c_str();
}
template <class T>
void DatabaseWorkerPool<T>::Execute(char const* sql)
{
if (Acore::IsFormatEmptyOrNull(sql))
return;
BasicStatementTask* task = new BasicStatementTask(sql);
Enqueue(task);
}
template <class T>
void DatabaseWorkerPool<T>::Execute(PreparedStatement<T>* stmt)
{
PreparedStatementTask* task = new PreparedStatementTask(stmt);
Enqueue(task);
}
template <class T>
void DatabaseWorkerPool<T>::DirectExecute(char const* sql)
{
if (Acore::IsFormatEmptyOrNull(sql))
return;
T* connection = GetFreeConnection();
connection->Execute(sql);
connection->Unlock();
}
template <class T>
void DatabaseWorkerPool<T>::DirectExecute(PreparedStatement<T>* stmt)
{
T* connection = GetFreeConnection();
connection->Execute(stmt);
connection->Unlock();
//! Delete proxy-class. Not needed anymore
delete stmt;
}
template <class T>
void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction<T>& trans, char const* sql)
{
if (!trans)
Execute(sql);
@@ -364,51 +487,14 @@ void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction& trans, const char* s
}
template <class T>
PreparedStatement* DatabaseWorkerPool<T>::GetPreparedStatement(uint32 index)
void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction<T>& trans, PreparedStatement<T>* stmt)
{
return new PreparedStatement(index);
if (!trans)
Execute(stmt);
else
trans->Append(stmt);
}
template <class T>
void DatabaseWorkerPool<T>::KeepAlive()
{
//! Ping synchronous connections
for (uint8 i = 0; i < _connectionCount[IDX_SYNCH]; ++i)
{
T* t = _connections[IDX_SYNCH][i];
if (t->LockIfReady())
{
t->Ping();
t->Unlock();
}
}
//! Assuming all worker threads are free, every worker thread will receive 1 ping operation request
//! If one or more worker threads are busy, the ping operations will not be split evenly, but this doesn't matter
//! as the sole purpose is to prevent connections from idling.
for (size_t i = 0; i < _connections[IDX_ASYNC].size(); ++i)
Enqueue(new PingOperation);
}
template <class T>
T* DatabaseWorkerPool<T>::GetFreeConnection()
{
uint8 i = 0;
size_t num_cons = _connectionCount[IDX_SYNCH];
T* t = nullptr;
//! Block forever until a connection is free
for (;;)
{
t = _connections[IDX_SYNCH][++i % num_cons];
//! Must be matched with t->Unlock() or you will get deadlocks
if (t->LockIfReady())
break;
}
return t;
}
template class DatabaseWorkerPool<LoginDatabaseConnection>;
template class DatabaseWorkerPool<WorldDatabaseConnection>;
template class DatabaseWorkerPool<CharacterDatabaseConnection>;
template class AC_DATABASE_API DatabaseWorkerPool<LoginDatabaseConnection>;
template class AC_DATABASE_API DatabaseWorkerPool<WorldDatabaseConnection>;
template class AC_DATABASE_API DatabaseWorkerPool<CharacterDatabaseConnection>;

View File

@@ -1,45 +1,45 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _DATABASEWORKERPOOL_H
#define _DATABASEWORKERPOOL_H
#include "Common.h"
#include "Callback.h"
#include "MySQLConnection.h"
#include "Transaction.h"
#include "DatabaseWorker.h"
#include "PreparedStatement.h"
#include "Log.h"
#include "QueryResult.h"
#include "QueryHolder.h"
#include "AdhocStatement.h"
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include "StringFormat.h"
#include <mutex>
#include <array>
#include <string>
#include <vector>
class PingOperation : public SQLOperation
{
//! Operation for idle delaythreads
bool Execute() override
{
m_conn->Ping();
return true;
}
};
template <typename T>
class ProducerConsumerQueue;
class SQLOperation;
struct MySQLConnectionInfo;
template <class T>
class DatabaseWorkerPool
{
private:
enum InternalIndex
{
IDX_ASYNC,
IDX_SYNCH,
IDX_SIZE
};
public:
/* Activity state */
DatabaseWorkerPool();
~DatabaseWorkerPool() = default;
~DatabaseWorkerPool();
void SetConnectionInfo(std::string const& infoString, uint8 const asyncThreads, uint8 const synchThreads);
uint32 Open();
void Close();
//! Prepares all prepared statements
@@ -56,12 +56,12 @@ public:
//! Enqueues a one-way SQL operation in string format that will be executed asynchronously.
//! This method should only be used for queries that are only executed once, e.g during startup.
void Execute(const char* sql);
void Execute(char const* sql);
//! Enqueues a one-way SQL operation in string format -with variable args- that will be executed asynchronously.
//! This method should only be used for queries that are only executed once, e.g during startup.
template<typename Format, typename... Args>
void PExecute(Format&& sql, Args&& ... args)
void PExecute(Format&& sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
return;
@@ -71,7 +71,7 @@ public:
//! Enqueues a one-way SQL operation in prepared statement format that will be executed asynchronously.
//! Statement must be prepared with CONNECTION_ASYNC flag.
void Execute(PreparedStatement* stmt);
void Execute(PreparedStatement<T>* stmt);
/**
Direct synchronous one-way statement methods.
@@ -79,12 +79,12 @@ public:
//! Directly executes a one-way SQL operation in string format, that will block the calling thread until finished.
//! This method should only be used for queries that are only executed once, e.g during startup.
void DirectExecute(const char* sql);
void DirectExecute(char const* sql);
//! Directly executes a one-way SQL operation in string format -with variable args-, that will block the calling thread until finished.
//! This method should only be used for queries that are only executed once, e.g during startup.
template<typename Format, typename... Args>
void DirectPExecute(Format&& sql, Args&& ... args)
void DirectPExecute(Format&& sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
return;
@@ -94,7 +94,7 @@ public:
//! Directly executes a one-way SQL operation in prepared statement format, that will block the calling thread until finished.
//! Statement must be prepared with the CONNECTION_SYNCH flag.
void DirectExecute(PreparedStatement* stmt);
void DirectExecute(PreparedStatement<T>* stmt);
/**
Synchronous query (with resultset) methods.
@@ -102,12 +102,12 @@ public:
//! Directly executes an SQL query in string format that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
QueryResult Query(const char* sql, T* conn = nullptr);
QueryResult Query(char const* sql, T* connection = nullptr);
//! Directly executes an SQL query in string format -with variable args- that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
template<typename Format, typename... Args>
QueryResult PQuery(Format&& sql, T* conn, Args&& ... args)
QueryResult PQuery(Format&& sql, T* conn, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
return QueryResult(nullptr);
@@ -118,7 +118,7 @@ public:
//! Directly executes an SQL query in string format -with variable args- that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
template<typename Format, typename... Args>
QueryResult PQuery(Format&& sql, Args&& ... args)
QueryResult PQuery(Format&& sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
return QueryResult(nullptr);
@@ -129,7 +129,7 @@ public:
//! Directly executes an SQL query in prepared format that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
//! Statement must be prepared with CONNECTION_SYNCH flag.
PreparedQueryResult Query(PreparedStatement* stmt);
PreparedQueryResult Query(PreparedStatement<T>* stmt);
/**
Asynchronous query (with resultset) methods.
@@ -137,113 +137,92 @@ public:
//! Enqueues a query in string format that will set the value of the QueryResultFuture return object as soon as the query is executed.
//! The return value is then processed in ProcessQueryCallback methods.
QueryResultFuture AsyncQuery(const char* sql);
//! Enqueues a query in string format -with variable args- that will set the value of the QueryResultFuture return object as soon as the query is executed.
//! The return value is then processed in ProcessQueryCallback methods.
template<typename Format, typename... Args>
QueryResultFuture AsyncPQuery(Format&& sql, Args&& ... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
return QueryResult(nullptr);
return AsyncQuery(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str());
}
QueryCallback AsyncQuery(char const* sql);
//! Enqueues a query in prepared format that will set the value of the PreparedQueryResultFuture return object as soon as the query is executed.
//! The return value is then processed in ProcessQueryCallback methods.
//! Statement must be prepared with CONNECTION_ASYNC flag.
PreparedQueryResultFuture AsyncQuery(PreparedStatement* stmt);
QueryCallback AsyncQuery(PreparedStatement<T>* stmt);
//! Enqueues a vector of SQL operations (can be both adhoc and prepared) that will set the value of the QueryResultHolderFuture
//! return object as soon as the query is executed.
//! The return value is then processed in ProcessQueryCallback methods.
//! Any prepared statements added to this holder need to be prepared with the CONNECTION_ASYNC flag.
QueryResultHolderFuture DelayQueryHolder(SQLQueryHolder* holder);
SQLQueryHolderCallback DelayQueryHolder(std::shared_ptr<SQLQueryHolder<T>> holder);
/**
Transaction context methods.
*/
//! Begins an automanaged transaction pointer that will automatically rollback if not commited. (Autocommit=0)
SQLTransaction BeginTransaction();
SQLTransaction<T> BeginTransaction();
//! Enqueues a collection of one-way SQL operations (can be both adhoc and prepared). The order in which these operations
//! were appended to the transaction will be respected during execution.
void CommitTransaction(SQLTransaction transaction);
void CommitTransaction(SQLTransaction<T> transaction);
//! Enqueues a collection of one-way SQL operations (can be both adhoc and prepared). The order in which these operations
//! were appended to the transaction will be respected during execution.
TransactionCallback AsyncCommitTransaction(SQLTransaction<T> transaction);
//! Directly executes a collection of one-way SQL operations (can be both adhoc and prepared). The order in which these operations
//! were appended to the transaction will be respected during execution.
void DirectCommitTransaction(SQLTransaction& transaction);
//! Method used to execute prepared statements in a diverse context.
//! Will be wrapped in a transaction if valid object is present, otherwise executed standalone.
void ExecuteOrAppend(SQLTransaction& trans, PreparedStatement* stmt);
void DirectCommitTransaction(SQLTransaction<T>& transaction);
//! Method used to execute ad-hoc statements in a diverse context.
//! Will be wrapped in a transaction if valid object is present, otherwise executed standalone.
void ExecuteOrAppend(SQLTransaction& trans, const char* sql);
void ExecuteOrAppend(SQLTransaction<T>& trans, char const* sql);
//! Method used to execute prepared statements in a diverse context.
//! Will be wrapped in a transaction if valid object is present, otherwise executed standalone.
void ExecuteOrAppend(SQLTransaction<T>& trans, PreparedStatement<T>* stmt);
/**
Other
*/
typedef typename T::Statements PreparedStatementIndex;
//! Automanaged (internally) pointer to a prepared statement object for usage in upper level code.
//! Pointer is deleted in this->DirectExecute(PreparedStatement*), this->Query(PreparedStatement*) or PreparedStatementTask::~PreparedStatementTask.
//! This object is not tied to the prepared statement on the MySQL context yet until execution.
PreparedStatement* GetPreparedStatement(uint32 index);
PreparedStatement<T>* GetPreparedStatement(PreparedStatementIndex index);
//! Apply escape string'ing for current collation. (utf8)
unsigned long EscapeString(char* to, const char* from, unsigned long length)
{
if (!to || !from || !length)
return 0;
return mysql_real_escape_string(_connections[IDX_SYNCH][0]->GetHandle(), to, from, length);
}
void EscapeString(std::string& str);
//! Keeps all our MySQL connections alive, prevent the server from disconnecting us.
void KeepAlive();
void EscapeString(std::string& str)
void WarnAboutSyncQueries([[maybe_unused]] bool warn)
{
if (str.empty())
return;
char* buf = new char[str.size() * 2 + 1];
EscapeString(buf, str.c_str(), str.size());
str = buf;
delete[] buf;
#ifdef ACORE_DEBUG
_warnSyncQueries = warn;
#endif
}
private:
enum InternalIndex
{
IDX_ASYNC,
IDX_SYNCH,
IDX_SIZE
};
uint32 OpenConnections(InternalIndex type, uint8 numConnections);
void Enqueue(SQLOperation* op)
{
_queue->enqueue(op);
}
unsigned long EscapeString(char* to, char const* from, unsigned long length);
[[nodiscard]] char const* GetDatabaseName() const;
void Enqueue(SQLOperation* op);
//! Gets a free connection in the synchronous connection pool.
//! Caller MUST call t->Unlock() after touching the MySQL context to prevent deadlocks.
T* GetFreeConnection();
ACE_Message_Queue<ACE_SYNCH>* _mqueue;
ACE_Activation_Queue* _queue; //! Queue shared by async worker threads.
std::vector<std::vector<T*>> _connections;
uint32 _connectionCount[IDX_SIZE]; //! Counter of MySQL connections;
char const* GetDatabaseName() const;
//! Queue shared by async worker threads.
std::unique_ptr<ProducerConsumerQueue<SQLOperation*>> _queue;
std::array<std::vector<std::unique_ptr<T>>, IDX_SIZE> _connections;
std::unique_ptr<MySQLConnectionInfo> _connectionInfo;
std::vector<uint8> _preparedStatementSize;
uint8 _async_threads, _synch_threads;
#ifdef ACORE_DEBUG
static inline thread_local bool _warnSyncQueries = false;
#endif
};
#endif

View File

@@ -1,57 +1,240 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "Field.h"
#include "Errors.h"
#include "Log.h"
#include "MySQLHacks.h"
Field::Field()
{
data.value = nullptr;
data.type = MYSQL_TYPE_NULL;
data.length = 0;
data.raw = false;
meta = nullptr;
}
Field::~Field()
{
CleanUp();
}
Field::~Field() = default;
void Field::SetByteValue(const void* newValue, const size_t newSize, enum_field_types newType, uint32 length)
uint8 Field::GetUInt8() const
{
if (data.value)
CleanUp();
if (!data.value)
return 0;
// This value stores raw bytes that have to be explicitly cast later
if (newValue)
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int8))
{
data.value = new char[newSize];
memcpy(data.value, newValue, newSize);
data.length = length;
LogWrongType(__FUNCTION__);
return 0;
}
data.type = newType;
data.raw = true;
#endif
if (data.raw)
return *reinterpret_cast<uint8 const*>(data.value);
return static_cast<uint8>(strtoul(data.value, nullptr, 10));
}
void Field::SetStructuredValue(char* newValue, enum_field_types newType, uint32 length)
int8 Field::GetInt8() const
{
if (data.value)
CleanUp();
if (!data.value)
return 0;
// This value stores somewhat structured data that needs function style casting
if (newValue)
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int8))
{
data.value = new char[length + 1];
memcpy(data.value, newValue, length);
*(reinterpret_cast<char*>(data.value) + length) = '\0';
data.length = length;
LogWrongType(__FUNCTION__);
return 0;
}
#endif
data.type = newType;
data.raw = false;
if (data.raw)
return *reinterpret_cast<int8 const*>(data.value);
return static_cast<int8>(strtol(data.value, nullptr, 10));
}
uint16 Field::GetUInt16() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int16))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint16 const*>(data.value);
return static_cast<uint16>(strtoul(data.value, nullptr, 10));
}
int16 Field::GetInt16() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int16))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int16 const*>(data.value);
return static_cast<int16>(strtol(data.value, nullptr, 10));
}
uint32 Field::GetUInt32() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int32))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint32 const*>(data.value);
return static_cast<uint32>(strtoul(data.value, nullptr, 10));
}
int32 Field::GetInt32() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int32))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int32 const*>(data.value);
return static_cast<int32>(strtol(data.value, nullptr, 10));
}
uint64 Field::GetUInt64() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int64))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint64 const*>(data.value);
return static_cast<uint64>(strtoull(data.value, nullptr, 10));
}
int64 Field::GetInt64() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int64))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int64 const*>(data.value);
return static_cast<int64>(strtoll(data.value, nullptr, 10));
}
float Field::GetFloat() const
{
if (!data.value)
return 0.0f;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Float))
{
LogWrongType(__FUNCTION__);
return 0.0f;
}
#endif
if (data.raw)
return *reinterpret_cast<float const*>(data.value);
return static_cast<float>(atof(data.value));
}
double Field::GetDouble() const
{
if (!data.value)
return 0.0f;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Double) && !IsType(DatabaseFieldTypes::Decimal))
{
LogWrongType(__FUNCTION__);
return 0.0f;
}
#endif
if (data.raw && !IsType(DatabaseFieldTypes::Decimal))
return *reinterpret_cast<double const*>(data.value);
return static_cast<double>(atof(data.value));
}
char const* Field::GetCString() const
{
if (!data.value)
return nullptr;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (IsNumeric() && data.raw)
{
LogWrongType(__FUNCTION__);
return nullptr;
}
#endif
return static_cast<char const*>(data.value);
}
std::string Field::GetString() const
{
if (!data.value)
return "";
char const* string = GetCString();
if (!string)
return "";
return std::string(string, data.length);
}
std::string_view Field::GetStringView() const
{
if (!data.value)
return {};
char const* const string = GetCString();
if (!string)
return {};
return { string, data.length };
}
std::vector<uint8> Field::GetBinary() const
@@ -67,6 +250,48 @@ std::vector<uint8> Field::GetBinary() const
void Field::GetBinarySizeChecked(uint8* buf, size_t length) const
{
ASSERT(data.value && (data.length == length));
ASSERT(data.value && (data.length == length), "Expected %zu-byte binary blob, got %sdata (%u bytes) instead", length, data.value ? "" : "no ", data.length);
memcpy(buf, data.value, length);
}
void Field::SetByteValue(char const* newValue, uint32 length)
{
// This value stores raw bytes that have to be explicitly cast later
data.value = newValue;
data.length = length;
data.raw = true;
}
void Field::SetStructuredValue(char const* newValue, uint32 length)
{
// This value stores somewhat structured data that needs function style casting
data.value = newValue;
data.length = length;
data.raw = false;
}
bool Field::IsType(DatabaseFieldTypes type) const
{
return meta->Type == type;
}
bool Field::IsNumeric() const
{
return (meta->Type == DatabaseFieldTypes::Int8 ||
meta->Type == DatabaseFieldTypes::Int16 ||
meta->Type == DatabaseFieldTypes::Int32 ||
meta->Type == DatabaseFieldTypes::Int64 ||
meta->Type == DatabaseFieldTypes::Float ||
meta->Type == DatabaseFieldTypes::Double);
}
void Field::LogWrongType(char const* getter) const
{
LOG_WARN("sql.sql", "Warning: %s on %s field %s.%s (%s.%s) at index %u.",
getter, meta->TypeName, meta->TableAlias, meta->Alias, meta->TableName, meta->Name, meta->Index);
}
void Field::SetMetadata(QueryResultFieldMetadata const* fieldMeta)
{
meta = fieldMeta;
}

View File

@@ -1,437 +1,135 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef AZEROTHCORE_FIELD_H
#define AZEROTHCORE_FIELD_H
#ifndef _FIELD_H
#define _FIELD_H
#include "Common.h"
#include "Log.h"
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include <array>
#include <mysql.h>
#include <string>
#include <string_view>
#include <vector>
class Field
enum class DatabaseFieldTypes : uint8
{
friend class ResultSet;
friend class PreparedResultSet;
Null,
Int8,
Int16,
Int32,
Int64,
Float,
Double,
Decimal,
Date,
Binary
};
struct QueryResultFieldMetadata
{
char const* TableName = nullptr;
char const* TableAlias = nullptr;
char const* Name = nullptr;
char const* Alias = nullptr;
char const* TypeName = nullptr;
uint32 Index = 0;
DatabaseFieldTypes Type = DatabaseFieldTypes::Null;
};
/**
@class Field
@brief Class used to access individual fields of database query result
Guideline on field type matching:
| MySQL type | method to use |
|------------------------|----------------------------------------|
| TINYINT | GetBool, GetInt8, GetUInt8 |
| SMALLINT | GetInt16, GetUInt16 |
| MEDIUMINT, INT | GetInt32, GetUInt32 |
| BIGINT | GetInt64, GetUInt64 |
| FLOAT | GetFloat |
| DOUBLE, DECIMAL | GetDouble |
| CHAR, VARCHAR, | GetCString, GetString |
| TINYTEXT, MEDIUMTEXT, | GetCString, GetString |
| TEXT, LONGTEXT | GetCString, GetString |
| TINYBLOB, MEDIUMBLOB, | GetBinary, GetString |
| BLOB, LONGBLOB | GetBinary, GetString |
| BINARY, VARBINARY | GetBinary |
Return types of aggregate functions:
| Function | Type |
|----------|-------------------|
| MIN, MAX | Same as the field |
| SUM, AVG | DECIMAL |
| COUNT | BIGINT |
*/
class AC_DATABASE_API Field
{
friend class ResultSet;
friend class PreparedResultSet;
public:
[[nodiscard]] bool GetBool() const // Wrapper, actually gets integer
Field();
~Field();
bool GetBool() const // Wrapper, actually gets integer
{
return (GetUInt8() == 1);
return GetUInt8() == 1 ? true : false;
}
[[nodiscard]] uint8 GetUInt8() const
{
if (!data.value)
return 0;
uint8 GetUInt8() const;
int8 GetInt8() const;
uint16 GetUInt16() const;
int16 GetInt16() const;
uint32 GetUInt32() const;
int32 GetInt32() const;
uint64 GetUInt64() const;
int64 GetInt64() const;
float GetFloat() const;
double GetDouble() const;
char const* GetCString() const;
std::string GetString() const;
std::string_view GetStringView() const;
std::vector<uint8> GetBinary() const;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_TINY))
{
LOG_INFO("sql.driver", "Warning: GetUInt8() on non-tinyint field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint8*>(data.value);
return static_cast<uint8>(atol((char*)data.value));
}
[[nodiscard]] int8 GetInt8() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_TINY))
{
LOG_INFO("sql.driver", "Warning: GetInt8() on non-tinyint field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int8*>(data.value);
return static_cast<int8>(atol((char*)data.value));
}
#ifdef ELUNA
enum_field_types GetType() const
{
return data.type;
}
#endif
[[nodiscard]] uint16 GetUInt16() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_SHORT) && !IsType(MYSQL_TYPE_YEAR))
{
LOG_INFO("sql.driver", "Warning: GetUInt16() on non-smallint field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint16*>(data.value);
return static_cast<uint16>(atol((char*)data.value));
}
[[nodiscard]] int16 GetInt16() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_SHORT) && !IsType(MYSQL_TYPE_YEAR))
{
LOG_INFO("sql.driver", "Warning: GetInt16() on non-smallint field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int16*>(data.value);
return static_cast<int16>(atol((char*)data.value));
}
[[nodiscard]] uint32 GetUInt32() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_INT24) && !IsType(MYSQL_TYPE_LONG))
{
LOG_INFO("sql.driver", "Warning: GetUInt32() on non-(medium)int field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint32*>(data.value);
return static_cast<uint32>(atol((char*)data.value));
}
[[nodiscard]] int32 GetInt32() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_INT24) && !IsType(MYSQL_TYPE_LONG))
{
LOG_INFO("sql.driver", "Warning: GetInt32() on non-(medium)int field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int32*>(data.value);
return static_cast<int32>(atol((char*)data.value));
}
[[nodiscard]] uint64 GetUInt64() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_LONGLONG) && !IsType(MYSQL_TYPE_BIT))
{
LOG_INFO("sql.driver", "Warning: GetUInt64() on non-bigint field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint64*>(data.value);
return static_cast<uint64>(atol((char*)data.value));
}
[[nodiscard]] int64 GetInt64() const
{
if (!data.value)
return 0;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_LONGLONG) && !IsType(MYSQL_TYPE_BIT))
{
LOG_INFO("sql.driver", "Warning: GetInt64() on non-bigint field. Using type: %s.", FieldTypeToString(data.type));
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int64*>(data.value);
return static_cast<int64>(strtol((char*)data.value, nullptr, 10));
}
[[nodiscard]] float GetFloat() const
{
if (!data.value)
return 0.0f;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_FLOAT))
{
LOG_INFO("sql.driver", "Warning: GetFloat() on non-float field. Using type: %s.", FieldTypeToString(data.type));
return 0.0f;
}
#endif
if (data.raw)
return *reinterpret_cast<float*>(data.value);
return static_cast<float>(atof((char*)data.value));
}
[[nodiscard]] double GetDouble() const
{
if (!data.value)
return 0.0f;
#ifdef ACORE_DEBUG
if (!IsType(MYSQL_TYPE_DOUBLE))
{
LOG_INFO("sql.driver", "Warning: GetDouble() on non-double field. Using type: %s.", FieldTypeToString(data.type));
return 0.0f;
}
#endif
if (data.raw)
return *reinterpret_cast<double*>(data.value);
return static_cast<double>(atof((char*)data.value));
}
[[nodiscard]] char const* GetCString() const
{
if (!data.value)
return nullptr;
#ifdef ACORE_DEBUG
if (IsNumeric())
{
LOG_INFO("sql.driver", "Error: GetCString() on numeric field. Using type: %s.", FieldTypeToString(data.type));
return nullptr;
}
#endif
return static_cast<char const*>(data.value);
}
[[nodiscard]] std::string GetString() const
{
if (!data.value)
return "";
if (data.raw)
{
char const* string = GetCString();
if (!string)
string = "";
return std::string(string, data.length);
}
return std::string((char*)data.value);
}
[[nodiscard]] bool IsNull() const
{
if (IsBinary() && data.length == 0)
{
return true;
}
return data.value == nullptr;
}
[[nodiscard]] std::vector<uint8> GetBinary() const;
template<size_t S>
[[nodiscard]] std::array<uint8, S> GetBinary() const
template <size_t S>
std::array<uint8, S> GetBinary() const
{
std::array<uint8, S> buf;
GetBinarySizeChecked(buf.data(), S);
return buf;
}
protected:
Field();
~Field();
bool IsNull() const
{
return data.value == nullptr;
}
#if defined(__GNUC__)
#pragma pack(1)
#else
#pragma pack(push, 1)
#endif
DatabaseFieldTypes GetType() { return meta->Type; }
protected:
struct
{
uint32 length; // Length (prepared strings only)
void* value; // Actual data in memory
enum_field_types type; // Field type
char const* value; // Actual data in memory
uint32 length; // Length
bool raw; // Raw bytes? (Prepared statement or ad hoc)
} data;
#if defined(__GNUC__)
#pragma pack()
#else
#pragma pack(pop)
#endif
void SetByteValue(void const* newValue, size_t const newSize, enum_field_types newType, uint32 length);
void SetStructuredValue(char* newValue, enum_field_types newType, uint32 length);
void CleanUp()
{
delete[] ((char*)data.value);
data.value = nullptr;
}
static size_t SizeForType(MYSQL_FIELD* field)
{
switch (field->type)
{
case MYSQL_TYPE_NULL:
return 0;
case MYSQL_TYPE_TINY:
return 1;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return 2;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_FLOAT:
return 4;
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return 8;
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return sizeof(MYSQL_TIME);
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return field->max_length + 1;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return 64;
case MYSQL_TYPE_GEOMETRY:
/*
Following types are not sent over the wire:
MYSQL_TYPE_ENUM:
MYSQL_TYPE_SET:
*/
default:
LOG_INFO("sql.driver", "SQL::SizeForType(): invalid field type %u", uint32(field->type));
return 0;
}
}
[[nodiscard]] bool IsType(enum_field_types type) const
{
return data.type == type;
}
[[nodiscard]] bool IsNumeric() const
{
return (data.type == MYSQL_TYPE_TINY ||
data.type == MYSQL_TYPE_SHORT ||
data.type == MYSQL_TYPE_INT24 ||
data.type == MYSQL_TYPE_LONG ||
data.type == MYSQL_TYPE_FLOAT ||
data.type == MYSQL_TYPE_DOUBLE ||
data.type == MYSQL_TYPE_LONGLONG );
}
[[nodiscard]] bool IsBinary() const
{
return (
data.type == MYSQL_TYPE_TINY_BLOB ||
data.type == MYSQL_TYPE_MEDIUM_BLOB ||
data.type == MYSQL_TYPE_LONG_BLOB ||
data.type == MYSQL_TYPE_BLOB ||
data.type == MYSQL_TYPE_VAR_STRING ||
data.type == MYSQL_TYPE_STRING
);
}
void GetBinarySizeChecked(uint8* buf, size_t size) const;
void SetByteValue(char const* newValue, uint32 length);
void SetStructuredValue(char const* newValue, uint32 length);
bool IsType(DatabaseFieldTypes type) const;
bool IsNumeric() const;
private:
#ifdef ACORE_DEBUG
static char const* FieldTypeToString(enum_field_types type)
{
switch (type)
{
case MYSQL_TYPE_BIT:
return "BIT";
case MYSQL_TYPE_BLOB:
return "BLOB";
case MYSQL_TYPE_DATE:
return "DATE";
case MYSQL_TYPE_DATETIME:
return "DATETIME";
case MYSQL_TYPE_NEWDECIMAL:
return "NEWDECIMAL";
case MYSQL_TYPE_DECIMAL:
return "DECIMAL";
case MYSQL_TYPE_DOUBLE:
return "DOUBLE";
case MYSQL_TYPE_ENUM:
return "ENUM";
case MYSQL_TYPE_FLOAT:
return "FLOAT";
case MYSQL_TYPE_GEOMETRY:
return "GEOMETRY";
case MYSQL_TYPE_INT24:
return "INT24";
case MYSQL_TYPE_LONG:
return "LONG";
case MYSQL_TYPE_LONGLONG:
return "LONGLONG";
case MYSQL_TYPE_LONG_BLOB:
return "LONG_BLOB";
case MYSQL_TYPE_MEDIUM_BLOB:
return "MEDIUM_BLOB";
case MYSQL_TYPE_NEWDATE:
return "NEWDATE";
case MYSQL_TYPE_NULL:
return "nullptr";
case MYSQL_TYPE_SET:
return "SET";
case MYSQL_TYPE_SHORT:
return "SHORT";
case MYSQL_TYPE_STRING:
return "STRING";
case MYSQL_TYPE_TIME:
return "TIME";
case MYSQL_TYPE_TIMESTAMP:
return "TIMESTAMP";
case MYSQL_TYPE_TINY:
return "TINY";
case MYSQL_TYPE_TINY_BLOB:
return "TINY_BLOB";
case MYSQL_TYPE_VAR_STRING:
return "VAR_STRING";
case MYSQL_TYPE_YEAR:
return "YEAR";
default:
return "-Unknown-";
}
}
#endif
QueryResultFieldMetadata const* meta;
void LogWrongType(char const* getter) const;
void SetMetadata(QueryResultFieldMetadata const* fieldMeta);
void GetBinarySizeChecked(uint8* buf, size_t size) const;
};
#endif

View File

@@ -1,10 +1,10 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "CharacterDatabase.h"
#include "MySQLPreparedStatement.h"
void CharacterDatabaseConnection::DoPrepareStatements()
{
@@ -37,7 +37,7 @@ void CharacterDatabaseConnection::DoPrepareStatements()
"cb.guid, c.extra_flags, cd.genitive FROM characters AS c LEFT JOIN character_pet AS cp ON c.guid = cp.owner AND cp.slot = ? "
"LEFT JOIN character_declinedname AS cd ON c.guid = cd.guid LEFT JOIN guild_member AS gm ON c.guid = gm.guid "
"LEFT JOIN character_banned AS cb ON c.guid = cb.guid AND cb.active = 1 WHERE c.account = ? AND c.deleteInfos_Name IS NULL ORDER BY c.guid", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_FREE_NAME, "SELECT guid, name FROM characters WHERE guid = ? AND account = ? AND (at_login & ?) = ? AND NOT EXISTS (SELECT NULL FROM characters WHERE name = ?)", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_FREE_NAME, "SELECT guid, name, at_login FROM characters WHERE guid = ? AND account = ? AND NOT EXISTS (SELECT NULL FROM characters WHERE name = ?)", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_CHAR_ZONE, "SELECT zone FROM characters WHERE guid = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHARACTER_NAME_DATA, "SELECT race, class, gender, level FROM characters WHERE guid = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHAR_POSITION_XYZ, "SELECT map, position_x, position_y, position_z FROM characters WHERE guid = ?", CONNECTION_SYNCH);
@@ -138,7 +138,6 @@ void CharacterDatabaseConnection::DoPrepareStatements()
PrepareStatement(CHAR_INS_ACCOUNT_INSTANCE_LOCK_TIMES, "INSERT INTO account_instance_times (accountId, instanceId, releaseTime) VALUES (?, ?, ?)", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_MATCH_MAKER_RATING, "SELECT matchMakerRating, maxMMR FROM character_arena_stats WHERE guid = ? AND slot = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHARACTER_COUNT, "SELECT ? AS account,(SELECT COUNT(*) FROM characters WHERE account =?) AS cnt", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_NAME, "UPDATE characters set name = ?, at_login = at_login & ~ ? WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_DEL_DECLINED_NAME, "DELETE FROM character_declinedname WHERE guid = ?", CONNECTION_ASYNC);
// Guild handling
@@ -361,7 +360,7 @@ void CharacterDatabaseConnection::DoPrepareStatements()
PrepareStatement(CHAR_DEL_INVALID_PET_SPELL, "DELETE FROM pet_spell WHERE spell = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_GLOBAL_INSTANCE_RESETTIME, "UPDATE instance_reset SET resettime = ? WHERE mapid = ? AND difficulty = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_CHAR_ONLINE, "UPDATE characters SET online = 1 WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_CHAR_NAME_AT_LOGIN, "UPDATE characters set name = ?, at_login = at_login & ~ ? WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_CHAR_NAME_AT_LOGIN, "UPDATE characters set name = ?, at_login = ? WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_WORLDSTATE, "UPDATE worldstates SET value = ? WHERE entry = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_INS_WORLDSTATE, "INSERT INTO worldstates (entry, value) VALUES (?, ?)", CONNECTION_ASYNC);
PrepareStatement(CHAR_DEL_CHAR_INSTANCE_BY_INSTANCE, "DELETE FROM character_instance WHERE instance = ?", CONNECTION_ASYNC);
@@ -396,6 +395,8 @@ void CharacterDatabaseConnection::DoPrepareStatements()
PrepareStatement(CHAR_SEL_POOL_QUEST_SAVE, "SELECT quest_id FROM pool_quest_save WHERE pool_id = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHARACTER_AT_LOGIN, "SELECT at_login FROM characters WHERE guid = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHAR_CLASS_LVL_AT_LOGIN, "SELECT class, level, at_login, knownTitles FROM characters WHERE guid = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHAR_CUSTOMIZE_INFO, "SELECT name, race, class, gender, at_login FROM characters WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_CHAR_RACE_OR_FACTION_CHANGE_INFOS, "SELECT at_login, knownTitles, money FROM characters WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_CHAR_AT_LOGIN_TITLES_MONEY, "SELECT at_login, knownTitles, money FROM characters WHERE guid = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHAR_COD_ITEM_MAIL, "SELECT id, messageType, mailTemplateId, sender, subject, body, money, has_items FROM mail WHERE receiver = ? AND has_items <> 0 AND cod <> 0", CONNECTION_SYNCH);
PrepareStatement(CHAR_SEL_CHAR_SOCIAL, "SELECT DISTINCT guid FROM character_social WHERE friend = ?", CONNECTION_SYNCH);
@@ -437,7 +438,7 @@ void CharacterDatabaseConnection::DoPrepareStatements()
PrepareStatement(CHAR_DEL_PETITION_SIGNATURE_BY_GUID, "DELETE FROM petition_sign WHERE petitionguid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_DEL_CHAR_DECLINED_NAME, "DELETE FROM character_declinedname WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_INS_CHAR_DECLINED_NAME, "INSERT INTO character_declinedname (guid, genitive, dative, accusative, instrumental, prepositional) VALUES (?, ?, ?, ?, ?, ?)", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_FACTION_OR_RACE, "UPDATE characters SET name = ?, race = ?, at_login = at_login & ~ ? WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_CHAR_RACE, "UPDATE characters SET race = ? WHERE guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_DEL_CHAR_SKILL_LANGUAGES, "DELETE FROM character_skills WHERE skill IN (98, 113, 759, 111, 313, 109, 115, 315, 673, 137) AND guid = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_INS_CHAR_SKILL_LANGUAGE, "INSERT INTO `character_skills` (guid, skill, value, max) VALUES (?, ?, 300, 300)", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_CHAR_TAXI_PATH, "UPDATE characters SET taxi_path = '' WHERE guid = ?", CONNECTION_ASYNC);
@@ -546,6 +547,7 @@ void CharacterDatabaseConnection::DoPrepareStatements()
PrepareStatement(CHAR_SEL_CHAR_PET_BY_ENTRY_AND_SLOT_2, "SELECT id, entry, owner, modelid, level, exp, Reactstate, slot, name, renamed, curhealth, curmana, curhappiness, abdata, savetime, CreatedBySpell, PetType FROM character_pet WHERE owner = ? AND entry = ? AND (slot = ? OR slot > ?)", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_CHAR_PET_BY_SLOT, "SELECT id, entry, owner, modelid, level, exp, Reactstate, slot, name, renamed, curhealth, curmana, curhappiness, abdata, savetime, CreatedBySpell, PetType FROM character_pet WHERE owner = ? AND (slot = ? OR slot > ?) ", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_CHAR_PET_BY_ENTRY_AND_SLOT, "SELECT id, entry, owner, modelid, level, exp, Reactstate, slot, name, renamed, curhealth, curmana, curhappiness, abdata, savetime, CreatedBySpell, PetType FROM character_pet WHERE owner = ? AND slot = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_SEL_CHAR_PET_BY_ENTRY_AND_SLOT_SYNS, "SELECT id, entry, owner, modelid, level, exp, Reactstate, slot, name, renamed, curhealth, curmana, curhappiness, abdata, savetime, CreatedBySpell, PetType FROM character_pet WHERE owner = ? AND slot = ?", CONNECTION_SYNCH);
PrepareStatement(CHAR_DEL_CHAR_PET_BY_OWNER, "DELETE FROM character_pet WHERE owner = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UPD_CHAR_PET_NAME, "UPDATE character_pet SET name = ?, renamed = 1 WHERE owner = ? AND id = ?", CONNECTION_ASYNC);
PrepareStatement(CHAR_UDP_CHAR_PET_SLOT_BY_SLOT_EXCLUDE_ID, "UPDATE character_pet SET slot = ? WHERE owner = ? AND slot = ? AND id <> ?", CONNECTION_ASYNC);
@@ -578,3 +580,15 @@ void CharacterDatabaseConnection::DoPrepareStatements()
// Character names
PrepareStatement(CHAR_INS_RESERVED_PLAYER_NAME, "INSERT IGNORE INTO reserved_name (name) VALUES (?)", CONNECTION_ASYNC);
}
CharacterDatabaseConnection::CharacterDatabaseConnection(MySQLConnectionInfo& connInfo) : MySQLConnection(connInfo)
{
}
CharacterDatabaseConnection::CharacterDatabaseConnection(ProducerConsumerQueue<SQLOperation*>* q, MySQLConnectionInfo& connInfo) : MySQLConnection(q, connInfo)
{
}
CharacterDatabaseConnection::~CharacterDatabaseConnection()
{
}

View File

@@ -1,29 +1,14 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _CHARACTERDATABASE_H
#define _CHARACTERDATABASE_H
#include "DatabaseWorkerPool.h"
#include "MySQLConnection.h"
class CharacterDatabaseConnection : public MySQLConnection
{
public:
//- Constructors for sync and async connections
CharacterDatabaseConnection(MySQLConnectionInfo& connInfo) : MySQLConnection(connInfo) {}
CharacterDatabaseConnection(ACE_Activation_Queue* q, MySQLConnectionInfo& connInfo) : MySQLConnection(q, connInfo) {}
//- Loads database type specific prepared statements
void DoPrepareStatements() override;
};
typedef DatabaseWorkerPool<CharacterDatabaseConnection> CharacterDatabaseWorkerPool;
enum CharacterDatabaseStatements
enum CharacterDatabaseStatements : uint32
{
/* Naming standard for defines:
{DB}_{SEL/INS/UPD/DEL/REP}_{Summary of data changed}
@@ -137,7 +122,6 @@ enum CharacterDatabaseStatements
CHAR_INS_ACCOUNT_INSTANCE_LOCK_TIMES,
CHAR_SEL_MATCH_MAKER_RATING,
CHAR_SEL_CHARACTER_COUNT,
CHAR_UPD_NAME,
CHAR_DEL_DECLINED_NAME,
CHAR_INS_GUILD,
@@ -342,6 +326,8 @@ enum CharacterDatabaseStatements
CHAR_SEL_POOL_QUEST_SAVE,
CHAR_SEL_CHARACTER_AT_LOGIN,
CHAR_SEL_CHAR_CLASS_LVL_AT_LOGIN,
CHAR_SEL_CHAR_CUSTOMIZE_INFO,
CHAR_SEL_CHAR_RACE_OR_FACTION_CHANGE_INFOS,
CHAR_SEL_CHAR_AT_LOGIN_TITLES_MONEY,
CHAR_SEL_CHAR_COD_ITEM_MAIL,
CHAR_SEL_CHAR_SOCIAL,
@@ -378,7 +364,7 @@ enum CharacterDatabaseStatements
CHAR_DEL_PETITION_SIGNATURE_BY_GUID,
CHAR_DEL_CHAR_DECLINED_NAME,
CHAR_INS_CHAR_DECLINED_NAME,
CHAR_UPD_FACTION_OR_RACE,
CHAR_UPD_CHAR_RACE,
CHAR_DEL_CHAR_SKILL_LANGUAGES,
CHAR_INS_CHAR_SKILL_LANGUAGE,
CHAR_UPD_CHAR_TAXI_PATH,
@@ -466,6 +452,7 @@ enum CharacterDatabaseStatements
CHAR_DEL_CHAR_PET_BY_OWNER,
CHAR_DEL_CHAR_PET_DECLINEDNAME_BY_OWNER,
CHAR_SEL_CHAR_PET_BY_ENTRY_AND_SLOT,
CHAR_SEL_CHAR_PET_BY_ENTRY_AND_SLOT_SYNS,
CHAR_SEL_PET_SLOTS,
CHAR_SEL_PET_SLOTS_DETAIL,
CHAR_SEL_PET_ENTRY,
@@ -512,4 +499,18 @@ enum CharacterDatabaseStatements
MAX_CHARACTERDATABASE_STATEMENTS
};
class AC_DATABASE_API CharacterDatabaseConnection : public MySQLConnection
{
public:
typedef CharacterDatabaseStatements Statements;
//- Constructors for sync and async connections
CharacterDatabaseConnection(MySQLConnectionInfo& connInfo);
CharacterDatabaseConnection(ProducerConsumerQueue<SQLOperation*>* q, MySQLConnectionInfo& connInfo);
~CharacterDatabaseConnection();
//- Loads database type specific prepared statements
void DoPrepareStatements() override;
};
#endif

View File

@@ -1,10 +1,10 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "LoginDatabase.h"
#include "MySQLPreparedStatement.h"
void LoginDatabaseConnection::DoPrepareStatements()
{
@@ -110,3 +110,15 @@ void LoginDatabaseConnection::DoPrepareStatements()
PrepareStatement(LOGIN_SEL_ACCOUNT_TOTP_SECRET, "SELECT totp_secret FROM account WHERE id = ?", CONNECTION_SYNCH);
PrepareStatement(LOGIN_UPD_ACCOUNT_TOTP_SECRET, "UPDATE account SET totp_secret = ? WHERE id = ?", CONNECTION_ASYNC);
}
LoginDatabaseConnection::LoginDatabaseConnection(MySQLConnectionInfo& connInfo) : MySQLConnection(connInfo)
{
}
LoginDatabaseConnection::LoginDatabaseConnection(ProducerConsumerQueue<SQLOperation*>* q, MySQLConnectionInfo& connInfo) : MySQLConnection(q, connInfo)
{
}
LoginDatabaseConnection::~LoginDatabaseConnection()
{
}

View File

@@ -1,29 +1,14 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _LOGINDATABASE_H
#define _LOGINDATABASE_H
#include "DatabaseWorkerPool.h"
#include "MySQLConnection.h"
class LoginDatabaseConnection : public MySQLConnection
{
public:
//- Constructors for sync and async connections
LoginDatabaseConnection(MySQLConnectionInfo& connInfo) : MySQLConnection(connInfo) { }
LoginDatabaseConnection(ACE_Activation_Queue* q, MySQLConnectionInfo& connInfo) : MySQLConnection(q, connInfo) { }
//- Loads database type specific prepared statements
void DoPrepareStatements() override;
};
typedef DatabaseWorkerPool<LoginDatabaseConnection> LoginDatabaseWorkerPool;
enum LoginDatabaseStatements
enum LoginDatabaseStatements : uint32
{
/* Naming standard for defines:
{DB}_{SEL/INS/UPD/DEL/REP}_{Summary of data changed}
@@ -121,4 +106,18 @@ enum LoginDatabaseStatements
MAX_LOGINDATABASE_STATEMENTS
};
class AC_DATABASE_API LoginDatabaseConnection : public MySQLConnection
{
public:
typedef LoginDatabaseStatements Statements;
//- Constructors for sync and async connections
LoginDatabaseConnection(MySQLConnectionInfo& connInfo);
LoginDatabaseConnection(ProducerConsumerQueue<SQLOperation*>* q, MySQLConnectionInfo& connInfo);
~LoginDatabaseConnection();
//- Loads database type specific prepared statements
void DoPrepareStatements() override;
};
#endif

View File

@@ -1,10 +1,10 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "WorldDatabase.h"
#include "MySQLPreparedStatement.h"
void WorldDatabaseConnection::DoPrepareStatements()
{
@@ -84,3 +84,15 @@ void WorldDatabaseConnection::DoPrepareStatements()
// 0: uint8
PrepareStatement(WORLD_SEL_REQ_XP, "SELECT Experience FROM player_xp_for_level WHERE Level = ?", CONNECTION_SYNCH);
}
WorldDatabaseConnection::WorldDatabaseConnection(MySQLConnectionInfo& connInfo) : MySQLConnection(connInfo)
{
}
WorldDatabaseConnection::WorldDatabaseConnection(ProducerConsumerQueue<SQLOperation*>* q, MySQLConnectionInfo& connInfo) : MySQLConnection(q, connInfo)
{
}
WorldDatabaseConnection::~WorldDatabaseConnection()
{
}

View File

@@ -1,29 +1,14 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _WORLDDATABASE_H
#define _WORLDDATABASE_H
#include "DatabaseWorkerPool.h"
#include "MySQLConnection.h"
class WorldDatabaseConnection : public MySQLConnection
{
public:
//- Constructors for sync and async connections
WorldDatabaseConnection(MySQLConnectionInfo& connInfo) : MySQLConnection(connInfo) { }
WorldDatabaseConnection(ACE_Activation_Queue* q, MySQLConnectionInfo& connInfo) : MySQLConnection(q, connInfo) { }
//- Loads database type specific prepared statements
void DoPrepareStatements() override;
};
typedef DatabaseWorkerPool<WorldDatabaseConnection> WorldDatabaseWorkerPool;
enum WorldDatabaseStatements
enum WorldDatabaseStatements : uint32
{
/* Naming standard for defines:
{DB}_{SEL/INS/UPD/DEL/REP}_{Summary of data changed}
@@ -107,4 +92,18 @@ enum WorldDatabaseStatements
MAX_WORLDDATABASE_STATEMENTS
};
class AC_DATABASE_API WorldDatabaseConnection : public MySQLConnection
{
public:
typedef WorldDatabaseStatements Statements;
//- Constructors for sync and async connections
WorldDatabaseConnection(MySQLConnectionInfo& connInfo);
WorldDatabaseConnection(ProducerConsumerQueue<SQLOperation*>* q, MySQLConnectionInfo& connInfo);
~WorldDatabaseConnection();
//- Loads database type specific prepared statements
void DoPrepareStatements() override;
};
#endif

View File

@@ -1,54 +1,71 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "Common.h"
#include "MySQLConnection.h"
#include "MySQLThreading.h"
#include "QueryResult.h"
#include "SQLOperation.h"
#include "PreparedStatement.h"
#include "Common.h"
#include "DatabaseWorker.h"
#include "Timer.h"
#include "Log.h"
#include "Duration.h"
#include <mysql.h>
#include <mysqld_error.h>
#include "MySQLHacks.h"
#include "MySQLPreparedStatement.h"
#include "MySQLWorkaround.h"
#include "PreparedStatement.h"
#include "QueryResult.h"
#include "Timer.h"
#include "Tokenize.h"
#include "Transaction.h"
#include "Util.h"
#include <errmsg.h>
#include <thread>
#include <mysqld_error.h>
#ifdef _WIN32
#include <winsock2.h>
#endif
MySQLConnection::MySQLConnection(MySQLConnectionInfo& connInfo) :
m_reconnecting(false),
m_prepareError(false),
m_queue(nullptr),
m_worker(nullptr),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_SYNCH)
MySQLConnectionInfo::MySQLConnectionInfo(std::string const& infoString)
{
std::vector<std::string_view> tokens = Acore::Tokenize(infoString, ';', true);
if (tokens.size() != 5 && tokens.size() != 6)
return;
host.assign(tokens[0]);
port_or_socket.assign(tokens[1]);
user.assign(tokens[2]);
password.assign(tokens[3]);
database.assign(tokens[4]);
if (tokens.size() == 6)
ssl.assign(tokens[5]);
}
MySQLConnection::MySQLConnection(ACE_Activation_Queue* queue, MySQLConnectionInfo& connInfo) :
m_reconnecting(false),
m_prepareError(false),
m_queue(queue),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_ASYNC)
MySQLConnection::MySQLConnection(MySQLConnectionInfo& connInfo) :
m_reconnecting(false),
m_prepareError(false),
m_queue(nullptr),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_SYNCH) { }
MySQLConnection::MySQLConnection(ProducerConsumerQueue<SQLOperation*>* queue, MySQLConnectionInfo& connInfo) :
m_reconnecting(false),
m_prepareError(false),
m_queue(queue),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_ASYNC)
{
m_worker = new DatabaseWorker(m_queue, this);
m_worker = std::make_unique<DatabaseWorker>(m_queue, this);
}
MySQLConnection::~MySQLConnection()
{
for (auto stmt : m_stmts)
delete stmt;
Close();
}
void MySQLConnection::Close()
{
// Stop the worker thread before the statements are cleared
m_worker.reset();
m_stmts.clear();
if (m_Mysql)
{
@@ -57,19 +74,14 @@ MySQLConnection::~MySQLConnection()
}
}
void MySQLConnection::Close()
{
/// Only close us if we're not operating
delete this;
}
uint32 MySQLConnection::Open()
{
MYSQL* mysqlInit = mysql_init(nullptr);
MYSQL *mysqlInit;
mysqlInit = mysql_init(nullptr);
if (!mysqlInit)
{
LOG_ERROR("sql.sql", "Could not initialize Mysql connection to database `%s`", m_connectionInfo.database.c_str());
return false;
return CR_UNKNOWN_ERROR;
}
int port;
@@ -78,7 +90,7 @@ uint32 MySQLConnection::Open()
mysql_options(mysqlInit, MYSQL_SET_CHARSET_NAME, "utf8");
//mysql_options(mysqlInit, MYSQL_OPT_READ_TIMEOUT, (char const*)&timeout);
#ifdef _WIN32
#ifdef _WIN32
if (m_connectionInfo.host == ".") // named pipe use option (Windows)
{
unsigned int opt = MYSQL_PROTOCOL_PIPE;
@@ -91,7 +103,7 @@ uint32 MySQLConnection::Open()
port = atoi(m_connectionInfo.port_or_socket.c_str());
unix_socket = 0;
}
#else
#else
if (m_connectionInfo.host == ".") // socket use option (Unix/Linux)
{
unsigned int opt = MYSQL_PROTOCOL_SOCKET;
@@ -103,19 +115,31 @@ uint32 MySQLConnection::Open()
else // generic case
{
port = atoi(m_connectionInfo.port_or_socket.c_str());
unix_socket = 0;
unix_socket = nullptr;
}
#endif
#endif
m_Mysql = mysql_real_connect(
mysqlInit,
m_connectionInfo.host.c_str(),
m_connectionInfo.user.c_str(),
m_connectionInfo.password.c_str(),
m_connectionInfo.database.c_str(),
port,
unix_socket,
0);
if (m_connectionInfo.ssl != "")
{
#if !defined(MARIADB_VERSION_ID) && MYSQL_VERSION_ID >= 80000
mysql_ssl_mode opt_use_ssl = SSL_MODE_DISABLED;
if (m_connectionInfo.ssl == "ssl")
{
opt_use_ssl = SSL_MODE_REQUIRED;
}
mysql_options(mysqlInit, MYSQL_OPT_SSL_MODE, (char const*)&opt_use_ssl);
#else
MySQLBool opt_use_ssl = MySQLBool(0);
if (m_connectionInfo.ssl == "ssl")
{
opt_use_ssl = MySQLBool(1);
}
mysql_options(mysqlInit, MYSQL_OPT_SSL_ENFORCE, (char const*)&opt_use_ssl);
#endif
}
m_Mysql = reinterpret_cast<MySQLHandle*>(mysql_real_connect(mysqlInit, m_connectionInfo.host.c_str(), m_connectionInfo.user.c_str(),
m_connectionInfo.password.c_str(), m_connectionInfo.database.c_str(), port, unix_socket, 0));
if (m_Mysql)
{
@@ -123,15 +147,12 @@ uint32 MySQLConnection::Open()
{
LOG_INFO("sql.sql", "MySQL client library: %s", mysql_get_client_info());
LOG_INFO("sql.sql", "MySQL server ver: %s ", mysql_get_server_info(m_Mysql));
if (mysql_get_server_version(m_Mysql) != mysql_get_client_version())
{
LOG_WARN("sql.sql", "[WARNING] MySQL client/server version mismatch; may conflict with behaviour of prepared statements.");
}
// MySQL version above 5.1 IS required in both client and server and there is no known issue with different versions above 5.1
// if (mysql_get_server_version(m_Mysql) != mysql_get_client_version())
// LOG_INFO("sql.sql", "[WARNING] MySQL client/server version mismatch; may conflict with behaviour of prepared statements.");
}
LOG_INFO("sql.sql", "Connected to MySQL database at %s", m_connectionInfo.host.c_str());
mysql_autocommit(m_Mysql, 1);
// set connection properties to UTF8 to properly handle locales for different
@@ -139,11 +160,13 @@ uint32 MySQLConnection::Open()
mysql_set_character_set(m_Mysql, "utf8");
return 0;
}
LOG_ERROR("sql.sql", "Could not connect to MySQL database at %s: %s", m_connectionInfo.host.c_str(), mysql_error(mysqlInit));
uint32 errorCode = mysql_errno(mysqlInit);
mysql_close(mysqlInit);
return errorCode;
else
{
LOG_ERROR("sql.sql", "Could not connect to MySQL database at %s: %s", m_connectionInfo.host.c_str(), mysql_error(mysqlInit));
uint32 errorCode = mysql_errno(mysqlInit);
mysql_close(mysqlInit);
return errorCode;
}
}
bool MySQLConnection::PrepareStatements()
@@ -152,154 +175,7 @@ bool MySQLConnection::PrepareStatements()
return !m_prepareError;
}
bool MySQLConnection::Execute(const char* sql)
{
if (!m_Mysql)
return false;
uint32 _s = getMSTime();
if (mysql_query(m_Mysql, sql))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL: %s", sql);
LOG_ERROR("sql.sql", "ERROR: [%u] %s", lErrno, mysql_error(m_Mysql));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return Execute(sql); // Try again
return false;
}
LOG_DEBUG("sql.sql", "[%u ms] SQL: %s", getMSTimeDiff(_s, getMSTime()), sql);
return true;
}
bool MySQLConnection::Execute(PreparedStatement* stmt)
{
if (!m_Mysql)
return false;
uint32 index = stmt->m_index;
{
MySQLPreparedStatement* m_mStmt = GetPreparedStatement(index);
ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query
m_mStmt->m_stmt = stmt; // Cross reference them for debug output
stmt->m_stmt = m_mStmt; // TODO: Cleaner way
stmt->BindParameters();
MYSQL_STMT* msql_STMT = m_mStmt->GetSTMT();
MYSQL_BIND* msql_BIND = m_mStmt->GetBind();
uint32 _s = getMSTime();
if (mysql_stmt_bind_param(msql_STMT, msql_BIND))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s", m_mStmt->getQueryString(m_queries[index].first).c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return Execute(stmt); // Try again
m_mStmt->ClearParameters();
return false;
}
if (mysql_stmt_execute(msql_STMT))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s", m_mStmt->getQueryString(m_queries[index].first).c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return Execute(stmt); // Try again
m_mStmt->ClearParameters();
return false;
}
LOG_DEBUG("sql.sql", "[%u ms] SQL(p): %s", getMSTimeDiff(_s, getMSTime()), m_mStmt->getQueryString(m_queries[index].first).c_str());
m_mStmt->ClearParameters();
return true;
}
}
bool MySQLConnection::_Query(PreparedStatement* stmt, MYSQL_RES** pResult, uint64* pRowCount, uint32* pFieldCount)
{
if (!m_Mysql)
return false;
uint32 index = stmt->m_index;
{
MySQLPreparedStatement* m_mStmt = GetPreparedStatement(index);
ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query
m_mStmt->m_stmt = stmt; // Cross reference them for debug output
stmt->m_stmt = m_mStmt; // TODO: Cleaner way
stmt->BindParameters();
MYSQL_STMT* msql_STMT = m_mStmt->GetSTMT();
MYSQL_BIND* msql_BIND = m_mStmt->GetBind();
uint32 _s = getMSTime();
if (mysql_stmt_bind_param(msql_STMT, msql_BIND))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s", m_mStmt->getQueryString(m_queries[index].first).c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(stmt, pResult, pRowCount, pFieldCount); // Try again
m_mStmt->ClearParameters();
return false;
}
if (mysql_stmt_execute(msql_STMT))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s",
m_mStmt->getQueryString(m_queries[index].first).c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(stmt, pResult, pRowCount, pFieldCount); // Try again
m_mStmt->ClearParameters();
return false;
}
LOG_DEBUG("sql.sql", "[%u ms] SQL(p): %s", getMSTimeDiff(_s, getMSTime()), m_mStmt->getQueryString(m_queries[index].first).c_str());
m_mStmt->ClearParameters();
*pResult = mysql_stmt_result_metadata(msql_STMT);
*pRowCount = mysql_stmt_num_rows(msql_STMT);
*pFieldCount = mysql_stmt_field_count(msql_STMT);
return true;
}
}
ResultSet* MySQLConnection::Query(const char* sql)
{
if (!sql)
return nullptr;
MYSQL_RES* result = nullptr;
MYSQL_FIELD* fields = nullptr;
uint64 rowCount = 0;
uint32 fieldCount = 0;
if (!_Query(sql, &result, &fields, &rowCount, &fieldCount))
return nullptr;
return new ResultSet(result, fields, rowCount, fieldCount);
}
bool MySQLConnection::_Query(const char* sql, MYSQL_RES** pResult, MYSQL_FIELD** pFields, uint64* pRowCount, uint32* pFieldCount)
bool MySQLConnection::Execute(char const* sql)
{
if (!m_Mysql)
return false;
@@ -310,23 +186,167 @@ bool MySQLConnection::_Query(const char* sql, MYSQL_RES** pResult, MYSQL_FIELD**
if (mysql_query(m_Mysql, sql))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL: %s", sql);
LOG_ERROR("sql.sql", "ERROR: [%u] %s", lErrno, mysql_error(m_Mysql));
LOG_INFO("sql.sql", "SQL: %s", sql);
LOG_ERROR("sql.sql", "[%u] %s", lErrno, mysql_error(m_Mysql));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return Execute(sql); // Try again
return false;
}
else
LOG_DEBUG("sql.sql", "[%u ms] SQL: %s", getMSTimeDiff(_s, getMSTime()), sql);
}
return true;
}
bool MySQLConnection::Execute(PreparedStatementBase* stmt)
{
if (!m_Mysql)
return false;
uint32 index = stmt->GetIndex();
MySQLPreparedStatement* m_mStmt = GetPreparedStatement(index);
ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query
m_mStmt->BindParameters(stmt);
MYSQL_STMT* msql_STMT = m_mStmt->GetSTMT();
MYSQL_BIND* msql_BIND = m_mStmt->GetBind();
uint32 _s = getMSTime();
if (mysql_stmt_bind_param(msql_STMT, msql_BIND))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s", m_mStmt->getQueryString().c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return Execute(stmt); // Try again
m_mStmt->ClearParameters();
return false;
}
if (mysql_stmt_execute(msql_STMT))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s", m_mStmt->getQueryString().c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return Execute(stmt); // Try again
m_mStmt->ClearParameters();
return false;
}
LOG_DEBUG("sql.sql", "[%u ms] SQL(p): %s", getMSTimeDiff(_s, getMSTime()), m_mStmt->getQueryString().c_str());
m_mStmt->ClearParameters();
return true;
}
bool MySQLConnection::_Query(PreparedStatementBase* stmt, MySQLPreparedStatement** mysqlStmt, MySQLResult** pResult, uint64* pRowCount, uint32* pFieldCount)
{
if (!m_Mysql)
return false;
uint32 index = stmt->GetIndex();
MySQLPreparedStatement* m_mStmt = GetPreparedStatement(index);
ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query
m_mStmt->BindParameters(stmt);
*mysqlStmt = m_mStmt;
MYSQL_STMT* msql_STMT = m_mStmt->GetSTMT();
MYSQL_BIND* msql_BIND = m_mStmt->GetBind();
uint32 _s = getMSTime();
if (mysql_stmt_bind_param(msql_STMT, msql_BIND))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s", m_mStmt->getQueryString().c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(stmt, mysqlStmt, pResult, pRowCount, pFieldCount); // Try again
m_mStmt->ClearParameters();
return false;
}
if (mysql_stmt_execute(msql_STMT))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): %s\n [ERROR]: [%u] %s",
m_mStmt->getQueryString().c_str(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(stmt, mysqlStmt, pResult, pRowCount, pFieldCount); // Try again
m_mStmt->ClearParameters();
return false;
}
LOG_DEBUG("sql.sql", "[%u ms] SQL(p): %s", getMSTimeDiff(_s, getMSTime()), m_mStmt->getQueryString().c_str());
m_mStmt->ClearParameters();
*pResult = reinterpret_cast<MySQLResult*>(mysql_stmt_result_metadata(msql_STMT));
*pRowCount = mysql_stmt_num_rows(msql_STMT);
*pFieldCount = mysql_stmt_field_count(msql_STMT);
return true;
}
ResultSet* MySQLConnection::Query(char const* sql)
{
if (!sql)
return nullptr;
MySQLResult* result = nullptr;
MySQLField* fields = nullptr;
uint64 rowCount = 0;
uint32 fieldCount = 0;
if (!_Query(sql, &result, &fields, &rowCount, &fieldCount))
return nullptr;
return new ResultSet(result, fields, rowCount, fieldCount);
}
bool MySQLConnection::_Query(const char* sql, MySQLResult** pResult, MySQLField** pFields, uint64* pRowCount, uint32* pFieldCount)
{
if (!m_Mysql)
return false;
{
uint32 _s = getMSTime();
if (mysql_query(m_Mysql, sql))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_INFO("sql.sql", "SQL: %s", sql);
LOG_ERROR("sql.sql", "[%u] %s", lErrno, mysql_error(m_Mysql));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(sql, pResult, pFields, pRowCount, pFieldCount); // We try again
return false;
}
else
LOG_DEBUG("sql.sql", "[%u ms] SQL: %s", getMSTimeDiff(_s, getMSTime()), sql);
LOG_DEBUG("sql.sql", "[%u ms] SQL: %s", getMSTimeDiff(_s, getMSTime()), sql);
*pResult = mysql_store_result(m_Mysql);
*pResult = reinterpret_cast<MySQLResult*>(mysql_store_result(m_Mysql));
*pRowCount = mysql_affected_rows(m_Mysql);
*pFieldCount = mysql_field_count(m_Mysql);
}
if (!*pResult)
if (!*pResult )
return false;
if (!*pRowCount)
@@ -335,7 +355,7 @@ bool MySQLConnection::_Query(const char* sql, MYSQL_RES** pResult, MYSQL_FIELD**
return false;
}
*pFields = mysql_fetch_fields(*pResult);
*pFields = reinterpret_cast<MySQLField*>(mysql_fetch_fields(*pResult));
return true;
}
@@ -355,27 +375,26 @@ void MySQLConnection::CommitTransaction()
Execute("COMMIT");
}
int MySQLConnection::ExecuteTransaction(SQLTransaction& transaction)
int MySQLConnection::ExecuteTransaction(std::shared_ptr<TransactionBase> transaction)
{
std::list<SQLElementData> const& queries = transaction->m_queries;
std::vector<SQLElementData> const& queries = transaction->m_queries;
if (queries.empty())
return -1;
BeginTransaction();
std::list<SQLElementData>::const_iterator itr;
for (itr = queries.begin(); itr != queries.end(); ++itr)
for (auto itr = queries.begin(); itr != queries.end(); ++itr)
{
SQLElementData const& data = *itr;
switch (itr->type)
{
case SQL_ELEMENT_PREPARED:
{
PreparedStatement* stmt = data.element.stmt;
PreparedStatementBase* stmt = data.element.stmt;
ASSERT(stmt);
if (!Execute(stmt))
{
LOG_INFO("sql.driver", "[Warning] Transaction aborted. %u queries not executed.", (uint32)queries.size());
LOG_WARN("sql.sql", "Transaction aborted. %u queries not executed.", (uint32)queries.size());
int errorCode = GetLastError();
RollbackTransaction();
return errorCode;
@@ -384,11 +403,11 @@ int MySQLConnection::ExecuteTransaction(SQLTransaction& transaction)
break;
case SQL_ELEMENT_RAW:
{
const char* sql = data.element.query;
char const* sql = data.element.query;
ASSERT(sql);
if (!Execute(sql))
{
LOG_INFO("sql.driver", "[Warning] Transaction aborted. %u queries not executed.", (uint32)queries.size());
LOG_WARN("sql.sql", "Transaction aborted. %u queries not executed.", (uint32)queries.size());
int errorCode = GetLastError();
RollbackTransaction();
return errorCode;
@@ -407,103 +426,157 @@ int MySQLConnection::ExecuteTransaction(SQLTransaction& transaction)
return 0;
}
size_t MySQLConnection::EscapeString(char* to, const char* from, size_t length)
{
return mysql_real_escape_string(m_Mysql, to, from, length);
}
void MySQLConnection::Ping()
{
mysql_ping(m_Mysql);
}
uint32 MySQLConnection::GetLastError()
{
return mysql_errno(m_Mysql);
}
bool MySQLConnection::LockIfReady()
{
return m_Mutex.try_lock();
}
void MySQLConnection::Unlock()
{
m_Mutex.unlock();
}
uint32 MySQLConnection::GetServerVersion() const
{
return mysql_get_server_version(m_Mysql);
}
MySQLPreparedStatement* MySQLConnection::GetPreparedStatement(uint32 index)
{
ASSERT(index < m_stmts.size());
MySQLPreparedStatement* ret = m_stmts[index];
ASSERT(index < m_stmts.size(), "Tried to access invalid prepared statement index %u (max index " SZFMTD ") on database `%s`, connection type: %s",
index, m_stmts.size(), m_connectionInfo.database.c_str(), (m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
MySQLPreparedStatement* ret = m_stmts[index].get();
if (!ret)
LOG_INFO("sql.driver", "ERROR: Could not fetch prepared statement %u on database `%s`, connection type: %s.",
index, m_connectionInfo.database.c_str(), (m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
LOG_ERROR("sql.sql", "Could not fetch prepared statement %u on database `%s`, connection type: %s.",
index, m_connectionInfo.database.c_str(), (m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
return ret;
}
void MySQLConnection::PrepareStatement(uint32 index, const char* sql, ConnectionFlags flags)
void MySQLConnection::PrepareStatement(uint32 index, std::string const& sql, ConnectionFlags flags)
{
m_queries.insert(PreparedStatementMap::value_type(index, std::make_pair(sql, flags)));
// For reconnection case
if (m_reconnecting)
delete m_stmts[index];
// Check if specified query should be prepared on this connection
// i.e. don't prepare async statements on synchronous connections
// to save memory that will not be used.
if (!(m_connectionFlags & flags))
{
m_stmts[index] = nullptr;
m_stmts[index].reset();
return;
}
MYSQL_STMT* stmt = mysql_stmt_init(m_Mysql);
if (!stmt)
{
LOG_INFO("sql.driver", "[ERROR]: In mysql_stmt_init() id: %u, sql: \"%s\"", index, sql);
LOG_INFO("sql.driver", "[ERROR]: %s", mysql_error(m_Mysql));
LOG_ERROR("sql.sql", "In mysql_stmt_init() id: %u, sql: \"%s\"", index, sql.c_str());
LOG_ERROR("sql.sql", "%s", mysql_error(m_Mysql));
m_prepareError = true;
}
else
{
if (mysql_stmt_prepare(stmt, sql, static_cast<unsigned long>(strlen(sql))))
if (mysql_stmt_prepare(stmt, sql.c_str(), static_cast<unsigned long>(sql.size())))
{
LOG_INFO("sql.driver", "[ERROR]: In mysql_stmt_prepare() id: %u, sql: \"%s\"", index, sql);
LOG_INFO("sql.driver", "[ERROR]: %s", mysql_stmt_error(stmt));
LOG_ERROR("sql.sql", "In mysql_stmt_prepare() id: %u, sql: \"%s\"", index, sql.c_str());
LOG_ERROR("sql.sql", "%s", mysql_stmt_error(stmt));
mysql_stmt_close(stmt);
m_prepareError = true;
}
else
{
MySQLPreparedStatement* mStmt = new MySQLPreparedStatement(stmt);
m_stmts[index] = mStmt;
}
m_stmts[index] = std::make_unique<MySQLPreparedStatement>(reinterpret_cast<MySQLStmt*>(stmt), sql);
}
}
PreparedResultSet* MySQLConnection::Query(PreparedStatement* stmt)
PreparedResultSet* MySQLConnection::Query(PreparedStatementBase* stmt)
{
MYSQL_RES* result = nullptr;
MySQLPreparedStatement* mysqlStmt = nullptr;
MySQLResult* result = nullptr;
uint64 rowCount = 0;
uint32 fieldCount = 0;
if (!_Query(stmt, &result, &rowCount, &fieldCount))
if (!_Query(stmt, &mysqlStmt, &result, &rowCount, &fieldCount))
return nullptr;
if (mysql_more_results(m_Mysql))
{
mysql_next_result(m_Mysql);
}
return new PreparedResultSet(stmt->m_stmt->GetSTMT(), result, rowCount, fieldCount);
return new PreparedResultSet(mysqlStmt->GetSTMT(), result, rowCount, fieldCount);
}
bool MySQLConnection::_HandleMySQLErrno(uint32 errNo)
bool MySQLConnection::_HandleMySQLErrno(uint32 errNo, uint8 attempts /*= 5*/)
{
switch (errNo)
{
case CR_SERVER_GONE_ERROR:
case CR_SERVER_LOST:
case CR_SERVER_LOST_EXTENDED:
#if !(MARIADB_VERSION_ID >= 100200)
case CR_INVALID_CONN_HANDLE:
#endif
{
m_reconnecting = true;
uint64 oldThreadId = mysql_thread_id(GetHandle());
mysql_close(GetHandle());
if (this->Open()) // Don't remove 'this' pointer unless you want to skip loading all prepared statements....
if (m_Mysql)
{
LOG_INFO("sql.driver", "Connection to the MySQL server is active.");
if (oldThreadId != mysql_thread_id(GetHandle()))
LOG_INFO("sql.driver", "Successfully reconnected to %s @%s:%s (%s).",
m_connectionInfo.database.c_str(), m_connectionInfo.host.c_str(), m_connectionInfo.port_or_socket.c_str(),
(m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
LOG_ERROR("sql.sql", "Lost the connection to the MySQL server!");
mysql_close(m_Mysql);
m_Mysql = nullptr;
}
[[fallthrough]];
}
case CR_CONN_HOST_ERROR:
{
LOG_INFO("sql.sql", "Attempting to reconnect to the MySQL server...");
m_reconnecting = true;
uint32 const lErrno = Open();
if (!lErrno)
{
// Don't remove 'this' pointer unless you want to skip loading all prepared statements...
if (!this->PrepareStatements())
{
LOG_FATAL("sql.sql", "Could not re-prepare statements!");
std::this_thread::sleep_for(std::chrono::seconds(10));
std::abort();
}
LOG_INFO("sql.sql", "Successfully reconnected to %s @%s:%s (%s).",
m_connectionInfo.database.c_str(), m_connectionInfo.host.c_str(), m_connectionInfo.port_or_socket.c_str(),
(m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
m_reconnecting = false;
return true;
}
uint32 lErrno = mysql_errno(GetHandle()); // It's possible this attempted reconnect throws 2006 at us. To prevent crazy recursive calls, sleep here.
std::this_thread::sleep_for(3s); // Sleep 3 seconds
return _HandleMySQLErrno(lErrno); // Call self (recursive)
if ((--attempts) == 0)
{
// Shut down the server when the mysql server isn't
// reachable for some time
LOG_FATAL("sql.sql", "Failed to reconnect to the MySQL server, "
"terminating the server to prevent data corruption!");
// We could also initiate a shutdown through using std::raise(SIGTERM)
std::this_thread::sleep_for(std::chrono::seconds(10));
std::abort();
}
else
{
// It's possible this attempted reconnect throws 2006 at us.
// To prevent crazy recursive calls, sleep here.
std::this_thread::sleep_for(std::chrono::seconds(3)); // Sleep 3 seconds
return _HandleMySQLErrno(lErrno, attempts); // Call self (recursive)
}
}
case ER_LOCK_DEADLOCK:
@@ -516,17 +589,17 @@ bool MySQLConnection::_HandleMySQLErrno(uint32 errNo)
// Outdated table or database structure - terminate core
case ER_BAD_FIELD_ERROR:
case ER_NO_SUCH_TABLE:
LOG_ERROR("server", "Your database structure is not up to date. Please make sure you've executed all queries in the sql/updates folders.");
std::this_thread::sleep_for(10s);
LOG_ERROR("sql.sql", "Your database structure is not up to date. Please make sure you've executed all queries in the sql/updates folders.");
std::this_thread::sleep_for(std::chrono::seconds(10));
std::abort();
return false;
case ER_PARSE_ERROR:
LOG_ERROR("server", "Error while parsing SQL. Core fix required.");
std::this_thread::sleep_for(10s);
LOG_ERROR("sql.sql", "Error while parsing SQL. Core fix required.");
std::this_thread::sleep_for(std::chrono::seconds(10));
std::abort();
return false;
default:
LOG_ERROR("server", "Unhandled MySQL errno %u. Unexpected behaviour possible.", errNo);
LOG_ERROR("sql.sql", "Unhandled MySQL errno %u. Unexpected behaviour possible.", errNo);
return false;
}
}

View File

@@ -1,22 +1,25 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include <ace/Activation_Queue.h>
#include "DatabaseWorkerPool.h"
#include "Transaction.h"
#include "Util.h"
#ifndef _MYSQLCONNECTION_H
#define _MYSQLCONNECTION_H
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
template <typename T>
class ProducerConsumerQueue;
class DatabaseWorker;
class PreparedStatement;
class MySQLPreparedStatement;
class PingOperation;
class SQLOperation;
enum ConnectionFlags
{
@@ -25,99 +28,75 @@ enum ConnectionFlags
CONNECTION_BOTH = CONNECTION_ASYNC | CONNECTION_SYNCH
};
struct MySQLConnectionInfo
struct AC_DATABASE_API MySQLConnectionInfo
{
MySQLConnectionInfo() = default;
MySQLConnectionInfo(const std::string& infoString)
{
Tokenizer tokens(infoString, ';');
if (tokens.size() != 5)
return;
uint8 i = 0;
host.assign(tokens[i++]);
port_or_socket.assign(tokens[i++]);
user.assign(tokens[i++]);
password.assign(tokens[i++]);
database.assign(tokens[i++]);
}
explicit MySQLConnectionInfo(std::string const& infoString);
std::string user;
std::string password;
std::string database;
std::string host;
std::string port_or_socket;
std::string ssl;
};
typedef std::map<uint32 /*index*/, std::pair<std::string /*query*/, ConnectionFlags /*sync/async*/>> PreparedStatementMap;
class MySQLConnection
class AC_DATABASE_API MySQLConnection
{
template <class T> friend class DatabaseWorkerPool;
friend class PingOperation;
template <class T> friend class DatabaseWorkerPool;
friend class PingOperation;
public:
MySQLConnection(MySQLConnectionInfo& connInfo); //! Constructor for synchronous connections.
MySQLConnection(ACE_Activation_Queue* queue, MySQLConnectionInfo& connInfo); //! Constructor for asynchronous connections.
MySQLConnection(ProducerConsumerQueue<SQLOperation*>* queue, MySQLConnectionInfo& connInfo); //! Constructor for asynchronous connections.
virtual ~MySQLConnection();
virtual uint32 Open();
void Close();
bool PrepareStatements();
public:
bool Execute(const char* sql);
bool Execute(PreparedStatement* stmt);
ResultSet* Query(const char* sql);
PreparedResultSet* Query(PreparedStatement* stmt);
bool _Query(const char* sql, MYSQL_RES** pResult, MYSQL_FIELD** pFields, uint64* pRowCount, uint32* pFieldCount);
bool _Query(PreparedStatement* stmt, MYSQL_RES** pResult, uint64* pRowCount, uint32* pFieldCount);
bool Execute(char const* sql);
bool Execute(PreparedStatementBase* stmt);
ResultSet* Query(char const* sql);
PreparedResultSet* Query(PreparedStatementBase* stmt);
bool _Query(char const* sql, MySQLResult** pResult, MySQLField** pFields, uint64* pRowCount, uint32* pFieldCount);
bool _Query(PreparedStatementBase* stmt, MySQLPreparedStatement** mysqlStmt, MySQLResult** pResult, uint64* pRowCount, uint32* pFieldCount);
void BeginTransaction();
void RollbackTransaction();
void CommitTransaction();
int ExecuteTransaction(SQLTransaction& transaction);
int ExecuteTransaction(std::shared_ptr<TransactionBase> transaction);
size_t EscapeString(char* to, const char* from, size_t length);
void Ping();
operator bool () const { return m_Mysql != nullptr; }
void Ping() { mysql_ping(m_Mysql); }
uint32 GetLastError() { return mysql_errno(m_Mysql); }
uint32 GetLastError();
protected:
bool LockIfReady()
{
/// Tries to acquire lock. If lock is acquired by another thread
/// the calling parent will just try another connection
return m_Mutex.try_lock();
}
/// Tries to acquire lock. If lock is acquired by another thread
/// the calling parent will just try another connection
bool LockIfReady();
void Unlock()
{
/// Called by parent databasepool. Will let other threads access this connection
m_Mutex.unlock();
}
/// Called by parent databasepool. Will let other threads access this connection
void Unlock();
MYSQL* GetHandle() { return m_Mysql; }
uint32 GetServerVersion() const;
MySQLPreparedStatement* GetPreparedStatement(uint32 index);
void PrepareStatement(uint32 index, const char* sql, ConnectionFlags flags);
void PrepareStatement(uint32 index, std::string const& sql, ConnectionFlags flags);
virtual void DoPrepareStatements() = 0;
protected:
std::vector<MySQLPreparedStatement*> m_stmts; //! PreparedStatements storage
PreparedStatementMap m_queries; //! Query storage
typedef std::vector<std::unique_ptr<MySQLPreparedStatement>> PreparedStatementContainer;
PreparedStatementContainer m_stmts; //! PreparedStatements storage
bool m_reconnecting; //! Are we reconnecting?
bool m_prepareError; //! Was there any error while preparing statements?
private:
bool _HandleMySQLErrno(uint32 errNo);
bool _HandleMySQLErrno(uint32 errNo, uint8 attempts = 5);
private:
ACE_Activation_Queue* m_queue; //! Queue shared with other asynchronous connections.
DatabaseWorker* m_worker; //! Core worker task.
MYSQL* m_Mysql; //! MySQL Handle.
ProducerConsumerQueue<SQLOperation*>* m_queue; //! Queue shared with other asynchronous connections.
std::unique_ptr<DatabaseWorker> m_worker; //! Core worker task.
MySQLHandle* m_Mysql; //! MySQL Handle.
MySQLConnectionInfo& m_connectionInfo; //! Connection info (used for logging)
ConnectionFlags m_connectionFlags; //! Connection flags (for preparing relevant statements)
std::mutex m_Mutex;

View File

@@ -0,0 +1,22 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef MySQLHacks_h__
#define MySQLHacks_h__
#include "MySQLWorkaround.h"
#include <type_traits>
struct MySQLHandle : MYSQL { };
struct MySQLResult : MYSQL_RES { };
struct MySQLField : MYSQL_FIELD { };
struct MySQLBind : MYSQL_BIND { };
struct MySQLStmt : MYSQL_STMT { };
// mysql 8 removed my_bool typedef (it was char) and started using bools directly
// to maintain compatibility we use this trick to retrieve which type is being used
using MySQLBool = std::remove_pointer_t<decltype(std::declval<MYSQL_BIND>().is_null)>;
#endif // MySQLHacks_h__

View File

@@ -0,0 +1,188 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "MySQLPreparedStatement.h"
#include "Errors.h"
#include "Log.h"
#include "MySQLHacks.h"
#include "PreparedStatement.h"
template<typename T>
struct MySQLType { };
template<> struct MySQLType<uint8> : std::integral_constant<enum_field_types, MYSQL_TYPE_TINY> { };
template<> struct MySQLType<uint16> : std::integral_constant<enum_field_types, MYSQL_TYPE_SHORT> { };
template<> struct MySQLType<uint32> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONG> { };
template<> struct MySQLType<uint64> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONGLONG> { };
template<> struct MySQLType<int8> : std::integral_constant<enum_field_types, MYSQL_TYPE_TINY> { };
template<> struct MySQLType<int16> : std::integral_constant<enum_field_types, MYSQL_TYPE_SHORT> { };
template<> struct MySQLType<int32> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONG> { };
template<> struct MySQLType<int64> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONGLONG> { };
template<> struct MySQLType<float> : std::integral_constant<enum_field_types, MYSQL_TYPE_FLOAT> { };
template<> struct MySQLType<double> : std::integral_constant<enum_field_types, MYSQL_TYPE_DOUBLE> { };
MySQLPreparedStatement::MySQLPreparedStatement(MySQLStmt* stmt, std::string queryString) :
m_stmt(nullptr), m_Mstmt(stmt), m_bind(nullptr), m_queryString(std::move(queryString))
{
/// Initialize variable parameters
m_paramCount = mysql_stmt_param_count(stmt);
m_paramsSet.assign(m_paramCount, false);
m_bind = new MySQLBind[m_paramCount];
memset(m_bind, 0, sizeof(MySQLBind) * m_paramCount);
/// "If set to 1, causes mysql_stmt_store_result() to update the metadata MYSQL_FIELD->max_length value."
MySQLBool bool_tmp = MySQLBool(1);
mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, &bool_tmp);
}
MySQLPreparedStatement::~MySQLPreparedStatement()
{
ClearParameters();
if (m_Mstmt->bind_result_done)
{
delete[] m_Mstmt->bind->length;
delete[] m_Mstmt->bind->is_null;
}
mysql_stmt_close(m_Mstmt);
delete[] m_bind;
}
void MySQLPreparedStatement::BindParameters(PreparedStatementBase* stmt)
{
m_stmt = stmt; // Cross reference them for debug output
uint8 pos = 0;
for (PreparedStatementData const& data : stmt->GetParameters())
{
std::visit([&](auto&& param)
{
SetParameter(pos, param);
}, data.data);
++pos;
}
#ifdef _DEBUG
if (pos < m_paramCount)
LOG_WARN("sql.sql", "[WARNING]: BindParameters() for statement %u did not bind all allocated parameters", stmt->GetIndex());
#endif
}
void MySQLPreparedStatement::ClearParameters()
{
for (uint32 i=0; i < m_paramCount; ++i)
{
delete m_bind[i].length;
m_bind[i].length = nullptr;
delete[] (char*) m_bind[i].buffer;
m_bind[i].buffer = nullptr;
m_paramsSet[i] = false;
}
}
static bool ParamenterIndexAssertFail(uint32 stmtIndex, uint8 index, uint32 paramCount)
{
LOG_ERROR("sql.driver", "Attempted to bind parameter %u%s on a PreparedStatement %u (statement has only %u parameters)", uint32(index) + 1, (index == 1 ? "st" : (index == 2 ? "nd" : (index == 3 ? "rd" : "nd"))), stmtIndex, paramCount);
return false;
}
//- Bind on mysql level
void MySQLPreparedStatement::AssertValidIndex(uint8 index)
{
ASSERT(index < m_paramCount || ParamenterIndexAssertFail(m_stmt->GetIndex(), index, m_paramCount));
if (m_paramsSet[index])
LOG_ERROR("sql.sql", "[ERROR] Prepared Statement (id: %u) trying to bind value on already bound index (%u).", m_stmt->GetIndex(), index);
}
void MySQLPreparedStatement::SetParameter(uint8 index, std::nullptr_t)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
param->buffer_type = MYSQL_TYPE_NULL;
delete[] static_cast<char*>(param->buffer);
param->buffer = nullptr;
param->buffer_length = 0;
param->is_null_value = 1;
delete param->length;
param->length = nullptr;
}
void MySQLPreparedStatement::SetParameter(uint8 index, bool value)
{
SetParameter(index, uint8(value ? 1 : 0));
}
template<typename T>
void MySQLPreparedStatement::SetParameter(uint8 index, T value)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(sizeof(T));
param->buffer_type = MySQLType<T>::value;
delete[] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = 0;
param->is_null_value = 0;
param->length = nullptr; // Only != NULL for strings
param->is_unsigned = std::is_unsigned_v<T>;
memcpy(param->buffer, &value, len);
}
void MySQLPreparedStatement::SetParameter(uint8 index, std::string const& value)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(value.size());
param->buffer_type = MYSQL_TYPE_VAR_STRING;
delete [] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = len;
param->is_null_value = 0;
delete param->length;
param->length = new unsigned long(len);
memcpy(param->buffer, value.c_str(), len);
}
void MySQLPreparedStatement::SetParameter(uint8 index, std::vector<uint8> const& value)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(value.size());
param->buffer_type = MYSQL_TYPE_BLOB;
delete [] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = len;
param->is_null_value = 0;
delete param->length;
param->length = new unsigned long(len);
memcpy(param->buffer, value.data(), len);
}
std::string MySQLPreparedStatement::getQueryString() const
{
std::string queryString(m_queryString);
size_t pos = 0;
for (PreparedStatementData const& data : m_stmt->GetParameters())
{
pos = queryString.find('?', pos);
std::string replaceStr = std::visit([&](auto&& data)
{
return PreparedStatementData::ToString(data);
}, data.data);
queryString.replace(pos, 1, replaceStr);
pos += replaceStr.length();
}
return queryString;
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef MySQLPreparedStatement_h__
#define MySQLPreparedStatement_h__
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include "MySQLWorkaround.h"
#include <string>
#include <vector>
class MySQLConnection;
class PreparedStatementBase;
//- Class of which the instances are unique per MySQLConnection
//- access to these class objects is only done when a prepared statement task
//- is executed.
class AC_DATABASE_API MySQLPreparedStatement
{
friend class MySQLConnection;
friend class PreparedStatementBase;
public:
MySQLPreparedStatement(MySQLStmt* stmt, std::string queryString);
~MySQLPreparedStatement();
void BindParameters(PreparedStatementBase* stmt);
uint32 GetParameterCount() const { return m_paramCount; }
protected:
void SetParameter(uint8 index, std::nullptr_t);
void SetParameter(uint8 index, bool value);
template<typename T>
void SetParameter(uint8 index, T value);
void SetParameter(uint8 index, std::string const& value);
void SetParameter(uint8 index, std::vector<uint8> const& value);
MySQLStmt* GetSTMT() { return m_Mstmt; }
MySQLBind* GetBind() { return m_bind; }
PreparedStatementBase* m_stmt;
void ClearParameters();
void AssertValidIndex(uint8 index);
std::string getQueryString() const;
private:
MySQLStmt* m_Mstmt;
uint32 m_paramCount;
std::vector<bool> m_paramsSet;
MySQLBind* m_bind;
std::string const m_queryString;
MySQLPreparedStatement(MySQLPreparedStatement const& right) = delete;
MySQLPreparedStatement& operator=(MySQLPreparedStatement const& right) = delete;
};
#endif // MySQLPreparedStatement_h__

View File

@@ -1,11 +1,10 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2021 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "MySQLThreading.h"
#include "Log.h"
#include <mysql.h>
#include "MySQLWorkaround.h"
void MySQL::Library_Init()
{

View File

@@ -1,19 +1,18 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _MYSQLTHREADING_H
#define _MYSQLTHREADING_H
#include "Log.h"
#include "Define.h"
namespace MySQL
{
void Library_Init();
void Library_End();
uint32 GetLibraryVersion();
AC_DATABASE_API void Library_Init();
AC_DATABASE_API void Library_End();
AC_DATABASE_API uint32 GetLibraryVersion();
}
#endif

View File

@@ -0,0 +1,9 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifdef _WIN32 // hack for broken mysql.h not including the correct winsock header for SOCKET definition, fixed in 5.7
#include <winsock2.h>
#endif
#include <mysql.h>

View File

@@ -1,481 +1,126 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "PreparedStatement.h"
#include "MySQLConnection.h"
#include "Errors.h"
#include "Log.h"
#include <sstream>
#include "MySQLConnection.h"
#include "MySQLPreparedStatement.h"
#include "MySQLWorkaround.h"
#include "QueryResult.h"
PreparedStatement::PreparedStatement(uint32 index) :
m_stmt(nullptr),
m_index(index)
{
}
PreparedStatementBase::PreparedStatementBase(uint32 index, uint8 capacity) :
m_index(index), statement_data(capacity) { }
PreparedStatement::~PreparedStatement()
{
}
void PreparedStatement::BindParameters()
{
ASSERT (m_stmt);
uint8 i = 0;
for (; i < statement_data.size(); i++)
{
switch (statement_data[i].type)
{
case TYPE_BOOL:
m_stmt->setBool(i, statement_data[i].data.boolean);
break;
case TYPE_UI8:
m_stmt->setUInt8(i, statement_data[i].data.ui8);
break;
case TYPE_UI16:
m_stmt->setUInt16(i, statement_data[i].data.ui16);
break;
case TYPE_UI32:
m_stmt->setUInt32(i, statement_data[i].data.ui32);
break;
case TYPE_I8:
m_stmt->setInt8(i, statement_data[i].data.i8);
break;
case TYPE_I16:
m_stmt->setInt16(i, statement_data[i].data.i16);
break;
case TYPE_I32:
m_stmt->setInt32(i, statement_data[i].data.i32);
break;
case TYPE_UI64:
m_stmt->setUInt64(i, statement_data[i].data.ui64);
break;
case TYPE_I64:
m_stmt->setInt64(i, statement_data[i].data.i64);
break;
case TYPE_FLOAT:
m_stmt->setFloat(i, statement_data[i].data.f);
break;
case TYPE_DOUBLE:
m_stmt->setDouble(i, statement_data[i].data.d);
break;
case TYPE_STRING:
m_stmt->setBinary(i, statement_data[i].binary, true);
break;
case TYPE_BINARY:
m_stmt->setBinary(i, statement_data[i].binary, false);
break;
case TYPE_NULL:
m_stmt->setNull(i);
break;
}
}
#ifdef _DEBUG
if (i < m_stmt->m_paramCount)
LOG_INFO("sql.driver", "[WARNING]: BindParameters() for statement %u did not bind all allocated parameters", m_index);
#endif
}
PreparedStatementBase::~PreparedStatementBase() { }
//- Bind to buffer
void PreparedStatement::setBool(const uint8 index, const bool value)
void PreparedStatementBase::setBool(const uint8 index, const bool value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.boolean = value;
statement_data[index].type = TYPE_BOOL;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setUInt8(const uint8 index, const uint8 value)
void PreparedStatementBase::setUInt8(const uint8 index, const uint8 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.ui8 = value;
statement_data[index].type = TYPE_UI8;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setUInt16(const uint8 index, const uint16 value)
void PreparedStatementBase::setUInt16(const uint8 index, const uint16 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.ui16 = value;
statement_data[index].type = TYPE_UI16;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setUInt32(const uint8 index, const uint32 value)
void PreparedStatementBase::setUInt32(const uint8 index, const uint32 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.ui32 = value;
statement_data[index].type = TYPE_UI32;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setUInt64(const uint8 index, const uint64 value)
void PreparedStatementBase::setUInt64(const uint8 index, const uint64 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.ui64 = value;
statement_data[index].type = TYPE_UI64;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setInt8(const uint8 index, const int8 value)
void PreparedStatementBase::setInt8(const uint8 index, const int8 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.i8 = value;
statement_data[index].type = TYPE_I8;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setInt16(const uint8 index, const int16 value)
void PreparedStatementBase::setInt16(const uint8 index, const int16 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.i16 = value;
statement_data[index].type = TYPE_I16;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setInt32(const uint8 index, const int32 value)
void PreparedStatementBase::setInt32(const uint8 index, const int32 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.i32 = value;
statement_data[index].type = TYPE_I32;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setInt64(const uint8 index, const int64 value)
void PreparedStatementBase::setInt64(const uint8 index, const int64 value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.i64 = value;
statement_data[index].type = TYPE_I64;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setFloat(const uint8 index, const float value)
void PreparedStatementBase::setFloat(const uint8 index, const float value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.f = value;
statement_data[index].type = TYPE_FLOAT;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setDouble(const uint8 index, const double value)
void PreparedStatementBase::setDouble(const uint8 index, const double value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].data.d = value;
statement_data[index].type = TYPE_DOUBLE;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setString(const uint8 index, const std::string& value)
void PreparedStatementBase::setString(const uint8 index, const std::string& value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].binary.resize(value.length() + 1);
memcpy(statement_data[index].binary.data(), value.c_str(), value.length() + 1);
statement_data[index].type = TYPE_STRING;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatement::setBinary(const uint8 index, const std::vector<uint8>& value)
void PreparedStatementBase::setStringView(const uint8 index, const std::string_view value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].binary = value;
statement_data[index].type = TYPE_BINARY;
ASSERT(index < statement_data.size());
statement_data[index].data.emplace<std::string>(value);
}
void PreparedStatement::setNull(const uint8 index)
void PreparedStatementBase::setBinary(const uint8 index, const std::vector<uint8>& value)
{
if (index >= statement_data.size())
statement_data.resize(index + 1);
statement_data[index].type = TYPE_NULL;
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
MySQLPreparedStatement::MySQLPreparedStatement(MYSQL_STMT* stmt) :
m_stmt(nullptr),
m_Mstmt(stmt),
m_bind(nullptr)
void PreparedStatementBase::setNull(const uint8 index)
{
/// Initialize variable parameters
m_paramCount = mysql_stmt_param_count(stmt);
m_paramsSet.assign(m_paramCount, false);
m_bind = new MYSQL_BIND[m_paramCount];
memset(m_bind, 0, sizeof(MYSQL_BIND)*m_paramCount);
/// "If set to 1, causes mysql_stmt_store_result() to update the metadata MYSQL_FIELD->max_length value."
my_bool bool_tmp = 1;
mysql_stmt_attr_set(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, &bool_tmp);
}
MySQLPreparedStatement::~MySQLPreparedStatement()
{
ClearParameters();
if (m_Mstmt->bind_result_done)
{
delete[] m_Mstmt->bind->length;
delete[] m_Mstmt->bind->is_null;
}
mysql_stmt_close(m_Mstmt);
delete[] m_bind;
}
void MySQLPreparedStatement::ClearParameters()
{
for (uint32 i = 0; i < m_paramCount; ++i)
{
delete m_bind[i].length;
m_bind[i].length = nullptr;
delete[] (char*) m_bind[i].buffer;
m_bind[i].buffer = nullptr;
m_paramsSet[i] = false;
}
}
static bool ParamenterIndexAssertFail(uint32 stmtIndex, uint8 index, uint32 paramCount)
{
LOG_ERROR("server", "Attempted to bind parameter %u%s on a PreparedStatement %u (statement has only %u parameters)", uint32(index) + 1, (index == 1 ? "st" : (index == 2 ? "nd" : (index == 3 ? "rd" : "nd"))), stmtIndex, paramCount);
return false;
}
//- Bind on mysql level
bool MySQLPreparedStatement::CheckValidIndex(uint8 index)
{
ASSERT(index < m_paramCount || ParamenterIndexAssertFail(m_stmt->m_index, index, m_paramCount));
if (m_paramsSet[index])
LOG_INFO("sql.driver", "[WARNING] Prepared Statement (id: %u) trying to bind value on already bound index (%u).", m_stmt->m_index, index);
return true;
}
void MySQLPreparedStatement::setBool(const uint8 index, const bool value)
{
setUInt8(index, value ? 1 : 0);
}
void MySQLPreparedStatement::setUInt8(const uint8 index, const uint8 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_TINY, &value, sizeof(uint8), true);
}
void MySQLPreparedStatement::setUInt16(const uint8 index, const uint16 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_SHORT, &value, sizeof(uint16), true);
}
void MySQLPreparedStatement::setUInt32(const uint8 index, const uint32 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_LONG, &value, sizeof(uint32), true);
}
void MySQLPreparedStatement::setUInt64(const uint8 index, const uint64 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_LONGLONG, &value, sizeof(uint64), true);
}
void MySQLPreparedStatement::setInt8(const uint8 index, const int8 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_TINY, &value, sizeof(int8), false);
}
void MySQLPreparedStatement::setInt16(const uint8 index, const int16 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_SHORT, &value, sizeof(int16), false);
}
void MySQLPreparedStatement::setInt32(const uint8 index, const int32 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_LONG, &value, sizeof(int32), false);
}
void MySQLPreparedStatement::setInt64(const uint8 index, const int64 value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_LONGLONG, &value, sizeof(int64), false);
}
void MySQLPreparedStatement::setFloat(const uint8 index, const float value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_FLOAT, &value, sizeof(float), (value > 0.0f));
}
void MySQLPreparedStatement::setDouble(const uint8 index, const double value)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
setValue(param, MYSQL_TYPE_DOUBLE, &value, sizeof(double), (value > 0.0f));
}
void MySQLPreparedStatement::setBinary(const uint8 index, const std::vector<uint8>& value, bool isString)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(value.size());
param->buffer_type = MYSQL_TYPE_BLOB;
delete [] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = len;
param->is_null_value = 0;
delete param->length;
param->length = new unsigned long(len);
if (isString)
{
*param->length -= 1;
param->buffer_type = MYSQL_TYPE_VAR_STRING;
}
memcpy(param->buffer, value.data(), len);
}
void MySQLPreparedStatement::setNull(const uint8 index)
{
CheckValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
param->buffer_type = MYSQL_TYPE_NULL;
delete [] static_cast<char*>(param->buffer);
param->buffer = nullptr;
param->buffer_length = 0;
param->is_null_value = 1;
delete param->length;
param->length = nullptr;
}
void MySQLPreparedStatement::setValue(MYSQL_BIND* param, enum_field_types type, const void* value, uint32 len, bool isUnsigned)
{
param->buffer_type = type;
delete [] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = 0;
param->is_null_value = 0;
param->length = nullptr; // Only != nullptr for strings
param->is_unsigned = isUnsigned;
memcpy(param->buffer, value, len);
}
std::string MySQLPreparedStatement::getQueryString(std::string const& sqlPattern) const
{
std::string queryString = sqlPattern;
size_t pos = 0;
for (uint32 i = 0; i < m_stmt->statement_data.size(); i++)
{
pos = queryString.find('?', pos);
std::stringstream ss;
switch (m_stmt->statement_data[i].type)
{
case TYPE_BOOL:
ss << uint16(m_stmt->statement_data[i].data.boolean);
break;
case TYPE_UI8:
ss << uint16(m_stmt->statement_data[i].data.ui8); // stringstream will append a character with that code instead of numeric representation
break;
case TYPE_UI16:
ss << m_stmt->statement_data[i].data.ui16;
break;
case TYPE_UI32:
ss << m_stmt->statement_data[i].data.ui32;
break;
case TYPE_I8:
ss << int16(m_stmt->statement_data[i].data.i8); // stringstream will append a character with that code instead of numeric representation
break;
case TYPE_I16:
ss << m_stmt->statement_data[i].data.i16;
break;
case TYPE_I32:
ss << m_stmt->statement_data[i].data.i32;
break;
case TYPE_UI64:
ss << m_stmt->statement_data[i].data.ui64;
break;
case TYPE_I64:
ss << m_stmt->statement_data[i].data.i64;
break;
case TYPE_FLOAT:
ss << m_stmt->statement_data[i].data.f;
break;
case TYPE_DOUBLE:
ss << m_stmt->statement_data[i].data.d;
break;
case TYPE_STRING:
ss << '\'' << (char const*)m_stmt->statement_data[i].binary.data() << '\'';
break;
case TYPE_BINARY:
ss << "BINARY";
break;
case TYPE_NULL:
ss << "nullptr";
break;
}
std::string replaceStr = ss.str();
queryString.replace(pos, 1, replaceStr);
pos += replaceStr.length();
}
return queryString;
ASSERT(index < statement_data.size());
statement_data[index].data = nullptr;
}
//- Execution
PreparedStatementTask::PreparedStatementTask(PreparedStatement* stmt) :
m_stmt(stmt),
m_has_result(false)
{
}
PreparedStatementTask::PreparedStatementTask(PreparedStatement* stmt, PreparedQueryResultFuture result) :
m_stmt(stmt),
m_has_result(true),
m_result(result)
PreparedStatementTask::PreparedStatementTask(PreparedStatementBase* stmt, bool async) :
m_stmt(stmt), m_result(nullptr)
{
m_has_result = async; // If it's async, then there's a result
if (async)
m_result = new PreparedQueryResultPromise();
}
PreparedStatementTask::~PreparedStatementTask()
{
delete m_stmt;
if (m_has_result && m_result != nullptr)
delete m_result;
}
bool PreparedStatementTask::Execute()
@@ -486,12 +131,58 @@ bool PreparedStatementTask::Execute()
if (!result || !result->GetRowCount())
{
delete result;
m_result.set(PreparedQueryResult(nullptr));
m_result->set_value(PreparedQueryResult(nullptr));
return false;
}
m_result.set(PreparedQueryResult(result));
m_result->set_value(PreparedQueryResult(result));
return true;
}
return m_conn->Execute(m_stmt);
}
template<typename T>
std::string PreparedStatementData::ToString(T value)
{
return fmt::format("{}", value);
}
std::string PreparedStatementData::ToString(bool value)
{
return ToString<uint32>(value);
}
std::string PreparedStatementData::ToString(uint8 value)
{
return ToString<uint32>(value);
}
template std::string PreparedStatementData::ToString<uint16>(uint16);
template std::string PreparedStatementData::ToString<uint32>(uint32);
template std::string PreparedStatementData::ToString<uint64>(uint64);
std::string PreparedStatementData::ToString(int8 value)
{
return ToString<int32>(value);
}
template std::string PreparedStatementData::ToString<int16>(int16);
template std::string PreparedStatementData::ToString<int32>(int32);
template std::string PreparedStatementData::ToString<int64>(int64);
template std::string PreparedStatementData::ToString<float>(float);
template std::string PreparedStatementData::ToString<double>(double);
std::string PreparedStatementData::ToString(std::string const& value)
{
return fmt::format("'{}'", value);
}
std::string PreparedStatementData::ToString(std::vector<uint8> const& /*value*/)
{
return "BINARY";
}
std::string PreparedStatementData::ToString(std::nullptr_t)
{
return "NULL";
}

View File

@@ -1,75 +1,57 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _PREPAREDSTATEMENT_H
#define _PREPAREDSTATEMENT_H
#include "Define.h"
#include "SQLOperation.h"
#include <ace/Future.h>
#ifdef __APPLE__
#undef TYPE_BOOL
#endif
//- Union for data buffer (upper-level bind -> queue -> lower-level bind)
union PreparedStatementDataUnion
{
bool boolean;
uint8 ui8;
int8 i8;
uint16 ui16;
int16 i16;
uint32 ui32;
int32 i32;
uint64 ui64;
int64 i64;
float f;
double d;
};
//- This enum helps us differ data held in above union
enum PreparedStatementValueType
{
TYPE_BOOL,
TYPE_UI8,
TYPE_UI16,
TYPE_UI32,
TYPE_UI64,
TYPE_I8,
TYPE_I16,
TYPE_I32,
TYPE_I64,
TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_STRING,
TYPE_BINARY,
TYPE_NULL
};
#include <future>
#include <variant>
#include <vector>
struct PreparedStatementData
{
PreparedStatementDataUnion data;
PreparedStatementValueType type;
std::vector<uint8> binary;
std::variant<
bool,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
float,
double,
std::string,
std::vector<uint8>,
std::nullptr_t
> data;
template<typename T>
static std::string ToString(T value);
static std::string ToString(bool value);
static std::string ToString(uint8 value);
static std::string ToString(int8 value);
static std::string ToString(std::string const& value);
static std::string ToString(std::vector<uint8> const& value);
static std::string ToString(std::nullptr_t);
};
//- Forward declare
class MySQLPreparedStatement;
//- Upper-level class that is used in code
class PreparedStatement
class AC_DATABASE_API PreparedStatementBase
{
friend class PreparedStatementTask;
friend class MySQLPreparedStatement;
friend class MySQLConnection;
friend class PreparedStatementTask;
public:
explicit PreparedStatement(uint32 index);
~PreparedStatement();
explicit PreparedStatementBase(uint32 index, uint8 capacity);
virtual ~PreparedStatementBase();
void setNull(const uint8 index);
void setBool(const uint8 index, const bool value);
void setUInt8(const uint8 index, const uint8 value);
void setUInt16(const uint8 index, const uint16 value);
@@ -82,84 +64,55 @@ public:
void setFloat(const uint8 index, const float value);
void setDouble(const uint8 index, const double value);
void setString(const uint8 index, const std::string& value);
void setStringView(const uint8 index, const std::string_view value);
void setBinary(const uint8 index, const std::vector<uint8>& value);
template<size_t Size>
template <size_t Size>
void setBinary(const uint8 index, std::array<uint8, Size> const& value)
{
std::vector<uint8> vec(value.begin(), value.end());
setBinary(index, vec);
}
void setNull(const uint8 index);
uint32 GetIndex() const { return m_index; }
std::vector<PreparedStatementData> const& GetParameters() const { return statement_data; }
protected:
void BindParameters();
protected:
MySQLPreparedStatement* m_stmt;
uint32 m_index;
std::vector<PreparedStatementData> statement_data; //- Buffer of parameters, not tied to MySQL in any way yet
//- Buffer of parameters, not tied to MySQL in any way yet
std::vector<PreparedStatementData> statement_data;
PreparedStatementBase(PreparedStatementBase const& right) = delete;
PreparedStatementBase& operator=(PreparedStatementBase const& right) = delete;
};
//- Class of which the instances are unique per MySQLConnection
//- access to these class objects is only done when a prepared statement task
//- is executed.
class MySQLPreparedStatement
template<typename T>
class PreparedStatement : public PreparedStatementBase
{
friend class MySQLConnection;
friend class PreparedStatement;
public:
MySQLPreparedStatement(MYSQL_STMT* stmt);
~MySQLPreparedStatement();
void setBool(const uint8 index, const bool value);
void setUInt8(const uint8 index, const uint8 value);
void setUInt16(const uint8 index, const uint16 value);
void setUInt32(const uint8 index, const uint32 value);
void setUInt64(const uint8 index, const uint64 value);
void setInt8(const uint8 index, const int8 value);
void setInt16(const uint8 index, const int16 value);
void setInt32(const uint8 index, const int32 value);
void setInt64(const uint8 index, const int64 value);
void setFloat(const uint8 index, const float value);
void setDouble(const uint8 index, const double value);
void setBinary(const uint8 index, const std::vector<uint8>& value, bool isString);
void setNull(const uint8 index);
protected:
MYSQL_STMT* GetSTMT() { return m_Mstmt; }
MYSQL_BIND* GetBind() { return m_bind; }
PreparedStatement* m_stmt;
void ClearParameters();
bool CheckValidIndex(uint8 index);
[[nodiscard]] std::string getQueryString(std::string const& sqlPattern) const;
explicit PreparedStatement(uint32 index, uint8 capacity) : PreparedStatementBase(index, capacity)
{
}
private:
void setValue(MYSQL_BIND* param, enum_field_types type, const void* value, uint32 len, bool isUnsigned);
private:
MYSQL_STMT* m_Mstmt;
uint32 m_paramCount;
std::vector<bool> m_paramsSet;
MYSQL_BIND* m_bind;
PreparedStatement(PreparedStatement const& right) = delete;
PreparedStatement& operator=(PreparedStatement const& right) = delete;
};
typedef ACE_Future<PreparedQueryResult> PreparedQueryResultFuture;
//- Lower-level class, enqueuable operation
class PreparedStatementTask : public SQLOperation
class AC_DATABASE_API PreparedStatementTask : public SQLOperation
{
public:
PreparedStatementTask(PreparedStatement* stmt);
PreparedStatementTask(PreparedStatement* stmt, PreparedQueryResultFuture result);
~PreparedStatementTask() override;
PreparedStatementTask(PreparedStatementBase* stmt, bool async = false);
~PreparedStatementTask();
bool Execute() override;
PreparedQueryResultFuture GetFuture() { return m_result->get_future(); }
protected:
PreparedStatement* m_stmt;
PreparedStatementBase* m_stmt;
bool m_has_result;
PreparedQueryResultFuture m_result;
PreparedQueryResultPromise* m_result;
};
#endif

View File

@@ -0,0 +1,209 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "QueryCallback.h"
#include "Errors.h"
template<typename T, typename... Args>
inline void Construct(T& t, Args&&... args)
{
new (&t) T(std::forward<Args>(args)...);
}
template<typename T>
inline void Destroy(T& t)
{
t.~T();
}
template<typename T>
inline void ConstructActiveMember(T* obj)
{
if (!obj->_isPrepared)
Construct(obj->_string);
else
Construct(obj->_prepared);
}
template<typename T>
inline void DestroyActiveMember(T* obj)
{
if (!obj->_isPrepared)
Destroy(obj->_string);
else
Destroy(obj->_prepared);
}
template<typename T>
inline void MoveFrom(T* to, T&& from)
{
ASSERT(to->_isPrepared == from._isPrepared);
if (!to->_isPrepared)
to->_string = std::move(from._string);
else
to->_prepared = std::move(from._prepared);
}
struct QueryCallback::QueryCallbackData
{
public:
friend class QueryCallback;
QueryCallbackData(std::function<void(QueryCallback&, QueryResult)>&& callback) : _string(std::move(callback)), _isPrepared(false) { }
QueryCallbackData(std::function<void(QueryCallback&, PreparedQueryResult)>&& callback) : _prepared(std::move(callback)), _isPrepared(true) { }
QueryCallbackData(QueryCallbackData&& right)
{
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
MoveFrom(this, std::move(right));
}
QueryCallbackData& operator=(QueryCallbackData&& right)
{
if (this != &right)
{
if (_isPrepared != right._isPrepared)
{
DestroyActiveMember(this);
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
}
MoveFrom(this, std::move(right));
}
return *this;
}
~QueryCallbackData() { DestroyActiveMember(this); }
private:
QueryCallbackData(QueryCallbackData const&) = delete;
QueryCallbackData& operator=(QueryCallbackData const&) = delete;
template<typename T> friend void ConstructActiveMember(T* obj);
template<typename T> friend void DestroyActiveMember(T* obj);
template<typename T> friend void MoveFrom(T* to, T&& from);
union
{
std::function<void(QueryCallback&, QueryResult)> _string;
std::function<void(QueryCallback&, PreparedQueryResult)> _prepared;
};
bool _isPrepared;
};
// Not using initialization lists to work around segmentation faults when compiling with clang without precompiled headers
QueryCallback::QueryCallback(std::future<QueryResult>&& result)
{
_isPrepared = false;
Construct(_string, std::move(result));
}
QueryCallback::QueryCallback(std::future<PreparedQueryResult>&& result)
{
_isPrepared = true;
Construct(_prepared, std::move(result));
}
QueryCallback::QueryCallback(QueryCallback&& right)
{
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
MoveFrom(this, std::move(right));
_callbacks = std::move(right._callbacks);
}
QueryCallback& QueryCallback::operator=(QueryCallback&& right)
{
if (this != &right)
{
if (_isPrepared != right._isPrepared)
{
DestroyActiveMember(this);
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
}
MoveFrom(this, std::move(right));
_callbacks = std::move(right._callbacks);
}
return *this;
}
QueryCallback::~QueryCallback()
{
DestroyActiveMember(this);
}
QueryCallback&& QueryCallback::WithCallback(std::function<void(QueryResult)>&& callback)
{
return WithChainingCallback([callback](QueryCallback& /*this*/, QueryResult result) { callback(std::move(result)); });
}
QueryCallback&& QueryCallback::WithPreparedCallback(std::function<void(PreparedQueryResult)>&& callback)
{
return WithChainingPreparedCallback([callback](QueryCallback& /*this*/, PreparedQueryResult result) { callback(std::move(result)); });
}
QueryCallback&& QueryCallback::WithChainingCallback(std::function<void(QueryCallback&, QueryResult)>&& callback)
{
ASSERT(!_callbacks.empty() || !_isPrepared, "Attempted to set callback function for string query on a prepared async query");
_callbacks.emplace(std::move(callback));
return std::move(*this);
}
QueryCallback&& QueryCallback::WithChainingPreparedCallback(std::function<void(QueryCallback&, PreparedQueryResult)>&& callback)
{
ASSERT(!_callbacks.empty() || _isPrepared, "Attempted to set callback function for prepared query on a string async query");
_callbacks.emplace(std::move(callback));
return std::move(*this);
}
void QueryCallback::SetNextQuery(QueryCallback&& next)
{
MoveFrom(this, std::move(next));
}
bool QueryCallback::InvokeIfReady()
{
QueryCallbackData& callback = _callbacks.front();
auto checkStateAndReturnCompletion = [this]()
{
_callbacks.pop();
bool hasNext = !_isPrepared ? _string.valid() : _prepared.valid();
if (_callbacks.empty())
{
ASSERT(!hasNext);
return true;
}
// abort chain
if (!hasNext)
return true;
ASSERT(_isPrepared == _callbacks.front()._isPrepared);
return false;
};
if (!_isPrepared)
{
if (_string.valid() && _string.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
QueryResultFuture f(std::move(_string));
std::function<void(QueryCallback&, QueryResult)> cb(std::move(callback._string));
cb(*this, f.get());
return checkStateAndReturnCompletion();
}
}
else
{
if (_prepared.valid() && _prepared.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
PreparedQueryResultFuture f(std::move(_prepared));
std::function<void(QueryCallback&, PreparedQueryResult)> cb(std::move(callback._prepared));
cb(*this, f.get());
return checkStateAndReturnCompletion();
}
}
return false;
}

View File

@@ -0,0 +1,57 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _QUERY_CALLBACK_H
#define _QUERY_CALLBACK_H
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include <functional>
#include <future>
#include <list>
#include <queue>
#include <utility>
class AC_DATABASE_API QueryCallback
{
public:
explicit QueryCallback(QueryResultFuture&& result);
explicit QueryCallback(PreparedQueryResultFuture&& result);
QueryCallback(QueryCallback&& right);
QueryCallback& operator=(QueryCallback&& right);
~QueryCallback();
QueryCallback&& WithCallback(std::function<void(QueryResult)>&& callback);
QueryCallback&& WithPreparedCallback(std::function<void(PreparedQueryResult)>&& callback);
QueryCallback&& WithChainingCallback(std::function<void(QueryCallback&, QueryResult)>&& callback);
QueryCallback&& WithChainingPreparedCallback(std::function<void(QueryCallback&, PreparedQueryResult)>&& callback);
// Moves std::future from next to this object
void SetNextQuery(QueryCallback&& next);
// returns true when completed
bool InvokeIfReady();
private:
QueryCallback(QueryCallback const& right) = delete;
QueryCallback& operator=(QueryCallback const& right) = delete;
template<typename T> friend void ConstructActiveMember(T* obj);
template<typename T> friend void DestroyActiveMember(T* obj);
template<typename T> friend void MoveFrom(T* to, T&& from);
union
{
QueryResultFuture _string;
PreparedQueryResultFuture _prepared;
};
bool _isPrepared;
struct QueryCallbackData;
std::queue<QueryCallbackData, std::list<QueryCallbackData>> _callbacks;
};
#endif // _QUERY_CALLBACK_H

View File

@@ -1,109 +1,37 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "MySQLConnection.h"
#include "QueryHolder.h"
#include "PreparedStatement.h"
#include "Errors.h"
#include "Log.h"
#include "MySQLConnection.h"
#include "PreparedStatement.h"
#include "QueryResult.h"
bool SQLQueryHolder::SetQuery(size_t index, const char* sql)
bool SQLQueryHolderBase::SetPreparedQueryImpl(size_t index, PreparedStatementBase* stmt)
{
if (m_queries.size() <= index)
{
LOG_ERROR("server", "Query index (%u) out of range (size: %u) for query: %s", uint32(index), (uint32)m_queries.size(), sql);
LOG_ERROR("sql.sql", "Query index (%u) out of range (size: %u) for prepared statement", uint32(index), (uint32)m_queries.size());
return false;
}
/// not executed yet, just stored (it's not called a holder for nothing)
SQLElementData element;
element.type = SQL_ELEMENT_RAW;
element.element.query = strdup(sql);
SQLResultSetUnion result;
result.qresult = nullptr;
m_queries[index] = SQLResultPair(element, result);
m_queries[index].first = stmt;
return true;
}
bool SQLQueryHolder::SetPQuery(size_t index, const char* format, ...)
{
if (!format)
{
LOG_ERROR("server", "Query (index: %u) is empty.", uint32(index));
return false;
}
va_list ap;
char szQuery [MAX_QUERY_LEN];
va_start(ap, format);
int res = vsnprintf(szQuery, MAX_QUERY_LEN, format, ap);
va_end(ap);
if (res == -1)
{
LOG_ERROR("server", "SQL Query truncated (and not execute) for format: %s", format);
return false;
}
return SetQuery(index, szQuery);
}
bool SQLQueryHolder::SetPreparedQuery(size_t index, PreparedStatement* stmt)
{
if (m_queries.size() <= index)
{
LOG_ERROR("server", "Query index (%u) out of range (size: %u) for prepared statement", uint32(index), (uint32)m_queries.size());
return false;
}
/// not executed yet, just stored (it's not called a holder for nothing)
SQLElementData element;
element.type = SQL_ELEMENT_PREPARED;
element.element.stmt = stmt;
SQLResultSetUnion result;
result.presult = nullptr;
m_queries[index] = SQLResultPair(element, result);
return true;
}
QueryResult SQLQueryHolder::GetResult(size_t index)
{
// Don't call to this function if the index is of an ad-hoc statement
if (index < m_queries.size())
{
ResultSet* result = m_queries[index].second.qresult;
if (!result || !result->GetRowCount())
return QueryResult(nullptr);
result->NextRow();
return QueryResult(result);
}
else
return QueryResult(nullptr);
}
PreparedQueryResult SQLQueryHolder::GetPreparedResult(size_t index)
PreparedQueryResult SQLQueryHolderBase::GetPreparedResult(size_t index) const
{
// Don't call to this function if the index is of a prepared statement
if (index < m_queries.size())
{
PreparedResultSet* result = m_queries[index].second.presult;
if (!result || !result->GetRowCount())
return PreparedQueryResult(nullptr);
ASSERT(index < m_queries.size(), "Query holder result index out of range, tried to access index " SZFMTD " but there are only " SZFMTD " results",
index, m_queries.size());
return PreparedQueryResult(result);
}
else
return PreparedQueryResult(nullptr);
return m_queries[index].second;
}
void SQLQueryHolder::SetResult(size_t index, ResultSet* result)
void SQLQueryHolderBase::SetPreparedResult(size_t index, PreparedResultSet* result)
{
if (result && !result->GetRowCount())
{
@@ -113,85 +41,45 @@ void SQLQueryHolder::SetResult(size_t index, ResultSet* result)
/// store the result in the holder
if (index < m_queries.size())
m_queries[index].second.qresult = result;
m_queries[index].second = PreparedQueryResult(result);
}
void SQLQueryHolder::SetPreparedResult(size_t index, PreparedResultSet* result)
SQLQueryHolderBase::~SQLQueryHolderBase()
{
if (result && !result->GetRowCount())
{
delete result;
result = nullptr;
}
/// store the result in the holder
if (index < m_queries.size())
m_queries[index].second.presult = result;
}
SQLQueryHolder::~SQLQueryHolder()
{
for (size_t i = 0; i < m_queries.size(); i++)
for (std::pair<PreparedStatementBase*, PreparedQueryResult>& query : m_queries)
{
/// if the result was never used, free the resources
/// results used already (getresult called) are expected to be deleted
if (SQLElementData* data = &m_queries[i].first)
{
switch (data->type)
{
case SQL_ELEMENT_RAW:
free((void*)(const_cast<char*>(data->element.query)));
break;
case SQL_ELEMENT_PREPARED:
delete data->element.stmt;
break;
}
}
delete query.first;
}
}
void SQLQueryHolder::SetSize(size_t size)
void SQLQueryHolderBase::SetSize(size_t size)
{
/// to optimize push_back, reserve the number of queries about to be executed
m_queries.resize(size);
}
SQLQueryHolderTask::~SQLQueryHolderTask() = default;
bool SQLQueryHolderTask::Execute()
{
//the result can't be ready as we are processing it right now
ASSERT(!m_result.ready());
/// execute all queries in the holder and pass the results
for (size_t i = 0; i < m_holder->m_queries.size(); ++i)
if (PreparedStatementBase* stmt = m_holder->m_queries[i].first)
m_holder->SetPreparedResult(i, m_conn->Query(stmt));
if (!m_holder)
return false;
/// we can do this, we are friends
std::vector<SQLQueryHolder::SQLResultPair>& queries = m_holder->m_queries;
for (size_t i = 0; i < queries.size(); i++)
{
/// execute all queries in the holder and pass the results
if (SQLElementData* data = &queries[i].first)
{
switch (data->type)
{
case SQL_ELEMENT_RAW:
{
char const* sql = data->element.query;
if (sql)
m_holder->SetResult(i, m_conn->Query(sql));
break;
}
case SQL_ELEMENT_PREPARED:
{
PreparedStatement* stmt = data->element.stmt;
if (stmt)
m_holder->SetPreparedResult(i, m_conn->Query(stmt));
break;
}
}
}
}
m_result.set(m_holder);
m_result.set_value();
return true;
}
bool SQLQueryHolderCallback::InvokeIfReady()
{
if (m_future.valid() && m_future.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
m_callback(*m_holder);
return true;
}
return false;
}

View File

@@ -1,45 +1,76 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _QUERYHOLDER_H
#define _QUERYHOLDER_H
#include <ace/Future.h>
#include "SQLOperation.h"
#include <vector>
class SQLQueryHolder
class AC_DATABASE_API SQLQueryHolderBase
{
friend class SQLQueryHolderTask;
friend class SQLQueryHolderTask;
private:
typedef std::pair<SQLElementData, SQLResultSetUnion> SQLResultPair;
std::vector<SQLResultPair> m_queries;
std::vector<std::pair<PreparedStatementBase*, PreparedQueryResult>> m_queries;
public:
SQLQueryHolder() = default;
~SQLQueryHolder();
bool SetQuery(size_t index, const char* sql);
bool SetPQuery(size_t index, const char* format, ...) ATTR_PRINTF(3, 4);
bool SetPreparedQuery(size_t index, PreparedStatement* stmt);
SQLQueryHolderBase() = default;
virtual ~SQLQueryHolderBase();
void SetSize(size_t size);
QueryResult GetResult(size_t index);
PreparedQueryResult GetPreparedResult(size_t index);
void SetResult(size_t index, ResultSet* result);
PreparedQueryResult GetPreparedResult(size_t index) const;
void SetPreparedResult(size_t index, PreparedResultSet* result);
protected:
bool SetPreparedQueryImpl(size_t index, PreparedStatementBase* stmt);
};
typedef ACE_Future<SQLQueryHolder*> QueryResultHolderFuture;
template<typename T>
class SQLQueryHolder : public SQLQueryHolderBase
{
public:
bool SetPreparedQuery(size_t index, PreparedStatement<T>* stmt)
{
return SetPreparedQueryImpl(index, stmt);
}
};
class SQLQueryHolderTask : public SQLOperation
class AC_DATABASE_API SQLQueryHolderTask : public SQLOperation
{
private:
SQLQueryHolder* m_holder;
QueryResultHolderFuture m_result;
std::shared_ptr<SQLQueryHolderBase> m_holder;
QueryResultHolderPromise m_result;
public:
SQLQueryHolderTask(SQLQueryHolder* holder, QueryResultHolderFuture res)
: m_holder(holder), m_result(res) { };
explicit SQLQueryHolderTask(std::shared_ptr<SQLQueryHolderBase> holder)
: m_holder(std::move(holder)) { }
~SQLQueryHolderTask();
bool Execute() override;
QueryResultHolderFuture GetFuture() { return m_result.get_future(); }
};
class AC_DATABASE_API SQLQueryHolderCallback
{
public:
SQLQueryHolderCallback(std::shared_ptr<SQLQueryHolderBase>&& holder, QueryResultHolderFuture&& future)
: m_holder(std::move(holder)), m_future(std::move(future)) { }
SQLQueryHolderCallback(SQLQueryHolderCallback&&) = default;
SQLQueryHolderCallback& operator=(SQLQueryHolderCallback&&) = default;
void AfterComplete(std::function<void(SQLQueryHolderBase const&)> callback) &
{
m_callback = std::move(callback);
}
bool InvokeIfReady();
std::shared_ptr<SQLQueryHolderBase> m_holder;
QueryResultHolderFuture m_future;
std::function<void(SQLQueryHolderBase const&)> m_callback;
};
#endif

View File

@@ -1,33 +1,181 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "DatabaseEnv.h"
#include "QueryResult.h"
#include "Errors.h"
#include "Field.h"
#include "Log.h"
#include "MySQLHacks.h"
#include "MySQLWorkaround.h"
ResultSet::ResultSet(MYSQL_RES* result, MYSQL_FIELD* fields, uint64 rowCount, uint32 fieldCount) :
_rowCount(rowCount),
_fieldCount(fieldCount),
_result(result),
_fields(fields)
namespace
{
_currentRow = new Field[_fieldCount];
ASSERT(_currentRow);
static uint32 SizeForType(MYSQL_FIELD* field)
{
switch (field->type)
{
case MYSQL_TYPE_NULL:
return 0;
case MYSQL_TYPE_TINY:
return 1;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return 2;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_FLOAT:
return 4;
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return 8;
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return sizeof(MYSQL_TIME);
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return field->max_length + 1;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return 64;
case MYSQL_TYPE_GEOMETRY:
/*
Following types are not sent over the wire:
MYSQL_TYPE_ENUM:
MYSQL_TYPE_SET:
*/
default:
LOG_WARN("sql.sql", "SQL::SizeForType(): invalid field type %u", uint32(field->type));
return 0;
}
}
PreparedResultSet::PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES* result, uint64 rowCount, uint32 fieldCount) :
m_rowCount(rowCount),
m_rowPosition(0),
m_fieldCount(fieldCount),
m_rBind(nullptr),
m_stmt(stmt),
m_res(result),
m_isNull(nullptr),
m_length(nullptr)
DatabaseFieldTypes MysqlTypeToFieldType(enum_field_types type)
{
if (!m_res)
switch (type)
{
case MYSQL_TYPE_NULL:
return DatabaseFieldTypes::Null;
case MYSQL_TYPE_TINY:
return DatabaseFieldTypes::Int8;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return DatabaseFieldTypes::Int16;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
return DatabaseFieldTypes::Int32;
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return DatabaseFieldTypes::Int64;
case MYSQL_TYPE_FLOAT:
return DatabaseFieldTypes::Float;
case MYSQL_TYPE_DOUBLE:
return DatabaseFieldTypes::Double;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return DatabaseFieldTypes::Decimal;
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return DatabaseFieldTypes::Date;
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return DatabaseFieldTypes::Binary;
default:
LOG_WARN("sql.sql", "MysqlTypeToFieldType(): invalid field type %u", uint32(type));
break;
}
return DatabaseFieldTypes::Null;
}
static char const* FieldTypeToString(enum_field_types type)
{
switch (type)
{
case MYSQL_TYPE_BIT: return "BIT";
case MYSQL_TYPE_BLOB: return "BLOB";
case MYSQL_TYPE_DATE: return "DATE";
case MYSQL_TYPE_DATETIME: return "DATETIME";
case MYSQL_TYPE_NEWDECIMAL: return "NEWDECIMAL";
case MYSQL_TYPE_DECIMAL: return "DECIMAL";
case MYSQL_TYPE_DOUBLE: return "DOUBLE";
case MYSQL_TYPE_ENUM: return "ENUM";
case MYSQL_TYPE_FLOAT: return "FLOAT";
case MYSQL_TYPE_GEOMETRY: return "GEOMETRY";
case MYSQL_TYPE_INT24: return "INT24";
case MYSQL_TYPE_LONG: return "LONG";
case MYSQL_TYPE_LONGLONG: return "LONGLONG";
case MYSQL_TYPE_LONG_BLOB: return "LONG_BLOB";
case MYSQL_TYPE_MEDIUM_BLOB: return "MEDIUM_BLOB";
case MYSQL_TYPE_NEWDATE: return "NEWDATE";
case MYSQL_TYPE_NULL: return "NULL";
case MYSQL_TYPE_SET: return "SET";
case MYSQL_TYPE_SHORT: return "SHORT";
case MYSQL_TYPE_STRING: return "STRING";
case MYSQL_TYPE_TIME: return "TIME";
case MYSQL_TYPE_TIMESTAMP: return "TIMESTAMP";
case MYSQL_TYPE_TINY: return "TINY";
case MYSQL_TYPE_TINY_BLOB: return "TINY_BLOB";
case MYSQL_TYPE_VAR_STRING: return "VAR_STRING";
case MYSQL_TYPE_YEAR: return "YEAR";
default: return "-Unknown-";
}
}
void InitializeDatabaseFieldMetadata(QueryResultFieldMetadata* meta, MySQLField const* field, uint32 fieldIndex)
{
meta->TableName = field->org_table;
meta->TableAlias = field->table;
meta->Name = field->org_name;
meta->Alias = field->name;
meta->TypeName = FieldTypeToString(field->type);
meta->Index = fieldIndex;
meta->Type = MysqlTypeToFieldType(field->type);
}
}
ResultSet::ResultSet(MySQLResult* result, MySQLField* fields, uint64 rowCount, uint32 fieldCount) :
_rowCount(rowCount),
_fieldCount(fieldCount),
_result(result),
_fields(fields)
{
_fieldMetadata.resize(_fieldCount);
_currentRow = new Field[_fieldCount];
for (uint32 i = 0; i < _fieldCount; i++)
{
InitializeDatabaseFieldMetadata(&_fieldMetadata[i], &_fields[i], i);
_currentRow[i].SetMetadata(&_fieldMetadata[i]);
}
}
PreparedResultSet::PreparedResultSet(MySQLStmt* stmt, MySQLResult* result, uint64 rowCount, uint32 fieldCount) :
m_rowCount(rowCount),
m_rowPosition(0),
m_fieldCount(fieldCount),
m_rBind(nullptr),
m_stmt(stmt),
m_metadataResult(result)
{
if (!m_metadataResult)
return;
if (m_stmt->bind_result_done)
@@ -36,48 +184,22 @@ PreparedResultSet::PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES* result, uint64
delete[] m_stmt->bind->is_null;
}
m_rBind = new MYSQL_BIND[m_fieldCount];
m_isNull = new my_bool[m_fieldCount];
m_length = new unsigned long[m_fieldCount];
m_rBind = new MySQLBind[m_fieldCount];
memset(m_isNull, 0, sizeof(my_bool) * m_fieldCount);
memset(m_rBind, 0, sizeof(MYSQL_BIND) * m_fieldCount);
//- for future readers wondering where the fuck this is freed - mysql_stmt_bind_result moves pointers to these
// from m_rBind to m_stmt->bind and it is later freed by the `if (m_stmt->bind_result_done)` block just above here
// MYSQL_STMT lifetime is equal to connection lifetime
MySQLBool* m_isNull = new MySQLBool[m_fieldCount];
unsigned long* m_length = new unsigned long[m_fieldCount];
memset(m_isNull, 0, sizeof(MySQLBool) * m_fieldCount);
memset(m_rBind, 0, sizeof(MySQLBind) * m_fieldCount);
memset(m_length, 0, sizeof(unsigned long) * m_fieldCount);
//- This is where we store the (entire) resultset
if (mysql_stmt_store_result(m_stmt))
{
LOG_INFO("sql.driver", "%s:mysql_stmt_store_result, cannot bind result from MySQL server. Error: %s", __FUNCTION__, mysql_stmt_error(m_stmt));
delete[] m_rBind;
delete[] m_isNull;
delete[] m_length;
return;
}
//- This is where we prepare the buffer based on metadata
uint32 i = 0;
MYSQL_FIELD* field = mysql_fetch_field(m_res);
while (field)
{
size_t size = Field::SizeForType(field);
m_rBind[i].buffer_type = field->type;
m_rBind[i].buffer = malloc(size);
memset(m_rBind[i].buffer, 0, size);
m_rBind[i].buffer_length = size;
m_rBind[i].length = &m_length[i];
m_rBind[i].is_null = &m_isNull[i];
m_rBind[i].error = nullptr;
m_rBind[i].is_unsigned = field->flags & UNSIGNED_FLAG;
++i;
field = mysql_fetch_field(m_res);
}
//- This is where we bind the bind the buffer to the statement
if (mysql_stmt_bind_result(m_stmt, m_rBind))
{
LOG_INFO("sql.driver", "%s:mysql_stmt_bind_result, cannot bind result from MySQL server. Error: %s", __FUNCTION__, mysql_stmt_error(m_stmt));
LOG_WARN("sql.sql", "%s:mysql_stmt_store_result, cannot bind result from MySQL server. Error: %s", __FUNCTION__, mysql_stmt_error(m_stmt));
delete[] m_rBind;
delete[] m_isNull;
delete[] m_length;
@@ -86,18 +208,55 @@ PreparedResultSet::PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES* result, uint64
m_rowCount = mysql_stmt_num_rows(m_stmt);
m_rows.resize(uint32(m_rowCount));
//- This is where we prepare the buffer based on metadata
MySQLField* field = reinterpret_cast<MySQLField*>(mysql_fetch_fields(m_metadataResult));
m_fieldMetadata.resize(m_fieldCount);
std::size_t rowSize = 0;
for (uint32 i = 0; i < m_fieldCount; ++i)
{
uint32 size = SizeForType(&field[i]);
rowSize += size;
InitializeDatabaseFieldMetadata(&m_fieldMetadata[i], &field[i], i);
m_rBind[i].buffer_type = field[i].type;
m_rBind[i].buffer_length = size;
m_rBind[i].length = &m_length[i];
m_rBind[i].is_null = &m_isNull[i];
m_rBind[i].error = nullptr;
m_rBind[i].is_unsigned = field[i].flags & UNSIGNED_FLAG;
}
char* dataBuffer = new char[rowSize * m_rowCount];
for (uint32 i = 0, offset = 0; i < m_fieldCount; ++i)
{
m_rBind[i].buffer = dataBuffer + offset;
offset += m_rBind[i].buffer_length;
}
//- This is where we bind the bind the buffer to the statement
if (mysql_stmt_bind_result(m_stmt, m_rBind))
{
LOG_WARN("sql.sql", "%s:mysql_stmt_bind_result, cannot bind result from MySQL server. Error: %s", __FUNCTION__, mysql_stmt_error(m_stmt));
mysql_stmt_free_result(m_stmt);
CleanUp();
delete[] m_isNull;
delete[] m_length;
return;
}
m_rows.resize(uint32(m_rowCount) * m_fieldCount);
while (_NextRow())
{
m_rows[uint32(m_rowPosition)] = new Field[m_fieldCount];
for (uint64 fIndex = 0; fIndex < m_fieldCount; ++fIndex)
for (uint32 fIndex = 0; fIndex < m_fieldCount; ++fIndex)
{
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetMetadata(&m_fieldMetadata[fIndex]);
unsigned long buffer_length = m_rBind[fIndex].buffer_length;
unsigned long fetched_length = *m_rBind[fIndex].length;
if (!*m_rBind[fIndex].is_null)
m_rows[uint32(m_rowPosition)][fIndex].SetByteValue( m_rBind[fIndex].buffer,
m_rBind[fIndex].buffer_length,
m_rBind[fIndex].buffer_type,
*m_rBind[fIndex].length );
else
{
void* buffer = m_stmt->bind[fIndex].buffer;
switch (m_rBind[fIndex].buffer_type)
{
case MYSQL_TYPE_TINY_BLOB:
@@ -106,24 +265,38 @@ PreparedResultSet::PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES* result, uint64
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
m_rows[uint32(m_rowPosition)][fIndex].SetByteValue( "",
m_rBind[fIndex].buffer_length,
m_rBind[fIndex].buffer_type,
*m_rBind[fIndex].length );
// warning - the string will not be null-terminated if there is no space for it in the buffer
// when mysql_stmt_fetch returned MYSQL_DATA_TRUNCATED
// we cannot blindly null-terminate the data either as it may be retrieved as binary blob and not specifically a string
// in this case using Field::GetCString will result in garbage
// TODO: remove Field::GetCString and use std::string_view in C++17
if (fetched_length < buffer_length)
*((char*)buffer + fetched_length) = '\0';
break;
default:
m_rows[uint32(m_rowPosition)][fIndex].SetByteValue( 0,
m_rBind[fIndex].buffer_length,
m_rBind[fIndex].buffer_type,
*m_rBind[fIndex].length );
break;
}
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetByteValue(
(char const*)buffer,
fetched_length);
// move buffer pointer to next part
m_stmt->bind[fIndex].buffer = (char*)buffer + rowSize;
}
else
{
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetByteValue(
nullptr,
*m_rBind[fIndex].length);
}
}
m_rowPosition++;
}
m_rowPosition = 0;
/// All data is buffered, let go of mysql c api structures
CleanUp();
mysql_stmt_free_result(m_stmt);
}
ResultSet::~ResultSet()
@@ -133,8 +306,7 @@ ResultSet::~ResultSet()
PreparedResultSet::~PreparedResultSet()
{
for (uint32 i = 0; i < uint32(m_rowCount); ++i)
delete[] m_rows[i];
CleanUp();
}
bool ResultSet::NextRow()
@@ -154,16 +326,23 @@ bool ResultSet::NextRow()
unsigned long* lengths = mysql_fetch_lengths(_result);
if (!lengths)
{
LOG_WARN("sql.sql", "%s:mysql_fetch_lengths, cannot retrieve value lengths. Error %s.", __FUNCTION__, mysql_error(_result->handle));
CleanUp();
return false;
}
for (uint32 i = 0; i < _fieldCount; i++)
_currentRow[i].SetStructuredValue(row[i], _fields[i].type, lengths[i]);
_currentRow[i].SetStructuredValue(row[i], lengths[i]);
return true;
}
std::string ResultSet::GetFieldName(uint32 index) const
{
ASSERT(index < _fieldCount);
return _fields[index].name;
}
bool PreparedResultSet::NextRow()
{
/// Only updates the m_rowPosition so upper level code knows in which element
@@ -181,25 +360,10 @@ bool PreparedResultSet::_NextRow()
if (m_rowPosition >= m_rowCount)
return false;
int retval = mysql_stmt_fetch( m_stmt );
if (!retval || retval == MYSQL_DATA_TRUNCATED)
retval = true;
if (retval == MYSQL_NO_DATA)
retval = false;
return retval;
int retval = mysql_stmt_fetch(m_stmt);
return retval == 0 || retval == MYSQL_DATA_TRUNCATED;
}
#ifdef ELUNA
std::string ResultSet::GetFieldName(uint32 index) const
{
ASSERT(index < _fieldCount);
return _fields[index].name;
}
#endif
void ResultSet::CleanUp()
{
if (_currentRow)
@@ -215,20 +379,34 @@ void ResultSet::CleanUp()
}
}
Field const& ResultSet::operator[](std::size_t index) const
{
ASSERT(index < _fieldCount);
return _currentRow[index];
}
Field* PreparedResultSet::Fetch() const
{
ASSERT(m_rowPosition < m_rowCount);
return const_cast<Field*>(&m_rows[uint32(m_rowPosition) * m_fieldCount]);
}
Field const& PreparedResultSet::operator[](std::size_t index) const
{
ASSERT(m_rowPosition < m_rowCount);
ASSERT(index < m_fieldCount);
return m_rows[uint32(m_rowPosition) * m_fieldCount + index];
}
void PreparedResultSet::CleanUp()
{
/// More of the in our code allocated sources are deallocated by the poorly documented mysql c api
if (m_res)
mysql_free_result(m_res);
if (m_metadataResult)
mysql_free_result(m_metadataResult);
FreeBindBuffer();
mysql_stmt_free_result(m_stmt);
delete[] m_rBind;
}
void PreparedResultSet::FreeBindBuffer()
{
for (uint32 i = 0; i < m_fieldCount; ++i)
free (m_rBind[i].buffer);
if (m_rBind)
{
delete[](char*)m_rBind->buffer;
delete[] m_rBind;
m_rBind = nullptr;
}
}

View File

@@ -1,99 +1,74 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef QUERYRESULT_H
#define QUERYRESULT_H
#include "Errors.h"
#include "Field.h"
#include <mutex>
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include <vector>
#ifdef _WIN32
#include <winsock2.h>
#endif
#include <mysql.h>
#if !defined(MARIADB_VERSION_ID) && MYSQL_VERSION_ID >= 80001
typedef bool my_bool;
#endif
class ResultSet
class AC_DATABASE_API ResultSet
{
public:
ResultSet(MYSQL_RES* result, MYSQL_FIELD* fields, uint64 rowCount, uint32 fieldCount);
ResultSet(MySQLResult* result, MySQLField* fields, uint64 rowCount, uint32 fieldCount);
~ResultSet();
bool NextRow();
[[nodiscard]] uint64 GetRowCount() const { return _rowCount; }
[[nodiscard]] uint32 GetFieldCount() const { return _fieldCount; }
#ifdef ELUNA
uint64 GetRowCount() const { return _rowCount; }
uint32 GetFieldCount() const { return _fieldCount; }
std::string GetFieldName(uint32 index) const;
#endif
[[nodiscard]] Field* Fetch() const { return _currentRow; }
const Field& operator [] (uint32 index) const
{
ASSERT(index < _fieldCount);
return _currentRow[index];
}
Field* Fetch() const { return _currentRow; }
Field const& operator[](std::size_t index) const;
protected:
std::vector<QueryResultFieldMetadata> _fieldMetadata;
uint64 _rowCount;
Field* _currentRow;
uint32 _fieldCount;
private:
void CleanUp();
MYSQL_RES* _result;
MYSQL_FIELD* _fields;
MySQLResult* _result;
MySQLField* _fields;
ResultSet(ResultSet const& right) = delete;
ResultSet& operator=(ResultSet const& right) = delete;
};
typedef std::shared_ptr<ResultSet> QueryResult;
class PreparedResultSet
class AC_DATABASE_API PreparedResultSet
{
public:
PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES* result, uint64 rowCount, uint32 fieldCount);
PreparedResultSet(MySQLStmt* stmt, MySQLResult* result, uint64 rowCount, uint32 fieldCount);
~PreparedResultSet();
bool NextRow();
[[nodiscard]] uint64 GetRowCount() const { return m_rowCount; }
[[nodiscard]] uint32 GetFieldCount() const { return m_fieldCount; }
uint64 GetRowCount() const { return m_rowCount; }
uint32 GetFieldCount() const { return m_fieldCount; }
[[nodiscard]] Field* Fetch() const
{
ASSERT(m_rowPosition < m_rowCount);
return m_rows[uint32(m_rowPosition)];
}
const Field& operator [] (uint32 index) const
{
ASSERT(m_rowPosition < m_rowCount);
ASSERT(index < m_fieldCount);
return m_rows[uint32(m_rowPosition)][index];
}
Field* Fetch() const;
Field const& operator[](std::size_t index) const;
protected:
std::vector<Field*> m_rows;
std::vector<QueryResultFieldMetadata> m_fieldMetadata;
std::vector<Field> m_rows;
uint64 m_rowCount;
uint64 m_rowPosition;
uint32 m_fieldCount;
private:
MYSQL_BIND* m_rBind;
MYSQL_STMT* m_stmt;
MYSQL_RES* m_res;
MySQLBind* m_rBind;
MySQLStmt* m_stmt;
MySQLResult* m_metadataResult; ///< Field metadata, returned by mysql_stmt_result_metadata
my_bool* m_isNull;
unsigned long* m_length;
void FreeBindBuffer();
void CleanUp();
bool _NextRow();
PreparedResultSet(PreparedResultSet const& right) = delete;
PreparedResultSet& operator=(PreparedResultSet const& right) = delete;
};
typedef std::shared_ptr<PreparedResultSet> PreparedQueryResult;
#endif

View File

@@ -1,25 +1,19 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _SQLOPERATION_H
#define _SQLOPERATION_H
#include <ace/Method_Request.h>
#include <ace/Activation_Queue.h>
#include "QueryResult.h"
//- Forward declare (don't include header to prevent circular includes)
class PreparedStatement;
#include "DatabaseEnvFwd.h"
#include "Define.h"
//- Union that holds element data
union SQLElementUnion
{
PreparedStatement* stmt;
const char* query;
PreparedStatementBase* stmt;
char const* query;
};
//- Type specifier of our element data
@@ -36,20 +30,15 @@ struct SQLElementData
SQLElementDataType type;
};
//- For ambigious resultsets
union SQLResultSetUnion
{
PreparedResultSet* presult;
ResultSet* qresult;
};
class MySQLConnection;
class SQLOperation : public ACE_Method_Request
class AC_DATABASE_API SQLOperation
{
public:
SQLOperation(): m_conn(nullptr) { }
int call() override
virtual ~SQLOperation() { }
virtual int call()
{
Execute();
return 0;
@@ -58,6 +47,10 @@ public:
virtual void SetConnection(MySQLConnection* con) { m_conn = con; }
MySQLConnection* m_conn;
private:
SQLOperation(SQLOperation const& right) = delete;
SQLOperation& operator=(SQLOperation const& right) = delete;
};
#endif

View File

@@ -1,16 +1,23 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "DatabaseEnv.h"
#include "Transaction.h"
#include "Log.h"
#include "MySQLConnection.h"
#include "PreparedStatement.h"
#include "Timer.h"
#include <mysqld_error.h>
#include <sstream>
#include <thread>
std::mutex TransactionTask::_deadlockLock;
#define DEADLOCK_MAX_RETRY_TIME_MS 60000
//- Append a raw ad-hoc query to the transaction
void Transaction::Append(const char* sql)
void TransactionBase::Append(char const* sql)
{
SQLElementData data;
data.type = SQL_ELEMENT_RAW;
@@ -18,19 +25,8 @@ void Transaction::Append(const char* sql)
m_queries.push_back(data);
}
void Transaction::PAppend(const char* sql, ...)
{
va_list ap;
char szQuery [MAX_QUERY_LEN];
va_start(ap, sql);
vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap);
va_end(ap);
Append(szQuery);
}
//- Append a prepared statement to the transaction
void Transaction::Append(PreparedStatement* stmt)
void TransactionBase::AppendPreparedStatement(PreparedStatementBase* stmt)
{
SQLElementData data;
data.type = SQL_ELEMENT_PREPARED;
@@ -38,47 +34,116 @@ void Transaction::Append(PreparedStatement* stmt)
m_queries.push_back(data);
}
void Transaction::Cleanup()
void TransactionBase::Cleanup()
{
// This might be called by explicit calls to Cleanup or by the auto-destructor
if (_cleanedUp)
return;
while (!m_queries.empty())
for (SQLElementData const& data : m_queries)
{
SQLElementData const& data = m_queries.front();
switch (data.type)
{
case SQL_ELEMENT_PREPARED:
delete data.element.stmt;
break;
break;
case SQL_ELEMENT_RAW:
free((void*)(data.element.query));
break;
break;
}
m_queries.pop_front();
}
m_queries.clear();
_cleanedUp = true;
}
bool TransactionTask::Execute()
{
int errorCode = m_conn->ExecuteTransaction(m_trans);
int errorCode = TryExecute();
if (!errorCode)
return true;
if (errorCode == ER_LOCK_DEADLOCK)
{
uint8 loopBreaker = 5; // Handle MySQL Errno 1213 without extending deadlock to the core itself
for (uint8 i = 0; i < loopBreaker; ++i)
if (!m_conn->ExecuteTransaction(m_trans))
std::ostringstream threadIdStream;
threadIdStream << std::this_thread::get_id();
std::string threadId = threadIdStream.str();
// Make sure only 1 async thread retries a transaction so they don't keep dead-locking each other
std::lock_guard<std::mutex> lock(_deadlockLock);
for (uint32 loopDuration = 0, startMSTime = getMSTime(); loopDuration <= DEADLOCK_MAX_RETRY_TIME_MS; loopDuration = GetMSTimeDiffToNow(startMSTime))
{
if (!TryExecute())
return true;
LOG_WARN("sql.sql", "Deadlocked SQL Transaction, retrying. Loop timer: %u ms, Thread Id: %s", loopDuration, threadId.c_str());
}
LOG_ERROR("sql.sql", "Fatal deadlocked SQL Transaction, it will not be retried anymore. Thread Id: %s", threadId.c_str());
}
// Clean up now.
m_trans->Cleanup();
CleanupOnFailure();
return false;
}
int TransactionTask::TryExecute()
{
return m_conn->ExecuteTransaction(m_trans);
}
void TransactionTask::CleanupOnFailure()
{
m_trans->Cleanup();
}
bool TransactionWithResultTask::Execute()
{
int errorCode = TryExecute();
if (!errorCode)
{
m_result.set_value(true);
return true;
}
if (errorCode == ER_LOCK_DEADLOCK)
{
std::ostringstream threadIdStream;
threadIdStream << std::this_thread::get_id();
std::string threadId = threadIdStream.str();
// Make sure only 1 async thread retries a transaction so they don't keep dead-locking each other
std::lock_guard<std::mutex> lock(_deadlockLock);
for (uint32 loopDuration = 0, startMSTime = getMSTime(); loopDuration <= DEADLOCK_MAX_RETRY_TIME_MS; loopDuration = GetMSTimeDiffToNow(startMSTime))
{
if (!TryExecute())
{
m_result.set_value(true);
return true;
}
LOG_WARN("sql.sql", "Deadlocked SQL Transaction, retrying. Loop timer: %u ms, Thread Id: %s", loopDuration, threadId.c_str());
}
LOG_ERROR("sql.sql", "Fatal deadlocked SQL Transaction, it will not be retried anymore. Thread Id: %s", threadId.c_str());
}
// Clean up now.
CleanupOnFailure();
m_result.set_value(false);
return false;
}
bool TransactionCallback::InvokeIfReady()
{
if (m_future.valid() && m_future.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
m_callback(m_future.get());
return true;
}
return false;
}

View File

@@ -1,62 +1,111 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
* Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/>
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU GPL v2 license, you may redistribute it and/or modify it under version 2 of the License, or (at your option), any later version.
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef _TRANSACTION_H
#define _TRANSACTION_H
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include "SQLOperation.h"
#include <list>
#include <utility>
//- Forward declare (don't include header to prevent circular includes)
class PreparedStatement;
#include "StringFormat.h"
#include <functional>
#include <mutex>
#include <vector>
/*! Transactions, high level class. */
class Transaction
class AC_DATABASE_API TransactionBase
{
friend class TransactionTask;
friend class MySQLConnection;
friend class TransactionTask;
friend class MySQLConnection;
template <typename T>
friend class DatabaseWorkerPool;
template <typename T>
friend class DatabaseWorkerPool;
public:
Transaction() { }
~Transaction() { Cleanup(); }
TransactionBase() : _cleanedUp(false) { }
virtual ~TransactionBase() { Cleanup(); }
void Append(PreparedStatement* statement);
void Append(const char* sql);
void PAppend(const char* sql, ...);
void Append(char const* sql);
template<typename Format, typename... Args>
void PAppend(Format&& sql, Args&&... args)
{
Append(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str());
}
[[nodiscard]] size_t GetSize() const { return m_queries.size(); }
std::size_t GetSize() const { return m_queries.size(); }
protected:
void AppendPreparedStatement(PreparedStatementBase* statement);
void Cleanup();
std::list<SQLElementData> m_queries;
std::vector<SQLElementData> m_queries;
private:
bool _cleanedUp{false};
bool _cleanedUp;
};
typedef std::shared_ptr<Transaction> SQLTransaction;
template<typename T>
class Transaction : public TransactionBase
{
public:
using TransactionBase::Append;
void Append(PreparedStatement<T>* statement)
{
AppendPreparedStatement(statement);
}
};
/*! Low level class*/
class TransactionTask : public SQLOperation
class AC_DATABASE_API TransactionTask : public SQLOperation
{
template <class T> friend class DatabaseWorkerPool;
friend class DatabaseWorker;
template <class T> friend class DatabaseWorkerPool;
friend class DatabaseWorker;
friend class TransactionCallback;
public:
TransactionTask(SQLTransaction trans) : m_trans(std::move(trans)) { } ;
~TransactionTask() override = default;
TransactionTask(std::shared_ptr<TransactionBase> trans) : m_trans(trans) { }
~TransactionTask() { }
protected:
bool Execute() override;
int TryExecute();
void CleanupOnFailure();
std::shared_ptr<TransactionBase> m_trans;
static std::mutex _deadlockLock;
};
class AC_DATABASE_API TransactionWithResultTask : public TransactionTask
{
public:
TransactionWithResultTask(std::shared_ptr<TransactionBase> trans) : TransactionTask(trans) { }
TransactionFuture GetFuture() { return m_result.get_future(); }
protected:
bool Execute() override;
SQLTransaction m_trans;
TransactionPromise m_result;
};
class AC_DATABASE_API TransactionCallback
{
public:
TransactionCallback(TransactionFuture&& future) : m_future(std::move(future)) { }
TransactionCallback(TransactionCallback&&) = default;
TransactionCallback& operator=(TransactionCallback&&) = default;
void AfterComplete(std::function<void(bool)> callback) &
{
m_callback = std::move(callback);
}
bool InvokeIfReady();
TransactionFuture m_future;
std::function<void(bool)> m_callback;
};
#endif

View File

@@ -19,7 +19,7 @@ void AppenderDB::_write(LogMessage const* message)
if (!enabled || (message->type.find("sql") != std::string::npos))
return;
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_LOG);
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_LOG);
stmt->setUInt64(0, message->mtime);
stmt->setUInt32(1, realmId);
stmt->setString(2, message->type);

View File

@@ -1,19 +1,23 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "Define.h"
// #include "DatabaseEnvFwd.h"
#include "DatabaseEnvFwd.h"
#include "Errors.h"
#include "Field.h"
#include "Log.h"
#include "MySQLConnection.h"
// #include "MySQLPreparedStatement.h"
// #include "MySQLWorkaround.h"
#include "MySQLPreparedStatement.h"
#include "MySQLWorkaround.h"
#include "PreparedStatement.h"
#include "QueryResult.h"
#include "SQLOperation.h"
#include "Transaction.h"
#ifdef _WIN32 // hack for broken mysql.h not including the correct winsock header for SOCKET definition, fixed in 5.7
#include <winsock2.h>
#endif
#include <mysql.h>
#include <string>
#include <vector>

View File

@@ -0,0 +1,433 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "DBUpdater.h"
#include "BuiltInConfig.h"
#include "Config.h"
#include "DatabaseEnv.h"
#include "DatabaseLoader.h"
#include "GitRevision.h"
#include "Log.h"
#include "QueryResult.h"
#include "StartProcess.h"
#include "UpdateFetcher.h"
#include <boost/filesystem/operations.hpp>
#include <fstream>
#include <iostream>
std::string DBUpdaterUtil::GetCorrectedMySQLExecutable()
{
if (!corrected_path().empty())
return corrected_path();
else
return BuiltInConfig::GetMySQLExecutable();
}
bool DBUpdaterUtil::CheckExecutable()
{
boost::filesystem::path exe(GetCorrectedMySQLExecutable());
if (!is_regular_file(exe))
{
exe = Warhead::SearchExecutableInPath("mysql");
if (!exe.empty() && is_regular_file(exe))
{
// Correct the path to the cli
corrected_path() = absolute(exe).generic_string();
return true;
}
LOG_FATAL("sql.updates", "Didn't find any executable MySQL binary at \'%s\' or in path, correct the path in the *.conf (\"MySQLExecutable\").",
absolute(exe).generic_string().c_str());
return false;
}
return true;
}
std::string& DBUpdaterUtil::corrected_path()
{
static std::string path;
return path;
}
// Auth Database
template<>
std::string DBUpdater<LoginDatabaseConnection>::GetConfigEntry()
{
return "Updates.Auth";
}
template<>
std::string DBUpdater<LoginDatabaseConnection>::GetTableName()
{
return "Auth";
}
template<>
std::string DBUpdater<LoginDatabaseConnection>::GetBaseFilesDirectory()
{
return BuiltInConfig::GetSourceDirectory() + "/data/sql/base/db_auth/";
}
template<>
bool DBUpdater<LoginDatabaseConnection>::IsEnabled(uint32 const updateMask)
{
// This way silences warnings under msvc
return (updateMask & DatabaseLoader::DATABASE_LOGIN) ? true : false;
}
template<>
std::string DBUpdater<LoginDatabaseConnection>::GetDBModuleName()
{
return "db_auth";
}
// World Database
template<>
std::string DBUpdater<WorldDatabaseConnection>::GetConfigEntry()
{
return "Updates.World";
}
template<>
std::string DBUpdater<WorldDatabaseConnection>::GetTableName()
{
return "World";
}
template<>
std::string DBUpdater<WorldDatabaseConnection>::GetBaseFilesDirectory()
{
return BuiltInConfig::GetSourceDirectory() + "/data/sql/base/db_world/";
}
template<>
bool DBUpdater<WorldDatabaseConnection>::IsEnabled(uint32 const updateMask)
{
// This way silences warnings under msvc
return (updateMask & DatabaseLoader::DATABASE_WORLD) ? true : false;
}
template<>
std::string DBUpdater<WorldDatabaseConnection>::GetDBModuleName()
{
return "db_world";
}
// Character Database
template<>
std::string DBUpdater<CharacterDatabaseConnection>::GetConfigEntry()
{
return "Updates.Character";
}
template<>
std::string DBUpdater<CharacterDatabaseConnection>::GetTableName()
{
return "Character";
}
template<>
std::string DBUpdater<CharacterDatabaseConnection>::GetBaseFilesDirectory()
{
return BuiltInConfig::GetSourceDirectory() + "/data/sql/base/db_characters/";
}
template<>
bool DBUpdater<CharacterDatabaseConnection>::IsEnabled(uint32 const updateMask)
{
// This way silences warnings under msvc
return (updateMask & DatabaseLoader::DATABASE_CHARACTER) ? true : false;
}
template<>
std::string DBUpdater<CharacterDatabaseConnection>::GetDBModuleName()
{
return "db_characters";
}
// All
template<class T>
BaseLocation DBUpdater<T>::GetBaseLocationType()
{
return LOCATION_REPOSITORY;
}
template<class T>
bool DBUpdater<T>::Create(DatabaseWorkerPool<T>& pool)
{
LOG_WARN("sql.updates", "Database \"%s\" does not exist, do you want to create it? [yes (default) / no]: ",
pool.GetConnectionInfo()->database.c_str());
std::string answer;
std::getline(std::cin, answer);
if (!sConfigMgr->isDryRun() && !answer.empty() && !(answer.substr(0, 1) == "y"))
return false;
LOG_INFO("sql.updates", "Creating database \"%s\"...", pool.GetConnectionInfo()->database.c_str());
// Path of temp file
static Path const temp("create_table.sql");
// Create temporary query to use external MySQL CLi
std::ofstream file(temp.generic_string());
if (!file.is_open())
{
LOG_FATAL("sql.updates", "Failed to create temporary query file \"%s\"!", temp.generic_string().c_str());
return false;
}
file << "CREATE DATABASE `" << pool.GetConnectionInfo()->database << "` DEFAULT CHARACTER SET UTF8MB4 COLLATE utf8mb4_general_ci;\n\n";
file.close();
try
{
DBUpdater<T>::ApplyFile(pool, pool.GetConnectionInfo()->host, pool.GetConnectionInfo()->user, pool.GetConnectionInfo()->password,
pool.GetConnectionInfo()->port_or_socket, "", pool.GetConnectionInfo()->ssl, temp);
}
catch (UpdateException&)
{
LOG_FATAL("sql.updates", "Failed to create database %s! Does the user (named in *.conf) have `CREATE`, `ALTER`, `DROP`, `INSERT` and `DELETE` privileges on the MySQL server?", pool.GetConnectionInfo()->database.c_str());
boost::filesystem::remove(temp);
return false;
}
LOG_INFO("sql.updates", "Done.");
LOG_INFO("sql.updates", " ");
boost::filesystem::remove(temp);
return true;
}
template<class T>
bool DBUpdater<T>::Update(DatabaseWorkerPool<T>& pool)
{
if (!DBUpdaterUtil::CheckExecutable())
return false;
LOG_INFO("sql.updates", "Updating %s database...", DBUpdater<T>::GetTableName().c_str());
Path const sourceDirectory(BuiltInConfig::GetSourceDirectory());
if (!is_directory(sourceDirectory))
{
LOG_ERROR("sql.updates", "DBUpdater: The given source directory %s does not exist, change the path to the directory where your sql directory exists (for example c:\\source\\trinitycore). Shutting down.",
sourceDirectory.generic_string().c_str());
return false;
}
auto CheckUpdateTable = [&](std::string const& tableName)
{
auto checkTable = DBUpdater<T>::Retrieve(pool, Warhead::StringFormat("SHOW TABLES LIKE '%s'", tableName.c_str()));
if (!checkTable)
{
LOG_WARN("sql.updates", "> Table '%s' not exist! Try add based table", tableName.c_str());
Path const temp(GetBaseFilesDirectory() + tableName + ".sql");
try
{
DBUpdater<T>::ApplyFile(pool, temp);
}
catch (UpdateException&)
{
LOG_FATAL("sql.updates", "Failed apply file to database %s! Does the user (named in *.conf) have `INSERT` and `DELETE` privileges on the MySQL server?", pool.GetConnectionInfo()->database.c_str());
return false;
}
return true;
}
return true;
};
if (!CheckUpdateTable("updates") || !CheckUpdateTable("updates_include"))
return false;
UpdateFetcher updateFetcher(sourceDirectory, [&](std::string const & query) { DBUpdater<T>::Apply(pool, query); },
[&](Path const & file) { DBUpdater<T>::ApplyFile(pool, file); },
[&](std::string const & query) -> QueryResult { return DBUpdater<T>::Retrieve(pool, query); }, DBUpdater<T>::GetDBModuleName());
UpdateResult result;
try
{
result = updateFetcher.Update(
sConfigMgr->GetBoolDefault("Updates.Redundancy", true),
sConfigMgr->GetBoolDefault("Updates.AllowRehash", true),
sConfigMgr->GetBoolDefault("Updates.ArchivedRedundancy", false),
sConfigMgr->GetIntDefault("Updates.CleanDeadRefMaxCount", 3));
}
catch (UpdateException&)
{
return false;
}
std::string const info = Warhead::StringFormat("Containing " SZFMTD " new and " SZFMTD " archived updates.",
result.recent, result.archived);
if (!result.updated)
LOG_INFO("sql.updates", ">> %s database is up-to-date! %s", DBUpdater<T>::GetTableName().c_str(), info.c_str());
else
LOG_INFO("sql.updates", ">> Applied " SZFMTD " %s. %s", result.updated, result.updated == 1 ? "query" : "queries", info.c_str());
LOG_INFO("sql.updates", " ");
return true;
}
template<class T>
bool DBUpdater<T>::Populate(DatabaseWorkerPool<T>& pool)
{
{
QueryResult const result = Retrieve(pool, "SHOW TABLES");
if (result && (result->GetRowCount() > 0))
return true;
}
if (!DBUpdaterUtil::CheckExecutable())
return false;
LOG_INFO("sql.updates", "Database %s is empty, auto populating it...", DBUpdater<T>::GetTableName().c_str());
std::string const DirPathStr = DBUpdater<T>::GetBaseFilesDirectory();
Path const DirPath(DirPathStr);
if (!boost::filesystem::is_directory(DirPath))
{
LOG_ERROR("sql.updates", ">> Directory \"%s\" not exist", DirPath.generic_string().c_str());
return false;
}
if (DirPath.empty())
{
LOG_ERROR("sql.updates", ">> Directory \"%s\" is empty", DirPath.generic_string().c_str());
return false;
}
boost::filesystem::directory_iterator const DirItr;
uint32 FilesCount = 0;
for (boost::filesystem::directory_iterator itr(DirPath); itr != DirItr; ++itr)
{
if (itr->path().extension() == ".sql")
FilesCount++;
}
if (!FilesCount)
{
LOG_ERROR("sql.updates", ">> In directory \"%s\" not exist '*.sql' files", DirPath.generic_string().c_str());
return false;
}
for (boost::filesystem::directory_iterator itr(DirPath); itr != DirItr; ++itr)
{
if (itr->path().extension() != ".sql")
continue;
LOG_INFO("sql.updates", ">> Applying \'%s\'...", itr->path().filename().generic_string().c_str());
try
{
ApplyFile(pool, itr->path());
}
catch (UpdateException&)
{
return false;
}
}
LOG_INFO("sql.updates", ">> Done!");
LOG_INFO("sql.updates", " ");
return true;
}
template<class T>
QueryResult DBUpdater<T>::Retrieve(DatabaseWorkerPool<T>& pool, std::string const& query)
{
return pool.Query(query.c_str());
}
template<class T>
void DBUpdater<T>::Apply(DatabaseWorkerPool<T>& pool, std::string const& query)
{
pool.DirectExecute(query.c_str());
}
template<class T>
void DBUpdater<T>::ApplyFile(DatabaseWorkerPool<T>& pool, Path const& path)
{
DBUpdater<T>::ApplyFile(pool, pool.GetConnectionInfo()->host, pool.GetConnectionInfo()->user, pool.GetConnectionInfo()->password,
pool.GetConnectionInfo()->port_or_socket, pool.GetConnectionInfo()->database, pool.GetConnectionInfo()->ssl, path);
}
template<class T>
void DBUpdater<T>::ApplyFile(DatabaseWorkerPool<T>& pool, std::string const& host, std::string const& user,
std::string const& password, std::string const& port_or_socket, std::string const& database, std::string const& ssl, Path const& path)
{
std::vector<std::string> args;
args.reserve(7);
// CLI Client connection info
args.emplace_back("-h" + host);
args.emplace_back("-u" + user);
if (!password.empty())
args.emplace_back("-p" + password);
// Check if we want to connect through ip or socket (Unix only)
#ifdef _WIN32
if (host == ".")
args.emplace_back("--protocol=PIPE");
else
args.emplace_back("-P" + port_or_socket);
#else
if (!std::isdigit(port_or_socket[0]))
{
// We can't check if host == "." here, because it is named localhost if socket option is enabled
args.emplace_back("-P0");
args.emplace_back("--protocol=SOCKET");
args.emplace_back("-S" + port_or_socket);
}
else
// generic case
args.emplace_back("-P" + port_or_socket);
#endif
// Set the default charset to utf8
args.emplace_back("--default-character-set=utf8");
// Set max allowed packet to 1 GB
args.emplace_back("--max-allowed-packet=1GB");
if (ssl == "ssl")
args.emplace_back("--ssl");
// Database
if (!database.empty())
args.emplace_back(database);
// Invokes a mysql process which doesn't leak credentials to logs
int const ret = Warhead::StartProcess(DBUpdaterUtil::GetCorrectedMySQLExecutable(), args,
"sql.updates", path.generic_string(), true);
if (ret != EXIT_SUCCESS)
{
LOG_FATAL("sql.updates", "Applying of file \'%s\' to database \'%s\' failed!" \
" If you are a user, please pull the latest revision from the repository. "
"Also make sure you have not applied any of the databases with your sql client. "
"You cannot use auto-update system and import sql files from WarheadCore repository with your sql client. "
"If you are a developer, please fix your sql query.",
path.generic_string().c_str(), pool.GetConnectionInfo()->database.c_str());
throw UpdateException("update failed");
}
}
template class AC_DATABASE_API DBUpdater<LoginDatabaseConnection>;
template class AC_DATABASE_API DBUpdater<WorldDatabaseConnection>;
template class AC_DATABASE_API DBUpdater<CharacterDatabaseConnection>;

View File

@@ -0,0 +1,79 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef DBUpdater_h__
#define DBUpdater_h__
#include "DatabaseEnv.h"
#include "Define.h"
#include <string>
template <class T>
class DatabaseWorkerPool;
namespace boost
{
namespace filesystem
{
class path;
}
}
class AC_DATABASE_API UpdateException : public std::exception
{
public:
UpdateException(std::string const& msg) : _msg(msg) { }
~UpdateException() throw() { }
char const* what() const throw() override { return _msg.c_str(); }
private:
std::string const _msg;
};
enum BaseLocation
{
LOCATION_REPOSITORY,
LOCATION_DOWNLOAD
};
class AC_DATABASE_API DBUpdaterUtil
{
public:
static std::string GetCorrectedMySQLExecutable();
static bool CheckExecutable();
private:
static std::string& corrected_path();
};
template <class T>
class AC_DATABASE_API DBUpdater
{
public:
using Path = boost::filesystem::path;
static inline std::string GetConfigEntry();
static inline std::string GetTableName();
static std::string GetBaseFilesDirectory();
static bool IsEnabled(uint32 const updateMask);
static BaseLocation GetBaseLocationType();
static bool Create(DatabaseWorkerPool<T>& pool);
static bool Update(DatabaseWorkerPool<T>& pool);
static bool Populate(DatabaseWorkerPool<T>& pool);
// module
static std::string GetDBModuleName();
private:
static QueryResult Retrieve(DatabaseWorkerPool<T>& pool, std::string const& query);
static void Apply(DatabaseWorkerPool<T>& pool, std::string const& query);
static void ApplyFile(DatabaseWorkerPool<T>& pool, Path const& path);
static void ApplyFile(DatabaseWorkerPool<T>& pool, std::string const& host, std::string const& user,
std::string const& password, std::string const& port_or_socket, std::string const& database, std::string const& ssl, Path const& path);
};
#endif // DBUpdater_h__

View File

@@ -0,0 +1,433 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#include "UpdateFetcher.h"
#include "Common.h"
#include "CryptoHash.h"
#include "DBUpdater.h"
#include "Field.h"
#include "Log.h"
#include "QueryResult.h"
#include "Tokenize.h"
#include "Util.h"
#include <boost/filesystem/operations.hpp>
#include <fstream>
#include <sstream>
using namespace boost::filesystem;
struct UpdateFetcher::DirectoryEntry
{
DirectoryEntry(Path const& path_, State state_) : path(path_), state(state_) { }
Path const path;
State const state;
};
UpdateFetcher::UpdateFetcher(Path const& sourceDirectory,
std::function<void(std::string const&)> const& apply,
std::function<void(Path const& path)> const& applyFile,
std::function<QueryResult(std::string const&)> const& retrieve, std::string const& dbModuleName_) :
_sourceDirectory(std::make_unique<Path>(sourceDirectory)), _apply(apply), _applyFile(applyFile),
_retrieve(retrieve), _dbModuleName(dbModuleName_)
{
}
UpdateFetcher::~UpdateFetcher()
{
}
UpdateFetcher::LocaleFileStorage UpdateFetcher::GetFileList() const
{
LocaleFileStorage files;
DirectoryStorage directories = ReceiveIncludedDirectories();
for (auto const& entry : directories)
FillFileListRecursively(entry.path, files, entry.state, 1);
return files;
}
void UpdateFetcher::FillFileListRecursively(Path const& path, LocaleFileStorage& storage, State const state, uint32 const depth) const
{
static uint32 const MAX_DEPTH = 10;
static directory_iterator const end;
for (directory_iterator itr(path); itr != end; ++itr)
{
if (is_directory(itr->path()))
{
if (depth < MAX_DEPTH)
FillFileListRecursively(itr->path(), storage, state, depth + 1);
}
else if (itr->path().extension() == ".sql")
{
LOG_TRACE("sql.updates", "Added locale file \"%s\".", itr->path().filename().generic_string().c_str());
LocaleFileEntry const entry = { itr->path(), state };
// Check for doubled filenames
// Because elements are only compared by their filenames, this is ok
if (storage.find(entry) != storage.end())
{
LOG_FATAL("sql.updates", "Duplicate filename \"%s\" occurred. Because updates are ordered " \
"by their filenames, every name needs to be unique!", itr->path().generic_string().c_str());
throw UpdateException("Updating failed, see the log for details.");
}
storage.insert(entry);
}
}
}
UpdateFetcher::DirectoryStorage UpdateFetcher::ReceiveIncludedDirectories() const
{
DirectoryStorage directories;
QueryResult const result = _retrieve("SELECT `path`, `state` FROM `updates_include`");
if (!result)
return directories;
do
{
Field* fields = result->Fetch();
std::string path = fields[0].GetString();
if (path.substr(0, 1) == "$")
path = _sourceDirectory->generic_string() + path.substr(1);
Path const p(path);
if (!is_directory(p))
{
LOG_WARN("sql.updates", "DBUpdater: Given update include directory \"%s\" does not exist, skipped!", p.generic_string().c_str());
continue;
}
DirectoryEntry const entry = { p, AppliedFileEntry::StateConvert(fields[1].GetString()) };
directories.push_back(entry);
LOG_TRACE("sql.updates", "Added applied file \"%s\" from remote.", p.filename().generic_string().c_str());
} while (result->NextRow());
std::vector<std::string> moduleList;
auto const& _modulesTokens = Warhead::Tokenize(WH_MODULES_LIST, ',', true);
for (auto const& itr : _modulesTokens)
moduleList.push_back(std::string(itr));
for (auto const& itr : moduleList)
{
std::string path = _sourceDirectory->generic_string() + "/modules/" + itr + "/sql/" + _dbModuleName; // module/mod-name/sql/db_world
Path const p(path);
if (!is_directory(p))
continue;
DirectoryEntry const entry = { p, AppliedFileEntry::StateConvert("RELEASED") };
directories.push_back(entry);
LOG_TRACE("sql.updates", "Added applied modules file \"%s\" from remote.", p.filename().generic_string().c_str());
}
return directories;
}
UpdateFetcher::AppliedFileStorage UpdateFetcher::ReceiveAppliedFiles() const
{
AppliedFileStorage map;
QueryResult result = _retrieve("SELECT `name`, `hash`, `state`, UNIX_TIMESTAMP(`timestamp`) FROM `updates` ORDER BY `name` ASC");
if (!result)
return map;
do
{
Field* fields = result->Fetch();
AppliedFileEntry const entry = { fields[0].GetString(), fields[1].GetString(),
AppliedFileEntry::StateConvert(fields[2].GetString()), fields[3].GetUInt64()
};
map.insert(std::make_pair(entry.name, entry));
} while (result->NextRow());
return map;
}
std::string UpdateFetcher::ReadSQLUpdate(Path const& file) const
{
std::ifstream in(file.c_str());
if (!in.is_open())
{
LOG_FATAL("sql.updates", "Failed to open the sql update \"%s\" for reading! "
"Stopping the server to keep the database integrity, "
"try to identify and solve the issue or disable the database updater.",
file.generic_string().c_str());
throw UpdateException("Opening the sql update failed!");
}
auto update = [&in]
{
std::ostringstream ss;
ss << in.rdbuf();
return ss.str();
}();
in.close();
return update;
}
UpdateResult UpdateFetcher::Update(bool const redundancyChecks,
bool const allowRehash,
bool const archivedRedundancy,
int32 const cleanDeadReferencesMaxCount) const
{
LocaleFileStorage const available = GetFileList();
AppliedFileStorage applied = ReceiveAppliedFiles();
size_t countRecentUpdates = 0;
size_t countArchivedUpdates = 0;
// Count updates
for (auto const& entry : applied)
if (entry.second.state == RELEASED)
++countRecentUpdates;
else
++countArchivedUpdates;
// Fill hash to name cache
HashToFileNameStorage hashToName;
for (auto entry : applied)
hashToName.insert(std::make_pair(entry.second.hash, entry.first));
size_t importedUpdates = 0;
for (auto const& availableQuery : available)
{
LOG_DEBUG("sql.updates", "Checking update \"%s\"...", availableQuery.first.filename().generic_string().c_str());
AppliedFileStorage::const_iterator iter = applied.find(availableQuery.first.filename().string());
if (iter != applied.end())
{
// If redundancy is disabled, skip it, because the update is already applied.
if (!redundancyChecks)
{
LOG_DEBUG("sql.updates", ">> Update is already applied, skipping redundancy checks.");
applied.erase(iter);
continue;
}
// If the update is in an archived directory and is marked as archived in our database, skip redundancy checks (archived updates never change).
if (!archivedRedundancy && (iter->second.state == ARCHIVED) && (availableQuery.second == ARCHIVED))
{
LOG_DEBUG("sql.updates", ">> Update is archived and marked as archived in database, skipping redundancy checks.");
applied.erase(iter);
continue;
}
}
std::string const hash = ByteArrayToHexStr(Warhead::Crypto::SHA1::GetDigestOf(ReadSQLUpdate(availableQuery.first)));
UpdateMode mode = MODE_APPLY;
// Update is not in our applied list
if (iter == applied.end())
{
// Catch renames (different filename, but same hash)
HashToFileNameStorage::const_iterator const hashIter = hashToName.find(hash);
if (hashIter != hashToName.end())
{
// Check if the original file was removed. If not, we've got a problem.
LocaleFileStorage::const_iterator localeIter;
// Push localeIter forward
for (localeIter = available.begin(); (localeIter != available.end()) &&
(localeIter->first.filename().string() != hashIter->second); ++localeIter);
// Conflict!
if (localeIter != available.end())
{
LOG_WARN("sql.updates", ">> It seems like the update \"%s\" \'%s\' was renamed, but the old file is still there! " \
"Treating it as a new file! (It is probably an unmodified copy of the file \"%s\")",
availableQuery.first.filename().string().c_str(), hash.substr(0, 7).c_str(),
localeIter->first.filename().string().c_str());
}
// It is safe to treat the file as renamed here
else
{
LOG_INFO("sql.updates", ">> Renaming update \"%s\" to \"%s\" \'%s\'.",
hashIter->second.c_str(), availableQuery.first.filename().string().c_str(), hash.substr(0, 7).c_str());
RenameEntry(hashIter->second, availableQuery.first.filename().string());
applied.erase(hashIter->second);
continue;
}
}
// Apply the update if it was never seen before.
else
{
LOG_INFO("sql.updates", ">> Applying update \"%s\" \'%s\'...",
availableQuery.first.filename().string().c_str(), hash.substr(0, 7).c_str());
}
}
// Rehash the update entry if it exists in our database with an empty hash.
else if (allowRehash && iter->second.hash.empty())
{
mode = MODE_REHASH;
LOG_INFO("sql.updates", ">> Re-hashing update \"%s\" \'%s\'...", availableQuery.first.filename().string().c_str(),
hash.substr(0, 7).c_str());
}
else
{
// If the hash of the files differs from the one stored in our database, reapply the update (because it changed).
if (iter->second.hash != hash)
{
LOG_INFO("sql.updates", ">> Reapplying update \"%s\" \'%s\' -> \'%s\' (it changed)...", availableQuery.first.filename().string().c_str(),
iter->second.hash.substr(0, 7).c_str(), hash.substr(0, 7).c_str());
}
else
{
// If the file wasn't changed and just moved, update its state (if necessary).
if (iter->second.state != availableQuery.second)
{
LOG_DEBUG("sql.updates", ">> Updating the state of \"%s\" to \'%s\'...",
availableQuery.first.filename().string().c_str(), AppliedFileEntry::StateConvert(availableQuery.second).c_str());
UpdateState(availableQuery.first.filename().string(), availableQuery.second);
}
LOG_DEBUG("sql.updates", ">> Update is already applied and matches the hash \'%s\'.", hash.substr(0, 7).c_str());
applied.erase(iter);
continue;
}
}
uint32 speed = 0;
AppliedFileEntry const file = { availableQuery.first.filename().string(), hash, availableQuery.second, 0 };
switch (mode)
{
case MODE_APPLY:
speed = Apply(availableQuery.first);
/* fallthrough */
case MODE_REHASH:
UpdateEntry(file, speed);
break;
}
if (iter != applied.end())
applied.erase(iter);
if (mode == MODE_APPLY)
++importedUpdates;
}
// Cleanup up orphaned entries (if enabled)
if (!applied.empty())
{
bool const doCleanup = (cleanDeadReferencesMaxCount < 0) || (applied.size() <= static_cast<size_t>(cleanDeadReferencesMaxCount));
for (auto const& entry : applied)
{
LOG_WARN("sql.updates", ">> The file \'%s\' was applied to the database, but is missing in" \
" your update directory now!", entry.first.c_str());
if (doCleanup)
LOG_INFO("sql.updates", "Deleting orphaned entry \'%s\'...", entry.first.c_str());
}
if (doCleanup)
CleanUp(applied);
else
{
LOG_ERROR("sql.updates", "Cleanup is disabled! There were " SZFMTD " dirty files applied to your database, " \
"but they are now missing in your source directory!", applied.size());
}
}
return UpdateResult(importedUpdates, countRecentUpdates, countArchivedUpdates);
}
uint32 UpdateFetcher::Apply(Path const& path) const
{
using Time = std::chrono::high_resolution_clock;
// Benchmark query speed
auto const begin = Time::now();
// Update database
_applyFile(path);
// Return the time it took the query to apply
return uint32(std::chrono::duration_cast<std::chrono::milliseconds>(Time::now() - begin).count());
}
void UpdateFetcher::UpdateEntry(AppliedFileEntry const& entry, uint32 const speed) const
{
std::string const update = "REPLACE INTO `updates` (`name`, `hash`, `state`, `speed`) VALUES (\"" +
entry.name + "\", \"" + entry.hash + "\", \'" + entry.GetStateAsString() + "\', " + std::to_string(speed) + ")";
// Update database
_apply(update);
}
void UpdateFetcher::RenameEntry(std::string const& from, std::string const& to) const
{
// Delete the target if it exists
{
std::string const update = "DELETE FROM `updates` WHERE `name`=\"" + to + "\"";
// Update database
_apply(update);
}
// Rename
{
std::string const update = "UPDATE `updates` SET `name`=\"" + to + "\" WHERE `name`=\"" + from + "\"";
// Update database
_apply(update);
}
}
void UpdateFetcher::CleanUp(AppliedFileStorage const& storage) const
{
if (storage.empty())
return;
std::stringstream update;
size_t remaining = storage.size();
update << "DELETE FROM `updates` WHERE `name` IN(";
for (auto const& entry : storage)
{
update << "\"" << entry.first << "\"";
if ((--remaining) > 0)
update << ", ";
}
update << ")";
// Update database
_apply(update.str());
}
void UpdateFetcher::UpdateState(std::string const& name, State const state) const
{
std::string const update = "UPDATE `updates` SET `state`=\'" + AppliedFileEntry::StateConvert(state) + "\' WHERE `name`=\"" + name + "\"";
// Update database
_apply(update);
}
bool UpdateFetcher::PathCompare::operator()(LocaleFileEntry const& left, LocaleFileEntry const& right) const
{
return left.first.filename().string() < right.first.filename().string();
}

View File

@@ -0,0 +1,131 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef UpdateFetcher_h__
#define UpdateFetcher_h__
#include "DatabaseEnv.h"
#include "Define.h"
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
namespace boost::filesystem
{
class path;
}
struct AC_DATABASE_API UpdateResult
{
UpdateResult()
: updated(0), recent(0), archived(0) { }
UpdateResult(size_t const updated_, size_t const recent_, size_t const archived_)
: updated(updated_), recent(recent_), archived(archived_) { }
size_t updated;
size_t recent;
size_t archived;
};
class AC_DATABASE_API UpdateFetcher
{
typedef boost::filesystem::path Path;
public:
UpdateFetcher(Path const& updateDirectory,
std::function<void(std::string const&)> const& apply,
std::function<void(Path const& path)> const& applyFile,
std::function<QueryResult(std::string const&)> const& retrieve, std::string const& dbModuleName);
~UpdateFetcher();
UpdateResult Update(bool const redundancyChecks, bool const allowRehash,
bool const archivedRedundancy, int32 const cleanDeadReferencesMaxCount) const;
private:
enum UpdateMode
{
MODE_APPLY,
MODE_REHASH
};
enum State
{
RELEASED,
ARCHIVED
};
struct AppliedFileEntry
{
AppliedFileEntry(std::string const& name_, std::string const& hash_, State state_, uint64 timestamp_)
: name(name_), hash(hash_), state(state_), timestamp(timestamp_) { }
std::string const name;
std::string const hash;
State const state;
uint64 const timestamp;
static inline State StateConvert(std::string const& state)
{
return (state == "RELEASED") ? RELEASED : ARCHIVED;
}
static inline std::string StateConvert(State const state)
{
return (state == RELEASED) ? "RELEASED" : "ARCHIVED";
}
std::string GetStateAsString() const
{
return StateConvert(state);
}
};
struct DirectoryEntry;
typedef std::pair<Path, State> LocaleFileEntry;
struct PathCompare
{
bool operator()(LocaleFileEntry const& left, LocaleFileEntry const& right) const;
};
typedef std::set<LocaleFileEntry, PathCompare> LocaleFileStorage;
typedef std::unordered_map<std::string, std::string> HashToFileNameStorage;
typedef std::unordered_map<std::string, AppliedFileEntry> AppliedFileStorage;
typedef std::vector<UpdateFetcher::DirectoryEntry> DirectoryStorage;
LocaleFileStorage GetFileList() const;
void FillFileListRecursively(Path const& path, LocaleFileStorage& storage,
State const state, uint32 const depth) const;
DirectoryStorage ReceiveIncludedDirectories() const;
AppliedFileStorage ReceiveAppliedFiles() const;
std::string ReadSQLUpdate(Path const& file) const;
uint32 Apply(Path const& path) const;
void UpdateEntry(AppliedFileEntry const& entry, uint32 const speed = 0) const;
void RenameEntry(std::string const& from, std::string const& to) const;
void CleanUp(AppliedFileStorage const& storage) const;
void UpdateState(std::string const& name, State const state) const;
std::unique_ptr<Path> const _sourceDirectory;
std::function<void(std::string const&)> const _apply;
std::function<void(Path const& path)> const _applyFile;
std::function<QueryResult(std::string const&)> const _retrieve;
// modules
std::string const _dbModuleName;
};
#endif // UpdateFetcher_h__