Improve confcall join.

This commit is contained in:
John Preston 2025-04-11 17:16:32 +04:00
parent 59abfcbd6d
commit 01d927aceb
5 changed files with 192 additions and 97 deletions

View file

@ -682,8 +682,11 @@ GroupCall::GroupCall(
setupMediaDevices();
setupOutgoingVideo();
if (_conferenceCall || conference.migrating || conference.show) {
setupConference();
if (_conferenceCall) {
setupConferenceCall();
initConferenceE2E();
} else if (conference.migrating || conference.show) {
initConferenceE2E();
}
if (conference.migrating || (conference.show && !_conferenceCall)) {
if (!conference.muted) {
@ -739,6 +742,7 @@ void GroupCall::processConferenceStart(StartConferenceInfo conference) {
}
GroupCall::~GroupCall() {
_e2e = nullptr;
destroyScreencast();
destroyController();
if (!_rtmp) {
@ -746,37 +750,54 @@ GroupCall::~GroupCall() {
}
}
void GroupCall::setupConference() {
if (!_e2e) {
_e2e = std::make_shared<TdE2E::Call>(
TdE2E::MakeUserId(_peer->session().user()));
void GroupCall::initConferenceE2E() {
if (!_e2eEncryptDecrypt) {
_e2eEncryptDecrypt = std::make_shared<TdE2E::EncryptDecrypt>();
}
for (auto &state : _subchains) {
_api.request(base::take(state.requestId)).cancel();
state = SubChainState();
}
_e2e = nullptr;
_pendingOutboundBlock = QByteArray();
const auto tde2eUserId = TdE2E::MakeUserId(_peer->session().user());
_e2e = std::make_unique<TdE2E::Call>(tde2eUserId);
_e2e->subchainRequests(
) | rpl::start_with_next([=](TdE2E::Call::SubchainRequest request) {
requestSubchainBlocks(request.subchain, request.height);
}, _lifetime);
}, _e2e->lifetime());
_e2e->sendOutboundBlock(
) | rpl::start_with_next([=](QByteArray &&block) {
sendOutboundBlock(std::move(block));
}, _lifetime);
}, _e2e->lifetime());
_e2e->failures() | rpl::start_with_next([=] {
LOG(("TdE2E: Got failure!"));
hangup();
}, _lifetime);
startRejoin();
}, _e2e->lifetime());
if (_conferenceCall) {
setupConferenceCall();
}
_e2e->registerEncryptDecrypt(_e2eEncryptDecrypt);
_emojiHash = _e2e->emojiHashValue();
}
void GroupCall::setupConferenceCall() {
Expects(_conferenceCall != nullptr && _e2e != nullptr);
Expects(_conferenceCall != nullptr);
_conferenceCall->staleParticipantIds(
) | rpl::start_with_next([=](const base::flat_set<UserId> &staleIds) {
removeConferenceParticipants(staleIds, true);
}, _lifetime);
}
void GroupCall::trackParticipantsWithAccess() {
if (!_conferenceCall || !_e2e) {
return;
}
_e2e->participantsSetValue(
) | rpl::start_with_next([=](const TdE2E::ParticipantsSet &set) {
@ -786,7 +807,7 @@ void GroupCall::setupConferenceCall() {
users.emplace(UserId(id.v));
}
_conferenceCall->setParticipantsWithAccess(std::move(users));
}, _lifetime);
}, _e2e->lifetime());
}
void GroupCall::removeConferenceParticipants(
@ -1233,9 +1254,7 @@ rpl::producer<not_null<Data::GroupCall*>> GroupCall::real() const {
}
rpl::producer<QByteArray> GroupCall::emojiHashValue() const {
Expects(_e2e != nullptr);
return _e2e->emojiHashValue();
return _emojiHash.value();
}
void GroupCall::start(TimeId scheduleDate, bool rtmp) {
@ -1482,9 +1501,16 @@ void GroupCall::markTrackPaused(const VideoEndpoint &endpoint, bool paused) {
}
void GroupCall::startRejoin() {
if (_joinState.action != JoinAction::None || _createRequestId) {
// Don't reset _e2e in that case, if rejoin() is a no-op.
return;
}
for (const auto &[task, part] : _broadcastParts) {
_api.request(part.requestId).cancel();
}
if (_conferenceCall || _startConferenceInfo) {
initConferenceE2E();
}
setState(State::Joining);
rejoin();
}
@ -1720,7 +1746,15 @@ void GroupCall::joinDone(
applyMeInCallLocally();
maybeSendMutedUpdate(wasMuteState);
for (auto &state : _subchains) {
// Accept initial join blocks.
_api.request(base::take(state.requestId)).cancel();
state.inShortPoll = true;
}
_peer->session().api().applyUpdates(result);
for (auto &state : _subchains) {
state.inShortPoll = false;
}
if (justCreated) {
subscribeToReal(_conferenceCall.get());
@ -1732,6 +1766,7 @@ void GroupCall::joinDone(
*_startConferenceInfo);
}
trackParticipantsWithAccess();
applyQueuedSelfUpdates();
checkFirstTimeJoined();
_screenJoinState.nextActionPending = true;
@ -1766,8 +1801,7 @@ void GroupCall::joinDone(
void GroupCall::joinFail(const QString &error) {
if (_e2e) {
if (error == u"BLOCK_INVALID"_q
|| error.startsWith(u"CONF_WRITE_CHAIN_INVALID"_q)) {
if (error.startsWith(u"CONF_WRITE_CHAIN_INVALID"_q)) {
if (_id) {
refreshLastBlockAndJoin();
} else {
@ -2443,10 +2477,12 @@ void GroupCall::applySubChainUpdate(
Expects(subchain >= 0 && subchain < kSubChainsCount);
auto &entry = _subchains[subchain];
auto now = next - int(blocks.size());
auto raw = std::vector<TdE2E::Block>();
raw.reserve(blocks.size());
for (const auto &block : blocks) {
_e2e->apply(subchain, now++, { block.v }, entry.inShortPoll);
raw.push_back({ block.v });
}
_e2e->apply(subchain, next, raw, entry.inShortPoll);
}
void GroupCall::applyQueuedSelfUpdates() {
@ -2955,7 +2991,9 @@ bool GroupCall::tryCreateController() {
});
return result;
},
.e2eEncryptDecrypt = _e2e ? _e2e->callbackEncryptDecrypt() : nullptr,
.e2eEncryptDecrypt = (_e2eEncryptDecrypt
? _e2eEncryptDecrypt->callback()
: nullptr),
};
if (Logs::DebugEnabled()) {
auto callLogFolder = cWorkingDir() + u"DebugLogs"_q;
@ -3008,7 +3046,9 @@ bool GroupCall::tryCreateScreencast() {
.videoCapture = _screenCapture,
.videoContentType = tgcalls::VideoContentType::Screencast,
.videoCodecPreferences = lookupVideoCodecPreferences(),
.e2eEncryptDecrypt = _e2e ? _e2e->callbackEncryptDecrypt() : nullptr,
.e2eEncryptDecrypt = (_e2eEncryptDecrypt
? _e2eEncryptDecrypt->callback()
: nullptr),
};
LOG(("Call Info: Creating group screen instance"));

View file

@ -45,6 +45,7 @@ class GroupCall;
namespace TdE2E {
class Call;
class EncryptDecrypt;
} // namespace TdE2E
namespace Calls {
@ -621,8 +622,9 @@ private:
void setupMediaDevices();
void setupOutgoingVideo();
void setupConference();
void initConferenceE2E();
void setupConferenceCall();
void trackParticipantsWithAccess();
void setScreenEndpoint(std::string endpoint);
void setCameraEndpoint(std::string endpoint);
void addVideoOutput(const std::string &endpoint, SinkPointer sink);
@ -648,7 +650,9 @@ private:
const not_null<Delegate*> _delegate;
std::shared_ptr<Data::GroupCall> _conferenceCall;
std::shared_ptr<TdE2E::Call> _e2e;
std::unique_ptr<TdE2E::Call> _e2e;
std::shared_ptr<TdE2E::EncryptDecrypt> _e2eEncryptDecrypt;
rpl::variable<QByteArray> _emojiHash;
QByteArray _pendingOutboundBlock;
std::shared_ptr<StartConferenceInfo> _startConferenceInfo;

View file

@ -83,9 +83,13 @@ constexpr auto kHideControlsTimeout = 5 * crl::time(1000);
const auto fp = bytes::make_span(hash).subspan(0, 32);
const auto emoji = Calls::ComputeEmojiFingerprint(fp);
result += QString::fromUtf8(" \xc2\xb7 ");
const auto base = result.size();
for (const auto &single : emoji) {
result += single->text();
}
MTP_LOG(0, ("Got Emoji: %1.").arg(result.mid(base)));
} else {
MTP_LOG(0, ("Cleared Emoji."));
}
return result;
}

View file

@ -74,6 +74,43 @@ constexpr auto kShortPollChainBlocksWaitFor = crl::time(1000);
} // namespace
auto EncryptDecrypt::callback()
-> Fn<EncryptionBuffer(const EncryptionBuffer&, int64_t, bool)> {
return [that = shared_from_this()](
const EncryptionBuffer &data,
int64_t userId,
bool encrypt) -> EncryptionBuffer {
const auto libId = that->_id.load();
if (!libId) {
return {};
}
const auto channelId = tde2e_api::CallChannelId(0);
const auto slice = Slice(data);
const auto result = encrypt
? tde2e_api::call_encrypt(libId, channelId, slice)
: tde2e_api::call_decrypt(libId, userId, channelId, slice);
if (!result.is_ok()) {
return {};
}
const auto &value = result.value();
const auto start = reinterpret_cast<const uint8_t*>(value.data());
const auto end = start + value.size();
return { start, end };
};
}
void EncryptDecrypt::setCallId(CallId id) {
Expects(id.v != 0);
_id.store(id.v);
}
void EncryptDecrypt::clearCallId(CallId fromId) {
Expects(fromId.v != 0);
_id.compare_exchange_strong(fromId.v, 0);
}
Call::Call(UserId myUserId)
: _myUserId(myUserId) {
const auto id = tde2e_api::key_generate_temporary_private_key();
@ -88,6 +125,9 @@ Call::Call(UserId myUserId)
Call::~Call() {
if (const auto id = libId()) {
if (const auto raw = _encryptDecrypt.get()) {
raw->clearCallId(_id);
}
tde2e_api::call_destroy(id);
}
}
@ -190,10 +230,13 @@ rpl::producer<ParticipantsSet> Call::participantsSetValue() const {
}
void Call::joined() {
shortPoll(0);
if (_id) {
shortPoll(1);
if (!_id) {
LOG(("TdE2E Error: Call::joined() without id."));
_failure = CallFailure::Unknown;
return;
}
shortPoll(0);
shortPoll(1);
}
void Call::apply(int subchain, const Block &last) {
@ -251,7 +294,6 @@ void Call::apply(int subchain, const Block &last) {
return;
}
setId({ uint64(id.value()) });
shortPoll(1);
for (auto i = 0; i != kSubChainsCount; ++i) {
auto &entry = _subchains[i];
@ -278,9 +320,8 @@ void Call::setId(CallId id) {
Expects(!_id);
_id = id;
if (const auto raw = _guardedId.get()) {
raw->value = id;
raw->exists = true;
if (const auto raw = _encryptDecrypt.get()) {
raw->setCallId(id);
}
}
@ -299,38 +340,50 @@ void Call::checkForOutboundMessages() {
void Call::apply(
int subchain,
int index,
const Block &block,
int indexAfterLast,
const std::vector<Block> &blocks,
bool fromShortPoll) {
Expects(subchain >= 0 && subchain < kSubChainsCount);
Expects(_id || !fromShortPoll || !subchain);
if (!subchain && index >= _lastBlock0Height) {
_lastBlock0 = block;
_lastBlock0Height = index;
}
if (failed()) {
return;
if (!subchain && !blocks.empty() && indexAfterLast > _lastBlock0Height) {
_lastBlock0 = blocks.back();
_lastBlock0Height = indexAfterLast;
}
auto &entry = _subchains[subchain];
if (!fromShortPoll) {
entry.lastUpdate = crl::now();
if (index > entry.height || (!_id && subchain != 0)) {
entry.waiting.emplace(index, block);
checkWaitingBlocks(subchain);
if (fromShortPoll) {
auto i = begin(entry.waiting);
while (i != end(entry.waiting) && i->first < indexAfterLast) {
++i;
}
entry.waiting.erase(begin(entry.waiting), i);
if (subchain && !_id && !blocks.empty()) {
LOG(("TdE2E Error: Broadcast shortpoll block without id."));
fail(CallFailure::Unknown);
return;
}
} else {
entry.lastUpdate = crl::now();
}
if (failed()) {
return;
} else if (!_id
|| (subchain && !entry.height && fromShortPoll)
|| (entry.height == index)) {
apply(subchain, block);
}
entry.height = std::max(entry.height, index + 1);
auto index = indexAfterLast - int(blocks.size());
if (!fromShortPoll && (index > entry.height || (!_id && subchain))) {
for (const auto &block : blocks) {
entry.waiting.emplace(index++, block);
}
} else {
for (const auto &block : blocks) {
if (!_id || (entry.height == index)) {
apply(subchain, block);
}
entry.height = std::max(entry.height, ++index);
}
entry.height = std::max(entry.height, indexAfterLast);
}
checkWaitingBlocks(subchain);
}
@ -451,37 +504,18 @@ rpl::producer<QByteArray> Call::emojiHashValue() const {
return _emojiHash.value();
}
auto Call::callbackEncryptDecrypt()
-> Fn<std::vector<uint8_t>(const std::vector<uint8_t>&, int64_t, bool)> {
if (!_guardedId) {
_guardedId = std::make_shared<GuardedCallId>();
if (const auto raw = _id ? _guardedId.get() : nullptr) {
raw->value = _id;
raw->exists = true;
}
void Call::registerEncryptDecrypt(std::shared_ptr<EncryptDecrypt> object) {
Expects(object != nullptr);
Expects(_encryptDecrypt == nullptr);
_encryptDecrypt = std::move(object);
if (_id) {
_encryptDecrypt->setCallId(_id);
}
return [id = _guardedId](
const std::vector<uint8_t> &data,
int64_t userId,
bool encrypt) {
const auto raw = id.get();
if (!raw->exists) {
return std::vector<uint8_t>();
}
const auto libId = std::int64_t(raw->value.v);
const auto channelId = tde2e_api::CallChannelId(0);
const auto slice = Slice(data);
const auto result = encrypt
? tde2e_api::call_encrypt(libId, channelId, slice)
: tde2e_api::call_decrypt(libId, userId, channelId, slice);
if (!result.is_ok()) {
return std::vector<uint8_t>();
}
const auto &value = result.value();
const auto start = reinterpret_cast<const uint8_t*>(value.data());
const auto end = start + value.size();
return std::vector<uint8_t>{ start, end };
};
}
rpl::lifetime &Call::lifetime() {
return _lifetime;
}
} // namespace TdE2E

View file

@ -12,6 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "base/timer.h"
#include <rpl/event_stream.h>
#include <rpl/lifetime.h>
#include <rpl/producer.h>
#include <rpl/variable.h>
@ -66,6 +67,22 @@ enum class CallFailure {
Unknown,
};
using EncryptionBuffer = std::vector<uint8_t>;
class EncryptDecrypt final
: public std::enable_shared_from_this<EncryptDecrypt> {
public:
[[nodiscard]] auto callback()
-> Fn<EncryptionBuffer(const EncryptionBuffer&, int64_t, bool)>;
void setCallId(CallId id);
void clearCallId(CallId fromId);
private:
std::atomic<uint64> _id = 0;
};
class Call final {
public:
explicit Call(UserId myUserId);
@ -76,8 +93,8 @@ public:
void joined();
void apply(
int subchain,
int index,
const Block &block,
int indexAfterLast,
const std::vector<Block> &blocks,
bool fromShortPoll);
struct SubchainRequest {
@ -100,22 +117,16 @@ public:
[[nodiscard]] Block makeJoinBlock();
[[nodiscard]] Block makeRemoveBlock(const base::flat_set<UserId> &ids);
[[nodiscard]] rpl::producer<ParticipantsSet> participantsSetValue() const;
[[nodiscard]] auto participantsSetValue() const
-> rpl::producer<ParticipantsSet>;
[[nodiscard]] auto callbackEncryptDecrypt()
-> Fn<std::vector<uint8_t>(
const std::vector<uint8_t>&,
int64_t,
bool)>;
void registerEncryptDecrypt(std::shared_ptr<EncryptDecrypt> object);
[[nodiscard]] rpl::lifetime &lifetime();
private:
static constexpr int kSubChainsCount = 2;
struct GuardedCallId {
CallId value;
std::atomic<bool> exists;
};
struct SubChainState {
base::Timer shortPollTimer;
base::Timer waitingTimer;
@ -141,7 +152,7 @@ private:
PublicKey _myKey;
std::optional<CallFailure> _failure;
rpl::event_stream<CallFailure> _failures;
std::shared_ptr<GuardedCallId> _guardedId;
std::shared_ptr<EncryptDecrypt> _encryptDecrypt;
SubChainState _subchains[kSubChainsCount];
rpl::event_stream<SubchainRequest> _subchainRequests;
@ -153,6 +164,8 @@ private:
rpl::variable<ParticipantsSet> _participantsSet;
rpl::variable<QByteArray> _emojiHash;
rpl::lifetime _lifetime;
};
} // namespace TdE2E