Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions src/webserver/DOSGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,24 @@ class BaseDOSGuard
template <typename SweepHandler>
class BasicDOSGuard : public BaseDOSGuard
{
mutable std::mutex mtx_; // protects ipFetchCount_
std::unordered_map<std::string, std::uint32_t> ipFetchCount_;
// Accumulated state per IP, state will be reset accordingly
struct ClientState
{
// accumulated transfered byte
std::uint32_t transferedByte = 0;
// accumulated served requests count
std::uint32_t requestsCount = 0;
};

mutable std::mutex mtx_;
// accumulated states map
std::unordered_map<std::string, ClientState> ipState_;
std::unordered_map<std::string, std::uint32_t> ipConnCount_;
std::unordered_set<std::string> const whitelist_;

std::uint32_t const maxFetches_;
std::uint32_t const maxConnCount_;
std::uint32_t const maxRequestCount_;
clio::Logger log_{"RPC"};

public:
Expand All @@ -68,6 +79,7 @@ class BasicDOSGuard : public BaseDOSGuard
: whitelist_{getWhitelist(config)}
, maxFetches_{config.valueOr("dos_guard.max_fetches", 100000000u)}
, maxConnCount_{config.valueOr("dos_guard.max_connections", 1u)}
, maxRequestCount_{config.valueOr("dos_guard.max_requests", 10u)}
{
sweepHandler.setup(this);
}
Expand Down Expand Up @@ -98,25 +110,33 @@ class BasicDOSGuard : public BaseDOSGuard
if (whitelist_.contains(ip))
return true;

std::unique_lock lck(mtx_);
bool fetchesOk = true;
bool connsOk = true;
{
auto it = ipFetchCount_.find(ip);
if (it != ipFetchCount_.end())
fetchesOk = it->second <= maxFetches_;
}
{
std::unique_lock lck(mtx_);
if (ipState_.find(ip) != ipState_.end())
{
auto [transferedByte, requests] = ipState_.at(ip);
if (transferedByte > maxFetches_ || requests > maxRequestCount_)
{
log_.warn()
<< "Dosguard:Client surpassed the rate limit. ip = "
<< ip << " Transfered Byte:" << transferedByte
<< " Requests:" << requests;
return false;
}
}
auto it = ipConnCount_.find(ip);
if (it != ipConnCount_.end())
{
connsOk = it->second <= maxConnCount_;
if (it->second > maxConnCount_)
{
log_.warn()
<< "Dosguard:Client surpassed the rate limit. ip = "
<< ip << " Concurrent connection:" << it->second;
return false;
}
}
}
if (!fetchesOk || !connsOk)
log_.warn() << "Client surpassed the rate limit. ip = " << ip;

return fetchesOk && connsOk;
return true;
}

/**
Expand Down Expand Up @@ -170,11 +190,32 @@ class BasicDOSGuard : public BaseDOSGuard

{
std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip);
if (it == ipFetchCount_.end())
ipFetchCount_[ip] = numObjects;
else
it->second += numObjects;
ipState_[ip].transferedByte += numObjects;
}

return isOk(ip);
}

/**
* @brief Adds one request for the given ip address.
*
* If the total sums up to a value equal or larger than maxRequestCount_
* the operation is no longer allowed and false is returned; true is
* returned otherwise.
*
* @param ip
* @return true
* @return false
*/
[[maybe_unused]] bool
request(std::string const& ip) noexcept
{
if (whitelist_.contains(ip))
return true;

{
std::unique_lock lck(mtx_);
ipState_[ip].requestsCount++;
}

return isOk(ip);
Expand All @@ -188,7 +229,7 @@ class BasicDOSGuard : public BaseDOSGuard
clear() noexcept override
{
std::unique_lock lck(mtx_);
ipFetchCount_.clear();
ipState_.clear();
}

private:
Expand Down
57 changes: 36 additions & 21 deletions src/webserver/HttpBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,27 @@ class HttpBase : public util::Taggable
if (ec)
return httpFail(ec, "read");

auto ip = derived().ip();

if (!ip)
{
return;
}

auto const httpResponse = [&](http::status status,
std::string content_type,
std::string message) {
http::response<http::string_body> res{status, req_.version()};
res.set(
http::field::server,
"clio-server-" + Build::getClioVersionString());
res.set(http::field::content_type, content_type);
res.keep_alive(req_.keep_alive());
res.body() = std::string(message);
res.prepare_payload();
return res;
};

if (boost::beast::websocket::is_upgrade(req_))
{
upgraded_ = true;
Expand All @@ -260,10 +281,16 @@ class HttpBase : public util::Taggable
workQueue_);
}

