feat(Core/DBLayer): replace char const* to std::string_view (#10211)

* feat(Core/DBLayer): replace `char const*` to `std::string_view`

* CString

* 1

* chore(Core/Misc): code cleanup

* cl

* db fix

* fmt style sql

* to fmt

* py

* del old

* 1

* 2

* 3

* 1

* 1
This commit is contained in:
Kargatum
2022-02-05 06:37:11 +07:00
committed by GitHub
parent d6ead1d1e0
commit de13bf426e
140 changed files with 5055 additions and 4882 deletions

View File

@@ -19,23 +19,21 @@
#include "Errors.h"
#include "MySQLConnection.h"
#include "QueryResult.h"
#include <cstdlib>
#include <cstring>
/*! Basic, ad-hoc queries. */
BasicStatementTask::BasicStatementTask(char const* sql, bool async) :
m_result(nullptr)
BasicStatementTask::BasicStatementTask(std::string_view sql, bool async) : m_result(nullptr)
{
m_sql = strdup(sql);
m_sql = std::string(sql);
m_has_result = async; // If the operation is async, then there's a result
if (async)
m_result = new QueryResultPromise();
}
BasicStatementTask::~BasicStatementTask()
{
free((void*)m_sql);
if (m_has_result && m_result != nullptr)
m_sql.clear();
if (m_has_result && m_result)
delete m_result;
}

View File

@@ -26,14 +26,14 @@
class AC_DATABASE_API BasicStatementTask : public SQLOperation
{
public:
BasicStatementTask(char const* sql, bool async = false);
BasicStatementTask(std::string_view sql, bool async = false);
~BasicStatementTask();
bool Execute() override;
QueryResultFuture GetFuture() const { return m_result->get_future(); }
private:
char const* m_sql; //- Raw query to be executed
std::string m_sql; //- Raw query to be executed
bool m_has_result;
QueryResultPromise* m_result;
};

View File

@@ -68,7 +68,7 @@ DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::st
while (reconnectCount < attempts)
{
LOG_INFO(_logger, "> Retrying after {} seconds", static_cast<uint32>(reconnectSeconds.count()));
LOG_WARN(_logger, "> Retrying after {} seconds", static_cast<uint32>(reconnectSeconds.count()));
std::this_thread::sleep_for(reconnectSeconds);
error = pool.Open();
@@ -153,7 +153,22 @@ DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::st
bool DatabaseLoader::Load()
{
return OpenDatabases() && PopulateDatabases() && UpdateDatabases() && PrepareStatements();
if (!_updateFlags)
LOG_INFO("sql.updates", "Automatic database updates are disabled for all databases!");
if (!OpenDatabases())
return false;
if (!PopulateDatabases())
return false;
if (!UpdateDatabases())
return false;
if (!PrepareStatements())
return false;
return true;
}
bool DatabaseLoader::OpenDatabases()

View File

@@ -17,11 +17,10 @@
#include "DatabaseWorkerPool.h"
#include "AdhocStatement.h"
#include "CharacterDatabase.h"
#include "Errors.h"
#include "Implementation/CharacterDatabase.h"
#include "Implementation/LoginDatabase.h"
#include "Implementation/WorldDatabase.h"
#include "Log.h"
#include "LoginDatabase.h"
#include "MySQLPreparedStatement.h"
#include "MySQLWorkaround.h"
#include "PCQueue.h"
@@ -31,7 +30,9 @@
#include "QueryResult.h"
#include "SQLOperation.h"
#include "Transaction.h"
#include "WorldDatabase.h"
#include <mysqld_error.h>
#include <limits>
#ifdef ACORE_DEBUG
#include <boost/stacktrace.hpp>
@@ -57,9 +58,10 @@ class PingOperation : public SQLOperation
};
template <class T>
DatabaseWorkerPool<T>::DatabaseWorkerPool()
: _queue(new ProducerConsumerQueue<SQLOperation*>()),
_async_threads(0), _synch_threads(0)
DatabaseWorkerPool<T>::DatabaseWorkerPool() :
_queue(new ProducerConsumerQueue<SQLOperation*>()),
_async_threads(0),
_synch_threads(0)
{
WPFatal(mysql_thread_safe(), "Used MySQL library isn't thread-safe.");
@@ -72,7 +74,7 @@ DatabaseWorkerPool<T>::DatabaseWorkerPool()
#endif
WPFatal(isSupportClientDB, "AzerothCore does not support MySQL versions below 5.7 and MariaDB 10.5\nSearch the wiki for ACE00043 in Common Errors (https://www.azerothcore.org/wiki/common-errors).");
WPFatal(isSameClientDB, "Used MySQL library version (%s id %lu) does not match the version id used to compile AzerothCore (id %u).\nSearch the wiki for ACE00046 in Common Errors (https://www.azerothcore.org/wiki/common-errors).",
WPFatal(isSameClientDB, "Used MySQL library version ({} id {}) does not match the version id used to compile AzerothCore (id {}).\nSearch the wiki for ACE00046 in Common Errors (https://www.azerothcore.org/wiki/common-errors).",
mysql_get_client_info(), mysql_get_client_version(), MYSQL_VERSION_ID);
}
@@ -83,8 +85,7 @@ DatabaseWorkerPool<T>::~DatabaseWorkerPool()
}
template <class T>
void DatabaseWorkerPool<T>::SetConnectionInfo(std::string const& infoString,
uint8 const asyncThreads, uint8 const synchThreads)
void DatabaseWorkerPool<T>::SetConnectionInfo(std::string_view infoString, uint8 const asyncThreads, uint8 const synchThreads)
{
_connectionInfo = std::make_unique<MySQLConnectionInfo>(infoString);
@@ -141,9 +142,9 @@ void DatabaseWorkerPool<T>::Close()
template <class T>
bool DatabaseWorkerPool<T>::PrepareStatements()
{
for (auto& connections : _connections)
for (auto const& connections : _connections)
{
for (auto& connection : connections)
for (auto const& connection : connections)
{
connection->LockIfReady();
if (!connection->PrepareStatements())
@@ -170,8 +171,8 @@ bool DatabaseWorkerPool<T>::PrepareStatements()
{
uint32 const paramCount = stmt->GetParameterCount();
// TC only supports uint8 indices.
ASSERT(paramCount < (std::numeric_limits<uint8>::max)());
// WH only supports uint8 indices.
ASSERT(paramCount < std::numeric_limits<uint8>::max());
_preparedStatementSize[i] = static_cast<uint8>(paramCount);
}
@@ -183,13 +184,13 @@ bool DatabaseWorkerPool<T>::PrepareStatements()
}
template <class T>
QueryResult DatabaseWorkerPool<T>::Query(char const* sql, T* connection /*= nullptr*/)
QueryResult DatabaseWorkerPool<T>::Query(std::string_view sql)
{
if (!connection)
connection = GetFreeConnection();
auto connection = GetFreeConnection();
ResultSet* result = connection->Query(sql);
connection->Unlock();
if (!result || !result->GetRowCount() || !result->NextRow())
{
delete result;
@@ -219,7 +220,7 @@ PreparedQueryResult DatabaseWorkerPool<T>::Query(PreparedStatement<T>* stmt)
}
template <class T>
QueryCallback DatabaseWorkerPool<T>::AsyncQuery(char const* sql)
QueryCallback DatabaseWorkerPool<T>::AsyncQuery(std::string_view sql)
{
BasicStatementTask* task = new BasicStatementTask(sql, true);
// Store future result before enqueueing - task might get already processed and deleted before returning from this method
@@ -308,6 +309,7 @@ void DatabaseWorkerPool<T>::DirectCommitTransaction(SQLTransaction<T>& transacti
{
T* connection = GetFreeConnection();
int errorCode = connection->ExecuteTransaction(transaction);
if (!errorCode)
{
connection->Unlock(); // OK, operation succesful
@@ -320,6 +322,7 @@ void DatabaseWorkerPool<T>::DirectCommitTransaction(SQLTransaction<T>& transacti
{
//todo: handle multiple sync threads deadlocking in a similar way as async threads
uint8 loopBreaker = 5;
for (uint8 i = 0; i < loopBreaker; ++i)
{
if (!connection->ExecuteTransaction(transaction))
@@ -368,6 +371,7 @@ void DatabaseWorkerPool<T>::KeepAlive()
//! If one or more worker threads are busy, the ping operations will not be split evenly, but this doesn't matter
//! as the sole purpose is to prevent connections from idling.
auto const count = _connections[IDX_ASYNC].size();
for (uint8 i = 0; i < count; ++i)
Enqueue(new PingOperation);
}
@@ -378,7 +382,8 @@ uint32 DatabaseWorkerPool<T>::OpenConnections(InternalIndex type, uint8 numConne
for (uint8 i = 0; i < numConnections; ++i)
{
// Create the connection
auto connection = [&] {
auto connection = [&]
{
switch (type)
{
case IDX_ASYNC:
@@ -461,15 +466,15 @@ T* DatabaseWorkerPool<T>::GetFreeConnection()
}
template <class T>
char const* DatabaseWorkerPool<T>::GetDatabaseName() const
std::string_view DatabaseWorkerPool<T>::GetDatabaseName() const
{
return _connectionInfo->database.c_str();
return std::string_view{ _connectionInfo->database };
}
template <class T>
void DatabaseWorkerPool<T>::Execute(char const* sql)
void DatabaseWorkerPool<T>::Execute(std::string_view sql)
{
if (Acore::IsFormatEmptyOrNull(sql))
if (sql.empty())
return;
BasicStatementTask* task = new BasicStatementTask(sql);
@@ -484,9 +489,9 @@ void DatabaseWorkerPool<T>::Execute(PreparedStatement<T>* stmt)
}
template <class T>
void DatabaseWorkerPool<T>::DirectExecute(char const* sql)
void DatabaseWorkerPool<T>::DirectExecute(std::string_view sql)
{
if (Acore::IsFormatEmptyOrNull(sql))
if (sql.empty())
return;
T* connection = GetFreeConnection();
@@ -506,7 +511,7 @@ void DatabaseWorkerPool<T>::DirectExecute(PreparedStatement<T>* stmt)
}
template <class T>
void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction<T>& trans, char const* sql)
void DatabaseWorkerPool<T>::ExecuteOrAppend(SQLTransaction<T>& trans, std::string_view sql)
{
if (!trans)
Execute(sql);

View File

@@ -22,7 +22,6 @@
#include "Define.h"
#include "StringFormat.h"
#include <array>
#include <string>
#include <vector>
template <typename T>
@@ -45,13 +44,11 @@ private:
public:
/* Activity state */
DatabaseWorkerPool();
~DatabaseWorkerPool();
void SetConnectionInfo(std::string const& infoString, uint8 const asyncThreads, uint8 const synchThreads);
void SetConnectionInfo(std::string_view infoString, uint8 const asyncThreads, uint8 const synchThreads);
uint32 Open();
void Close();
//! Prepares all prepared statements
@@ -68,17 +65,17 @@ public:
//! Enqueues a one-way SQL operation in string format that will be executed asynchronously.
//! This method should only be used for queries that are only executed once, e.g during startup.
void Execute(char const* sql);
void Execute(std::string_view sql);
//! Enqueues a one-way SQL operation in string format -with variable args- that will be executed asynchronously.
//! This method should only be used for queries that are only executed once, e.g during startup.
template<typename Format, typename... Args>
void PExecute(Format&& sql, Args&&... args)
template<typename... Args>
void Execute(std::string_view sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
if (sql.empty())
return;
Execute(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str());
Execute(Acore::StringFormatFmt(sql, std::forward<Args>(args)...));
}
//! Enqueues a one-way SQL operation in prepared statement format that will be executed asynchronously.
@@ -91,17 +88,17 @@ public:
//! Directly executes a one-way SQL operation in string format, that will block the calling thread until finished.
//! This method should only be used for queries that are only executed once, e.g during startup.
void DirectExecute(char const* sql);
void DirectExecute(std::string_view sql);
//! Directly executes a one-way SQL operation in string format -with variable args-, that will block the calling thread until finished.
//! This method should only be used for queries that are only executed once, e.g during startup.
template<typename Format, typename... Args>
void DirectPExecute(Format&& sql, Args&&... args)
template<typename... Args>
void DirectExecute(std::string_view sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
if (sql.empty())
return;
DirectExecute(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str());
DirectExecute(Acore::StringFormatFmt(sql, std::forward<Args>(args)...));
}
//! Directly executes a one-way SQL operation in prepared statement format, that will block the calling thread until finished.
@@ -114,28 +111,17 @@ public:
//! Directly executes an SQL query in string format that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
QueryResult Query(char const* sql, T* connection = nullptr);
QueryResult Query(std::string_view sql);
//! Directly executes an SQL query in string format -with variable args- that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
template<typename Format, typename... Args>
QueryResult PQuery(Format&& sql, T* conn, Args&&... args)
template<typename... Args>
QueryResult Query(std::string_view sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
if (sql.empty())
return QueryResult(nullptr);
return Query(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str(), conn);
}
//! Directly executes an SQL query in string format -with variable args- that will block the calling thread until finished.
//! Returns reference counted auto pointer, no need for manual memory management in upper level code.
template<typename Format, typename... Args>
QueryResult PQuery(Format&& sql, Args&&... args)
{
if (Acore::IsFormatEmptyOrNull(sql))
return QueryResult(nullptr);
return Query(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str());
return Query(Acore::StringFormatFmt(sql, std::forward<Args>(args)...));
}
//! Directly executes an SQL query in prepared format that will block the calling thread until finished.
@@ -149,7 +135,7 @@ public:
//! Enqueues a query in string format that will set the value of the QueryResultFuture return object as soon as the query is executed.
//! The return value is then processed in ProcessQueryCallback methods.
QueryCallback AsyncQuery(char const* sql);
QueryCallback AsyncQuery(std::string_view sql);
//! Enqueues a query in prepared format that will set the value of the PreparedQueryResultFuture return object as soon as the query is executed.
//! The return value is then processed in ProcessQueryCallback methods.
@@ -183,7 +169,7 @@ public:
//! Method used to execute ad-hoc statements in a diverse context.
//! Will be wrapped in a transaction if valid object is present, otherwise executed standalone.
void ExecuteOrAppend(SQLTransaction<T>& trans, char const* sql);
void ExecuteOrAppend(SQLTransaction<T>& trans, std::string_view sql);
//! Method used to execute prepared statements in a diverse context.
//! Will be wrapped in a transaction if valid object is present, otherwise executed standalone.
@@ -226,7 +212,7 @@ private:
//! Caller MUST call t->Unlock() after touching the MySQL context to prevent deadlocks.
T* GetFreeConnection();
[[nodiscard]] char const* GetDatabaseName() const;
[[nodiscard]] std::string_view GetDatabaseName() const;
//! Queue shared by async worker threads.
std::unique_ptr<ProducerConsumerQueue<SQLOperation*>> _queue;

View File

@@ -19,6 +19,8 @@
#include "Errors.h"
#include "Log.h"
#include "MySQLHacks.h"
#include "StringConvert.h"
#include "Types.h"
Field::Field()
{
@@ -28,236 +30,118 @@ Field::Field()
meta = nullptr;
}
Field::~Field() = default;
uint8 Field::GetUInt8() const
namespace
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int8))
template<typename T>
constexpr T GetDefaultValue()
{
LogWrongType(__FUNCTION__);
return 0;
if constexpr (std::is_same_v<T, bool>)
return false;
else if constexpr (std::is_integral_v<T>)
return 0;
else if constexpr (std::is_floating_point_v<T>)
return 1.0f;
else if constexpr (std::is_same_v<T, std::vector<uint8>> || std::is_same_v<std::string_view, T>)
return {};
else
return "";
}
#endif
if (data.raw)
return *reinterpret_cast<uint8 const*>(data.value);
return static_cast<uint8>(strtoul(data.value, nullptr, 10));
}
int8 Field::GetInt8() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int8))
template<typename T>
inline bool IsCorrectFieldType(DatabaseFieldTypes type)
{
LogWrongType(__FUNCTION__);
return 0;
// Int8
if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, int8> || std::is_same_v<T, uint8>)
{
if (type == DatabaseFieldTypes::Int8)
return true;
}
// In16
if constexpr (std::is_same_v<T, uint16> || std::is_same_v<T, int16>)
{
if (type == DatabaseFieldTypes::Int16)
return true;
}
// Int32
if constexpr (std::is_same_v<T, uint32> || std::is_same_v<T, int32>)
{
if (type == DatabaseFieldTypes::Int32)
return true;
}
// Int64
if constexpr (std::is_same_v<T, uint64> || std::is_same_v<T, int64>)
{
if (type == DatabaseFieldTypes::Int64)
return true;
}
// float
if constexpr (std::is_same_v<T, float>)
{
if (type == DatabaseFieldTypes::Float)
return true;
}
// dobule
if constexpr (std::is_same_v<T, double>)
{
if (type == DatabaseFieldTypes::Double || type == DatabaseFieldTypes::Decimal)
return true;
}
// Binary
if constexpr (std::is_same_v<T, Binary>)
{
if (type == DatabaseFieldTypes::Binary)
return true;
}
return false;
}
#endif
if (data.raw)
return *reinterpret_cast<int8 const*>(data.value);
return static_cast<int8>(strtol(data.value, nullptr, 10));
}
uint16 Field::GetUInt16() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int16))
inline Optional<std::string_view> GetCleanAliasName(std::string_view alias)
{
LogWrongType(__FUNCTION__);
return 0;
if (alias.empty())
return {};
auto pos = alias.find_first_of('(');
if (pos == std::string_view::npos)
return {};
alias.remove_suffix(alias.length() - pos);
return { alias };
}
#endif
if (data.raw)
return *reinterpret_cast<uint16 const*>(data.value);
return static_cast<uint16>(strtoul(data.value, nullptr, 10));
}
int16 Field::GetInt16() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int16))
template<typename T>
inline bool IsCorrectAlias(DatabaseFieldTypes type, std::string_view alias)
{
LogWrongType(__FUNCTION__);
return 0;
if constexpr (std::is_same_v<T, double>)
{
if ((StringEqualI(alias, "sum") || StringEqualI(alias, "avg")) && type == DatabaseFieldTypes::Decimal)
return true;
return false;
}
if constexpr (std::is_same_v<T, uint64>)
{
if (StringEqualI(alias, "count") && type == DatabaseFieldTypes::Int64)
return true;
return false;
}
if ((StringEqualI(alias, "min") || StringEqualI(alias, "max")) && IsCorrectFieldType<T>(type))
{
return true;
}
return false;
}
#endif
if (data.raw)
return *reinterpret_cast<int16 const*>(data.value);
return static_cast<int16>(strtol(data.value, nullptr, 10));
}
uint32 Field::GetUInt32() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int32))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint32 const*>(data.value);
return static_cast<uint32>(strtoul(data.value, nullptr, 10));
}
int32 Field::GetInt32() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int32))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int32 const*>(data.value);
return static_cast<int32>(strtol(data.value, nullptr, 10));
}
uint64 Field::GetUInt64() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int64))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<uint64 const*>(data.value);
return static_cast<uint64>(strtoull(data.value, nullptr, 10));
}
int64 Field::GetInt64() const
{
if (!data.value)
return 0;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Int64))
{
LogWrongType(__FUNCTION__);
return 0;
}
#endif
if (data.raw)
return *reinterpret_cast<int64 const*>(data.value);
return static_cast<int64>(strtoll(data.value, nullptr, 10));
}
float Field::GetFloat() const
{
if (!data.value)
return 0.0f;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Float))
{
LogWrongType(__FUNCTION__);
return 0.0f;
}
#endif
if (data.raw)
return *reinterpret_cast<float const*>(data.value);
return static_cast<float>(atof(data.value));
}
double Field::GetDouble() const
{
if (!data.value)
return 0.0f;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsType(DatabaseFieldTypes::Double) && !IsType(DatabaseFieldTypes::Decimal))
{
LogWrongType(__FUNCTION__);
return 0.0f;
}
#endif
if (data.raw && !IsType(DatabaseFieldTypes::Decimal))
return *reinterpret_cast<double const*>(data.value);
return static_cast<double>(atof(data.value));
}
char const* Field::GetCString() const
{
if (!data.value)
return nullptr;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (IsNumeric() && data.raw)
{
LogWrongType(__FUNCTION__);
return nullptr;
}
#endif
return static_cast<char const*>(data.value);
}
std::string Field::GetString() const
{
if (!data.value)
return "";
char const* string = GetCString();
if (!string)
return "";
return std::string(string, data.length);
}
std::string_view Field::GetStringView() const
{
if (!data.value)
return {};
char const* const string = GetCString();
if (!string)
return {};
return { string, data.length };
}
std::vector<uint8> Field::GetBinary() const
{
std::vector<uint8> result;
if (!data.value || !data.length)
return result;
result.resize(data.length);
memcpy(result.data(), data.value, data.length);
return result;
}
void Field::GetBinarySizeChecked(uint8* buf, size_t length) const
@@ -297,13 +181,159 @@ bool Field::IsNumeric() const
meta->Type == DatabaseFieldTypes::Double);
}
void Field::LogWrongType(char const* getter) const
void Field::LogWrongType(std::string_view getter, std::string_view typeName) const
{
LOG_WARN("sql.sql", "Warning: {} on {} field {}.{} ({}.{}) at index {}.",
getter, meta->TypeName, meta->TableAlias, meta->Alias, meta->TableName, meta->Name, meta->Index);
LOG_WARN("sql.sql", "Warning: {}<{}> on {} field {}.{} ({}.{}) at index {}.",
getter, typeName, meta->TypeName, meta->TableAlias, meta->Alias, meta->TableName, meta->Name, meta->Index);
}
void Field::SetMetadata(QueryResultFieldMetadata const* fieldMeta)
{
meta = fieldMeta;
}
template<typename T>
T Field::GetData() const
{
static_assert(std::is_arithmetic_v<T>, "Unsurropt type for Field::GetData()");
if (!data.value)
return GetDefaultValue<T>();
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsCorrectFieldType<T>(meta->Type))
{
LogWrongType(__FUNCTION__, typeid(T).name());
//return GetDefaultValue<T>();
}
#endif
Optional<T> result = {};
if (data.raw)
result = *reinterpret_cast<T const*>(data.value);
else
result = Acore::StringTo<T>(data.value);
// Correct double fields... this undefined behavior :/
if constexpr (std::is_same_v<T, double>)
{
if (data.raw && !IsType(DatabaseFieldTypes::Decimal))
result = *reinterpret_cast<double const*>(data.value);
else
result = Acore::StringTo<float>(data.value);
}
// Check -1 for *_dbc db tables
if constexpr (std::is_same_v<T, uint32>)
{
std::string_view tableName{ meta->TableName };
if (!tableName.empty() && tableName.size() > 4)
{
auto signedResult = Acore::StringTo<int32>(data.value);
if (signedResult && !result && tableName.substr(tableName.length() - 4) == "_dbc")
{
LOG_DEBUG("sql.sql", "> Found incorrect value '{}' for type '{}' in _dbc table.", data.value, typeid(T).name());
LOG_DEBUG("sql.sql", "> Table name '{}'. Field name '{}'. Try return int32 value", meta->TableName, meta->Name);
return GetData<int32>();
}
}
}
if (auto alias = GetCleanAliasName(meta->Alias))
{
if ((StringEqualI(*alias, "min") || StringEqualI(*alias, "max")) && !IsCorrectAlias<T>(meta->Type, *alias))
{
LogWrongType(__FUNCTION__, typeid(T).name());
}
if ((StringEqualI(*alias, "sum") || StringEqualI(*alias, "avg")) && !IsCorrectAlias<T>(meta->Type, *alias))
{
LogWrongType(__FUNCTION__, typeid(T).name());
LOG_WARN("sql.sql", "> Please use GetData<double>()");
return GetData<double>();
}
if (StringEqualI(*alias, "count") && !IsCorrectAlias<T>(meta->Type, *alias))
{
LogWrongType(__FUNCTION__, typeid(T).name());
LOG_WARN("sql.sql", "> Please use GetData<uint64>()");
return GetData<uint64>();
}
}
if (!result)
{
LOG_FATAL("sql.sql", "> Incorrect value '{}' for type '{}'. Value is raw ? '{}'", data.value, typeid(T).name(), data.raw);
LOG_FATAL("sql.sql", "> Table name '{}'. Field name '{}'", meta->TableName, meta->Name);
//ABORT();
return GetDefaultValue<T>();
}
return *result;
}
template bool Field::GetData() const;
template uint8 Field::GetData() const;
template uint16 Field::GetData() const;
template uint32 Field::GetData() const;
template uint64 Field::GetData() const;
template int8 Field::GetData() const;
template int16 Field::GetData() const;
template int32 Field::GetData() const;
template int64 Field::GetData() const;
template float Field::GetData() const;
template double Field::GetData() const;
std::string Field::GetDataString() const
{
if (!data.value)
return "";
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (IsNumeric() && data.raw)
{
LogWrongType(__FUNCTION__, "std::string");
return "";
}
#endif
return { data.value, data.length };
}
std::string_view Field::GetDataStringView() const
{
if (!data.value)
return {};
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (IsNumeric() && data.raw)
{
LogWrongType(__FUNCTION__, "std::string_view");
return {};
}
#endif
return { data.value, data.length };
}
Binary Field::GetDataBinary() const
{
Binary result = {};
if (!data.value || !data.length)
return result;
#ifdef ACORE_STRICT_DATABASE_TYPE_CHECKS
if (!IsCorrectFieldType<Binary>(meta->Type))
{
LogWrongType(__FUNCTION__, "Binary");
return {};
}
#endif
result.resize(data.length);
memcpy(result.data(), data.value, data.length);
return result;
}

