// Copyright Epic Games, Inc. All Rights Reserved. #include "zenhttp/auth/authmgr.h" #include #include #include #include #include #include #include #include #include #if ZEN_WITH_TESTS # include # include #endif #include #include #include #include #include #include namespace zen { using namespace std::literals; namespace details { // On-disk format for the GCM-encrypted authstate: // // [ 4 bytes: magic "ZEN1" ] // [12 bytes: GCM nonce ] — fresh random value per write // [ N bytes: ciphertext ] — same length as the plaintext // [16 bytes: GCM tag ] // // The magic is bound into the authentication by passing it as AAD to // AesGcm; that ties the tag to the exact header bytes, so an attacker // cannot strip the format indicator or swap to a different format without // failing authentication. // // The magic also doubles as the format version. If the framing ever // needs to change (e.g. adding a key-identifier for rotation), bump to // "ZEN2" and teach the reader to accept both. constexpr std::string_view kAuthStateMagic = "ZEN1"; constexpr size_t kAuthStateHeaderSz = 4 + AesGcm::NonceSize; IoBuffer ReadEncryptedFile(std::filesystem::path Path, const AesKey256Bit& Key, const AesIV128Bit& LegacyIV, std::optional& Reason, bool* OutWasLegacy = nullptr) { ZEN_TRACE_CPU("AuthMgr::ReadEncryptedFile"); if (OutWasLegacy) { *OutWasLegacy = false; } FileContents Result = ReadFile(Path); if (Result.ErrorCode) { return IoBuffer(); } IoBuffer EncryptedBuffer = Result.Flatten(); if (EncryptedBuffer.GetSize() == 0) { return IoBuffer(); } // Current format: magic-prefixed AES-256-GCM if (EncryptedBuffer.GetSize() >= kAuthStateHeaderSz + AesGcm::TagSize && memcmp(EncryptedBuffer.GetData(), kAuthStateMagic.data(), kAuthStateMagic.size()) == 0) { const uint8_t* Bytes = static_cast(EncryptedBuffer.GetData()); const size_t Total = EncryptedBuffer.GetSize(); const size_t CipherSize = Total - kAuthStateHeaderSz - AesGcm::TagSize; MemoryView Magic(Bytes, kAuthStateMagic.size()); MemoryView Nonce(Bytes + kAuthStateMagic.size(), AesGcm::NonceSize); MemoryView Cipher(Bytes + kAuthStateHeaderSz, CipherSize); MemoryView Tag(Bytes + kAuthStateHeaderSz + CipherSize, AesGcm::TagSize); std::vector PlainBuffer(CipherSize); MemoryView PlainView = AesGcm::Decrypt(Key, Nonce, /*Aad=*/Magic, Cipher, Tag, MakeMutableMemoryView(PlainBuffer), Reason); if (PlainView.IsEmpty()) { return IoBuffer(); } return IoBufferBuilder::MakeCloneFromMemory(PlainView); } // Legacy format: raw AES-256-CBC with the caller-configured IV. // Decrypt it once so we don't lose the user's cached state across // the upgrade; the next SaveState will rewrite in the GCM format. std::vector DecryptionBuffer; DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize); MemoryView DecryptedView = Aes::Decrypt(Key, LegacyIV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); if (DecryptedView.IsEmpty()) { return IoBuffer(); } if (OutWasLegacy) { *OutWasLegacy = true; } return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); } void WriteEncryptedFile(std::filesystem::path Path, IoBuffer FileData, const AesKey256Bit& Key, std::optional& Reason) { ZEN_TRACE_CPU("AuthMgr::WriteEncryptedFile"); if (FileData.GetSize() == 0) { return; } if (!Key.IsValid()) { Reason = "invalid key"; return; } // Fresh nonce per write. Never reuse a (Key, Nonce) pair — GCM is // catastrophically broken under nonce reuse. uint8_t Nonce[AesGcm::NonceSize]; if (!CryptoRandom::Fill(MakeMutableMemoryView(Nonce), &Reason)) { return; } std::vector FileBuffer; FileBuffer.resize(kAuthStateHeaderSz + FileData.GetSize() + AesGcm::TagSize); // [magic][nonce][cipher][tag] memcpy(FileBuffer.data(), kAuthStateMagic.data(), kAuthStateMagic.size()); memcpy(FileBuffer.data() + kAuthStateMagic.size(), Nonce, AesGcm::NonceSize); MutableMemoryView CipherOut(FileBuffer.data() + kAuthStateHeaderSz, FileData.GetSize()); MutableMemoryView TagOut(FileBuffer.data() + kAuthStateHeaderSz + FileData.GetSize(), AesGcm::TagSize); MemoryView CipherView = AesGcm::Encrypt(Key, MakeMemoryView(Nonce), /*Aad=*/MakeMemoryView(kAuthStateMagic), FileData.GetView(), CipherOut, TagOut, Reason); if (CipherView.IsEmpty()) { return; } TemporaryFile::SafeWriteFile(Path, MakeMemoryView(FileBuffer)); } } // namespace details class AuthMgrImpl final : public AuthMgr { using Clock = std::chrono::system_clock; using TimePoint = Clock::time_point; using Seconds = std::chrono::seconds; public: AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) { LoadState(); m_BackgroundThread.Interval = Config.UpdateInterval; m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this); } virtual ~AuthMgrImpl() { Shutdown(); } virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { ZEN_TRACE_CPU("AuthMgr::AddOpenIdProvider"); if (Params.Name.empty()) { ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); return; } { std::unique_lock _(m_ProviderMutex); if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end()) { OpenIdProvider& ExistingProvider = *It->second; if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url) { ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name); return; } else { m_OpenIdProviders.erase(It); m_OpenIdTokens.erase(std::string(Params.Name)); ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name); } } } Ref Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); return; } std::string NewProviderName = std::string(Params.Name); OpenIdProvider* NewProvider = nullptr; { std::unique_lock _(m_ProviderMutex); if (m_OpenIdProviders.contains(NewProviderName)) { return; } auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique()); NewProvider = InsertResult.first->second.get(); } NewProvider->Name = std::string(Params.Name); NewProvider->Url = std::string(Params.Url); NewProvider->ClientId = std::string(Params.ClientId); NewProvider->HttpClient = std::move(Client); ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); } virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final { ZEN_TRACE_CPU("AuthMgr::AddOpenIdToken"); if (Params.ProviderName.empty()) { ZEN_WARN("trying add OpenID token with invalid provider name"); return false; } if (Params.RefreshToken.empty()) { ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'"); return false; } auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); if (RefreshResult.Ok == false) { ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); return false; } bool IsNew = false; { auto Token = OpenIdToken{.RefreshToken = RefreshResult.RefreshToken, .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; std::unique_lock _(m_TokenMutex); const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token)); IsNew = InsertResult.second; } if (IsNew) { ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); } else { ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); } return true; } virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final { std::unique_lock _(m_TokenMutex); if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) { const OpenIdToken& Token = It->second; return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime}; } return {}; } private: struct OpenIdProvider { std::string Name; std::string Url; std::string ClientId; Ref HttpClient; }; struct OpenIdToken { std::string RefreshToken; std::string AccessToken; TimePoint ExpireTime{}; }; bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); return m_OpenIdProviders.contains(std::string(ProviderName)); } OpenIdProvider GetOpenIdProvider(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); return *m_OpenIdProviders[std::string(ProviderName)]; } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken"); Ref Client = GetOpenIdProvider(ProviderName).HttpClient; if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } return Client->RefreshToken(RefreshToken); } void Shutdown() { BackgroundThread::Stop(m_BackgroundThread); SaveState(); } void LoadState() { ZEN_TRACE_CPU("AuthMgrImpl::LoadState"); try { std::optional Reason; bool WasLegacy = false; IoBuffer Buffer = details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason, &WasLegacy); if (Buffer && WasLegacy) { ZEN_INFO("authstate read via legacy AES-CBC fallback; next save will migrate to AES-GCM"); } if (!Buffer) { if (Reason) { ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value()); } return; } CbValidateError ValidationError; if (CbObject AuthState = ValidateAndReadCompactBinaryObject(std::move(Buffer), ValidationError); ValidationError != CbValidateError::None) { ZEN_WARN("load serialized state FAILED, reason '{}'", ToString(ValidationError)); return; } else { for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) { CbObjectView ProviderObj = ProviderView.AsObjectView(); std::string_view ProviderName = ProviderObj["Name"].AsString(); std::string_view Url = ProviderObj["Url"].AsString(); std::string_view ClientId = ProviderObj["ClientId"].AsString(); AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); } for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) { CbObjectView TokenObj = TokenView.AsObjectView(); std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken}); if (!Ok) { ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); } } } } catch (const std::exception& Err) { ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); { std::unique_lock _(m_ProviderMutex); m_OpenIdProviders.clear(); } { std::unique_lock _(m_TokenMutex); m_OpenIdTokens.clear(); } } } void SaveState() { ZEN_TRACE_CPU("AuthMgr::SaveState"); try { CbObjectWriter AuthState; { std::unique_lock _(m_ProviderMutex); if (m_OpenIdProviders.size() > 0) { AuthState.BeginArray("OpenIdProviders"); for (const auto& Kv : m_OpenIdProviders) { AuthState.BeginObject(); AuthState << "Name"sv << Kv.second->Name; AuthState << "Url"sv << Kv.second->Url; AuthState << "ClientId"sv << Kv.second->ClientId; AuthState.EndObject(); } AuthState.EndArray(); } } { std::unique_lock _(m_TokenMutex); AuthState.BeginArray("OpenIdTokens"); if (m_OpenIdTokens.size() > 0) { for (const auto& Kv : m_OpenIdTokens) { AuthState.BeginObject(); AuthState << "ProviderName"sv << Kv.first; AuthState << "RefreshToken"sv << Kv.second.RefreshToken; AuthState.EndObject(); } } AuthState.EndArray(); } CreateDirectories(m_Config.RootDirectory); std::optional Reason; details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer(), m_Config.EncryptionKey, Reason); if (Reason) { ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value()); } } catch (const std::exception& Err) { ZEN_WARN("serialize state FAILED, reason '{}'", Err.what()); } } void BackgroundThreadEntry() { SetCurrentThreadName("auth"); for (;;) { std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread); if (m_BackgroundThread.Running.load() == false) { break; } if (SignalStatus != std::cv_status::timeout) { continue; } { // Refresh Open ID token(s) std::vector ExpiredTokens; { std::unique_lock _(m_TokenMutex); for (const auto& Kv : m_OpenIdTokens) { const Seconds ExpiresIn = std::chrono::duration_cast(Kv.second.ExpireTime - Clock::now()); const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2); if (Expired) { ExpiredTokens.push_back(Kv); } } } if (ExpiredTokens.empty()) { continue; } ZEN_DEBUG("refreshing {} OpenID token(s)", ExpiredTokens.size()); for (const auto& Kv : ExpiredTokens) { OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken); if (RefreshResult.Ok) { ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first); auto Token = OpenIdToken{.RefreshToken = RefreshResult.RefreshToken, .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; { std::unique_lock _(m_TokenMutex); m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token)); } } else { ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason); } } } } } struct BackgroundThread { std::chrono::seconds Interval{10}; std::mutex Mutex; std::condition_variable Signal; std::atomic_bool Running{true}; std::thread Thread; static void Stop(BackgroundThread& State) { if (State.Running.load()) { State.Running.store(false); State.Signal.notify_one(); } if (State.Thread.joinable()) { State.Thread.join(); } } static std::cv_status WaitForSignal(BackgroundThread& State) { std::unique_lock Lock(State.Mutex); return State.Signal.wait_for(Lock, State.Interval); } }; using OpenIdProviderMap = std::unordered_map>; using OpenIdTokenMap = std::unordered_map; LoggerRef Log() { return m_Log; } AuthConfig m_Config; LoggerRef m_Log; BackgroundThread m_BackgroundThread; OpenIdProviderMap m_OpenIdProviders; OpenIdTokenMap m_OpenIdTokens; std::mutex m_ProviderMutex; std::shared_mutex m_TokenMutex; }; std::unique_ptr AuthMgr::Create(const AuthConfig& Config) { return std::make_unique(Config); } #if ZEN_WITH_TESTS TEST_SUITE_BEGIN("http.authmgr"); TEST_CASE("authmgr.authstate_gcm_roundtrip") { ScopedTemporaryDirectory TmpDir; std::filesystem::path Path = TmpDir.Path() / "authstate"; const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); const AesIV128Bit UnusedIv; // ignored on write, unused on GCM read const std::string_view Plain = "sensitive compact-binary payload representing cached auth state"sv; IoBuffer InBuf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Plain)); std::optional WriteReason; details::WriteEncryptedFile(Path, InBuf, Key, WriteReason); REQUIRE_FALSE(WriteReason.has_value()); std::optional ReadReason; bool WasLegacy = true; IoBuffer Out = details::ReadEncryptedFile(Path, Key, UnusedIv, ReadReason, &WasLegacy); REQUIRE_FALSE(ReadReason.has_value()); REQUIRE(Out.GetSize() == Plain.size()); CHECK(memcmp(Out.GetData(), Plain.data(), Plain.size()) == 0); CHECK_FALSE(WasLegacy); } TEST_CASE("authmgr.authstate_legacy_cbc_fallback") { // Hand-craft a legacy-format authstate file (raw AES-CBC, no magic) and // verify the reader falls back, returns the plaintext, and flags the file // as legacy so the caller can rewrite it in the new format. ScopedTemporaryDirectory TmpDir; std::filesystem::path Path = TmpDir.Path() / "authstate"; const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); const uint8_t IvBytes[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const AesIV128Bit LegacyIv = AesIV128Bit::FromMemoryView(MakeMemoryView(IvBytes)); const std::string_view Plain = "legacy authstate payload"sv; std::vector CbcBuf(Plain.size() + Aes::BlockSize); std::optional CbcReason; MemoryView CbcView = Aes::Encrypt(Key, LegacyIv, MakeMemoryView(Plain), MakeMutableMemoryView(CbcBuf), CbcReason); REQUIRE_FALSE(CbcReason.has_value()); TemporaryFile::SafeWriteFile(Path, CbcView); std::optional ReadReason; bool WasLegacy = false; IoBuffer Out = details::ReadEncryptedFile(Path, Key, LegacyIv, ReadReason, &WasLegacy); REQUIRE_FALSE(ReadReason.has_value()); REQUIRE(Out.GetSize() == Plain.size()); CHECK(memcmp(Out.GetData(), Plain.data(), Plain.size()) == 0); CHECK(WasLegacy); } TEST_CASE("authmgr.authstate_gcm_tamper_detection") { ScopedTemporaryDirectory TmpDir; std::filesystem::path Path = TmpDir.Path() / "authstate"; const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); const AesIV128Bit UnusedIv; const std::string_view Plain = "payload that should not decrypt after tampering"sv; IoBuffer InBuf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Plain)); std::optional WriteReason; details::WriteEncryptedFile(Path, InBuf, Key, WriteReason); REQUIRE_FALSE(WriteReason.has_value()); // Flip a single byte in the middle of the ciphertext region (past magic + // nonce) and verify the GCM tag check rejects the tampered file. Copy // the bytes out first and drop the FileContents / IoBuffer handles before // rewriting so the underlying file isn't still open when we overwrite it. std::vector Mutated; { FileContents FC = ReadFile(Path); IoBuffer Whole = FC.Flatten(); REQUIRE(Whole.GetSize() > 4 + AesGcm::NonceSize); Mutated.assign(static_cast(Whole.GetData()), static_cast(Whole.GetData()) + Whole.GetSize()); } Mutated[4 + AesGcm::NonceSize] ^= 0x40; TemporaryFile::SafeWriteFile(Path, MakeMemoryView(Mutated)); std::optional ReadReason; bool WasLegacy = false; IoBuffer Out = details::ReadEncryptedFile(Path, Key, UnusedIv, ReadReason, &WasLegacy); CHECK(ReadReason.has_value()); CHECK(Out.GetSize() == 0); } TEST_SUITE_END(); void authmgr_forcelink() { } #endif // ZEN_WITH_TESTS } // namespace zen