feat(Core/Network): Add Proxy Protocol v2 support. (#18839)

* feat(Core/Network): Add Proxy Protocol v2 support.

* Fix codestyle and build.

* Another codestyle fix.

* One more missing include.
This commit is contained in:
Anton Popovichenko
2024-05-04 18:38:32 +02:00
committed by GitHub
parent 715b290cb7
commit 9815025341
7 changed files with 247 additions and 12 deletions

View File

@@ -20,6 +20,7 @@
#include "DeadlineTimer.h"
#include "Define.h"
#include "Socket.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
@@ -39,7 +40,7 @@ class NetworkThread
{
public:
NetworkThread() :
_ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext) { }
_ioContext(1), _acceptSocket(_ioContext), _updateTimer(_ioContext), _proxyHeaderReadingEnabled(false) { }
virtual ~NetworkThread()
{
@@ -94,6 +95,8 @@ public:
tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
void EnableProxyProtocol() { _proxyHeaderReadingEnabled = true; }
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
@@ -105,20 +108,73 @@ protected:
if (_newSockets.empty())
return;
for (std::shared_ptr<SocketType> sock : _newSockets)
if (!_proxyHeaderReadingEnabled)
{
for (std::shared_ptr<SocketType> sock : _newSockets)
{
if (!sock->IsOpen())
{
SocketRemoved(sock);
--_connections;
continue;
}
_sockets.emplace_back(sock);
sock->Start();
}
_newSockets.clear();
}
else
{
HandleNewSocketsProxyReadingOnConnect();
}
}
void HandleNewSocketsProxyReadingOnConnect()
{
size_t index = 0;
std::vector<int> newSocketsToRemoveIndexes;
for (auto sock_iter = _newSockets.begin(); sock_iter != _newSockets.end(); ++sock_iter, ++index)
{
std::shared_ptr<SocketType> sock = *sock_iter;
if (!sock->IsOpen())
{
newSocketsToRemoveIndexes.emplace_back(index);
SocketRemoved(sock);
--_connections;
continue;
}
else
{
_sockets.emplace_back(sock);
const auto proxyHeaderReadingState = sock->GetProxyHeaderReadingState();
if (proxyHeaderReadingState == PROXY_HEADER_READING_STATE_STARTED)
continue;
switch (proxyHeaderReadingState) {
case PROXY_HEADER_READING_STATE_NOT_STARTED:
sock->AsyncReadProxyHeader();
break;
case PROXY_HEADER_READING_STATE_FINISHED:
newSocketsToRemoveIndexes.emplace_back(index);
_sockets.emplace_back(sock);
sock->Start();
break;
default:
newSocketsToRemoveIndexes.emplace_back(index);
SocketRemoved(sock);
--_connections;
break;
}
}
_newSockets.clear();
for (int removeIndex : newSocketsToRemoveIndexes)
_newSockets.erase(_newSockets.begin() + removeIndex);
}
void Run()
@@ -177,6 +233,8 @@ private:
Acore::Asio::IoContext _ioContext;
tcp::socket _acceptSocket;
Acore::Asio::DeadlineTimer _updateTimer;
bool _proxyHeaderReadingEnabled;
};
#endif // NetworkThread_h__

View File

@@ -22,6 +22,7 @@
#include "MessageBuffer.h"
#include <atomic>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio.hpp>
#include <functional>
#include <memory>
#include <queue>
@@ -34,12 +35,25 @@ using boost::asio::ip::tcp;
#define AC_SOCKET_USE_IOCP
#endif
enum ProxyHeaderReadingState {
PROXY_HEADER_READING_STATE_NOT_STARTED,
PROXY_HEADER_READING_STATE_STARTED,
PROXY_HEADER_READING_STATE_FINISHED,
PROXY_HEADER_READING_STATE_FAILED,
};
enum ProxyHeaderAddressFamilyAndProtocol {
PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4 = 0x11,
PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6 = 0x21,
};
template<class T>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false),
_proxyHeaderReadingState(PROXY_HEADER_READING_STATE_NOT_STARTED)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
}
@@ -92,11 +106,25 @@ public:
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
void AsyncReadProxyHeader()
{
if (!IsOpen())
{
return;
}
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_STARTED;
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
std::bind(&Socket<T>::ProxyReadHeaderHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
{
if (!IsOpen())
@@ -120,6 +148,8 @@ public:
#endif
}
[[nodiscard]] ProxyHeaderReadingState GetProxyHeaderReadingState() const { return _proxyHeaderReadingState; }
[[nodiscard]] bool IsOpen() const { return !_closed && !_closing; }
void CloseSocket()
@@ -187,6 +217,118 @@ private:
ReadHandler();
}
// ProxyReadHeaderHandler reads Proxy Protocol v2 header (v1 is not supported).
// See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt (2.2. Binary header format (version 2)) for more details.
void ProxyReadHeaderHandler(boost::system::error_code error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return;
}
_readBuffer.WriteCompleted(transferredBytes);
MessageBuffer& packet = GetReadBuffer();
const int minimumProxyProtocolV2Size = 28;
if (packet.GetActiveSize() < minimumProxyProtocolV2Size)
{
AsyncReadProxyHeader();
return;
}
uint8* readPointer = packet.GetReadPointer();
const uint8 signatureSize = 12;
const uint8 expectedSignature[signatureSize] = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
if (memcmp(packet.GetReadPointer(), expectedSignature, signatureSize) != 0)
{
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string());
return;
}
const uint8 version = (readPointer[signatureSize] & 0xF0) >> 4;
const uint8 command = (readPointer[signatureSize] & 0xF);
if (version != 2)
{
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string());
return;
}
const uint8 addressFamily = readPointer[13];
const uint16 len = (readPointer[14] << 8) | readPointer[15];
if (len+16 > packet.GetActiveSize())
{
AsyncReadProxyHeader();
return;
}
// Connection created by a proxy itself (health checks?), ignore and do nothing.
if (command == 0)
{
packet.ReadCompleted(len+16);
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED;
return;
}
auto remainingLen = packet.GetActiveSize() - 16;
readPointer += 16; // Skip strait to address.
switch (addressFamily) {
case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4:
{
if (remainingLen < 12)
{
AsyncReadProxyHeader();
return;
}
boost::asio::ip::address_v4::bytes_type b;
auto addressSize = sizeof(b);
std::copy(readPointer, readPointer+addressSize, b.begin());
_remoteAddress = boost::asio::ip::address_v4(b);
readPointer += 2 * addressSize; // Skip server address.
_remotePort = (readPointer[0] << 8) | readPointer[1];
break;
}
case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6:
{
if (remainingLen < 36)
{
AsyncReadProxyHeader();
return;
}
boost::asio::ip::address_v6::bytes_type b;
auto addressSize = sizeof(b);
std::copy(readPointer, readPointer+addressSize, b.begin());
_remoteAddress = boost::asio::ip::address_v6(b);
readPointer += 2 * addressSize; // Skip server address.
_remotePort = (readPointer[0] << 8) | readPointer[1];
break;
}
default:
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED;
LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: unsupported address family type {}", GetRemoteIpAddress().to_string());
return;
}
packet.ReadCompleted(len+16);
_proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED;
}
#ifdef AC_SOCKET_USE_IOCP
void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
{
@@ -283,6 +425,8 @@ private:
std::atomic<bool> _closing;
bool _isWritingAsync;
ProxyHeaderReadingState _proxyHeaderReadingState;
};
#endif // __SOCKET_H__

View File

@@ -94,8 +94,6 @@ public:
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();
_threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)