View File

@@ -20,11 +20,26 @@
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include "Duration.h"
#include <array>
#include <string>
#include <string_view>
#include <vector>
namespace Acore::Types
{
template <typename T>
using is_chrono_v = std::enable_if_t<std::is_same_v<Milliseconds, T>
|| std::is_same_v<Seconds, T>
|| std::is_same_v<Minutes, T>
|| std::is_same_v<Hours, T>
|| std::is_same_v<Days, T>
|| std::is_same_v<Weeks, T>
|| std::is_same_v<Years, T>
|| std::is_same_v<Months, T>, T>;
}
using Binary = std::vector<uint8>;
enum class DatabaseFieldTypes : uint8
{
Null,
@@ -41,11 +56,11 @@ enum class DatabaseFieldTypes : uint8
struct QueryResultFieldMetadata
{
char const* TableName = nullptr;
char const* TableAlias = nullptr;
char const* Name = nullptr;
char const* Alias = nullptr;
char const* TypeName = nullptr;
std::string TableName{};
std::string TableAlias{};
std::string Name{};
std::string Alias{};
std::string TypeName{};
uint32 Index = 0;
DatabaseFieldTypes Type = DatabaseFieldTypes::Null;
};
@@ -57,20 +72,20 @@ struct QueryResultFieldMetadata
Guideline on field type matching:
| MySQL type | method to use |
|------------------------|----------------------------------------|
| TINYINT | GetBool, GetInt8, GetUInt8 |
| SMALLINT | GetInt16, GetUInt16 |
| MEDIUMINT, INT | GetInt32, GetUInt32 |
| BIGINT | GetInt64, GetUInt64 |
| FLOAT | GetFloat |
| DOUBLE, DECIMAL | GetDouble |
| CHAR, VARCHAR, | GetCString, GetString |
| TINYTEXT, MEDIUMTEXT, | GetCString, GetString |
| TEXT, LONGTEXT | GetCString, GetString |
| TINYBLOB, MEDIUMBLOB, | GetBinary, GetString |
| BLOB, LONGBLOB | GetBinary, GetString |
| BINARY, VARBINARY | GetBinary |
| MySQL type | method to use |
|------------------------|-----------------------------------------|
| TINYINT | Get<bool>, Get<int8>, Get<uint8> |
| SMALLINT | Get<int16>, Get<uint16> |
| MEDIUMINT, INT | Get<int32>, Get<uint32> |
| BIGINT | Get<int64>, Get<uint64> |
| FLOAT | Get<float> |
| DOUBLE, DECIMAL | Get<double> |
| CHAR, VARCHAR, | Get<std::string>, Get<std::string_view> |
| TINYTEXT, MEDIUMTEXT, | Get<std::string>, Get<std::string_view> |
| TEXT, LONGTEXT | Get<std::string>, Get<std::string_view> |
| TINYBLOB, MEDIUMBLOB, | Get<Binary>, Get<std::string> |
| BLOB, LONGBLOB | Get<Binary>, Get<std::string> |
| BINARY, VARBINARY | Get<Binary> |
Return types of aggregate functions:
@@ -87,39 +102,49 @@ friend class PreparedResultSet;
public:
Field();
~Field();
~Field() = default;
[[nodiscard]] bool GetBool() const // Wrapper, actually gets integer
[[nodiscard]] inline bool IsNull() const
{
return GetUInt8() == 1 ? true : false;
return data.value == nullptr;
}
[[nodiscard]] uint8 GetUInt8() const;
[[nodiscard]] int8 GetInt8() const;
[[nodiscard]] uint16 GetUInt16() const;
[[nodiscard]] int16 GetInt16() const;
[[nodiscard]] uint32 GetUInt32() const;
[[nodiscard]] int32 GetInt32() const;
[[nodiscard]] uint64 GetUInt64() const;
[[nodiscard]] int64 GetInt64() const;
[[nodiscard]] float GetFloat() const;
[[nodiscard]] double GetDouble() const;
[[nodiscard]] char const* GetCString() const;
[[nodiscard]] std::string GetString() const;
[[nodiscard]] std::string_view GetStringView() const;
[[nodiscard]] std::vector<uint8> GetBinary() const;
template <size_t S>
std::array<uint8, S> GetBinary() const
template<typename T>
inline std::enable_if_t<std::is_arithmetic_v<T>, T> Get() const
{
std::array<uint8, S> buf;
return GetData<T>();
}
template<typename T>
inline std::enable_if_t<std::is_same_v<std::string, T>, T> Get() const
{
return GetDataString();
}
template<typename T>
inline std::enable_if_t<std::is_same_v<std::string_view, T>, T> Get() const
{
return GetDataStringView();
}
template<typename T>
inline std::enable_if_t<std::is_same_v<Binary, T>, T> Get() const
{
return GetDataBinary();
}
template <typename T, size_t S>
inline std::enable_if_t<std::is_same_v<Binary, T>, std::array<uint8, S>> Get() const
{
std::array<uint8, S> buf = {};
GetBinarySizeChecked(buf.data(), S);
return buf;
}
[[nodiscard]] bool IsNull() const
template<typename T>
inline Acore::Types::is_chrono_v<T> Get(bool convertToUin32 = true) const
{
return data.value == nullptr;
return convertToUin32 ? T(GetData<uint32>()) : T(GetData<uint64>());
}
DatabaseFieldTypes GetType() { return meta->Type; }
@@ -138,8 +163,15 @@ protected:
[[nodiscard]] bool IsNumeric() const;
private:
template<typename T>
T GetData() const;
std::string GetDataString() const;
std::string_view GetDataStringView() const;
Binary GetDataBinary() const;
QueryResultFieldMetadata const* meta;
void LogWrongType(char const* getter) const;
void LogWrongType(std::string_view getter, std::string_view typeName) const;
void SetMetadata(QueryResultFieldMetadata const* fieldMeta);
void GetBinarySizeChecked(uint8* buf, size_t size) const;
};

