feat(Core/Common): add Asio network threading (#6063)

This commit is contained in:
Kargatum
2021-05-27 21:09:31 +07:00
committed by GitHub
parent 2ae84e2faf
commit c1e96064e9
11 changed files with 1035 additions and 0 deletions

View File

@@ -0,0 +1,146 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef __ASYNCACCEPT_H_
#define __ASYNCACCEPT_H_
#include "IoContext.h"
#include "IpAddress.h"
#include "Log.h"
#include <boost/asio/ip/tcp.hpp>
#include <functional>
#include <atomic>
using boost::asio::ip::tcp;
#if BOOST_VERSION >= 106600
#define WARHEAD_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
#else
#define WARHEAD_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_connections
#endif
class AsyncAcceptor
{
public:
typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex);
AsyncAcceptor(acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
_acceptor(ioContext), _endpoint(acore::Net::make_address(bindIp), port),
_socket(ioContext), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this))
{
}
template<class T>
void AsyncAccept();
template<AcceptCallback acceptCallback>
void AsyncAcceptWithCallback()
{
tcp::socket* socket;
uint32 threadIndex;
std::tie(socket, threadIndex) = _socketFactory();
_acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
{
if (!error)
{
try
{
socket->non_blocking(true);
acceptCallback(std::move(*socket), threadIndex);
}
catch (boost::system::system_error const& err)
{
LOG_INFO("network", "Failed to initialize client's socket %s", err.what());
}
}
if (!_closed)
this->AsyncAcceptWithCallback<acceptCallback>();
});
}
bool Bind()
{
boost::system::error_code errorCode;
_acceptor.open(_endpoint.protocol(), errorCode);
if (errorCode)
{
LOG_INFO("network", "Failed to open acceptor %s", errorCode.message().c_str());
return false;
}
#if WARHEAD_PLATFORM != WARHEAD_PLATFORM_WINDOWS
_acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(true), errorCode);
if (errorCode)
{
LOG_INFO("network", "Failed to set reuse_address option on acceptor %s", errorCode.message().c_str());
return false;
}
#endif
_acceptor.bind(_endpoint, errorCode);
if (errorCode)
{
LOG_INFO("network", "Could not bind to %s:%u %s", _endpoint.address().to_string().c_str(), _endpoint.port(), errorCode.message().c_str());
return false;
}
_acceptor.listen(WARHEAD_MAX_LISTEN_CONNECTIONS, errorCode);
if (errorCode)
{
LOG_INFO("network", "Failed to start listening on %s:%u %s", _endpoint.address().to_string().c_str(), _endpoint.port(), errorCode.message().c_str());
return false;
}
return true;
}
void Close()
{
if (_closed.exchange(true))
return;
boost::system::error_code err;
_acceptor.close(err);
}
void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; }
private:
std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
tcp::acceptor _acceptor;
tcp::endpoint _endpoint;
tcp::socket _socket;
std::atomic<bool> _closed;
std::function<std::pair<tcp::socket*, uint32>()> _socketFactory;
};
template<class T>
void AsyncAcceptor::AsyncAccept()
{
_acceptor.async_accept(_socket, [this](boost::system::error_code error)
{
if (!error)
{
try
{
// this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
std::make_shared<T>(std::move(this->_socket))->Start();
}
catch (boost::system::system_error const& err)
{
LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what());
}
}
// lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
if (!_closed)
this->AsyncAccept<T>();
});
}
#endif /* __ASYNCACCEPT_H_ */

View File

