refactor(Core/Network): Port TrinityCore socket optimizations (#24384)

Co-authored-by: blinkysc <blinkysc@users.noreply.github.com>
Co-authored-by: Shauren <shauren@users.noreply.github.com>
This commit is contained in:
blinkysc
2026-01-15 07:47:58 -06:00
committed by GitHub
parent a8ce95ad71
commit d908b4c2fc
16 changed files with 242 additions and 75 deletions

View File

@@ -20,6 +20,7 @@
#include "IpAddress.h"
#include "Log.h"
#include "Socket.h"
#include "Systemd.h"
#include <atomic>
#include <boost/asio/ip/tcp.hpp>
@@ -32,7 +33,7 @@ constexpr auto ACORE_MAX_LISTEN_CONNECTIONS = boost::asio::socket_base::max_list
class AsyncAcceptor
{
public:
typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex);
typedef void(*AcceptCallback)(IoContextTcpSocket&& newSocket, uint32 threadIndex);
AsyncAcceptor(Acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, bool supportSocketActivation = false) :
_acceptor(ioContext), _endpoint(Acore::Net::make_address(bindIp), port),
@@ -56,7 +57,7 @@ public:
template<AcceptCallback acceptCallback>
void AsyncAcceptWithCallback()
{
tcp::socket* socket;
IoContextTcpSocket* socket;
uint32 threadIndex;
std::tie(socket, threadIndex) = _socketFactory();
_acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
@@ -129,16 +130,16 @@ public:
_acceptor.close(err);
}
void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; }
void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
private:
std::pair<tcp::socket*, uint32> DefaultSocketFactory() { return std::make_pair(&_socket, 0); }
std::pair<IoContextTcpSocket*, uint32> DefaultSocketFactory() { return std::make_pair(&_socket, 0); }
tcp::acceptor _acceptor;
tcp::endpoint _endpoint;
tcp::socket _socket;
boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
boost::asio::ip::tcp::endpoint _endpoint;
IoContextTcpSocket _socket;
std::atomic<bool> _closed;
std::function<std::pair<tcp::socket*, uint32>()> _socketFactory;
std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
bool _supportSocketActivation;
};

View File