View File

@@ -23,46 +23,46 @@
#include "MySQLWorkaround.h"
#include "PreparedStatement.h"
#include "QueryResult.h"
#include "StringConvert.h"
#include "Timer.h"
#include "Tokenize.h"
#include "Transaction.h"
#include "Util.h"
#include "StringConvert.h"
#include <errmsg.h>
#include <mysqld_error.h>
MySQLConnectionInfo::MySQLConnectionInfo(std::string const& infoString)
MySQLConnectionInfo::MySQLConnectionInfo(std::string_view infoString)
{
std::vector<std::string_view> tokens = Acore::Tokenize(infoString, ';', true);
if (tokens.size() != 5 && tokens.size() != 6)
return;
host.assign(tokens[0]);
port_or_socket.assign(tokens[1]);
user.assign(tokens[2]);
password.assign(tokens[3]);
database.assign(tokens[4]);
host.assign(tokens.at(0));
port_or_socket.assign(tokens.at(1));
user.assign(tokens.at(2));
password.assign(tokens.at(3));
database.assign(tokens.at(4));
if (tokens.size() == 6)
ssl.assign(tokens[5]);
ssl.assign(tokens.at(5));
}
MySQLConnection::MySQLConnection(MySQLConnectionInfo& connInfo) :
m_reconnecting(false),
m_prepareError(false),
m_queue(nullptr),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_SYNCH) { }
m_reconnecting(false),
m_prepareError(false),
m_queue(nullptr),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_SYNCH) { }
MySQLConnection::MySQLConnection(ProducerConsumerQueue<SQLOperation*>* queue, MySQLConnectionInfo& connInfo) :
m_reconnecting(false),
m_prepareError(false),
m_queue(queue),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_ASYNC)
m_reconnecting(false),
m_prepareError(false),
m_queue(queue),
m_Mysql(nullptr),
m_connectionInfo(connInfo),
m_connectionFlags(CONNECTION_ASYNC)
{
m_worker = std::make_unique<DatabaseWorker>(m_queue, this);
}
@@ -76,7 +76,6 @@ void MySQLConnection::Close()
{
// Stop the worker thread before the statements are cleared
m_worker.reset();
m_stmts.clear();
if (m_Mysql)
@@ -99,7 +98,8 @@ uint32 MySQLConnection::Open()
char const* unix_socket;
mysql_options(mysqlInit, MYSQL_SET_CHARSET_NAME, "utf8");
#ifdef _WIN32
#ifdef _WIN32
if (m_connectionInfo.host == ".") // named pipe use option (Windows)
{
unsigned int opt = MYSQL_PROTOCOL_PIPE;
@@ -112,7 +112,7 @@ uint32 MySQLConnection::Open()
port = *Acore::StringTo<uint32>(m_connectionInfo.port_or_socket);
unix_socket = 0;
}
#else
#else
if (m_connectionInfo.host == ".") // socket use option (Unix/Linux)
{
unsigned int opt = MYSQL_PROTOCOL_SOCKET;
@@ -126,7 +126,7 @@ uint32 MySQLConnection::Open()
port = *Acore::StringTo<uint32>(m_connectionInfo.port_or_socket);
unix_socket = nullptr;
}
#endif
#endif
if (m_connectionInfo.ssl != "")
{
@@ -136,6 +136,7 @@ uint32 MySQLConnection::Open()
{
opt_use_ssl = SSL_MODE_REQUIRED;
}
mysql_options(mysqlInit, MYSQL_OPT_SSL_MODE, (char const*)&opt_use_ssl);
#else
MySQLBool opt_use_ssl = MySQLBool(0);
@@ -143,6 +144,7 @@ uint32 MySQLConnection::Open()
{
opt_use_ssl = MySQLBool(1);
}
mysql_options(mysqlInit, MYSQL_OPT_SSL_ENFORCE, (char const*)&opt_use_ssl);
#endif
}
@@ -156,9 +158,6 @@ uint32 MySQLConnection::Open()
{
LOG_INFO("sql.sql", "MySQL client library: {}", mysql_get_client_info());
LOG_INFO("sql.sql", "MySQL server ver: {} ", mysql_get_server_info(m_Mysql));
// MySQL version above 5.1 IS required in both client and server and there is no known issue with different versions above 5.1
// if (mysql_get_server_version(m_Mysql) != mysql_get_client_version())
// LOG_INFO("sql.sql", "[WARNING] MySQL client/server version mismatch; may conflict with behaviour of prepared statements.");
}
LOG_INFO("sql.sql", "Connected to MySQL database at {}", m_connectionInfo.host);
@@ -166,7 +165,7 @@ uint32 MySQLConnection::Open()
// set connection properties to UTF8 to properly handle locales for different
// server configs - core sends data in UTF8, so MySQL must expect UTF8 too
mysql_set_character_set(m_Mysql, "utf8");
mysql_set_character_set(m_Mysql, "utf8mb4");
return 0;
}
else
@@ -184,7 +183,7 @@ bool MySQLConnection::PrepareStatements()
return !m_prepareError;
}
bool MySQLConnection::Execute(char const* sql)
bool MySQLConnection::Execute(std::string_view sql)
{
if (!m_Mysql)
return false;
@@ -192,7 +191,7 @@ bool MySQLConnection::Execute(char const* sql)
{
uint32 _s = getMSTime();
if (mysql_query(m_Mysql, sql))
if (mysql_query(m_Mysql, std::string(sql).c_str()))
{
uint32 lErrno = mysql_errno(m_Mysql);
@@ -219,7 +218,7 @@ bool MySQLConnection::Execute(PreparedStatementBase* stmt)
uint32 index = stmt->GetIndex();
MySQLPreparedStatement* m_mStmt = GetPreparedStatement(index);
ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query
ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query
m_mStmt->BindParameters(stmt);
@@ -291,8 +290,7 @@ bool MySQLConnection::_Query(PreparedStatementBase* stmt, MySQLPreparedStatement
if (mysql_stmt_execute(msql_STMT))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_ERROR("sql.sql", "SQL(p): {}\n [ERROR]: [{}] {}",
m_mStmt->getQueryString(), lErrno, mysql_stmt_error(msql_STMT));
LOG_ERROR("sql.sql", "SQL(p): {}\n [ERROR]: [{}] {}", m_mStmt->getQueryString(), lErrno, mysql_stmt_error(msql_STMT));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(stmt, mysqlStmt, pResult, pRowCount, pFieldCount); // Try again
@@ -312,9 +310,9 @@ bool MySQLConnection::_Query(PreparedStatementBase* stmt, MySQLPreparedStatement
return true;
}
ResultSet* MySQLConnection::Query(char const* sql)
ResultSet* MySQLConnection::Query(std::string_view sql)
{
if (!sql)
if (sql.empty())
return nullptr;
MySQLResult* result = nullptr;
@@ -328,7 +326,7 @@ ResultSet* MySQLConnection::Query(char const* sql)
return new ResultSet(result, fields, rowCount, fieldCount);
}
bool MySQLConnection::_Query(const char* sql, MySQLResult** pResult, MySQLField** pFields, uint64* pRowCount, uint32* pFieldCount)
bool MySQLConnection::_Query(std::string_view sql, MySQLResult** pResult, MySQLField** pFields, uint64* pRowCount, uint32* pFieldCount)
{
if (!m_Mysql)
return false;
@@ -336,13 +334,13 @@ bool MySQLConnection::_Query(const char* sql, MySQLResult** pResult, MySQLField*
{
uint32 _s = getMSTime();
if (mysql_query(m_Mysql, sql))
if (mysql_query(m_Mysql, std::string(sql).c_str()))
{
uint32 lErrno = mysql_errno(m_Mysql);
LOG_INFO("sql.sql", "SQL: {}", sql);
LOG_ERROR("sql.sql", "[{}] {}", lErrno, mysql_error(m_Mysql));
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
if (_HandleMySQLErrno(lErrno)) // If it returns true, an error was handled successfully (i.e. reconnection)
return _Query(sql, pResult, pFields, pRowCount, pFieldCount); // We try again
return false;
@@ -355,7 +353,7 @@ bool MySQLConnection::_Query(const char* sql, MySQLResult** pResult, MySQLField*
*pFieldCount = mysql_field_count(m_Mysql);
}
if (!*pResult )
if (!*pResult)
return false;
if (!*pRowCount)
@@ -392,18 +390,29 @@ int MySQLConnection::ExecuteTransaction(std::shared_ptr<TransactionBase> transac
BeginTransaction();
for (auto itr = queries.begin(); itr != queries.end(); ++itr)
for (auto const& data : queries)
{
SQLElementData const& data = *itr;
switch (itr->type)
switch (data.type)
{
case SQL_ELEMENT_PREPARED:
{
PreparedStatementBase* stmt = data.element.stmt;
PreparedStatementBase* stmt = nullptr;
try
{
stmt = std::get<PreparedStatementBase*>(data.element);
}
catch (const std::bad_variant_access& ex)
{
LOG_FATAL("sql.sql", "> PreparedStatementBase not found in SQLElementData. {}", ex.what());
ABORT();
}
ASSERT(stmt);
if (!Execute(stmt))
{
LOG_WARN("sql.sql", "Transaction aborted. {} queries not executed.", (uint32)queries.size());
LOG_WARN("sql.sql", "Transaction aborted. {} queries not executed.", queries.size());
int errorCode = GetLastError();
RollbackTransaction();
return errorCode;
@@ -412,12 +421,24 @@ int MySQLConnection::ExecuteTransaction(std::shared_ptr<TransactionBase> transac
break;
case SQL_ELEMENT_RAW:
{
char const* sql = data.element.query;
ASSERT(sql);
std::string sql{};
try
{
sql = std::get<std::string>(data.element);
}
catch (const std::bad_variant_access& ex)
{
LOG_FATAL("sql.sql", "> std::string not found in SQLElementData. {}", ex.what());
ABORT();
}
ASSERT(!sql.empty());
if (!Execute(sql))
{
LOG_WARN("sql.sql", "Transaction aborted. {} queries not executed.", (uint32)queries.size());
int errorCode = GetLastError();
LOG_WARN("sql.sql", "Transaction aborted. {} queries not executed.", queries.size());
uint32 errorCode = GetLastError();
RollbackTransaction();
return errorCode;
}
@@ -469,7 +490,9 @@ MySQLPreparedStatement* MySQLConnection::GetPreparedStatement(uint32 index)
{
ASSERT(index < m_stmts.size(), "Tried to access invalid prepared statement index {} (max index {}) on database `{}`, connection type: {}",
index, m_stmts.size(), m_connectionInfo.database, (m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
MySQLPreparedStatement* ret = m_stmts[index].get();
if (!ret)
LOG_ERROR("sql.sql", "Could not fetch prepared statement {} on database `{}`, connection type: {}.",
index, m_connectionInfo.database, (m_connectionFlags & CONNECTION_ASYNC) ? "asynchronous" : "synchronous");
@@ -477,7 +500,7 @@ MySQLPreparedStatement* MySQLConnection::GetPreparedStatement(uint32 index)
return ret;
}
void MySQLConnection::PrepareStatement(uint32 index, std::string const& sql, ConnectionFlags flags)
void MySQLConnection::PrepareStatement(uint32 index, std::string_view sql, ConnectionFlags flags)
{
// Check if specified query should be prepared on this connection
// i.e. don't prepare async statements on synchronous connections
@@ -497,7 +520,7 @@ void MySQLConnection::PrepareStatement(uint32 index, std::string const& sql, Con
}
else
{
if (mysql_stmt_prepare(stmt, sql.c_str(), static_cast<unsigned long>(sql.size())))
if (mysql_stmt_prepare(stmt, std::string(sql).c_str(), static_cast<unsigned long>(sql.size())))
{
LOG_ERROR("sql.sql", "In mysql_stmt_prepare() id: {}, sql: \"{}\"", index, sql);
LOG_ERROR("sql.sql", "{}", mysql_stmt_error(stmt));
@@ -523,6 +546,7 @@ PreparedResultSet* MySQLConnection::Query(PreparedStatementBase* stmt)
{
mysql_next_result(m_Mysql);
}
return new PreparedResultSet(mysqlStmt->GetSTMT(), result, rowCount, fieldCount);
}
@@ -556,7 +580,7 @@ bool MySQLConnection::_HandleMySQLErrno(uint32 errNo, uint8 attempts /*= 5*/)
if (!this->PrepareStatements())
{
LOG_FATAL("sql.sql", "Could not re-prepare statements!");
std::this_thread::sleep_for(std::chrono::seconds(10));
std::this_thread::sleep_for(10s);
std::abort();
}
@@ -572,24 +596,24 @@ bool MySQLConnection::_HandleMySQLErrno(uint32 errNo, uint8 attempts /*= 5*/)
{
// Shut down the server when the mysql server isn't
// reachable for some time
LOG_FATAL("sql.sql", "Failed to reconnect to the MySQL server, "
"terminating the server to prevent data corruption!");
LOG_FATAL("sql.sql", "Failed to reconnect to the MySQL server, terminating the server to prevent data corruption!");
// We could also initiate a shutdown through using std::raise(SIGTERM)
std::this_thread::sleep_for(std::chrono::seconds(10));
std::this_thread::sleep_for(10s);
std::abort();
}
else
{
// It's possible this attempted reconnect throws 2006 at us.
// To prevent crazy recursive calls, sleep here.
std::this_thread::sleep_for(std::chrono::seconds(3)); // Sleep 3 seconds
std::this_thread::sleep_for(3s); // Sleep 3 seconds
return _HandleMySQLErrno(lErrno, attempts); // Call self (recursive)
}
}
case ER_LOCK_DEADLOCK:
return false; // Implemented in TransactionTask::Execute and DatabaseWorkerPool<T>::DirectCommitTransaction
return false; // Implemented in TransactionTask::Execute and DatabaseWorkerPool<T>::DirectCommitTransaction
// Query related errors - skip query
case ER_WRONG_VALUE_COUNT:
case ER_DUP_ENTRY:
@@ -599,12 +623,12 @@ bool MySQLConnection::_HandleMySQLErrno(uint32 errNo, uint8 attempts /*= 5*/)
case ER_BAD_FIELD_ERROR:
case ER_NO_SUCH_TABLE:
LOG_ERROR("sql.sql", "Your database structure is not up to date. Please make sure you've executed all queries in the sql/updates folders.");
std::this_thread::sleep_for(std::chrono::seconds(10));
std::this_thread::sleep_for(10s);
std::abort();
return false;
case ER_PARSE_ERROR:
LOG_ERROR("sql.sql", "Error while parsing SQL. Core fix required.");
std::this_thread::sleep_for(std::chrono::seconds(10));
std::this_thread::sleep_for(10s);
std::abort();
return false;
default:

View File

@@ -42,7 +42,7 @@ enum ConnectionFlags
struct AC_DATABASE_API MySQLConnectionInfo
{
explicit MySQLConnectionInfo(std::string const& infoString);
explicit MySQLConnectionInfo(std::string_view infoString);
std::string user;
std::string password;
@@ -54,7 +54,9 @@ struct AC_DATABASE_API MySQLConnectionInfo
class AC_DATABASE_API MySQLConnection
{
template <class T> friend class DatabaseWorkerPool;
template <class T>
friend class DatabaseWorkerPool;
friend class PingOperation;
public:
@@ -67,11 +69,11 @@ public:
bool PrepareStatements();
bool Execute(char const* sql);
bool Execute(std::string_view sql);
bool Execute(PreparedStatementBase* stmt);
ResultSet* Query(char const* sql);
ResultSet* Query(std::string_view sql);
PreparedResultSet* Query(PreparedStatementBase* stmt);
bool _Query(char const* sql, MySQLResult** pResult, MySQLField** pFields, uint64* pRowCount, uint32* pFieldCount);
bool _Query(std::string_view sql, MySQLResult** pResult, MySQLField** pFields, uint64* pRowCount, uint32* pFieldCount);
bool _Query(PreparedStatementBase* stmt, MySQLPreparedStatement** mysqlStmt, MySQLResult** pResult, uint64* pRowCount, uint32* pFieldCount);
void BeginTransaction();
@@ -93,25 +95,25 @@ protected:
[[nodiscard]] uint32 GetServerVersion() const;
MySQLPreparedStatement* GetPreparedStatement(uint32 index);
void PrepareStatement(uint32 index, std::string const& sql, ConnectionFlags flags);
void PrepareStatement(uint32 index, std::string_view sql, ConnectionFlags flags);
virtual void DoPrepareStatements() = 0;
typedef std::vector<std::unique_ptr<MySQLPreparedStatement>> PreparedStatementContainer;
PreparedStatementContainer m_stmts; //! PreparedStatements storage
bool m_reconnecting; //! Are we reconnecting?
bool m_prepareError; //! Was there any error while preparing statements?
PreparedStatementContainer m_stmts; //! PreparedStatements storage
bool m_reconnecting; //! Are we reconnecting?
bool m_prepareError; //! Was there any error while preparing statements?
private:
bool _HandleMySQLErrno(uint32 errNo, uint8 attempts = 5);
ProducerConsumerQueue<SQLOperation*>* m_queue; //! Queue shared with other asynchronous connections.
std::unique_ptr<DatabaseWorker> m_worker; //! Core worker task.
MySQLHandle* m_Mysql; //! MySQL Handle.
MySQLConnectionInfo& m_connectionInfo; //! Connection info (used for logging)
ConnectionFlags m_connectionFlags; //! Connection flags (for preparing relevant statements)
std::mutex m_Mutex;
MySQLHandle* m_Mysql; //! MySQL Handle.
MySQLConnectionInfo& m_connectionInfo; //! Connection info (used for logging)
ConnectionFlags m_connectionFlags; //! Connection flags (for preparing relevant statements)
std::mutex m_Mutex;
MySQLConnection(MySQLConnection const& right) = delete;
MySQLConnection& operator=(MySQLConnection const& right) = delete;

View File

@@ -35,8 +35,11 @@ template<> struct MySQLType<int64> : std::integral_constant<enum_field_types, MY
template<> struct MySQLType<float> : std::integral_constant<enum_field_types, MYSQL_TYPE_FLOAT> { };
template<> struct MySQLType<double> : std::integral_constant<enum_field_types, MYSQL_TYPE_DOUBLE> { };
MySQLPreparedStatement::MySQLPreparedStatement(MySQLStmt* stmt, std::string queryString) :
m_stmt(nullptr), m_Mstmt(stmt), m_bind(nullptr), m_queryString(std::move(queryString))
MySQLPreparedStatement::MySQLPreparedStatement(MySQLStmt* stmt, std::string_view queryString) :
m_stmt(nullptr),
m_Mstmt(stmt),
m_bind(nullptr),
m_queryString(std::string(queryString))
{
/// Initialize variable parameters
m_paramCount = mysql_stmt_param_count(stmt);
@@ -57,6 +60,7 @@ MySQLPreparedStatement::~MySQLPreparedStatement()
delete[] m_Mstmt->bind->length;
delete[] m_Mstmt->bind->is_null;
}
mysql_stmt_close(m_Mstmt);
delete[] m_bind;
}
@@ -72,8 +76,10 @@ void MySQLPreparedStatement::BindParameters(PreparedStatementBase* stmt)
{
SetParameter(pos, param);
}, data.data);
++pos;
}
#ifdef _DEBUG
if (pos < m_paramCount)
LOG_WARN("sql.sql", "[WARNING]: BindParameters() for statement {} did not bind all allocated parameters", stmt->GetIndex());
@@ -86,7 +92,7 @@ void MySQLPreparedStatement::ClearParameters()
{
delete m_bind[i].length;
m_bind[i].length = nullptr;
delete[] (char*) m_bind[i].buffer;
delete[] (char*)m_bind[i].buffer;
m_bind[i].buffer = nullptr;
m_paramsSet[i] = false;
}
@@ -94,7 +100,9 @@ void MySQLPreparedStatement::ClearParameters()
static bool ParamenterIndexAssertFail(uint32 stmtIndex, uint8 index, uint32 paramCount)
{
LOG_ERROR("sql.driver", "Attempted to bind parameter {}{} on a PreparedStatement {} (statement has only {} parameters)", uint32(index) + 1, (index == 1 ? "st" : (index == 2 ? "nd" : (index == 3 ? "rd" : "nd"))), stmtIndex, paramCount);
LOG_ERROR("sql.driver", "Attempted to bind parameter {}{} on a PreparedStatement {} (statement has only {} parameters)",
uint32(index) + 1, (index == 1 ? "st" : (index == 2 ? "nd" : (index == 3 ? "rd" : "nd"))), stmtIndex, paramCount);
return false;
}
@@ -107,7 +115,30 @@ void MySQLPreparedStatement::AssertValidIndex(uint8 index)
LOG_ERROR("sql.sql", "[ERROR] Prepared Statement (id: {}) trying to bind value on already bound index ({}).", m_stmt->GetIndex(), index);
}
void MySQLPreparedStatement::SetParameter(uint8 index, std::nullptr_t)
template<typename T>
void MySQLPreparedStatement::SetParameter(const uint8 index, T value)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(sizeof(T));
param->buffer_type = MySQLType<T>::value;
delete[] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = 0;
param->is_null_value = 0;
param->length = nullptr; // Only != NULL for strings
param->is_unsigned = std::is_unsigned_v<T>;
memcpy(param->buffer, &value, len);
}
void MySQLPreparedStatement::SetParameter(const uint8 index, bool value)
{
SetParameter(index, uint8(value ? 1 : 0));
}
void MySQLPreparedStatement::SetParameter(const uint8 index, std::nullptr_t /*value*/)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
@@ -121,29 +152,6 @@ void MySQLPreparedStatement::SetParameter(uint8 index, std::nullptr_t)
param->length = nullptr;
}
void MySQLPreparedStatement::SetParameter(uint8 index, bool value)
{
SetParameter(index, uint8(value ? 1 : 0));
}
template<typename T>
void MySQLPreparedStatement::SetParameter(uint8 index, T value)
{
AssertValidIndex(index);
m_paramsSet[index] = true;
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(sizeof(T));
param->buffer_type = MySQLType<T>::value;
delete[] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = 0;
param->is_null_value = 0;
param->length = nullptr; // Only != NULL for strings
param->is_unsigned = std::is_unsigned_v<T>;
memcpy(param->buffer, &value, len);
}
void MySQLPreparedStatement::SetParameter(uint8 index, std::string const& value)
{
AssertValidIndex(index);
@@ -151,7 +159,7 @@ void MySQLPreparedStatement::SetParameter(uint8 index, std::string const& value)
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(value.size());
param->buffer_type = MYSQL_TYPE_VAR_STRING;
delete [] static_cast<char*>(param->buffer);
delete[] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = len;
param->is_null_value = 0;
@@ -168,7 +176,7 @@ void MySQLPreparedStatement::SetParameter(uint8 index, std::vector<uint8> const&
MYSQL_BIND* param = &m_bind[index];
uint32 len = uint32(value.size());
param->buffer_type = MYSQL_TYPE_BLOB;
delete [] static_cast<char*>(param->buffer);
delete[] static_cast<char*>(param->buffer);
param->buffer = new char[len];
param->buffer_length = len;
param->is_null_value = 0;
@@ -183,6 +191,7 @@ std::string MySQLPreparedStatement::getQueryString() const
std::string queryString(m_queryString);
size_t pos = 0;
for (PreparedStatementData const& data : m_stmt->GetParameters())
{
pos = queryString.find('?', pos);

View File

@@ -36,7 +36,7 @@ friend class MySQLConnection;
friend class PreparedStatementBase;
public:
MySQLPreparedStatement(MySQLStmt* stmt, std::string queryString);
MySQLPreparedStatement(MySQLStmt* stmt, std::string_view queryString);
~MySQLPreparedStatement();
void BindParameters(PreparedStatementBase* stmt);
@@ -44,18 +44,19 @@ public:
uint32 GetParameterCount() const { return m_paramCount; }
protected:
void SetParameter(uint8 index, std::nullptr_t);
void SetParameter(uint8 index, bool value);
void SetParameter(const uint8 index, bool value);
void SetParameter(const uint8 index, std::nullptr_t /*value*/);
void SetParameter(const uint8 index, std::string const& value);
void SetParameter(const uint8 index, std::vector<uint8> const& value);
template<typename T>
void SetParameter(uint8 index, T value);
void SetParameter(uint8 index, std::string const& value);
void SetParameter(uint8 index, std::vector<uint8> const& value);
void SetParameter(const uint8 index, T value);
MySQLStmt* GetSTMT() { return m_Mstmt; }
MySQLBind* GetBind() { return m_bind; }
PreparedStatementBase* m_stmt;
void ClearParameters();
void AssertValidIndex(uint8 index);
void AssertValidIndex(const uint8 index);
std::string getQueryString() const;
private:
@@ -63,7 +64,7 @@ private:
uint32 m_paramCount;
std::vector<bool> m_paramsSet;
MySQLBind* m_bind;
std::string const m_queryString;
std::string m_queryString{};
MySQLPreparedStatement(MySQLPreparedStatement const& right) = delete;
MySQLPreparedStatement& operator=(MySQLPreparedStatement const& right) = delete;

View File

@@ -24,7 +24,8 @@
#include "QueryResult.h"
PreparedStatementBase::PreparedStatementBase(uint32 index, uint8 capacity) :
m_index(index), statement_data(capacity) { }
m_index(index),
statement_data(capacity) { }
PreparedStatementBase::~PreparedStatementBase() { }
@@ -36,20 +37,6 @@ Acore::Types::is_non_string_view_v<T> PreparedStatementBase::SetValidData(const
statement_data[index].data.emplace<T>(value);
}
template<>
void PreparedStatementBase::SetValidData(const uint8 index, std::string const& value)
{
ASSERT(index < statement_data.size());
statement_data[index].data.emplace<std::string>(value);
}
template<>
void PreparedStatementBase::SetValidData(const uint8 index, std::vector<uint8> const& value)
{
ASSERT(index < statement_data.size());
statement_data[index].data.emplace<std::vector<uint8>>(value);
}
// Non template functions
void PreparedStatementBase::SetValidData(const uint8 index)
{
@@ -73,103 +60,16 @@ template void PreparedStatementBase::SetValidData(const uint8 index, uint64 cons
template void PreparedStatementBase::SetValidData(const uint8 index, int64 const& value);
template void PreparedStatementBase::SetValidData(const uint8 index, bool const& value);
template void PreparedStatementBase::SetValidData(const uint8 index, float const& value);
// Old api
void PreparedStatementBase::setBool(const uint8 index, const bool value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setUInt8(const uint8 index, const uint8 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setUInt16(const uint8 index, const uint16 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setUInt32(const uint8 index, const uint32 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setUInt64(const uint8 index, const uint64 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setInt8(const uint8 index, const int8 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setInt16(const uint8 index, const int16 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setInt32(const uint8 index, const int32 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setInt64(const uint8 index, const int64 value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setFloat(const uint8 index, const float value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setDouble(const uint8 index, const double value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setString(const uint8 index, const std::string& value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setStringView(const uint8 index, const std::string_view value)
{
ASSERT(index < statement_data.size());
statement_data[index].data.emplace<std::string>(value);
}
void PreparedStatementBase::setBinary(const uint8 index, const std::vector<uint8>& value)
{
ASSERT(index < statement_data.size());
statement_data[index].data = value;
}
void PreparedStatementBase::setNull(const uint8 index)
{
ASSERT(index < statement_data.size());
statement_data[index].data = nullptr;
}
template void PreparedStatementBase::SetValidData(const uint8 index, std::string const& value);
template void PreparedStatementBase::SetValidData(const uint8 index, std::vector<uint8> const& value);
//- Execution
PreparedStatementTask::PreparedStatementTask(PreparedStatementBase* stmt, bool async) :
m_stmt(stmt), m_result(nullptr)
m_stmt(stmt),
m_result(nullptr)
{
m_has_result = async; // If it's async, then there's a result
if (async)
m_result = new PreparedQueryResultPromise();
}
@@ -177,7 +77,8 @@ m_stmt(stmt), m_result(nullptr)
PreparedStatementTask::~PreparedStatementTask()
{
delete m_stmt;
if (m_has_result && m_result != nullptr)
if (m_has_result && m_result)
delete m_result;
}
@@ -192,6 +93,7 @@ bool PreparedStatementTask::Execute()
m_result->set_value(PreparedQueryResult(nullptr));
return false;
}
m_result->set_value(PreparedQueryResult(result));
return true;
}
@@ -202,45 +104,29 @@ bool PreparedStatementTask::Execute()
template<typename T>
std::string PreparedStatementData::ToString(T value)
{
return fmt::format("{}", value);
return Acore::StringFormatFmt("{}", value);
}
std::string PreparedStatementData::ToString(bool value)
{
return ToString<uint32>(value);
}
std::string PreparedStatementData::ToString(uint8 value)
{
return ToString<uint32>(value);
}
template std::string PreparedStatementData::ToString<uint16>(uint16);
template std::string PreparedStatementData::ToString<uint32>(uint32);
template std::string PreparedStatementData::ToString<uint64>(uint64);
std::string PreparedStatementData::ToString(int8 value)
{
return ToString<int32>(value);
}
template std::string PreparedStatementData::ToString<int16>(int16);
template std::string PreparedStatementData::ToString<int32>(int32);
template std::string PreparedStatementData::ToString<int64>(int64);
template std::string PreparedStatementData::ToString<float>(float);
template std::string PreparedStatementData::ToString<double>(double);
std::string PreparedStatementData::ToString(std::string const& value)
{
return fmt::format("'{}'", value);
}
std::string PreparedStatementData::ToString(std::vector<uint8> const& /*value*/)
template<>
std::string PreparedStatementData::ToString(std::vector<uint8> /*value*/)
{
return "BINARY";
}
std::string PreparedStatementData::ToString(std::nullptr_t)
template std::string PreparedStatementData::ToString(uint8);
template std::string PreparedStatementData::ToString(uint16);
template std::string PreparedStatementData::ToString(uint32);
template std::string PreparedStatementData::ToString(uint64);
template std::string PreparedStatementData::ToString(int8);
template std::string PreparedStatementData::ToString(int16);
template std::string PreparedStatementData::ToString(int32);
template std::string PreparedStatementData::ToString(int64);
template std::string PreparedStatementData::ToString(std::string);
template std::string PreparedStatementData::ToString(float);
template std::string PreparedStatementData::ToString(double);
template std::string PreparedStatementData::ToString(bool);
std::string PreparedStatementData::ToString(std::nullptr_t /*value*/)
{
return "NULL";
}

View File

@@ -20,6 +20,7 @@
#include "Define.h"
#include "Duration.h"
#include "Optional.h"
#include "SQLOperation.h"
#include <future>
#include <tuple>
@@ -60,12 +61,7 @@ struct PreparedStatementData
template<typename T>
static std::string ToString(T value);
static std::string ToString(bool value);
static std::string ToString(uint8 value);
static std::string ToString(int8 value);
static std::string ToString(std::string const& value);
static std::string ToString(std::vector<uint8> const& value);
static std::string ToString(std::nullptr_t);
static std::string ToString(std::nullptr_t /*value*/);
};
//- Upper-level class that is used in code
@@ -77,28 +73,6 @@ public:
explicit PreparedStatementBase(uint32 index, uint8 capacity);
virtual ~PreparedStatementBase();
void setNull(const uint8 index);
void setBool(const uint8 index, const bool value);
void setUInt8(const uint8 index, const uint8 value);
void setUInt16(const uint8 index, const uint16 value);
void setUInt32(const uint8 index, const uint32 value);
void setUInt64(const uint8 index, const uint64 value);
void setInt8(const uint8 index, const int8 value);
void setInt16(const uint8 index, const int16 value);
void setInt32(const uint8 index, const int32 value);
void setInt64(const uint8 index, const int64 value);
void setFloat(const uint8 index, const float value);
void setDouble(const uint8 index, const double value);
void setString(const uint8 index, const std::string& value);
void setStringView(const uint8 index, const std::string_view value);
void setBinary(const uint8 index, const std::vector<uint8>& value);
template <size_t Size>
void setBinary(const uint8 index, std::array<uint8, Size> const& value)
{
std::vector<uint8> vec(value.begin(), value.end());
setBinary(index, vec);
}
// Set numerlic and default binary
template<typename T>
inline Acore::Types::is_default<T> SetData(const uint8 index, T value)

View File

@@ -16,6 +16,7 @@
*/
#include "QueryCallback.h"
#include "Duration.h"
#include "Errors.h"
template<typename T, typename... Args>
@@ -66,13 +67,15 @@ public:
QueryCallbackData(std::function<void(QueryCallback&, QueryResult)>&& callback) : _string(std::move(callback)), _isPrepared(false) { }
QueryCallbackData(std::function<void(QueryCallback&, PreparedQueryResult)>&& callback) : _prepared(std::move(callback)), _isPrepared(true) { }
QueryCallbackData(QueryCallbackData&& right)
QueryCallbackData(QueryCallbackData&& right) noexcept
{
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
MoveFrom(this, std::move(right));
}
QueryCallbackData& operator=(QueryCallbackData&& right)
QueryCallbackData& operator=(QueryCallbackData&& right) noexcept
{
if (this != &right)
{
@@ -82,10 +85,13 @@ public:
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
}
MoveFrom(this, std::move(right));
}
return *this;
}
~QueryCallbackData() { DestroyActiveMember(this); }
private:
@@ -105,19 +111,19 @@ private:
};
// Not using initialization lists to work around segmentation faults when compiling with clang without precompiled headers
QueryCallback::QueryCallback(std::future<QueryResult>&& result)
QueryCallback::QueryCallback(QueryResultFuture&& result)
{
_isPrepared = false;
Construct(_string, std::move(result));
}
QueryCallback::QueryCallback(std::future<PreparedQueryResult>&& result)
QueryCallback::QueryCallback(PreparedQueryResultFuture&& result)
{
_isPrepared = true;
Construct(_prepared, std::move(result));
}
QueryCallback::QueryCallback(QueryCallback&& right)
QueryCallback::QueryCallback(QueryCallback&& right) noexcept
{
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
@@ -125,7 +131,7 @@ QueryCallback::QueryCallback(QueryCallback&& right)
_callbacks = std::move(right._callbacks);
}
QueryCallback& QueryCallback::operator=(QueryCallback&& right)
QueryCallback& QueryCallback::operator=(QueryCallback&& right) noexcept
{
if (this != &right)
{
@@ -135,9 +141,11 @@ QueryCallback& QueryCallback::operator=(QueryCallback&& right)
_isPrepared = right._isPrepared;
ConstructActiveMember(this);
}
MoveFrom(this, std::move(right));
_callbacks = std::move(right._callbacks);
}
return *this;
}
@@ -198,7 +206,7 @@ bool QueryCallback::InvokeIfReady()
if (!_isPrepared)
{
if (_string.valid() && _string.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
if (_string.valid() && _string.wait_for(0s) == std::future_status::ready)
{
QueryResultFuture f(std::move(_string));
std::function<void(QueryCallback&, QueryResult)> cb(std::move(callback._string));
@@ -208,7 +216,7 @@ bool QueryCallback::InvokeIfReady()
}
else
{
if (_prepared.valid() && _prepared.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
if (_prepared.valid() && _prepared.wait_for(0s) == std::future_status::ready)
{
PreparedQueryResultFuture f(std::move(_prepared));
std::function<void(QueryCallback&, PreparedQueryResult)> cb(std::move(callback._prepared));

View File

@@ -31,8 +31,9 @@ class AC_DATABASE_API QueryCallback
public:
explicit QueryCallback(QueryResultFuture&& result);
explicit QueryCallback(PreparedQueryResultFuture&& result);
QueryCallback(QueryCallback&& right);
QueryCallback& operator=(QueryCallback&& right);
QueryCallback(QueryCallback&& right) noexcept;
QueryCallback& operator=(QueryCallback&& right) noexcept;
~QueryCallback();
QueryCallback&& WithCallback(std::function<void(QueryResult)>&& callback);
@@ -60,6 +61,7 @@ private:
QueryResultFuture _string;
PreparedQueryResultFuture _prepared;
};
bool _isPrepared;
struct QueryCallbackData;

View File

@@ -24,8 +24,7 @@
class AC_DATABASE_API SQLQueryHolderBase
{
friend class SQLQueryHolderTask;
private:
std::vector<std::pair<PreparedStatementBase*, PreparedQueryResult>> m_queries;
public:
SQLQueryHolderBase() = default;
virtual ~SQLQueryHolderBase();
@@ -35,6 +34,9 @@ public:
protected:
bool SetPreparedQueryImpl(size_t index, PreparedStatementBase* stmt);
private:
std::vector<std::pair<PreparedStatementBase*, PreparedQueryResult>> m_queries;
};
template<typename T>
@@ -49,10 +51,6 @@ public:
class AC_DATABASE_API SQLQueryHolderTask : public SQLOperation
{
private:
std::shared_ptr<SQLQueryHolderBase> m_holder;
QueryResultHolderPromise m_result;
public:
explicit SQLQueryHolderTask(std::shared_ptr<SQLQueryHolderBase> holder)
: m_holder(std::move(holder)) { }
@@ -61,6 +59,10 @@ public:
bool Execute() override;
QueryResultHolderFuture GetFuture() { return m_result.get_future(); }
private:
std::shared_ptr<SQLQueryHolderBase> m_holder;
QueryResultHolderPromise m_result;
};
class AC_DATABASE_API SQLQueryHolderCallback
@@ -70,7 +72,6 @@ public:
: m_holder(std::move(holder)), m_future(std::move(future)) { }
SQLQueryHolderCallback(SQLQueryHolderCallback&&) = default;
SQLQueryHolderCallback& operator=(SQLQueryHolderCallback&&) = default;
void AfterComplete(std::function<void(SQLQueryHolderBase const&)> callback) &

View File

@@ -24,154 +24,155 @@
namespace
{
static uint32 SizeForType(MYSQL_FIELD* field)
{
switch (field->type)
static uint32 SizeForType(MYSQL_FIELD* field)
{
case MYSQL_TYPE_NULL:
return 0;
case MYSQL_TYPE_TINY:
return 1;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return 2;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_FLOAT:
return 4;
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return 8;
switch (field->type)
{
case MYSQL_TYPE_NULL:
return 0;
case MYSQL_TYPE_TINY:
return 1;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return 2;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_FLOAT:
return 4;
case MYSQL_TYPE_DOUBLE:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return 8;
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return sizeof(MYSQL_TIME);
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return sizeof(MYSQL_TIME);
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return field->max_length + 1;
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return field->max_length + 1;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return 64;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return 64;
case MYSQL_TYPE_GEOMETRY:
/*
Following types are not sent over the wire:
MYSQL_TYPE_ENUM:
MYSQL_TYPE_SET:
*/
default:
LOG_WARN("sql.sql", "SQL::SizeForType(): invalid field type {}", uint32(field->type));
return 0;
}
}
DatabaseFieldTypes MysqlTypeToFieldType(enum_field_types type)
{
switch (type)
{
case MYSQL_TYPE_NULL:
return DatabaseFieldTypes::Null;
case MYSQL_TYPE_TINY:
return DatabaseFieldTypes::Int8;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return DatabaseFieldTypes::Int16;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
return DatabaseFieldTypes::Int32;
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return DatabaseFieldTypes::Int64;
case MYSQL_TYPE_FLOAT:
return DatabaseFieldTypes::Float;
case MYSQL_TYPE_DOUBLE:
return DatabaseFieldTypes::Double;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return DatabaseFieldTypes::Decimal;
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return DatabaseFieldTypes::Date;
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return DatabaseFieldTypes::Binary;
default:
LOG_WARN("sql.sql", "MysqlTypeToFieldType(): invalid field type {}", uint32(type));
break;
case MYSQL_TYPE_GEOMETRY:
/*
Following types are not sent over the wire:
MYSQL_TYPE_ENUM:
MYSQL_TYPE_SET:
*/
default:
LOG_WARN("sql.sql", "SQL::SizeForType(): invalid field type {}", uint32(field->type));
return 0;
}
}
return DatabaseFieldTypes::Null;
}
static char const* FieldTypeToString(enum_field_types type)
{
switch (type)
DatabaseFieldTypes MysqlTypeToFieldType(enum_field_types type)
{
case MYSQL_TYPE_BIT: return "BIT";
case MYSQL_TYPE_BLOB: return "BLOB";
case MYSQL_TYPE_DATE: return "DATE";
case MYSQL_TYPE_DATETIME: return "DATETIME";
case MYSQL_TYPE_NEWDECIMAL: return "NEWDECIMAL";
case MYSQL_TYPE_DECIMAL: return "DECIMAL";
case MYSQL_TYPE_DOUBLE: return "DOUBLE";
case MYSQL_TYPE_ENUM: return "ENUM";
case MYSQL_TYPE_FLOAT: return "FLOAT";
case MYSQL_TYPE_GEOMETRY: return "GEOMETRY";
case MYSQL_TYPE_INT24: return "INT24";
case MYSQL_TYPE_LONG: return "LONG";
case MYSQL_TYPE_LONGLONG: return "LONGLONG";
case MYSQL_TYPE_LONG_BLOB: return "LONG_BLOB";
case MYSQL_TYPE_MEDIUM_BLOB: return "MEDIUM_BLOB";
case MYSQL_TYPE_NEWDATE: return "NEWDATE";
case MYSQL_TYPE_NULL: return "NULL";
case MYSQL_TYPE_SET: return "SET";
case MYSQL_TYPE_SHORT: return "SHORT";
case MYSQL_TYPE_STRING: return "STRING";
case MYSQL_TYPE_TIME: return "TIME";
case MYSQL_TYPE_TIMESTAMP: return "TIMESTAMP";
case MYSQL_TYPE_TINY: return "TINY";
case MYSQL_TYPE_TINY_BLOB: return "TINY_BLOB";
case MYSQL_TYPE_VAR_STRING: return "VAR_STRING";
case MYSQL_TYPE_YEAR: return "YEAR";
default: return "-Unknown-";
}
}
switch (type)
{
case MYSQL_TYPE_NULL:
return DatabaseFieldTypes::Null;
case MYSQL_TYPE_TINY:
return DatabaseFieldTypes::Int8;
case MYSQL_TYPE_YEAR:
case MYSQL_TYPE_SHORT:
return DatabaseFieldTypes::Int16;
case MYSQL_TYPE_INT24:
case MYSQL_TYPE_LONG:
return DatabaseFieldTypes::Int32;
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_BIT:
return DatabaseFieldTypes::Int64;
case MYSQL_TYPE_FLOAT:
return DatabaseFieldTypes::Float;
case MYSQL_TYPE_DOUBLE:
return DatabaseFieldTypes::Double;
case MYSQL_TYPE_DECIMAL:
case MYSQL_TYPE_NEWDECIMAL:
return DatabaseFieldTypes::Decimal;
case MYSQL_TYPE_TIMESTAMP:
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
case MYSQL_TYPE_DATETIME:
return DatabaseFieldTypes::Date;
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
return DatabaseFieldTypes::Binary;
default:
LOG_WARN("sql.sql", "MysqlTypeToFieldType(): invalid field type {}", uint32(type));
break;
}
void InitializeDatabaseFieldMetadata(QueryResultFieldMetadata* meta, MySQLField const* field, uint32 fieldIndex)
{
meta->TableName = field->org_table;
meta->TableAlias = field->table;
meta->Name = field->org_name;
meta->Alias = field->name;
meta->TypeName = FieldTypeToString(field->type);
meta->Index = fieldIndex;
meta->Type = MysqlTypeToFieldType(field->type);
}
return DatabaseFieldTypes::Null;
}
static std::string FieldTypeToString(enum_field_types type)
{
switch (type)
{
case MYSQL_TYPE_BIT: return "BIT";
case MYSQL_TYPE_BLOB: return "BLOB";
case MYSQL_TYPE_DATE: return "DATE";
case MYSQL_TYPE_DATETIME: return "DATETIME";
case MYSQL_TYPE_NEWDECIMAL: return "NEWDECIMAL";
case MYSQL_TYPE_DECIMAL: return "DECIMAL";
case MYSQL_TYPE_DOUBLE: return "DOUBLE";
case MYSQL_TYPE_ENUM: return "ENUM";
case MYSQL_TYPE_FLOAT: return "FLOAT";
case MYSQL_TYPE_GEOMETRY: return "GEOMETRY";
case MYSQL_TYPE_INT24: return "INT24";
case MYSQL_TYPE_LONG: return "LONG";
case MYSQL_TYPE_LONGLONG: return "LONGLONG";
case MYSQL_TYPE_LONG_BLOB: return "LONG_BLOB";
case MYSQL_TYPE_MEDIUM_BLOB: return "MEDIUM_BLOB";
case MYSQL_TYPE_NEWDATE: return "NEWDATE";
case MYSQL_TYPE_NULL: return "NULL";
case MYSQL_TYPE_SET: return "SET";
case MYSQL_TYPE_SHORT: return "SHORT";
case MYSQL_TYPE_STRING: return "STRING";
case MYSQL_TYPE_TIME: return "TIME";
case MYSQL_TYPE_TIMESTAMP: return "TIMESTAMP";
case MYSQL_TYPE_TINY: return "TINY";
case MYSQL_TYPE_TINY_BLOB: return "TINY_BLOB";
case MYSQL_TYPE_VAR_STRING: return "VAR_STRING";
case MYSQL_TYPE_YEAR: return "YEAR";
default: return "-Unknown-";
}
}
void InitializeDatabaseFieldMetadata(QueryResultFieldMetadata* meta, MySQLField const* field, uint32 fieldIndex)
{
meta->TableName = field->org_table;
meta->TableAlias = field->table;
meta->Name = field->org_name;
meta->Alias = field->name;
meta->TypeName = FieldTypeToString(field->type);
meta->Index = fieldIndex;
meta->Type = MysqlTypeToFieldType(field->type);
}
}
ResultSet::ResultSet(MySQLResult* result, MySQLField* fields, uint64 rowCount, uint32 fieldCount) :
_rowCount(rowCount),
_fieldCount(fieldCount),
_result(result),
_fields(fields)
_rowCount(rowCount),
_fieldCount(fieldCount),
_result(result),
_fields(fields)
{
_fieldMetadata.resize(_fieldCount);
_currentRow = new Field[_fieldCount];
for (uint32 i = 0; i < _fieldCount; i++)
{
InitializeDatabaseFieldMetadata(&_fieldMetadata[i], &_fields[i], i);
@@ -179,148 +180,11 @@ _fields(fields)
}
}
PreparedResultSet::PreparedResultSet(MySQLStmt* stmt, MySQLResult* result, uint64 rowCount, uint32 fieldCount) :
m_rowCount(rowCount),
m_rowPosition(0),
m_fieldCount(fieldCount),
m_rBind(nullptr),
m_stmt(stmt),
m_metadataResult(result)
{
if (!m_metadataResult)
return;
if (m_stmt->bind_result_done)
{
delete[] m_stmt->bind->length;
delete[] m_stmt->bind->is_null;
}
m_rBind = new MySQLBind[m_fieldCount];
//- for future readers wondering where this is freed - mysql_stmt_bind_result moves pointers to these
// from m_rBind to m_stmt->bind and it is later freed by the `if (m_stmt->bind_result_done)` block just above here
// MYSQL_STMT lifetime is equal to connection lifetime
MySQLBool* m_isNull = new MySQLBool[m_fieldCount];
unsigned long* m_length = new unsigned long[m_fieldCount];
memset(m_isNull, 0, sizeof(MySQLBool) * m_fieldCount);
memset(m_rBind, 0, sizeof(MySQLBind) * m_fieldCount);
memset(m_length, 0, sizeof(unsigned long) * m_fieldCount);
//- This is where we store the (entire) resultset
if (mysql_stmt_store_result(m_stmt))
{
LOG_WARN("sql.sql", "{}:mysql_stmt_store_result, cannot bind result from MySQL server. Error: {}", __FUNCTION__, mysql_stmt_error(m_stmt));
delete[] m_rBind;
delete[] m_isNull;
delete[] m_length;
return;
}
m_rowCount = mysql_stmt_num_rows(m_stmt);
//- This is where we prepare the buffer based on metadata
MySQLField* field = reinterpret_cast<MySQLField*>(mysql_fetch_fields(m_metadataResult));
m_fieldMetadata.resize(m_fieldCount);
std::size_t rowSize = 0;
for (uint32 i = 0; i < m_fieldCount; ++i)
{
uint32 size = SizeForType(&field[i]);
rowSize += size;
InitializeDatabaseFieldMetadata(&m_fieldMetadata[i], &field[i], i);
m_rBind[i].buffer_type = field[i].type;
m_rBind[i].buffer_length = size;
m_rBind[i].length = &m_length[i];
m_rBind[i].is_null = &m_isNull[i];
m_rBind[i].error = nullptr;
m_rBind[i].is_unsigned = field[i].flags & UNSIGNED_FLAG;
}
char* dataBuffer = new char[rowSize * m_rowCount];
for (uint32 i = 0, offset = 0; i < m_fieldCount; ++i)
{
m_rBind[i].buffer = dataBuffer + offset;
offset += m_rBind[i].buffer_length;
}
//- This is where we bind the bind the buffer to the statement
if (mysql_stmt_bind_result(m_stmt, m_rBind))
{
LOG_WARN("sql.sql", "{}:mysql_stmt_bind_result, cannot bind result from MySQL server. Error: {}", __FUNCTION__, mysql_stmt_error(m_stmt));
mysql_stmt_free_result(m_stmt);
CleanUp();
delete[] m_isNull;
delete[] m_length;
return;
}
m_rows.resize(uint32(m_rowCount) * m_fieldCount);
while (_NextRow())
{
for (uint32 fIndex = 0; fIndex < m_fieldCount; ++fIndex)
{
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetMetadata(&m_fieldMetadata[fIndex]);
unsigned long buffer_length = m_rBind[fIndex].buffer_length;
unsigned long fetched_length = *m_rBind[fIndex].length;
if (!*m_rBind[fIndex].is_null)
{
void* buffer = m_stmt->bind[fIndex].buffer;
switch (m_rBind[fIndex].buffer_type)
{
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
// warning - the string will not be null-terminated if there is no space for it in the buffer
// when mysql_stmt_fetch returned MYSQL_DATA_TRUNCATED
// we cannot blindly null-terminate the data either as it may be retrieved as binary blob and not specifically a string
// in this case using Field::GetCString will result in garbage
// TODO: remove Field::GetCString and use std::string_view in C++17
if (fetched_length < buffer_length)
*((char*)buffer + fetched_length) = '\0';
break;
default:
break;
}
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetByteValue(
(char const*)buffer,
fetched_length);
// move buffer pointer to next part
m_stmt->bind[fIndex].buffer = (char*)buffer + rowSize;
}
else
{
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetByteValue(
nullptr,
*m_rBind[fIndex].length);
}
}
m_rowPosition++;
}
m_rowPosition = 0;
/// All data is buffered, let go of mysql c api structures
mysql_stmt_free_result(m_stmt);
}
ResultSet::~ResultSet()
{
CleanUp();
}
PreparedResultSet::~PreparedResultSet()
{
CleanUp();
}
bool ResultSet::NextRow()
{
MYSQL_ROW row;
@@ -355,6 +219,169 @@ std::string ResultSet::GetFieldName(uint32 index) const
return _fields[index].name;
}
void ResultSet::CleanUp()
{
if (_currentRow)
{
delete[] _currentRow;
_currentRow = nullptr;
}
if (_result)
{
mysql_free_result(_result);
_result = nullptr;
}
}
Field const& ResultSet::operator[](std::size_t index) const
{
ASSERT(index < _fieldCount);
return _currentRow[index];
}
void ResultSet::AssertRows(std::size_t sizeRows)
{
ASSERT(sizeRows == _fieldCount);
}
PreparedResultSet::PreparedResultSet(MySQLStmt* stmt, MySQLResult* result, uint64 rowCount, uint32 fieldCount) :
m_rowCount(rowCount),
m_rowPosition(0),
m_fieldCount(fieldCount),
m_rBind(nullptr),
m_stmt(stmt),
m_metadataResult(result)
{
if (!m_metadataResult)
return;
if (m_stmt->bind_result_done)
{
delete[] m_stmt->bind->length;
delete[] m_stmt->bind->is_null;
}
m_rBind = new MySQLBind[m_fieldCount];
//- for future readers wondering where this is freed - mysql_stmt_bind_result moves pointers to these
// from m_rBind to m_stmt->bind and it is later freed by the `if (m_stmt->bind_result_done)` block just above here
// MYSQL_STMT lifetime is equal to connection lifetime
MySQLBool* m_isNull = new MySQLBool[m_fieldCount];
unsigned long* m_length = new unsigned long[m_fieldCount];
memset(m_isNull, 0, sizeof(MySQLBool) * m_fieldCount);
memset(m_rBind, 0, sizeof(MySQLBind) * m_fieldCount);
memset(m_length, 0, sizeof(unsigned long) * m_fieldCount);
//- This is where we store the (entire) resultset
if (mysql_stmt_store_result(m_stmt))
{
LOG_WARN("sql.sql", "{}:mysql_stmt_store_result, cannot bind result from MySQL server. Error: {}", __FUNCTION__, mysql_stmt_error(m_stmt));
delete[] m_rBind;
delete[] m_isNull;
delete[] m_length;
return;
}
m_rowCount = mysql_stmt_num_rows(m_stmt);
//- This is where we prepare the buffer based on metadata
MySQLField* field = reinterpret_cast<MySQLField*>(mysql_fetch_fields(m_metadataResult));
m_fieldMetadata.resize(m_fieldCount);
std::size_t rowSize = 0;
for (uint32 i = 0; i < m_fieldCount; ++i)
{
uint32 size = SizeForType(&field[i]);
rowSize += size;
InitializeDatabaseFieldMetadata(&m_fieldMetadata[i], &field[i], i);
m_rBind[i].buffer_type = field[i].type;
m_rBind[i].buffer_length = size;
m_rBind[i].length = &m_length[i];
m_rBind[i].is_null = &m_isNull[i];
m_rBind[i].error = nullptr;
m_rBind[i].is_unsigned = field[i].flags & UNSIGNED_FLAG;
}
char* dataBuffer = new char[rowSize * m_rowCount];
for (uint32 i = 0, offset = 0; i < m_fieldCount; ++i)
{
m_rBind[i].buffer = dataBuffer + offset;
offset += m_rBind[i].buffer_length;
}
//- This is where we bind the bind the buffer to the statement
if (mysql_stmt_bind_result(m_stmt, m_rBind))
{
LOG_WARN("sql.sql", "{}:mysql_stmt_bind_result, cannot bind result from MySQL server. Error: {}", __FUNCTION__, mysql_stmt_error(m_stmt));
mysql_stmt_free_result(m_stmt);
CleanUp();
delete[] m_isNull;
delete[] m_length;
return;
}
m_rows.resize(uint32(m_rowCount) * m_fieldCount);
while (_NextRow())
{
for (uint32 fIndex = 0; fIndex < m_fieldCount; ++fIndex)
{
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetMetadata(&m_fieldMetadata[fIndex]);
unsigned long buffer_length = m_rBind[fIndex].buffer_length;
unsigned long fetched_length = *m_rBind[fIndex].length;
if (!*m_rBind[fIndex].is_null)
{
void* buffer = m_stmt->bind[fIndex].buffer;
switch (m_rBind[fIndex].buffer_type)
{
case MYSQL_TYPE_TINY_BLOB:
case MYSQL_TYPE_MEDIUM_BLOB:
case MYSQL_TYPE_LONG_BLOB:
case MYSQL_TYPE_BLOB:
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
// warning - the string will not be null-terminated if there is no space for it in the buffer
// when mysql_stmt_fetch returned MYSQL_DATA_TRUNCATED
// we cannot blindly null-terminate the data either as it may be retrieved as binary blob and not specifically a string
// in this case using Field::GetCString will result in garbage
// TODO: remove Field::GetCString and use std::string_view in C++17
if (fetched_length < buffer_length)
*((char*)buffer + fetched_length) = '\0';
break;
default:
break;
}
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetByteValue((char const*)buffer, fetched_length);
// move buffer pointer to next part
m_stmt->bind[fIndex].buffer = (char*)buffer + rowSize;
}
else
{
m_rows[uint32(m_rowPosition) * m_fieldCount + fIndex].SetByteValue(nullptr, *m_rBind[fIndex].length);
}
}
m_rowPosition++;
}
m_rowPosition = 0;
/// All data is buffered, let go of mysql c api structures
mysql_stmt_free_result(m_stmt);
}
PreparedResultSet::~PreparedResultSet()
{
CleanUp();
}
bool PreparedResultSet::NextRow()
{
/// Only updates the m_rowPosition so upper level code knows in which element
@@ -376,27 +403,6 @@ bool PreparedResultSet::_NextRow()
return retval == 0 || retval == MYSQL_DATA_TRUNCATED;
}
void ResultSet::CleanUp()
{
if (_currentRow)
{
delete [] _currentRow;
_currentRow = nullptr;
}
if (_result)
{
mysql_free_result(_result);
_result = nullptr;
}
}
Field const& ResultSet::operator[](std::size_t index) const
{
ASSERT(index < _fieldCount);
return _currentRow[index];
}
Field* PreparedResultSet::Fetch() const
{
ASSERT(m_rowPosition < m_rowCount);
@@ -422,3 +428,9 @@ void PreparedResultSet::CleanUp()
m_rBind = nullptr;
}
}
void PreparedResultSet::AssertRows(std::size_t sizeRows)
{
ASSERT(m_rowPosition < m_rowCount);
ASSERT(sizeRows == m_fieldCount, "> Tuple size != count fields");
}

View File

@@ -20,6 +20,8 @@
#include "DatabaseEnvFwd.h"
#include "Define.h"
#include "Field.h"
#include <tuple>
#include <vector>
class AC_DATABASE_API ResultSet
@@ -36,6 +38,22 @@ public:
[[nodiscard]] Field* Fetch() const { return _currentRow; }
Field const& operator[](std::size_t index) const;
template<typename... Ts>
inline std::tuple<Ts...> FetchTuple()
{
AssertRows(sizeof...(Ts));
std::tuple<Ts...> theTuple = {};
std::apply([this](Ts&... args)
{
uint8 index{ 0 };
((args = _currentRow[index].Get<Ts>(), index++), ...);
}, theTuple);
return theTuple;
}
protected:
std::vector<QueryResultFieldMetadata> _fieldMetadata;
uint64 _rowCount;
@@ -44,6 +62,8 @@ protected:
private:
void CleanUp();
void AssertRows(std::size_t sizeRows);
MySQLResult* _result;
MySQLField* _fields;
@@ -64,6 +84,22 @@ public:
[[nodiscard]] Field* Fetch() const;
Field const& operator[](std::size_t index) const;
template<typename... Ts>
inline std::tuple<Ts...> FetchTuple()
{
AssertRows(sizeof...(Ts));
std::tuple<Ts...> theTuple = {};
std::apply([this](Ts&... args)
{
uint8 index{ 0 };
((args = m_rows[uint32(m_rowPosition) * m_fieldCount + index].Get<Ts>(), index++), ...);
}, theTuple);
return theTuple;
}
protected:
std::vector<QueryResultFieldMetadata> m_fieldMetadata;
std::vector<Field> m_rows;
@@ -79,6 +115,8 @@ private:
void CleanUp();
bool _NextRow();
void AssertRows(std::size_t sizeRows);
PreparedResultSet(PreparedResultSet const& right) = delete;
PreparedResultSet& operator=(PreparedResultSet const& right) = delete;
};

View File

@@ -20,13 +20,7 @@
#include "DatabaseEnvFwd.h"
#include "Define.h"
//- Union that holds element data
union SQLElementUnion
{
PreparedStatementBase* stmt;
char const* query;
};
#include <variant>
//- Type specifier of our element data
enum SQLElementDataType
@@ -38,7 +32,7 @@ enum SQLElementDataType
//- The element
struct SQLElementData
{
SQLElementUnion element;
std::variant<PreparedStatementBase*, std::string> element;
SQLElementDataType type;
};
@@ -55,6 +49,7 @@ public:
Execute();
return 0;
}
virtual bool Execute() = 0;
virtual void SetConnection(MySQLConnection* con) { m_conn = con; }

View File

@@ -16,6 +16,7 @@
*/
#include "Transaction.h"
#include "Errors.h"
#include "Log.h"
#include "MySQLConnection.h"
#include "PreparedStatement.h"
@@ -26,24 +27,24 @@
std::mutex TransactionTask::_deadlockLock;
#define DEADLOCK_MAX_RETRY_TIME_MS 60000
constexpr Milliseconds DEADLOCK_MAX_RETRY_TIME_MS = 1min;
//- Append a raw ad-hoc query to the transaction
void TransactionBase::Append(char const* sql)
void TransactionBase::Append(std::string_view sql)
{
SQLElementData data;
SQLElementData data = {};
data.type = SQL_ELEMENT_RAW;
data.element.query = strdup(sql);
m_queries.push_back(data);
data.element = std::string(sql);
m_queries.emplace_back(data);
}
//- Append a prepared statement to the transaction
void TransactionBase::AppendPreparedStatement(PreparedStatementBase* stmt)
{
SQLElementData data;
SQLElementData data = {};
data.type = SQL_ELEMENT_PREPARED;
data.element.stmt = stmt;
m_queries.push_back(data);
data.element = stmt;
m_queries.emplace_back(data);
}
void TransactionBase::Cleanup()
@@ -52,15 +53,38 @@ void TransactionBase::Cleanup()
if (_cleanedUp)
return;
for (SQLElementData const& data : m_queries)
for (SQLElementData& data : m_queries)
{
switch (data.type)
{
case SQL_ELEMENT_PREPARED:
delete data.element.stmt;
{
try
{
PreparedStatementBase* stmt = std::get<PreparedStatementBase*>(data.element);
ASSERT(stmt);
delete stmt;
}
catch (const std::bad_variant_access& ex)
{
LOG_FATAL("sql.sql", "> PreparedStatementBase not found in SQLElementData. {}", ex.what());
ABORT();
}
}
break;
case SQL_ELEMENT_RAW:
free((void*)(data.element.query));
{
try
{
std::get<std::string>(data.element).clear();
}
catch (const std::bad_variant_access& ex)
{
LOG_FATAL("sql.sql", "> std::string not found in SQLElementData. {}", ex.what());
ABORT();
}
}
break;
}
}
@@ -72,6 +96,7 @@ void TransactionBase::Cleanup()
bool TransactionTask::Execute()
{
int errorCode = TryExecute();
if (!errorCode)
return true;
@@ -81,15 +106,17 @@ bool TransactionTask::Execute()
threadIdStream << std::this_thread::get_id();
std::string threadId = threadIdStream.str();
// Make sure only 1 async thread retries a transaction so they don't keep dead-locking each other
std::lock_guard<std::mutex> lock(_deadlockLock);
for (uint32 loopDuration = 0, startMSTime = getMSTime(); loopDuration <= DEADLOCK_MAX_RETRY_TIME_MS; loopDuration = GetMSTimeDiffToNow(startMSTime))
{
if (!TryExecute())
return true;
// Make sure only 1 async thread retries a transaction so they don't keep dead-locking each other
std::lock_guard<std::mutex> lock(_deadlockLock);
LOG_WARN("sql.sql", "Deadlocked SQL Transaction, retrying. Loop timer: {} ms, Thread Id: {}", loopDuration, threadId);
for (Milliseconds loopDuration = 0s, startMSTime = GetTimeMS(); loopDuration <= DEADLOCK_MAX_RETRY_TIME_MS; loopDuration = GetMSTimeDiffToNow(startMSTime))
{
if (!TryExecute())
return true;
LOG_WARN("sql.sql", "Deadlocked SQL Transaction, retrying. Loop timer: {} ms, Thread Id: {}", loopDuration.count(), threadId);
}
}
LOG_ERROR("sql.sql", "Fatal deadlocked SQL Transaction, it will not be retried anymore. Thread Id: {}", threadId);
@@ -126,17 +153,20 @@ bool TransactionWithResultTask::Execute()
threadIdStream << std::this_thread::get_id();
std::string threadId = threadIdStream.str();
// Make sure only 1 async thread retries a transaction so they don't keep dead-locking each other
std::lock_guard<std::mutex> lock(_deadlockLock);
for (uint32 loopDuration = 0, startMSTime = getMSTime(); loopDuration <= DEADLOCK_MAX_RETRY_TIME_MS; loopDuration = GetMSTimeDiffToNow(startMSTime))
{
if (!TryExecute())
{
m_result.set_value(true);
return true;
}
// Make sure only 1 async thread retries a transaction so they don't keep dead-locking each other
std::lock_guard<std::mutex> lock(_deadlockLock);
LOG_WARN("sql.sql", "Deadlocked SQL Transaction, retrying. Loop timer: {} ms, Thread Id: {}", loopDuration, threadId);
for (Milliseconds loopDuration = 0s, startMSTime = GetTimeMS(); loopDuration <= DEADLOCK_MAX_RETRY_TIME_MS; loopDuration = GetMSTimeDiffToNow(startMSTime))
{
if (!TryExecute())
{
m_result.set_value(true);
return true;
}
LOG_WARN("sql.sql", "Deadlocked SQL Transaction, retrying. Loop timer: {} ms, Thread Id: {}", loopDuration.count(), threadId);
}
}
LOG_ERROR("sql.sql", "Fatal deadlocked SQL Transaction, it will not be retried anymore. Thread Id: {}", threadId);
@@ -151,7 +181,7 @@ bool TransactionWithResultTask::Execute()
bool TransactionCallback::InvokeIfReady()
{
if (m_future.valid() && m_future.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
if (m_future.valid() && m_future.wait_for(0s) == std::future_status::ready)
{
m_callback(m_future.get());
return true;

View File

@@ -30,21 +30,22 @@
/*! Transactions, high level class. */
class AC_DATABASE_API TransactionBase
{
friend class TransactionTask;
friend class MySQLConnection;
friend class TransactionTask;
friend class MySQLConnection;
template <typename T>
friend class DatabaseWorkerPool;
template <typename T>
friend class DatabaseWorkerPool;
public:
TransactionBase() = default;
virtual ~TransactionBase() { Cleanup(); }
void Append(char const* sql);
template<typename Format, typename... Args>
void PAppend(Format&& sql, Args&&... args)
void Append(std::string_view sql);
template<typename... Args>
void Append(std::string_view sql, Args&&... args)
{
Append(Acore::StringFormat(std::forward<Format>(sql), std::forward<Args>(args)...).c_str());
Append(Acore::StringFormatFmt(sql, std::forward<Args>(args)...));
}
[[nodiscard]] std::size_t GetSize() const { return m_queries.size(); }
@@ -63,6 +64,7 @@ class Transaction : public TransactionBase
{
public:
using TransactionBase::Append;
void Append(PreparedStatement<T>* statement)
{
AppendPreparedStatement(statement);
@@ -72,9 +74,11 @@ public:
/*! Low level class*/
class AC_DATABASE_API TransactionTask : public SQLOperation
{
template <class T> friend class DatabaseWorkerPool;
friend class DatabaseWorker;
friend class TransactionCallback;
template <class T>
friend class DatabaseWorkerPool;
friend class DatabaseWorker;
friend class TransactionCallback;
public:
TransactionTask(std::shared_ptr<TransactionBase> trans) : m_trans(std::move(trans)) { }

View File

@@ -27,16 +27,16 @@ AppenderDB::~AppenderDB() { }
void AppenderDB::_write(LogMessage const* message)
{
// Avoid infinite loop, PExecute triggers Logging with "sql.sql" type
// Avoid infinite loop, Execute triggers Logging with "sql.sql" type
if (!enabled || (message->type.find("sql") != std::string::npos))
return;
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_LOG);
stmt->setUInt64(0, message->mtime.count());
stmt->setUInt32(1, realmId);
stmt->setString(2, message->type);
stmt->setUInt8(3, uint8(message->level));
stmt->setString(4, message->text);
stmt->SetData(0, message->mtime.count());
stmt->SetData(1, realmId);
stmt->SetData(2, message->type);
stmt->SetData(3, uint8(message->level));
stmt->SetData(4, message->text);
LoginDatabase.Execute(stmt);
}

View File

@@ -132,8 +132,8 @@ UpdateFetcher::DirectoryStorage UpdateFetcher::ReceiveIncludedDirectories() cons
{
Field* fields = result->Fetch();
std::string path = fields[0].GetString();
std::string state = fields[1].GetString();
std::string path = fields[0].Get<std::string>();
std::string state = fields[1].Get<std::string>();
if (path.substr(0, 1) == "$")
path = _sourceDirectory->generic_string() + path.substr(1);
@@ -194,7 +194,7 @@ UpdateFetcher::AppliedFileStorage UpdateFetcher::ReceiveAppliedFiles() const
AppliedFileEntry const entry =
{
fields[0].GetString(), fields[1].GetString(), AppliedFileEntry::StateConvert(fields[2].GetString()), fields[3].GetUInt64()
fields[0].Get<std::string>(), fields[1].Get<std::string>(), AppliedFileEntry::StateConvert(fields[2].Get<std::string>()), fields[3].Get<uint64>()
};
map.emplace(entry.name, entry);