diff --git a/Telegram/SourceFiles/calls/group/calls_group_call.cpp b/Telegram/SourceFiles/calls/group/calls_group_call.cpp index 250ca5e49a..d171165f20 100644 --- a/Telegram/SourceFiles/calls/group/calls_group_call.cpp +++ b/Telegram/SourceFiles/calls/group/calls_group_call.cpp @@ -2691,14 +2691,6 @@ bool GroupCall::tryCreateController() { } }); }; - auto e2eEncryptDecrypt = Fn( - const std::vector&, - bool)>(); - if (_e2e) { - e2eEncryptDecrypt = [e2e = _e2e](const std::vector &data, bool encrypt) { - return encrypt ? e2e->encrypt(data) : e2e->decrypt(data); - }; - } tgcalls::GroupInstanceDescriptor descriptor = { .threads = tgcalls::StaticThreads::getThreads(), @@ -2785,7 +2777,7 @@ bool GroupCall::tryCreateController() { }); return result; }, - .e2eEncryptDecrypt = e2eEncryptDecrypt, + .e2eEncryptDecrypt = _e2e ? _e2e->callbackEncryptDecrypt() : nullptr, }; if (Logs::DebugEnabled()) { auto callLogFolder = cWorkingDir() + u"DebugLogs"_q; diff --git a/Telegram/SourceFiles/tde2e/tde2e_api.cpp b/Telegram/SourceFiles/tde2e/tde2e_api.cpp index b4e42a8f8d..89c6f490bf 100644 --- a/Telegram/SourceFiles/tde2e/tde2e_api.cpp +++ b/Telegram/SourceFiles/tde2e/tde2e_api.cpp @@ -222,7 +222,7 @@ void Call::apply(int subchain, const Block &last) { LOG_AND_FAIL(id.error(), CallFailure::Unknown); return; } - _id = CallId{ uint64(id.value()) }; + setId({ uint64(id.value()) }); for (auto i = 0; i != kSubChainsCount; ++i) { auto &entry = _subchains[i]; @@ -245,6 +245,16 @@ void Call::apply(int subchain, const Block &last) { _participantsSet = ParseParticipantsSet(state.value()); } +void Call::setId(CallId id) { + Expects(!_id); + + _id = id; + if (const auto raw = _guardedId.get()) { + raw->value = id; + raw->exists = true; + } +} + void Call::checkForOutboundMessages() { Expects(_id); @@ -408,26 +418,32 @@ rpl::producer Call::emojiHashValue() const { return _emojiHash.value(); } -std::vector Call::encrypt(const std::vector &data) const { - const auto result = tde2e_api::call_encrypt(libId(), Slice(data)); - if (!result.is_ok()) { - return {}; +auto Call::callbackEncryptDecrypt() +-> Fn(const std::vector&, bool)> { + if (!_guardedId) { + _guardedId = std::make_shared(); + if (const auto raw = _id ? _guardedId.get() : nullptr) { + raw->value = _id; + raw->exists = true; + } } - const auto &value = result.value(); - const auto start = reinterpret_cast(value.data()); - const auto end = start + value.size(); - return std::vector{ start, end }; -} - -std::vector Call::decrypt(const std::vector &data) const { - const auto result = tde2e_api::call_decrypt(libId(), Slice(data)); - if (!result.is_ok()) { - return {}; - } - const auto &value = result.value(); - const auto start = reinterpret_cast(value.data()); - const auto end = start + value.size(); - return std::vector{ start, end }; + return [v = _guardedId](const std::vector &data, bool encrypt) { + if (!v->exists) { + return std::vector(); + } + const auto libId = std::int64_t(v->value.v); + const auto slice = Slice(data); + const auto result = encrypt + ? tde2e_api::call_encrypt(libId, slice) + : tde2e_api::call_decrypt(libId, slice); + if (!result.is_ok()) { + return std::vector(); + } + const auto &value = result.value(); + const auto start = reinterpret_cast(value.data()); + const auto end = start + value.size(); + return std::vector{ start, end }; + }; } } // namespace TdE2E diff --git a/Telegram/SourceFiles/tde2e/tde2e_api.h b/Telegram/SourceFiles/tde2e/tde2e_api.h index 0ddb750576..0b0e518657 100644 --- a/Telegram/SourceFiles/tde2e/tde2e_api.h +++ b/Telegram/SourceFiles/tde2e/tde2e_api.h @@ -101,14 +101,17 @@ public: [[nodiscard]] rpl::producer participantsSetValue() const; - [[nodiscard]] std::vector encrypt( - const std::vector &data) const; - [[nodiscard]] std::vector decrypt( - const std::vector &data) const; + [[nodiscard]] auto callbackEncryptDecrypt() + -> Fn(const std::vector&, bool)>; private: static constexpr int kSubChainsCount = 2; + struct GuardedCallId { + CallId value; + std::atomic exists; + }; + struct SubChainState { base::Timer shortPollTimer; base::Timer waitingTimer; @@ -118,6 +121,7 @@ private: int height = 0; }; + void setId(CallId id); void apply(int subchain, const Block &last); void fail(CallFailure reason); @@ -133,6 +137,7 @@ private: PublicKey _myKey; std::optional _failure; rpl::event_stream _failures; + std::shared_ptr _guardedId; SubChainState _subchains[kSubChainsCount]; rpl::event_stream _subchainRequests;