// Copyright Epic Games, Inc. All Rights Reserved. #include "zenhttp/security/passwordsecurity.h" #include #include #include #if ZEN_WITH_TESTS # include # include #endif // ZEN_WITH_TESTS namespace zen { using namespace std::literals; PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config) { m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size()); for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++) { const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index]; if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) { throw std::runtime_error(fmt::format( "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')", Index + 1, UnprotectedUri, Result.first->second + 1, m_Config.UnprotectedUris[Result.first->second])); } } } bool PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const { if (!m_Config.UnprotectedUris.empty()) { uint32_t UriHash = HashStringDjb2(std::array{BaseUri, RelativeUri}); if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end()) { const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second]; if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length()) { if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri) { return true; } } } } return false; } bool PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest) { if (IsUnprotectedUri(BaseUri, RelativeUri)) { return true; } if (!ProtectMachineLocalRequests() && IsMachineLocalRequest) { return true; } if (Password().empty()) { return true; } if (Password() == InPassword) { return true; } return false; } #if ZEN_WITH_TESTS TEST_CASE("passwordsecurity.allowanything") { PasswordSecurity Anything({}); CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); } TEST_CASE("passwordsecurity.allowalllocal") { PasswordSecurity AllLocal({.Password = "123456"}); CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); } TEST_CASE("passwordsecurity.allowonlypassword") { PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true}); CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.allowsomeexternaluris") { PasswordSecurity AllLocal( {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector({"/free/access", "/ok"})}); CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.allowsomelocaluris") { PasswordSecurity AllLocal( {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector({"/free/access", "/ok"})}); CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.conflictingunprotecteduris") { try { PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector({"/free/access", "/free/access"})}); CHECK(false); } catch (const std::runtime_error& Ex) { CHECK_EQ(Ex.what(), std::string("password security unprotected uris does not generate unique hashes. Uri #2 ('/free/access') collides with " "uri #1 ('/free/access')")); } } void passwordsecurity_forcelink() { } #endif // ZEN_WITH_TESTS } // namespace zen