@@ -0,0 +1,166 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef NetworkThread_h__
#define NetworkThread_h__
#include "Define.h"
#include "DeadlineTimer.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
#include "Timer.h"
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <chrono>
#include <memory>
#include <mutex>
#include <set>
#include <thread>
using boost::asio::ip::tcp;
template<class SocketType>
class NetworkThread
{
public:
NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1),
_acceptSocket(_ioContext), _updateTimer(_ioContext) { }
virtual ~NetworkThread()
{
Stop();
if (_thread)
{
Wait();
delete _thread;
}
}
void Stop()
{
_stopped = true;
_ioContext.stop();
}
bool Start()
{
if (_thread)
return false;
_thread = new std::thread(&NetworkThread::Run, this);
return true;
}
void Wait()
{
ASSERT(_thread);
_thread->join();
delete _thread;
_thread = nullptr;
}
int32 GetConnectionCount() const
{
return _connections;
}
virtual void AddSocket(std::shared_ptr<SocketType> sock)
{
std::lock_guard<std::mutex> lock(_newSocketsLock);
++_connections;
_newSockets.push_back(sock);
SocketAdded(sock);
}
tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
void AddNewSockets()
{
std::lock_guard<std::mutex> lock(_newSocketsLock);
if (_newSockets.empty())
return;
for (std::shared_ptr<SocketType> sock : _newSockets)
{
if (!sock->IsOpen())
{
SocketRemoved(sock);
--_connections;
}
else
_sockets.push_back(sock);
}
_newSockets.clear();
}
void Run()
{
LOG_DEBUG("misc", "Network Thread Starting");
_updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
_updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
_ioContext.run();
LOG_DEBUG("misc", "Network Thread exits");
_newSockets.clear();
_sockets.clear();
}
void Update()
{
if (_stopped)
return;
_updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
_updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
AddNewSockets();
_sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
{
if (!sock->Update())
{
if (sock->IsOpen())
sock->CloseSocket();
this->SocketRemoved(sock);
--this->_connections;
return true;
}
return false;
}), _sockets.end());
}
private:
typedef std::vector<std::shared_ptr<SocketType>> SocketContainer;
std::atomic<int32> _connections;
std::atomic<bool> _stopped;
std::thread* _thread;
SocketContainer _sockets;
std::mutex _newSocketsLock;
SocketContainer _newSockets;
acore::Asio::IoContext _ioContext;
tcp::socket _acceptSocket;
acore::Asio::DeadlineTimer _updateTimer;
};
#endif // NetworkThread_h__

View File

@@ -0,0 +1,276 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef __SOCKET_H__
#define __SOCKET_H__
#include "MessageBuffer.h"
#include "Log.h"
#include <atomic>
#include <queue>
#include <memory>
#include <functional>
#include <type_traits>
#include <boost/asio/ip/tcp.hpp>
using boost::asio::ip::tcp;
#define READ_BLOCK_SIZE 4096
#ifdef BOOST_ASIO_HAS_IOCP
#define AC_SOCKET_USE_IOCP
#endif
template<class T>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
}
virtual ~Socket()
{
_closed = true;
boost::system::error_code error;
_socket.close(error);
}
virtual void Start() = 0;
virtual bool Update()
{
if (_closed)
{
return false;
}
#ifndef AC_SOCKET_USE_IOCP
if (_isWritingAsync || (_writeQueue.empty() && !_closing))
{
return true;
}
for (; HandleQueue();)
;
#endif
return true;
}
boost::asio::ip::address GetRemoteIpAddress() const
{
return _remoteAddress;
}
uint16 GetRemotePort() const
{
return _remotePort;
}
void AsyncRead()
{
if (!IsOpen())
{
return;
}
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
{
if (!IsOpen())
{
return;
}
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer));
#ifdef AC_SOCKET_USE_IOCP
AsyncProcessQueue();
#endif
}
bool IsOpen() const { return !_closed && !_closing; }
void CloseSocket()
{
if (_closed.exchange(true))
return;
boost::system::error_code shutdownError;
_socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
if (shutdownError)
LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(),
shutdownError.value(), shutdownError.message().c_str());
OnClose();
}
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket() { _closing = true; }
MessageBuffer& GetReadBuffer() { return _readBuffer; }
protected:
virtual void OnClose() { }
virtual void ReadHandler() = 0;
bool AsyncProcessQueue()
{
if (_isWritingAsync)
return false;
_isWritingAsync = true;
#ifdef AC_SOCKET_USE_IOCP
MessageBuffer& buffer = _writeQueue.front();
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler,
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
#else
_socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandlerWrapper,
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
#endif
return false;
}
void SetNoDelay(bool enable)
{
boost::system::error_code err;
_socket.set_option(tcp::no_delay(enable), err);
if (err)
LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for %s - %d (%s)",
GetRemoteIpAddress().to_string().c_str(), err.value(), err.message().c_str());
}
private:
void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return;
}
_readBuffer.WriteCompleted(transferredBytes);
ReadHandler();
}
#ifdef AC_SOCKET_USE_IOCP
void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
{
if (!error)
{
_isWritingAsync = false;
_writeQueue.front().ReadCompleted(transferedBytes);
if (!_writeQueue.front().GetActiveSize())
_writeQueue.pop();
if (!_writeQueue.empty())
AsyncProcessQueue();
else if (_closing)
CloseSocket();
}
else
CloseSocket();
}
#else
void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
{
_isWritingAsync = false;
HandleQueue();
}
bool HandleQueue()
{
if (_writeQueue.empty())
return false;
MessageBuffer& queuedMessage = _writeQueue.front();
std::size_t bytesToSend = queuedMessage.GetActiveSize();
boost::system::error_code error;
std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
if (error)
{
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
{
return AsyncProcessQueue();
}
_writeQueue.pop();
if (_closing && _writeQueue.empty())
{
CloseSocket();
}
return false;
}
else if (bytesSent == 0)
{
_writeQueue.pop();
if (_closing && _writeQueue.empty())
{
CloseSocket();
}
return false;
}
else if (bytesSent < bytesToSend) // now n > 0
{
queuedMessage.ReadCompleted(bytesSent);
return AsyncProcessQueue();
}
_writeQueue.pop();
if (_closing && _writeQueue.empty())
{
CloseSocket();
}
return !_writeQueue.empty();
}
#endif
tcp::socket _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort;
MessageBuffer _readBuffer;
std::queue<MessageBuffer> _writeQueue;
std::atomic<bool> _closed;
std::atomic<bool> _closing;
bool _isWritingAsync;
};
#endif // __SOCKET_H__