@@ -91,13 +91,13 @@ public:
SocketAdded(sock);
}
tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
void EnableProxyProtocol() { _proxyHeaderReadingEnabled = true; }
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { }
void AddNewSockets()
{
@@ -229,7 +229,7 @@ private:
SocketContainer _newSockets;
Acore::Asio::IoContext _ioContext;
tcp::socket _acceptSocket;
IoContextTcpSocket _acceptSocket;
boost::asio::steady_timer _updateTimer;
bool _proxyHeaderReadingEnabled;

View File

@@ -21,9 +21,8 @@
#include "Log.h"
#include "MessageBuffer.h"
#include <atomic>
#include <boost/asio.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <functional>
#include <memory>
#include <queue>
#include <type_traits>
@@ -35,6 +34,23 @@ using boost::asio::ip::tcp;
#define AC_SOCKET_USE_IOCP
#endif
// Specialize boost socket for io_context executor instead of type-erased any_io_executor
// This avoids the type-erasure overhead of any_io_executor
using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
enum class SocketReadCallbackResult
{
KeepReading,
Stop
};
enum class SocketState : uint8
{
Open = 0,
Closing = 1,
Closed = 2
};
enum ProxyHeaderReadingState {
PROXY_HEADER_READING_STATE_NOT_STARTED,
PROXY_HEADER_READING_STATE_STARTED,
@@ -51,8 +67,8 @@ template<class T>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false),
explicit Socket(IoContextTcpSocket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _state(SocketState::Open), _isWritingAsync(false),
_proxyHeaderReadingState(PROXY_HEADER_READING_STATE_NOT_STARTED)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
@@ -60,7 +76,7 @@ public:
virtual ~Socket()
{
_closed = true;
_state = SocketState::Closed;
boost::system::error_code error;
_socket.close(error);
}
@@ -69,13 +85,14 @@ public:
virtual bool Update()
{
if (_closed)
SocketState state = _state.load();
if (state == SocketState::Closed)
{
return false;
}
#ifndef AC_SOCKET_USE_IOCP
if (_isWritingAsync || (_writeQueue.empty() && !_closing))
if (_isWritingAsync || (_writeQueue.empty() && state != SocketState::Closing))
{
return true;
}
@@ -150,12 +167,18 @@ public:
[[nodiscard]] ProxyHeaderReadingState GetProxyHeaderReadingState() const { return _proxyHeaderReadingState; }
[[nodiscard]] bool IsOpen() const { return !_closed && !_closing; }
[[nodiscard]] bool IsOpen() const { return _state.load() == SocketState::Open; }
void CloseSocket()
{
if (_closed.exchange(true))
return;
SocketState expected = SocketState::Open;
if (!_state.compare_exchange_strong(expected, SocketState::Closed))
{
// If it was Closing, try to transition to Closed
expected = SocketState::Closing;
if (!_state.compare_exchange_strong(expected, SocketState::Closed))
return; // Already closed
}
boost::system::error_code shutdownError;
_socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
@@ -168,13 +191,17 @@ public:
}
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket() { _closing = true; }
void DelayedCloseSocket()
{
SocketState expected = SocketState::Open;
_state.compare_exchange_strong(expected, SocketState::Closing);
}
MessageBuffer& GetReadBuffer() { return _readBuffer; }
protected:
virtual void OnClose() { }
virtual void ReadHandler() = 0;
virtual SocketReadCallbackResult ReadHandler() = 0;
bool AsyncProcessQueue()
{
@@ -188,7 +215,7 @@ protected:
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler,
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
#else
_socket.async_wait(tcp::socket::wait_write, [self = this->shared_from_this()](boost::system::error_code error)
_socket.async_wait(boost::asio::socket_base::wait_write, [self = this->shared_from_this()](boost::system::error_code error)
{
self->WriteHandlerWrapper(error, 0);
});
@@ -216,7 +243,8 @@ private:
}
_readBuffer.WriteCompleted(transferredBytes);
ReadHandler();
if (ReadHandler() == SocketReadCallbackResult::KeepReading)
AsyncRead();
}
// ProxyReadHeaderHandler reads Proxy Protocol v2 header (v1 is not supported).
@@ -344,7 +372,7 @@ private:
if (!_writeQueue.empty())
AsyncProcessQueue();
else if (_closing)
else if (_state.load() == SocketState::Closing)
CloseSocket();
}
else
@@ -380,7 +408,7 @@ private:
_writeQueue.pop();
if (_closing && _writeQueue.empty())
if (_state.load() == SocketState::Closing && _writeQueue.empty())
{
CloseSocket();
}
@@ -391,7 +419,7 @@ private:
{
_writeQueue.pop();
if (_closing && _writeQueue.empty())
if (_state.load() == SocketState::Closing && _writeQueue.empty())
{
CloseSocket();
}
@@ -406,7 +434,7 @@ private:
_writeQueue.pop();
if (_closing && _writeQueue.empty())
if (_state.load() == SocketState::Closing && _writeQueue.empty())
{
CloseSocket();
}
@@ -415,7 +443,7 @@ private:
}
#endif
tcp::socket _socket;
IoContextTcpSocket _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort;
@@ -423,8 +451,7 @@ private:
MessageBuffer _readBuffer;
std::queue<MessageBuffer> _writeQueue;
std::atomic<bool> _closed;
std::atomic<bool> _closing;
std::atomic<SocketState> _state;
bool _isWritingAsync;

View File

@@ -91,7 +91,7 @@ public:
_threads[i].Wait();
}
virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
{
try
{
@@ -117,7 +117,7 @@ public:
return min;
}
std::pair<tcp::socket*, uint32> GetSocketForAccept()
std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return { _threads[threadIndex].GetSocketForAccept(), threadIndex };