auto ip = derived().ip();

if (!ip)
return;
// to avoid overwhelm work queue, the request limit check should be
// before posting to queue the web socket creation will be guarded via
// connection limit
if (!dosGuard_.request(ip.value()))
{
return lambda_(httpResponse(
http::status::service_unavailable,
"text/plain",
"Server is overloaded"));
}

perfLog_.debug() << tag() << "Received request from ip = " << *ip
<< " - posting to WorkQueue";
Expand Down Expand Up @@ -293,17 +320,11 @@ class HttpBase : public util::Taggable
{
// Non-whitelist connection rejected due to full connection
// queue
http::response<http::string_body> res{
http::status::ok, req_.version()};
res.set(
http::field::server,
"clio-server-" + Build::getClioVersionString());
res.set(http::field::content_type, "application/json");
res.keep_alive(req_.keep_alive());
res.body() = boost::json::serialize(
RPC::makeError(RPC::RippledError::rpcTOO_BUSY));
res.prepare_payload();
lambda_(std::move(res));
lambda_(httpResponse(
http::status::ok,
"application/json",
boost::json::serialize(
RPC::makeError(RPC::RippledError::rpcTOO_BUSY))));
}
}

Expand Down Expand Up @@ -380,12 +401,6 @@ handle_request(
return send(httpResponse(
http::status::bad_request, "text/html", "Expected a POST request"));

if (!dosGuard.isOk(ip))
return send(httpResponse(
http::status::service_unavailable,
"text/plain",
"Server is overloaded"));

try
{
perfLog.debug() << http->tag()
Expand Down
21 changes: 12 additions & 9 deletions src/webserver/WsBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,19 +452,22 @@ class WsSession : public WsBase,
}(std::move(msg));

boost::json::object request;
if (!raw.is_object())
return sendError(
RPC::RippledError::rpcINVALID_PARAMS, nullptr, request);
request = raw.as_object();

auto id = request.contains("id") ? request.at("id") : nullptr;

if (!dosGuard_.isOk(*ip))
// dosGuard served request++ and check ip address
// dosGuard should check before any request, even invalid request
if (!dosGuard_.request(*ip))
{
sendError(RPC::RippledError::rpcSLOW_DOWN, id, request);
sendError(RPC::RippledError::rpcSLOW_DOWN, nullptr, request);
}
else if (!raw.is_object())
{
// handle invalid request and async read again
sendError(RPC::RippledError::rpcINVALID_PARAMS, nullptr, request);
}
else
{
request = raw.as_object();

auto id = request.contains("id") ? request.at("id") : nullptr;
perfLog_.debug() << tag() << "Adding to work queue";

if (!queue_.postCoro(
Expand Down
25 changes: 25 additions & 0 deletions unittests/DOSGuard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ constexpr static auto JSONData = R"JSON(
"max_fetches": 100,
"sweep_interval": 1,
"max_connections": 2,
"max_requests": 3,
"whitelist": ["127.0.0.1"]
}
}
Expand Down Expand Up @@ -126,6 +127,30 @@ TEST_F(DOSGuardTest, ClearFetchCountOnTimer)
EXPECT_TRUE(guard.isOk(IP)); // can fetch again
}

TEST_F(DOSGuardTest, RequestLimit)
{
EXPECT_TRUE(guard.request(IP));
EXPECT_TRUE(guard.request(IP));
EXPECT_TRUE(guard.request(IP));
EXPECT_TRUE(guard.isOk(IP));
EXPECT_FALSE(guard.request(IP));
EXPECT_FALSE(guard.isOk(IP));
guard.clear();
EXPECT_TRUE(guard.isOk(IP)); // can request again
}

TEST_F(DOSGuardTest, RequestLimitOnTimer)
{
EXPECT_TRUE(guard.request(IP));
EXPECT_TRUE(guard.request(IP));
EXPECT_TRUE(guard.request(IP));
EXPECT_TRUE(guard.isOk(IP));
EXPECT_FALSE(guard.request(IP));
EXPECT_FALSE(guard.isOk(IP));
sweepHandler.sweep();
EXPECT_TRUE(guard.isOk(IP)); // can request again
}

template <typename SweepHandler>
struct BasicDOSGuardMock : public BaseDOSGuard
{
Expand Down