diff options
Diffstat (limited to 'src/zenhttp/clients/httpwsclient.cpp')
| -rw-r--r-- | src/zenhttp/clients/httpwsclient.cpp | 132 |
1 files changed, 123 insertions, 9 deletions
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 4337fcb79..842bf9d49 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -23,6 +23,8 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { +using namespace std::literals; + ////////////////////////////////////////////////////////////////////////// struct HttpWsClient::Impl @@ -271,6 +273,43 @@ struct HttpWsClient::Impl }); } + // Trim ASCII LWS (space / tab) from both ends of a header value, along with + // a trailing CR if the caller didn't strip it. + static std::string_view TrimHeaderValue(std::string_view V) + { + while (!V.empty() && (V.front() == ' ' || V.front() == '\t')) + { + V.remove_prefix(1); + } + while (!V.empty() && (V.back() == ' ' || V.back() == '\t' || V.back() == '\r')) + { + V.remove_suffix(1); + } + return V; + } + + // Return true if a comma-separated header value contains the given token, + // case-insensitively. Used for Connection header parsing where the value + // may legitimately be "Upgrade, keep-alive" etc. + static bool HeaderContainsToken(std::string_view Value, std::string_view Token) + { + while (!Value.empty()) + { + auto CommaPos = Value.find(','); + std::string_view Part = TrimHeaderValue(Value.substr(0, CommaPos)); + if (Part.size() == Token.size() && StrCaseCompare(Part, Token) == 0) + { + return true; + } + if (CommaPos == std::string_view::npos) + { + break; + } + Value.remove_prefix(CommaPos + 1); + } + return false; + } + void DoReadHandshakeResponse() { WithSocket([this](auto& Socket) { @@ -284,30 +323,105 @@ struct HttpWsClient::Impl return; } - // Parse the response const auto& Data = m_ReadBuffer.data(); std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); // Consume the headers from the read buffer (any extra data stays for frame parsing) auto HeaderEnd = Response.find("\r\n\r\n"); - if (HeaderEnd != std::string::npos) + if (HeaderEnd == std::string::npos) { - m_ReadBuffer.consume(HeaderEnd + 4); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: incomplete headers"); + m_Handler.OnWsClose(1006, "handshake incomplete"); + return; } + m_ReadBuffer.consume(HeaderEnd + 4); + + // Parse the status line. Substring matching on "101" anywhere + // in the response is unsafe โ a server returning + // "HTTP/1.1 404 Not Found\r\nX-Retry-After: 101\r\n" would have + // satisfied it. We require the first line to start with + // "HTTP/1.x 101" followed by end-of-line or space. + // + // ResponseView spans up through the first "\r\n" of the + // terminating "\r\n\r\n" so that every header line โ including + // the last one โ is terminated by "\r\n" in the view. + std::string_view ResponseView(Response.data(), HeaderEnd + 2); + auto StatusLineEnd = ResponseView.find("\r\n"); + if (StatusLineEnd == std::string_view::npos) + { + m_Handler.OnWsClose(1006, "handshake malformed"); + return; + } + std::string_view StatusLine = ResponseView.substr(0, StatusLineEnd); - // Validate 101 response - if (Response.find("101") == std::string::npos) + // Expect: "HTTP/1.x 101" (12 chars min), with 'x' being '0' or '1'. + bool StatusOk = StatusLine.size() >= 12 && StatusLine.substr(0, 7) == "HTTP/1." && + (StatusLine[7] == '0' || StatusLine[7] == '1') && StatusLine[8] == ' ' && + StatusLine.substr(9, 3) == "101" && (StatusLine.size() == 12 || StatusLine[12] == ' '); + if (!StatusOk) { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected; status line: {}", StatusLine.substr(0, 80)); m_Handler.OnWsClose(1006, "handshake rejected"); return; } - // Validate Sec-WebSocket-Accept + // Parse headers and extract the three fields RFC 6455 ยง4.1 + // requires a client to validate: Upgrade, Connection, and + // Sec-WebSocket-Accept. Case-insensitive on header names and + // on the Upgrade / Connection token values; exact-match on the + // Sec-WebSocket-Accept base64 value. + bool UpgradeOk = false; + bool ConnectionOk = false; + std::string_view AcceptValue; + + std::string_view HeaderBlock = ResponseView.substr(StatusLineEnd + 2); + while (!HeaderBlock.empty()) + { + auto NextLineEnd = HeaderBlock.find("\r\n"); + if (NextLineEnd == std::string_view::npos) + { + break; + } + std::string_view Line = HeaderBlock.substr(0, NextLineEnd); + HeaderBlock = HeaderBlock.substr(NextLineEnd + 2); + if (Line.empty()) + { + break; + } + + auto ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + continue; + } + std::string_view Name = Line.substr(0, ColonPos); + std::string_view Value = TrimHeaderValue(Line.substr(ColonPos + 1)); + + if (Name.size() == 7 && StrCaseCompare(Name, "Upgrade"sv) == 0) + { + UpgradeOk = (Value.size() == 9 && StrCaseCompare(Value, "websocket"sv) == 0); + } + else if (Name.size() == 10 && StrCaseCompare(Name, "Connection"sv) == 0) + { + ConnectionOk = HeaderContainsToken(Value, "upgrade"sv); + } + else if (Name.size() == 20 && StrCaseCompare(Name, "Sec-WebSocket-Accept"sv) == 0) + { + AcceptValue = Value; + } + } + + if (!UpgradeOk || !ConnectionOk) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake missing required Upgrade/Connection headers"); + m_Handler.OnWsClose(1006, "handshake missing headers"); + return; + } + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); - if (Response.find(ExpectedAccept) == std::string::npos) + if (AcceptValue.size() != ExpectedAccept.size() || AcceptValue != ExpectedAccept) { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid or missing Sec-WebSocket-Accept"); m_Handler.OnWsClose(1006, "invalid accept key"); return; } |