aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients/httpwsclient.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/clients/httpwsclient.cpp')
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp132
1 files changed, 123 insertions, 9 deletions
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 4337fcb79..842bf9d49 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -23,6 +23,8 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
+using namespace std::literals;
+
//////////////////////////////////////////////////////////////////////////
struct HttpWsClient::Impl
@@ -271,6 +273,43 @@ struct HttpWsClient::Impl
});
}
+ // Trim ASCII LWS (space / tab) from both ends of a header value, along with
+ // a trailing CR if the caller didn't strip it.
+ static std::string_view TrimHeaderValue(std::string_view V)
+ {
+ while (!V.empty() && (V.front() == ' ' || V.front() == '\t'))
+ {
+ V.remove_prefix(1);
+ }
+ while (!V.empty() && (V.back() == ' ' || V.back() == '\t' || V.back() == '\r'))
+ {
+ V.remove_suffix(1);
+ }
+ return V;
+ }
+
+ // Return true if a comma-separated header value contains the given token,
+ // case-insensitively. Used for Connection header parsing where the value
+ // may legitimately be "Upgrade, keep-alive" etc.
+ static bool HeaderContainsToken(std::string_view Value, std::string_view Token)
+ {
+ while (!Value.empty())
+ {
+ auto CommaPos = Value.find(',');
+ std::string_view Part = TrimHeaderValue(Value.substr(0, CommaPos));
+ if (Part.size() == Token.size() && StrCaseCompare(Part, Token) == 0)
+ {
+ return true;
+ }
+ if (CommaPos == std::string_view::npos)
+ {
+ break;
+ }
+ Value.remove_prefix(CommaPos + 1);
+ }
+ return false;
+ }
+
void DoReadHandshakeResponse()
{
WithSocket([this](auto& Socket) {
@@ -284,30 +323,105 @@ struct HttpWsClient::Impl
return;
}
- // Parse the response
const auto& Data = m_ReadBuffer.data();
std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
// Consume the headers from the read buffer (any extra data stays for frame parsing)
auto HeaderEnd = Response.find("\r\n\r\n");
- if (HeaderEnd != std::string::npos)
+ if (HeaderEnd == std::string::npos)
{
- m_ReadBuffer.consume(HeaderEnd + 4);
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: incomplete headers");
+ m_Handler.OnWsClose(1006, "handshake incomplete");
+ return;
}
+ m_ReadBuffer.consume(HeaderEnd + 4);
+
+ // Parse the status line. Substring matching on "101" anywhere
+ // in the response is unsafe โ€” a server returning
+ // "HTTP/1.1 404 Not Found\r\nX-Retry-After: 101\r\n" would have
+ // satisfied it. We require the first line to start with
+ // "HTTP/1.x 101" followed by end-of-line or space.
+ //
+ // ResponseView spans up through the first "\r\n" of the
+ // terminating "\r\n\r\n" so that every header line โ€” including
+ // the last one โ€” is terminated by "\r\n" in the view.
+ std::string_view ResponseView(Response.data(), HeaderEnd + 2);
+ auto StatusLineEnd = ResponseView.find("\r\n");
+ if (StatusLineEnd == std::string_view::npos)
+ {
+ m_Handler.OnWsClose(1006, "handshake malformed");
+ return;
+ }
+ std::string_view StatusLine = ResponseView.substr(0, StatusLineEnd);
- // Validate 101 response
- if (Response.find("101") == std::string::npos)
+ // Expect: "HTTP/1.x 101" (12 chars min), with 'x' being '0' or '1'.
+ bool StatusOk = StatusLine.size() >= 12 && StatusLine.substr(0, 7) == "HTTP/1." &&
+ (StatusLine[7] == '0' || StatusLine[7] == '1') && StatusLine[8] == ' ' &&
+ StatusLine.substr(9, 3) == "101" && (StatusLine.size() == 12 || StatusLine[12] == ' ');
+ if (!StatusOk)
{
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected; status line: {}", StatusLine.substr(0, 80));
m_Handler.OnWsClose(1006, "handshake rejected");
return;
}
- // Validate Sec-WebSocket-Accept
+ // Parse headers and extract the three fields RFC 6455 ยง4.1
+ // requires a client to validate: Upgrade, Connection, and
+ // Sec-WebSocket-Accept. Case-insensitive on header names and
+ // on the Upgrade / Connection token values; exact-match on the
+ // Sec-WebSocket-Accept base64 value.
+ bool UpgradeOk = false;
+ bool ConnectionOk = false;
+ std::string_view AcceptValue;
+
+ std::string_view HeaderBlock = ResponseView.substr(StatusLineEnd + 2);
+ while (!HeaderBlock.empty())
+ {
+ auto NextLineEnd = HeaderBlock.find("\r\n");
+ if (NextLineEnd == std::string_view::npos)
+ {
+ break;
+ }
+ std::string_view Line = HeaderBlock.substr(0, NextLineEnd);
+ HeaderBlock = HeaderBlock.substr(NextLineEnd + 2);
+ if (Line.empty())
+ {
+ break;
+ }
+
+ auto ColonPos = Line.find(':');
+ if (ColonPos == std::string_view::npos)
+ {
+ continue;
+ }
+ std::string_view Name = Line.substr(0, ColonPos);
+ std::string_view Value = TrimHeaderValue(Line.substr(ColonPos + 1));
+
+ if (Name.size() == 7 && StrCaseCompare(Name, "Upgrade"sv) == 0)
+ {
+ UpgradeOk = (Value.size() == 9 && StrCaseCompare(Value, "websocket"sv) == 0);
+ }
+ else if (Name.size() == 10 && StrCaseCompare(Name, "Connection"sv) == 0)
+ {
+ ConnectionOk = HeaderContainsToken(Value, "upgrade"sv);
+ }
+ else if (Name.size() == 20 && StrCaseCompare(Name, "Sec-WebSocket-Accept"sv) == 0)
+ {
+ AcceptValue = Value;
+ }
+ }
+
+ if (!UpgradeOk || !ConnectionOk)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake missing required Upgrade/Connection headers");
+ m_Handler.OnWsClose(1006, "handshake missing headers");
+ return;
+ }
+
std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
- if (Response.find(ExpectedAccept) == std::string::npos)
+ if (AcceptValue.size() != ExpectedAccept.size() || AcceptValue != ExpectedAccept)
{
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid or missing Sec-WebSocket-Accept");
m_Handler.OnWsClose(1006, "invalid accept key");
return;
}