View File

@@ -0,0 +1,131 @@
/*
* Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
* Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
*/
#ifndef SocketMgr_h__
#define SocketMgr_h__
#include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
#include <boost/asio/ip/tcp.hpp>
#include <memory>
using boost::asio::ip::tcp;
template<class SocketType>
class SocketMgr
{
public:
virtual ~SocketMgr()
{
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
}
virtual bool StartNetwork(acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
ASSERT(threadCount > 0);
AsyncAcceptor* acceptor = nullptr;
try
{
acceptor = new AsyncAcceptor(ioContext, bindIp, port);
}
catch (boost::system::system_error const& err)
{
LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork (%s:%u): %s", bindIp.c_str(), port, err.what());
return false;
}
if (!acceptor->Bind())
{
LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
delete acceptor;
return false;
}
_acceptor = acceptor;
_threadCount = threadCount;
_threads = CreateThreads();
ASSERT(_threads);
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Start();
_acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
return true;
}
virtual void StopNetwork()
{
_acceptor->Close();
if (_threadCount != 0)
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Stop();
Wait();
delete _acceptor;
_acceptor = nullptr;
delete[] _threads;
_threads = nullptr;
_threadCount = 0;
}
void Wait()
{
if (_threadCount != 0)
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Wait();
}
virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
{
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();
_threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
{
LOG_WARN("network", "Failed to retrieve client's remote address %s", err.what());
}
}
int32 GetNetworkThreadCount() const { return _threadCount; }
uint32 SelectThreadWithMinConnections() const
{
uint32 min = 0;
for (int32 i = 1; i < _threadCount; ++i)
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
min = i;
return min;
}
std::pair<tcp::socket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
}
protected:
SocketMgr() :
_acceptor(nullptr), _threads(nullptr), _threadCount(0) { }
virtual NetworkThread<SocketType>* CreateThreads() const = 0;
AsyncAcceptor* _acceptor;
NetworkThread<SocketType>* _threads;
int32 _threadCount;
};
#endif // SocketMgr_h__